In [3]:
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn import Parameter

from torch.optim import Adam
from torch.optim import RMSprop

from treeQN.treeqn_traj import TreeQN
import random

import pandas as pd

In [4]:
def get_start(size):
    starting_point = (random.randint(0,size), random.randint(0,size))
    goal_point = -1
    while True:
        goal_point = (random.randint(0,size), random.randint(0,size))
        if goal_point != starting_point:
            break
    return starting_point, goal_point

In [5]:
def hard_policy(state,goal_point):
    goal_x, goal_y = goal_point
    x,y = state
    x_right = goal_x > x # if goal is right
    x_left = goal_x < x # if goal is left
    y_up = goal_y > y # if goal is above
    y_down = goal_y < y # if goal is below
    possible_next_states = []
    if x_right:
        possible_next_states.append((x+1,y))
    if x_left:
        possible_next_states.append((x-1,y))
    if y_up:
        possible_next_states.append((x,y+1))
    if y_down:
        possible_next_states.append((x,y-1))
    if len(possible_next_states) == 0:
        return -1
    return random.choice(possible_next_states)



In [6]:
def point_to_tensor(point,goal,size):
    x,y = point
    x_goal, y_goal = goal
    tensor = torch.zeros(size+2,size+2)
    tensor[x][y] = 1
    tensor[x+1][y] = 1
    tensor[x][y+1] = 1
    tensor[x+1][y+1] = 1
    tensor[x_goal][y_goal] = -1
    tensor[x_goal+1][y_goal] = -1
    tensor[x_goal][y_goal+1] = -1
    tensor[x_goal+1][y_goal+1] = -1
    return tensor

In [7]:
def get_trajectory(size = 20):
    trajectory = []
    start, goal = get_start(size)
    trajectory.append(start)
    while start != goal:
        start = hard_policy(start,goal)
        trajectory.append(start)
    if len(trajectory) != 21:
        return get_trajectory(size)
    return [point_to_tensor(p,goal,size).unsqueeze(0).unsqueeze(0) for p in trajectory]

In [8]:
def get_trim_trajectory(size = 20,trimming = [0,5,10,15,20]):
    trajectory = get_trajectory(size)
    return [trajectory[i] for i in trimming]

In [9]:
train_data = [get_trim_trajectory(18) for _ in range(1000)]

In [10]:
len(train_data[0])

5

In [11]:
input_shape = torch.zeros(1, 1,20, 20).shape# minimum size #train_data[0][0].shape
num_actions = 4
tree_depth = 4
embedding_dim = 256
td_lambda = 0.8
gamma = 1    #0.99
model = TreeQN(input_shape=input_shape, num_actions=num_actions, tree_depth=tree_depth, embedding_dim=embedding_dim, td_lambda=td_lambda,gamma=gamma)
optimizer = Adam(model.parameters(), lr=1e-4)
#optimizer = RMSprop(model.parameters(), lr=1e-4,alpha =0.99, eps = 1e-5) | loss from treeqn paper

  b_init(module.bias, b_scale)


In [12]:
#Main training loop
#Looking at difference between detaching at each transition or not in treeqn file. This is with detach (so far seems to make no diff)
raw_losses = []
for epoch in range(3000):  # epochs
    avg_loss = 0
    temp_loss = 0
    temp_raw_loss = 0
    sample_count = 0

    avg_raw_loss = 0

    for t in random.sample(train_data, len(train_data)): #sample through all data in random order each epoch
        #Get reconstruction loss to help ground abstract state
        decoded_values, all_policies = model(t[0])
        decode_loss = F.mse_loss(decoded_values[0], t[0], reduction='sum')

        #Get transition probabilities for each state
        first_policy = all_policies[0]
        second_policy = all_policies[1].view(4, -1)
        third_policy = all_policies[2].view(4, 4, -1)
        fourth_policy = all_policies[3].view(4, 4, 4, -1)

        #These should all add to 1 (in testing there seems to be some small rounding error)
        second_layer_probs = first_policy * second_policy   
        third_layer_probs = second_layer_probs * third_policy
        fourth_layer_probs = third_layer_probs * fourth_policy
        
        #Flatten transition probabilities to then weigh with loss of each predicted state at each layer
        first = torch.flatten(first_policy).view(4, 1, 1, 1)
        second = torch.flatten(second_layer_probs).view(16, 1, 1, 1)
        third = torch.flatten(third_layer_probs).view(64, 1, 1, 1)
        fourth = torch.flatten(fourth_layer_probs).view(256, 1, 1, 1) 
        
        first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum() 
        second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum() 
        third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum() 
        fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum() 


        #For experimenting with different weights on different layers
        raw_loss = (decode_loss + first_loss + second_loss + third_loss + fourth_loss).detach().item()
        raw_losses.append(raw_loss)
        l2w , l3w ,l4w = 1,1,1
        total_loss = decode_loss + first_loss + second_loss*l2w + third_loss*l3w + fourth_loss*l4w

        temp_loss += total_loss
        temp_raw_loss += raw_loss
        sample_count += 1

        if sample_count % 1 == 0:
            optimizer.zero_grad()
            temp_loss.backward()

            #torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            avg_loss += temp_loss.item()
            avg_raw_loss += temp_raw_loss
            temp_loss = 0
            temp_raw_loss = 0

    # To handle the case where the number of samples is not a multiple of 10
    if sample_count % 1 != 0:
        optimizer.zero_grad()
        temp_loss.backward()
        
        #torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        avg_loss += temp_loss.item()
        avg_raw_loss += temp_raw_loss


    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss / len(train_data)}, Average Raw Loss: {avg_raw_loss / len(train_data)}")


  first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum()
  second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum()
  third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum()
  fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum()


