In [2]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
from torch.nn import Parameter


In [3]:
batch_size = 1 #for explanatory purposes
tensor = torch.rand(batch_size,3,100,100)
###Assumed input shape is (batch_size, in_channels, height, width)

print('initial shape:', tensor.shape)
in_channels = tensor.shape[1]

###################State Embedding, X -> Z(1 x embedding_dim) ######################
embedding_dim = 10

#Create CNN encoder and use it first, then flatten the CNN output to complete final embedding transformation
class CNN_Encoder(nn.Module): #did everything except some w2 division thing, but shapes will be same
    def __init__(self, in_channels):
        super(CNN_Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2)
        self.relu = nn.ReLU(True)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x
cnn_encoder = CNN_Encoder(in_channels)
flat_conv_dim = int(np.prod(cnn_encoder(torch.zeros(tensor.shape)).shape[1:]))

class Embed(nn.Module):
    def __init__(self, in_channels, embedding_dim):
        super(Embed, self).__init__()
        self.cnn_encoder = CNN_Encoder(in_channels)
        self.linear = nn.Linear(flat_conv_dim, embedding_dim)
        self.relu = nn.ReLU(True)
    def forward(self, x):
        x = self.cnn_encoder(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        x = self.relu(x)
        return x
    
embed = Embed(in_channels, embedding_dim)
tensor = embed(tensor)

print('post embedding shape:', tensor.shape)
####################State Embed Over -- Now Branch Into Actions ############################
num_actions = 4
###Create Reward Function, Transition Function, Value Function
# Transition function branches state for each action, multiplying batch size by num_actions.
# (Each branch is saved then backed up later) 
####Get Reward given init state for each action########
class MLPRewardFn(nn.Module):
    def __init__(self, embed_dim, num_actions):
        super(MLPRewardFn, self).__init__()
        self.embedding_dim = embed_dim
        self.num_actions = num_actions
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, num_actions)
        )

    def forward(self, x):
        x = x.view(-1, self.embedding_dim)
        return self.mlp(x).view(-1, self.num_actions)
reward_fun = MLPRewardFn(embedding_dim, num_actions)

transition_fun = Parameter(torch.zeros(embedding_dim, embedding_dim, num_actions))
transition_fun = nn.init.xavier_normal_(transition_fun)

def tree_transition(tensor):
    temp = nn.Tanh()(torch.einsum("ij,jab->iba", tensor, transition_fun))
    temp = temp.contiguous() ###HMMMMMMM!!!!!!!!
    next_state = temp
    return next_state
value_fn = nn.Linear(embedding_dim, 1) #literally its just this except with the w_scale


tree_depth = 10 #Big Hyperparameter
###Log each application of r,t,v to save intermediate branches to back up later
reward_list = []
transition_list = [tensor]
value_list = []
value_list.append(value_fn(tensor))
for i in range(tree_depth):
    reward = reward_fun(tensor)
    reward_list.append(reward.view(-1,1))
    tensor = tree_transition(tensor)
    tensor = tensor.view(-1, embedding_dim)
    transition_list.append(tensor)
    #i think it assumes return intermediate values is true
    value_list.append(value_fn(tensor))


tree_result = {
    "embeddings": transition_list,
    "values": value_list,
    "rewards": reward_list
}

td_lambda = 0.3
# ################Backup############################ 
#q_values = tree_backup(tree_result, batch_size) | function saved var and input params from original code
###Planning to comment more later on final part
backup_values = tree_result["values"][-1] #last value in the list
# ##Loop through range (1, tree_depth + 1)  
for i in range(1, tree_depth + 1):
    one_step_backup = tree_result['rewards'][-i] + backup_values # * gamma
    if i < tree_depth:
        one_step_backup = one_step_backup.view(batch_size, -1, num_actions)
        max_backup = (one_step_backup * F.softmax(one_step_backup, dim = 2)).sum(dim = 2)
        backup_values = ((1-td_lambda) * tree_result['values'][-i-1] + #return intermediate values seems to be necessary despite it being presented as an option
                         (td_lambda) * max_backup.view(-1, 1))
    else:
        backup_values = one_step_backup
backup_values = backup_values.view(batch_size, num_actions)
#softmax backup values, get "max action"
#look into V values
print('q values (backup values shape):', backup_values.shape)
print('first transition', tree_result['embeddings'][1].shape)
print('last transition', tree_result['embeddings'][-1].shape)


initial shape: torch.Size([1, 3, 100, 100])
post embedding shape: torch.Size([1, 10])
q values (backup values shape): torch.Size([1, 4])
first transition torch.Size([4, 10])
last transition torch.Size([1048576, 10])
