In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
from torch.nn import Parameter

from torch.optim import Adam
from torch.optim import RMSprop

from treeQN.treeqn_traj_simple import TreeQN
import random

import pandas as pd

In [2]:

# Create the tensor
tensor = torch.ones(16,1, 4, 4)

# Flatten dimensions 1 and 2

#reverse process

def flattener(tensor):
    return torch.flatten(tensor, start_dim=2, end_dim=3).squeeze(1)
def unflattener(tensor):
    return tensor.view(-1, 1, 4, 4)

flat = flattener(tensor)
print(flat.shape)  # Output: torch.Size([4, 16])
unflat = unflattener(flat)
print(unflat.shape)  # Output: torch.Size([4, 1, 4, 4])

torch.Size([16, 16])
torch.Size([16, 1, 4, 4])


In [3]:
def image_world(tensor,max_val): 
    assert tensor.shape[0] == tensor.shape[1]
    val = tensor.max() + 1
    state = torch.zeros_like(tensor).unsqueeze(0).unsqueeze(0) #to match treeqn input size
    new_state = torch.ones_like(state)
    middle = int(tensor.shape[0] / 2)
    # Create transitions by modifying slices of new_state
    new_state[:,:,:middle, :middle] += val
    transition_one = new_state.clone()
    new_state[:,:,middle:, :middle] += val
    transition_two = new_state.clone()
    new_state[:,:,:middle, middle:] += val
    transition_three = new_state.clone()
    new_state[:,:,middle:, middle:] += val #transition 4
    return [transition_one/max_val, transition_two/max_val, transition_three/max_val, new_state/max_val]
def image_world_samples(size_tensor,samples,max_val=1,x_input=-1):
    return_data = []
    for i in range(samples):
        loc_tensor = torch.zeros_like(size_tensor)+0.1
        x = int(random.random()*size_tensor.shape[0])
        if x_input != -1:
            x = x_input
        loc_tensor[x] += x
        result = image_world(loc_tensor,max_val)
        loc_tensor = loc_tensor.unsqueeze(0).unsqueeze(0)/max_val
        return_data.append([loc_tensor,result])
    return return_data

size_tensor = torch.zeros(20,20)
train_data = image_world_samples(size_tensor,1000,size_tensor.shape[0],-1)

test_data = [] #simply test once on all 20 possible initial states.
for i in range(size_tensor.shape[0]):
    test_data.append(image_world_samples(size_tensor,1,size_tensor.shape[0],i)[0])


In [4]:
input_shape = torch.zeros(1, 1,20, 20).shape# minimum size #train_data[0][0].shape
num_actions = 4
tree_depth = 4
embedding_dim = 400
td_lambda = 0.8
gamma = 1    #0.99
model = TreeQN(input_shape=input_shape, num_actions=num_actions, tree_depth=tree_depth, embedding_dim=embedding_dim, td_lambda=td_lambda,gamma=gamma)
optimizer = Adam(model.parameters(), lr=1e-4)
#optimizer = RMSprop(model.parameters(), lr=1e-4,alpha =0.99, eps = 1e-5) | loss from treeqn paper

  b_init(module.bias, b_scale)


In [5]:
#Main training loop
#Looking at difference between detaching at each transition or not in treeqn file. This is with detach (so far seems to make no diff)
raw_losses = []
for epoch in range(3000):  # epochs
    avg_loss = 0
    temp_loss = 0
    temp_raw_loss = 0
    sample_count = 0

    avg_raw_loss = 0

    for t in random.sample(train_data, len(train_data)): #sample through all data in random order each epoch
        #Get reconstruction loss to help ground abstract state
        decoded_values, all_policies = model(t[0])
        decode_loss = F.mse_loss(decoded_values[0], t[0], reduction='sum')

        #Get transition probabilities for each state
        first_policy = all_policies[0]
        second_policy = all_policies[1].view(4, -1)
        third_policy = all_policies[2].view(4, 4, -1)
        fourth_policy = all_policies[3].view(4, 4, 4, -1)

        #These should all add to 1 (in testing there seems to be some small rounding error)
        second_layer_probs = first_policy * second_policy   
        third_layer_probs = second_layer_probs * third_policy
        fourth_layer_probs = third_layer_probs * fourth_policy
        
        #Flatten transition probabilities to then weigh with loss of each predicted state at each layer
        first = torch.flatten(first_policy).view(4, 1, 1, 1)
        second = torch.flatten(second_layer_probs).view(16, 1, 1, 1)
        third = torch.flatten(third_layer_probs).view(64, 1, 1, 1)
        fourth = torch.flatten(fourth_layer_probs).view(256, 1, 1, 1) 
        
        first_loss = (F.mse_loss(decoded_values[1], t[1][0], reduction='none') * first).sum() 
        second_loss = (F.mse_loss(decoded_values[2], t[1][1], reduction='none') * second).sum() 
        third_loss = (F.mse_loss(decoded_values[3], t[1][2], reduction='none') * third).sum() 
        fourth_loss = (F.mse_loss(decoded_values[4], t[1][3], reduction='none') * fourth).sum() 


        #For experimenting with different weights on different layers
        raw_loss = (decode_loss + first_loss + second_loss + third_loss + fourth_loss).detach().item()
        raw_losses.append(raw_loss)
        l2w , l3w ,l4w = 1,1,1
        total_loss = decode_loss + first_loss + second_loss*l2w + third_loss*l3w + fourth_loss*l4w

        temp_loss += total_loss
        temp_raw_loss += raw_loss
        sample_count += 1

        if sample_count % 1 == 0:
            optimizer.zero_grad()
            temp_loss.backward()

            #torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            avg_loss += temp_loss.item()
            avg_raw_loss += temp_raw_loss
            temp_loss = 0
            temp_raw_loss = 0

    # To handle the case where the number of samples is not a multiple of 10
    if sample_count % 1 != 0:
        optimizer.zero_grad()
        temp_loss.backward()
        
        #torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        avg_loss += temp_loss.item()
        avg_raw_loss += temp_raw_loss


    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss / len(train_data)}, Average Raw Loss: {avg_raw_loss / len(train_data)}")


  first_loss = (F.mse_loss(decoded_values[1], t[1][0], reduction='none') * first).sum()
  second_loss = (F.mse_loss(decoded_values[2], t[1][1], reduction='none') * second).sum()
  third_loss = (F.mse_loss(decoded_values[3], t[1][2], reduction='none') * third).sum()
  fourth_loss = (F.mse_loss(decoded_values[4], t[1][3], reduction='none') * fourth).sum()