Epoch 1, Average Loss: 37.9537025680542, Average Raw Loss: 37.9537025680542
Epoch 2, Average Loss: 35.49645195770264, Average Raw Loss: 35.49645195770264
Epoch 3, Average Loss: 34.651431438446046, Average Raw Loss: 34.651431438446046
Epoch 4, Average Loss: 33.87087203788757, Average Raw Loss: 33.87087203788757
Epoch 5, Average Loss: 33.065388671875, Average Raw Loss: 33.065388671875
Epoch 6, Average Loss: 32.201125743865965, Average Raw Loss: 32.201125743865965
Epoch 7, Average Loss: 31.37285554122925, Average Raw Loss: 31.37285554122925
Epoch 8, Average Loss: 30.531128957748415, Average Raw Loss: 30.531128957748415
Epoch 9, Average Loss: 29.667922466278075, Average Raw Loss: 29.667922466278075
Epoch 10, Average Loss: 28.82611954307556, Average Raw Loss: 28.82611954307556
Epoch 11, Average Loss: 28.010008836746216, Average Raw Loss: 28.010008836746216
Epoch 12, Average Loss: 27.19715474510193, Average Raw Loss: 27.19715474510193
Epoch 13, Average Loss: 26.35925427055359, Average Raw Lo

In [13]:
train_data[0]

[tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
             0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
             0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
             0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
             0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
             0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
             0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
             0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
             0.,  0.,  0.,  0.,  0.,  0.],


In [17]:
#View Action Weights (This hasn't been informative yet)
dec, all_policies = model(train_data[0][0]) 
# dot = make_dot((dec[0],dec[1],dec[2],dec[3],dec[4],all_policies[0],all_policies[1],all_policies[2],all_policies[3]),params=dict(model.named_parameters()))
# dot.render('model', format='png')
print(f"Action Weight Sums { torch.round(model.transition_fun.data,decimals=3).sum(dim=0).sum(dim=0)}")  #might be summing the wrong way, or just not interesting

Action Weight Sums tensor([ 3.4890, -4.8890, -5.8250, 12.0320])


In [18]:
best_first_action = all_policies[0].argmax()
best_second_action = all_policies[1].view(4,-1)[best_first_action].argmax() 
best_third_action = all_policies[2].view(4,4,-1)[best_first_action][best_second_action].argmax()
best_fourth_action = all_policies[3].view(4,4,4,-1)[best_first_action][best_second_action][best_third_action].argmax() 
# print(torch.round(all_q[0],decimals=3).detach(), f"Argmax {all_q[0].argmax().item()}")
# print(torch.round(all_q[1],decimals=3).view(4,-1).detach(),f"Argmax {all_q[1].view(4,-1)[1].argmax().item()}")
# print(torch.round(all_q[2],decimals=3).view(4,4,-1)[0].detach(),f"Argmax {all_q[2].view(4,4,-1)[1][0].argmax().item()}")
# print(torch.round(all_q[3],decimals=3).view(4,4,4,-1)[0][0].detach(),f"Argmax {all_q[3].view(4,4,4,-1)[1][1][0].argmax().item()}")
print(f"Best Actions: {best_first_action.item()} {best_second_action.item()} {best_third_action.item()} {best_fourth_action.item()}")

Best Actions: 1 1 0 0


In [19]:
#View Action Weights (This hasn't been informative yet)
for i in range(500):
    dec, all_policies = model(train_data[i][0]) 

    best_first_action = all_policies[0].argmax()
    best_second_action = all_policies[1].view(4,-1)[best_first_action].argmax() 
    best_third_action = all_policies[2].view(4,4,-1)[best_first_action][best_second_action].argmax()
    best_fourth_action = all_policies[3].view(4,4,4,-1)[best_first_action][best_second_action][best_third_action].argmax() 

    print(f"Best Actions: {best_first_action.item()} {best_second_action.item()} {best_third_action.item()} {best_fourth_action.item()}")

Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actions: 1 1 0 0
Best Actio

In [20]:
train_data[0][0]

tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0., 

In [21]:
train_data[0][1]

tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
            0.,  0.,  0.,  0.,  0.,  0.],
          [ 0., 