In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn import Parameter

from torch.optim import Adam
from torch.optim import RMSprop

from treeQN.treeqn_traj_sr import TreeQN
import random

import pandas as pd

In [2]:
def get_start(size):
    goal_point = (size//2, size//2)
    start_point = -1
    while True:
        start_point = (random.randint(0,size), random.randint(0,size))
        if goal_point != start_point:
            break
    return start_point, goal_point

In [3]:
def hard_policy(state,goal_point):
    goal_x, goal_y = goal_point
    x,y = state
    x_right = goal_x > x # if goal is right
    x_left = goal_x < x # if goal is left
    y_up = goal_y > y # if goal is above
    y_down = goal_y < y # if goal is below
    possible_next_states = []
    if x_right:
        possible_next_states.append((x+1,y))
    if x_left:
        possible_next_states.append((x-1,y))
    if y_up:
        possible_next_states.append((x,y+1))
    if y_down:
        possible_next_states.append((x,y-1))
    if len(possible_next_states) == 0:
        return -1
    return random.choice(possible_next_states)



In [4]:
def point_to_tensor(point,goal,size):
    x,y = point
    x_goal, y_goal = goal
    tensor = torch.zeros(size+2,size+2)
    scale = 1
    tensor[x][y] = 1 * scale
    tensor[x+1][y] = 1 * scale
    tensor[x][y+1] = 1 * scale
    tensor[x+1][y+1] = 1
    tensor[x_goal][y_goal] = -1 * scale
    tensor[x_goal+1][y_goal] = -1 * scale
    tensor[x_goal][y_goal+1] = -1 * scale
    tensor[x_goal+1][y_goal+1] = -1 * scale
    return tensor

In [5]:
def get_trajectory(size = 18,start_point = None, goal_point = None):
    trajectory = []
    if start_point is None:
        start, goal = get_start(size)
    else:
        start, goal = start_point, goal_point
    trajectory.append(start)
    while start != goal:
        start = hard_policy(start,goal)
        trajectory.append(start)
    if len(trajectory) != 5:
        return get_trajectory(size)
    return [point_to_tensor(p,goal,size).unsqueeze(0).unsqueeze(0) for p in trajectory]

In [6]:
def max_starting_points(size = 18):
    start_points = []
    for i in range(10000):
        start_points.append(get_start(18)[0])
    start_points = set(start_points)
    goal_point = (size//2, size//2)
    return start_points, goal_point

In [7]:
s, goal_point = max_starting_points()
start_points = list(s)
train_start_points = start_points[:len(start_points)//2]
test_start_points = start_points[len(start_points)//2:]

train_data = [get_trajectory(18,start_point,goal_point) for start_point in train_start_points]
valid_data = [get_trajectory(18,start_point,goal_point) for start_point in test_start_points]

In [8]:
input_shape = torch.zeros(1, 1,20, 20).shape# minimum size #train_data[0][0].shape
num_actions = 2
tree_depth = 4
embedding_dim = 4
td_lambda = 1
gamma = 1    #0.99
model = TreeQN(input_shape=input_shape, num_actions=num_actions, tree_depth=tree_depth, embedding_dim=embedding_dim, td_lambda=td_lambda,gamma=gamma)
optimizer = Adam(model.parameters(), lr=1e-4)
#optimizer = RMSprop(model.parameters(), lr=1e-4,alpha =0.99, eps = 1e-5) | loss from treeqn paper
# Collect all encoder and decoder parameters

In [10]:
def validate(model, valid_data):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    with torch.no_grad():  # Disable gradient calculation
        for t in valid_data:
            decoded_values, all_policies = model(t[0])
            decode_loss = F.mse_loss(decoded_values[0], t[0], reduction='sum')

            first_policy = all_policies[0]
            second_policy = all_policies[1].view(num_actions, -1)
            third_policy = all_policies[2].view(num_actions, num_actions, -1)
            fourth_policy = all_policies[3].view(num_actions, num_actions, num_actions, -1)

            second_layer_probs = first_policy * second_policy
            third_layer_probs = second_layer_probs * third_policy
            fourth_layer_probs = third_layer_probs * fourth_policy

            first = torch.flatten(first_policy).view(num_actions, 1, 1, 1)
            second = torch.flatten(second_layer_probs).view(num_actions**2, 1, 1, 1)
            third = torch.flatten(third_layer_probs).view(num_actions**3, 1, 1, 1)
            fourth = torch.flatten(fourth_layer_probs).view(num_actions**4, 1, 1, 1)

            first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum()
            second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum()
            third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum()
            fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum()

            l2w, l3w, l4w = 1, 1, 1
            total_loss += first_loss + second_loss * l2w + third_loss * l3w + fourth_loss * l4w + decode_loss

    return total_loss / len(valid_data)

In [11]:
def store_gradients(model):
    gradients = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            gradients.append([name, param.grad.norm().item()])
    return gradients

In [12]:
all_gradients = []
raw_losses = []
for epoch in range(3000):  # epochs
    model.train()  # Set the model to training mode
    avg_loss = 0
    temp_loss = 0
    temp_raw_loss = 0
    sample_count = 0

    avg_raw_loss = 0
    raw_gradients = []
    for t in random.sample(train_data, len(train_data)):  # sample through all data in random order each epoch
        # Get reconstruction loss to help ground abstract state
        decoded_values, all_policies = model(t[0])
        decode_loss = F.mse_loss(decoded_values[0], t[0], reduction='sum')

        # Get transition probabilities for each state
        first_policy = all_policies[0]
        second_policy = all_policies[1].view(num_actions, -1)
        third_policy = all_policies[2].view(num_actions, num_actions, -1)
        fourth_policy = all_policies[3].view(num_actions, num_actions, num_actions, -1)

        # These should all add to 1 (in testing there seems to be some small rounding error)
        second_layer_probs = first_policy * second_policy
        third_layer_probs = second_layer_probs * third_policy
        fourth_layer_probs = third_layer_probs * fourth_policy

        # Flatten transition probabilities to then weigh with loss of each predicted state at each layer
        first = torch.flatten(first_policy).view(num_actions, 1, 1, 1)
        second = torch.flatten(second_layer_probs).view(num_actions**2, 1, 1, 1)
        third = torch.flatten(third_layer_probs).view(num_actions**3, 1, 1, 1)
        fourth = torch.flatten(fourth_layer_probs).view(num_actions**4, 1, 1, 1)

        first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum()
        second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum()
        third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum()
        fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum()

        # For experimenting with different weights on different layers
        raw_loss = (first_loss + second_loss + third_loss + fourth_loss).detach().item()
        raw_losses.append(raw_loss)
        l2w, l3w, l4w = 1, 1, 1
        total_loss = first_loss + second_loss * l2w + third_loss * l3w + fourth_loss * l4w + decode_loss

        # break if total loss is nan
        if torch.isnan(total_loss):
            raise ValueError("NAN LOSS")

        temp_loss += total_loss
        temp_raw_loss += raw_loss
        sample_count += 1

        if sample_count % 1 == 0:
            optimizer.zero_grad()
            temp_loss.backward()

            # Monitor gradients before clipping and stepping
            all_gradients.append(store_gradients(model))

            # Uncomment if you want to use gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

            optimizer.step()
            avg_loss += temp_loss.item()
            avg_raw_loss += temp_raw_loss
            temp_loss = 0
            temp_raw_loss = 0

    # To handle the case where the number of samples is not a multiple of 1
    if sample_count % 1 != 0:
        optimizer.zero_grad()
        temp_loss.backward()
        
        # Monitor gradients before clipping and stepping
        all_gradients.append(store_gradients(model))

        # Uncomment if you want to use gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

        optimizer.step()
        avg_loss += temp_loss.item()
        avg_raw_loss += temp_raw_loss

    avg_train_loss = avg_loss / len(train_data)
    avg_train_raw_loss = avg_raw_loss / len(train_data)

    # Perform validation
    avg_valid_loss = validate(model, valid_data)
    print(f"Epoch {epoch + 1}, Train Loss: {avg_train_loss}, Train Raw Loss: {avg_train_raw_loss}, Validation Loss: {avg_valid_loss}")


  first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum()
  second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum()
  third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum()
  fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum()
  first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum()
  second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum()
  third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum()
  fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum()


Epoch 1, Train Loss: 37.3575232717726, Train Raw Loss: 28.553085062238907, Validation Loss: 27.97074317932129
Epoch 2, Train Loss: 20.647692426045737, Train Raw Loss: 15.359329748153687, Validation Loss: 15.841936111450195
Epoch 3, Train Loss: 14.232015228271484, Train Raw Loss: 10.077617655860053, Validation Loss: 13.414351463317871
Epoch 4, Train Loss: 13.064902273813884, Train Raw Loss: 9.128142494625516, Validation Loss: 12.954946517944336
Epoch 5, Train Loss: 12.781253062354194, Train Raw Loss: 8.919904963175457, Validation Loss: 12.785567283630371
Epoch 6, Train Loss: 12.660044214460585, Train Raw Loss: 8.837732474009195, Validation Loss: 12.714303016662598
Epoch 7, Train Loss: 12.586459716161093, Train Raw Loss: 8.796415419048733, Validation Loss: 12.624241828918457
Epoch 8, Train Loss: 12.519942347208659, Train Raw Loss: 8.767472653918796, Validation Loss: 12.570167541503906
Epoch 9, Train Loss: 12.446549293729994, Train Raw Loss: 8.743793487548828, Validation Loss: 12.51102161

KeyboardInterrupt: 

In [15]:
#4.46, 5.3 at 2000
all_gradients[-1]

[['transition_fun', 0.7972447276115417],
 ['decoder.fc.weight', 1.0146783590316772],
 ['decoder.fc.bias', 1.225597858428955],
 ['decoder.deconv1.weight', 4.208570957183838],
 ['decoder.deconv1.bias', 0.8115442991256714],
 ['decoder.deconv2.weight', 0.8129922151565552],
 ['decoder.deconv2.bias', 0.6843503713607788],
 ['encoder.cnn_encoder.conv1.weight', 0.5368334054946899],
 ['encoder.cnn_encoder.conv1.bias', 0.45351430773735046],
 ['encoder.cnn_encoder.conv2.weight', 1.2930923700332642],
 ['encoder.cnn_encoder.conv2.bias', 0.7256703972816467],
 ['encoder.linear.weight', 2.945633888244629],
 ['encoder.linear.bias', 1.5791183710098267]]

In [25]:
d,q = model(valid_data[0][0]) ##Compare plateau in loss with just 1 action

In [27]:
def get_best_path(decoded, all_policies):
    best_first_action = all_policies[0].argmax().item()
    best_second_action = all_policies[1].view(num_actions,-1)[best_first_action].argmax().item() 
    best_third_action = all_policies[2].view(num_actions,num_actions,-1)[best_first_action][best_second_action].argmax().item()
    best_fourth_action = all_policies[3].view(num_actions,num_actions,num_actions,-1)[best_first_action][best_second_action][best_third_action].argmax().item()
    print(best_first_action,best_second_action,best_third_action,best_fourth_action)
    first = decoded[0]
    second = decoded[1][best_first_action].unsqueeze(0)
    third = decoded[2][best_first_action**2 + best_second_action].unsqueeze(0)
    fourth = decoded[3][best_first_action**3 + best_second_action**2 + best_third_action].unsqueeze(0)
    fifth = decoded[4][best_first_action**4 + best_second_action**3 + best_third_action**2 + best_fourth_action].unsqueeze(0)
    return [first,second,third,fourth,fifth]

In [28]:
i = get_best_path(d,q)
for a in i:
    print(a.shape)

0 0 1 1
torch.Size([1, 1, 20, 20])
torch.Size([1, 1, 20, 20])
torch.Size([1, 1, 20, 20])
torch.Size([1, 1, 20, 20])
torch.Size([1, 1, 20, 20])


In [29]:
#try backward in succession?
for i in valid_data:
    a,b = model(i[0])
    get_best_path(a,b)

0 0 1 1
0 0 0 0
0 0 1 1
0 0 0 1
0 0 0 1
0 0 1 1
0 0 1 1


0 0 1 1
0 0 0 1
0 0 0 0
1 1 1 1
0 0 0 1
0 0 0 1
0 0 0 1
0 0 0 0
0 0 1 1
0 0 1 1
0 0 0 1
0 0 1 1
0 0 0 0
1 0 1 1
0 0 1 1
0 0 1 1
0 0 0 0
0 0 0 1
1 1 1 1
0 0 1 1
0 0 0 1
0 0 0 1
0 0 1 1
0 0 0 1
0 0 0 1
0 0 1 1
1 0 0 0
0 0 0 0
1 0 0 0
0 0 1 1
0 0 1 1
0 0 1 1
1 0 1 1
0 0 0 1
1 0 0 0
1 0 0 0
0 0 0 1
0 0 0 1
1 0 0 0
0 0 0 0
0 0 0 1
0 0 0 0
1 0 0 0
0 0 1 1
0 0 0 1
1 0 0 0
0 0 1 1
0 0 0 1
0 0 1 1
1 0 0 0
1 0 1 1
0 0 1 1
1 0 0 0
0 0 0 0
0 0 0 1
0 0 0 0
0 0 1 1
0 0 0 0
0 0 0 0
0 0 1 1
0 0 0 0
0 0 0 1
0 0 1 1
0 0 0 1
0 0 1 1
1 0 1 1
0 0 1 1
0 0 0 1
0 0 1 1
0 0 0 1
0 0 0 1
1 0 0 0
0 0 1 1
0 0 1 1
0 0 0 0
0 0 1 1
0 0 1 1
0 0 0 0
0 0 0 1
0 0 0 0
1 0 0 0
0 0 1 1
0 0 1 1
0 0 1 1
0 0 0 0
0 0 1 1
0 0 0 0
0 0 1 1
0 0 1 1
0 0 0 0
1 0 0 0
1 0 0 0
0 0 0 0
0 0 0 0
0 0 0 1
0 0 0 1
0 0 1 1
0 0 1 1
0 0 1 1
1 0 0 0
0 0 1 1
1 0 0 0
0 0 0 0
0 0 0 0
0 0 1 1
0 0 0 1
0 0 0 1
0 0 1 1
0 0 0 0
1 0 0 0
0 0 0 1
0 0 0 0
0 0 0 0
0 0 0 0
1 0 0 0
0 0 0 1
0 0 1 1
0 0 1 1
0 0 0 1
0 0 1 1
0 0 0 1
0 0 1 1
0 0 0 1
1 0 1 1
0 0 0 0


In [32]:
def viewer(tensor):
    tensor = tensor.squeeze(0).squeeze(0)
    zoomed_tensor = tensor[5:-5, 5:-5]
    return torch.round(zoomed_tensor)

def print_movie(tensor_list):
    for tensor,name in zip(tensor_list,["Start","Action 1","Action 2","Action 3","Action"]):
        print(name)
        print(viewer(tensor))
        print("\n")

print_movie(get_best_path(d,q))

0 0 1 1
Start
tensor([[-0.,  0., -0., -0.,  0.,  0., -0., -0.,  0., -0.],
        [-0.,  0.,  0.,  0., -0., -0., -0.,  0.,  0., -0.],
        [-0., -0.,  0.,  0., -0.,  0.,  0.,  0.,  0., -0.],
        [ 0., -0.,  0., -0., -0.,  0.,  0.,  1.,  1.,  0.],
        [ 0.,  0., -0., -0., -1., -1.,  0.,  1.,  1.,  0.],
        [-0.,  0.,  0., -0., -1., -1.,  0., -0.,  0., -0.],
        [-0.,  0.,  0., -0., -0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -0., -0.],
        [ 0.,  0., -0.,  0., -0., -0., -0.,  0., -0.,  0.],
        [-0., -0., -0., -0., -0., -0., -0., -0.,  0.,  0.]],
       grad_fn=<RoundBackward0>)


Action 1
tensor([[ 0.,  0., -0.,  0.,  0., -0.,  0., -0.,  0., -0.],
        [-0.,  0.,  0., -0.,  0., -0.,  0.,  0., -0., -0.],
        [-0.,  0.,  0., -0.,  0.,  0.,  0.,  0., -0., -0.],
        [-0., -0., -0., -0.,  0.,  0.,  0.,  1.,  0., -0.],
        [ 0.,  0.,  0., -0., -1., -1.,  0.,  1.,  0., -0.],
        [-0., -0., -0., -0., -1., -1.,  0

In [None]:
#View Action Weights (This hasn't been informative yet)
dec, all_policies = model(train_data[0][0]) 
# dot = make_dot((dec[0],dec[1],dec[2],dec[3],dec[4],all_policies[0],all_policies[1],all_policies[2],all_policies[3]),params=dict(model.named_parameters()))
# dot.render('model', format='png')
print(f"Action Weight Sums { torch.round(model.transition_fun.data,decimals=3).sum(dim=0).sum(dim=0)}")  #might be summing the wrong way, or just not interesting

In [None]:
best_first_action = all_policies[0].argmax()
best_second_action = all_policies[1].view(num_actions,-1)[best_first_action].argmax() 
best_third_action = all_policies[2].view(num_actions,num_actions,-1)[best_first_action][best_second_action].argmax()
best_fourth_action = all_policies[3].view(num_actions,num_actions,num_actions,-1)[best_first_action][best_second_action][best_third_action].argmax() 
# print(torch.round(all_q[0],decimals=3).detach(), f"Argmax {all_q[0].argmax().item()}")
# print(torch.round(all_q[1],decimals=3).view(4,-1).detach(),f"Argmax {all_q[1].view(4,-1)[1].argmax().item()}")
# print(torch.round(all_q[2],decimals=3).view(4,4,-1)[0].detach(),f"Argmax {all_q[2].view(4,4,-1)[1][0].argmax().item()}")
# print(torch.round(all_q[3],decimals=3).view(4,4,4,-1)[0][0].detach(),f"Argmax {all_q[3].view(4,4,4,-1)[1][1][0].argmax().item()}")
print(f"Best Actions: {best_first_action.item()} {best_second_action.item()} {best_third_action.item()} {best_fourth_action.item()}")

In [None]:
#View Action Weights (This hasn't been informative yet)
for i in range(500):
    dec, all_policies = model(train_data[i][0]) 

    best_first_action = all_policies[0].argmax()
    best_second_action = all_policies[1].view(4,-1)[best_first_action].argmax() 
    best_third_action = all_policies[2].view(4,4,-1)[best_first_action][best_second_action].argmax()
    best_fourth_action = all_policies[3].view(4,4,4,-1)[best_first_action][best_second_action][best_third_action].argmax() 

    print(f"Best Actions: {best_first_action.item()} {best_second_action.item()} {best_third_action.item()} {best_fourth_action.item()}")

In [None]:
train_data[0][0]

In [None]:
train_data[0][1]