Epoch 1, Average Loss: 424.0273838863373, Average Raw Loss: 424.0273838863373
Epoch 2, Average Loss: 423.89982308959964, Average Raw Loss: 423.89982308959964
Epoch 3, Average Loss: 423.71750062179564, Average Raw Loss: 423.71750062179564
Epoch 4, Average Loss: 422.95883463954925, Average Raw Loss: 422.95883463954925
Epoch 5, Average Loss: 421.4079816145897, Average Raw Loss: 421.4079816145897
Epoch 6, Average Loss: 419.27885300350187, Average Raw Loss: 419.27885300350187
Epoch 7, Average Loss: 416.78679663562775, Average Raw Loss: 416.78679663562775
Epoch 8, Average Loss: 413.98007303905484, Average Raw Loss: 413.98007303905484
Epoch 9, Average Loss: 410.93938877391815, Average Raw Loss: 410.93938877391815
Epoch 10, Average Loss: 407.7604294695854, Average Raw Loss: 407.7604294695854
Epoch 11, Average Loss: 404.44913383197786, Average Raw Loss: 404.44913383197786
Epoch 12, Average Loss: 401.0200046405792, Average Raw Loss: 401.0200046405792
Epoch 13, Average Loss: 397.51603592729566, A

In [None]:
#View Action Weights (This hasn't been informative yet)
dec, all_policies = model(train_data[0][0]) 
# dot = make_dot((dec[0],dec[1],dec[2],dec[3],dec[4],all_policies[0],all_policies[1],all_policies[2],all_policies[3]),params=dict(model.named_parameters()))
# dot.render('model', format='png')
print(f"Action Weight Sums { torch.round(model.transition_fun.data,decimals=3).sum(dim=0).sum(dim=0)}")  #might be summing the wrong way, or just not interesting

Action Weight Sums tensor([-4.0040, -5.2650, -5.1780,  5.3700])


In [None]:
best_first_action = all_policies[0].argmax()
best_second_action = all_policies[1].view(4,-1)[best_first_action].argmax() 
best_third_action = all_policies[2].view(4,4,-1)[best_first_action][best_second_action].argmax()
best_fourth_action = all_policies[3].view(4,4,4,-1)[best_first_action][best_second_action][best_third_action].argmax() 
# print(torch.round(all_q[0],decimals=3).detach(), f"Argmax {all_q[0].argmax().item()}")
# print(torch.round(all_q[1],decimals=3).view(4,-1).detach(),f"Argmax {all_q[1].view(4,-1)[1].argmax().item()}")
# print(torch.round(all_q[2],decimals=3).view(4,4,-1)[0].detach(),f"Argmax {all_q[2].view(4,4,-1)[1][0].argmax().item()}")
# print(torch.round(all_q[3],decimals=3).view(4,4,4,-1)[0][0].detach(),f"Argmax {all_q[3].view(4,4,4,-1)[1][1][0].argmax().item()}")
print(f"Best Actions: {best_first_action.item()} {best_second_action.item()} {best_third_action.item()} {best_fourth_action.item()}")

Best Actions: 2 2 3 0


In [None]:
model.eval()
loss_data = []
with torch.no_grad():
    eval_loss = 0
    for i,k in zip(test_data,range(len(test_data))):
        #Get reconstruction loss to help ground abstract state
        decoded_values, all_policies = model(t[0])
        decode_loss = F.mse_loss(decoded_values[0], t[0], reduction='sum')

        #Get transition probabilities for each state
        first_policy = all_policies[0]
        second_policy = all_policies[1].view(4, -1)
        third_policy = all_policies[2].view(4, 4, -1)
        fourth_policy = all_policies[3].view(4, 4, 4, -1)

        #These should all add to 1 (in testing there seems to be some small rounding error)
        second_layer_probs = first_policy * second_policy   
        third_layer_probs = second_layer_probs * third_policy
        fourth_layer_probs = third_layer_probs * fourth_policy
        
        #Flatten transition probabilities to then weigh with loss of each predicted state at each layer
        first = torch.flatten(first_policy).view(4, 1, 1, 1)
        second = torch.flatten(second_layer_probs).view(16, 1, 1, 1)
        third = torch.flatten(third_layer_probs).view(64, 1, 1, 1)
        fourth = torch.flatten(fourth_layer_probs).view(256, 1, 1, 1) 
        
        first_loss = (F.mse_loss(decoded_values[1], t[1][0], reduction='none') * first).sum() 
        second_loss = (F.mse_loss(decoded_values[2], t[1][1], reduction='none') * second).sum() 
        third_loss = (F.mse_loss(decoded_values[3], t[1][2], reduction='none') * third).sum() 
        fourth_loss = (F.mse_loss(decoded_values[4], t[1][3], reduction='none') * fourth).sum() 

        #For experimenting with different weights on different layers
        raw_loss = (decode_loss + first_loss + second_loss + third_loss + fourth_loss).detach().item()

        loss_data.append([k,decode_loss.item(),first_loss.item(),second_loss.item(),third_loss.item(),fourth_loss.item(),raw_loss])

  first_loss = (F.mse_loss(decoded_values[1], t[1][0], reduction='none') * first).sum()
  second_loss = (F.mse_loss(decoded_values[2], t[1][1], reduction='none') * second).sum()
  third_loss = (F.mse_loss(decoded_values[3], t[1][2], reduction='none') * third).sum()
  fourth_loss = (F.mse_loss(decoded_values[4], t[1][3], reduction='none') * fourth).sum()


In [None]:
loss_df = pd.DataFrame(loss_data,columns=["Starting Input","Decode Loss","First Loss","Second Loss","Third Loss","Fourth Loss","Total Loss"]).set_index("Starting Input").round(2)
loss_df 

Unnamed: 0_level_0,Decode Loss,First Loss,Second Loss,Third Loss,Fourth Loss,Total Loss
Starting Input,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,0.1,0.04,0.01,0.0,0.0,0.15
1,0.1,0.04,0.01,0.0,0.0,0.15
2,0.1,0.04,0.01,0.0,0.0,0.15
3,0.1,0.04,0.01,0.0,0.0,0.15
4,0.1,0.04,0.01,0.0,0.0,0.15
5,0.1,0.04,0.01,0.0,0.0,0.15
6,0.1,0.04,0.01,0.0,0.0,0.15
7,0.1,0.04,0.01,0.0,0.0,0.15
8,0.1,0.04,0.01,0.0,0.0,0.15
9,0.1,0.04,0.01,0.0,0.0,0.15


In [None]:
#A check on if the training loop is valid
#Checking loss of each state unweighted, with both absolute difference loss and mse loss
#The earlier transitoins might be easier since it's easier to learn that a lot the environment is the same
min_losses = []
max_losses = []
for i in range(1,len(dec)):
    decoded_states = dec[i]
    true_state = train_data[0][1][i-1]
    curr_min = float('inf')
    curr_max = float('-inf')
    for state in decoded_states:
        loss = torch.abs(state-true_state).sum()
        if loss < curr_min:
            curr_min = loss.item()
        if loss > curr_max:
            curr_max = loss.item()
    min_losses.append(curr_min)
    max_losses.append(curr_max)
min_losses,max_losses

([2.005280017852783, 1.376326322555542, 91.25557708740234, 161.20518493652344],
 [154.4391326904297, 222.06932067871094, 311.039306640625, 375.0008850097656])

In [None]:
min_losses = []
max_losses = []
for i in range(1,len(dec)):
    decoded_states = dec[i]
    true_state = train_data[0][1][i-1]
    curr_min = float('inf')
    curr_max = float('-inf')
    for state in decoded_states:
        loss = F.mse_loss(state,true_state)
        if loss < curr_min:
            curr_min = loss.item()
        if loss > curr_max:
            curr_max = loss.item()
    min_losses.append(curr_min)
    max_losses.append(curr_max)
min_losses,max_losses

  loss = F.mse_loss(state,true_state)


([7.934335735626519e-05,
  2.415754170215223e-05,
  0.11819437146186829,
  0.2685106098651886],
 [0.3638641834259033,
  0.5465582013130188,
  0.7721645832061768,
  0.9390590786933899])