In [3]:
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn import Parameter

from torch.optim import Adam

from treeqn_traj_simplest import TreeQN
import image_world
import random

import pandas as pd

In [4]:
train_data, valid_data = image_world.get_data(size = 20)

<h1> Initialize Model

In [5]:
a = torch.zeros(20,20)
b = torch.zeros(20,20)
b[0][0] = 1
F.mse_loss(a,b) #Target Loss should be at least less than this

tensor(0.0025)

In [6]:
input_shape = train_data[0][0].shape# minimum size #train_data[0][0].shape
num_actions = 4
tree_depth = 4
embedding_dim = 64
gamma = 1 
decode_dropout = 0.5
t1 =False#True is Einsum. False +dx 
model = TreeQN(input_shape=input_shape, num_actions=num_actions, tree_depth=tree_depth, embedding_dim=embedding_dim, gamma=gamma,decode_dropout=decode_dropout,t1=t1)
optimizer = Adam(model.parameters(), lr=1e-4)

Addition Transition


<h1>Optional: Pretrain Autoencoder</h1>
(Doesn't seem necessary, useful for testing if decoding Z is possible though)

In [7]:
#image_world.train_autoencoder(model,optimizer,train_data,valid_data,epochs=100,lambda_reg=0)

In [8]:
# #freeze encoder and decoder
# for param in model.encoder.parameters():
#     param.requires_grad = False
# for param in model.decoder.parameters():
#     param.requires_grad = False

In [9]:
# #test_autoencode ability
# def test_autoencoder(model,valid_data):
#     with torch.no_grad():
#         test_sample = random.choice(valid_data)
#         encoding = model.encoder(test_sample[0])
#         decoding = model.decoder(encoding)
#         print('Original:\n', test_sample[0].numpy()[0][0][4:-4, 4:-4])
#         print('Reconstructed:\n', np.round(decoding.numpy()[0][0][4:-4, 4:-4]))

<h1>Train Full Model

In [10]:
all_gradients = []
model.train()  

for epoch in range(3000):  # epochs
    avg_loss = 0
    avg_decode_loss, avg_first_loss, avg_second_loss, avg_third_loss, avg_fourth_loss = 0, 0, 0, 0, 0
    raw_gradients = []

    for t in random.sample(train_data, len(train_data)):  # sample through all data in random order each epoch
        # Get reconstruction loss to help ground abstract state
        decoded_values, transition_probabilities = model(t[0])
        decode_loss = F.mse_loss(decoded_values[0], t[0], reduction='sum')

        # Flatten transition probabilities to then weigh with loss of each predicted state at each layer
        first = transition_probabilities[0].view(-1,1,1,1)
        second = transition_probabilities[1].view(-1,1,1,1)
        third = transition_probabilities[2].view(-1,1,1,1)
        fourth = transition_probabilities[3].view(-1,1,1,1)

        #Weighted Transitions
        # first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum()
        # second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum()
        # third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum()
        # fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum()

        #Greedy Policy (Squeezing to eliminate batch and channel dimensions)
        first_loss = (F.mse_loss(decoded_values[1][first.argmax()].squeeze(0),t[1].squeeze(0).squeeze(0)))
        second_loss = (F.mse_loss(decoded_values[2][second.argmax()].squeeze(0),t[2].squeeze(0).squeeze(0)))
        third_loss = (F.mse_loss(decoded_values[3][third.argmax()].squeeze(0),t[3].squeeze(0).squeeze(0)))
        fourth_loss = (F.mse_loss(decoded_values[4][fourth.argmax()].squeeze(0),t[4].squeeze(0).squeeze(0)))

        total_loss = first_loss + second_loss  + third_loss  + fourth_loss + decode_loss

        # break if total loss is nan
        if torch.isnan(total_loss):
            raise ValueError("NAN LOSS")


        avg_decode_loss += decode_loss.item()
        avg_first_loss += first_loss.item()
        avg_second_loss += second_loss.item()
        avg_third_loss += third_loss.item()
        avg_fourth_loss += fourth_loss.item()
        avg_loss += total_loss.item()


        optimizer.zero_grad()
        total_loss.backward()
        # Monitor gradients before clipping and stepping
        all_gradients.append(image_world.store_gradients(model))
        optimizer.step()

    if epoch % 10 == 0: 
        #print just validation
        print(f"Epoch {epoch + 1}, Validation Loss Weighted: {image_world.validate(model, valid_data,weighted=True)}, | Validation Unweighted Avg: {image_world.validate(model, valid_data,weighted=False)/5}")

    #Individual Lossses
    avg_decode_loss = avg_decode_loss / len(train_data)
    avg_first_loss = avg_first_loss / len(train_data)
    avg_second_loss = avg_second_loss / len(train_data)
    avg_third_loss = avg_third_loss / len(train_data)
    avg_fourth_loss = avg_fourth_loss / len(train_data)
    #Full Loss
    avg_train_loss = avg_loss / len(train_data)
    
    # avg_decode_loss = (avg_decode_loss / len(train_data))/0.0025
    # avg_first_loss = (avg_first_loss / len(train_data))/0.0025
    # avg_second_loss = (avg_second_loss / len(train_data))/0.0025
    # avg_third_loss = (avg_third_loss / len(train_data))/0.0025
    # avg_fourth_loss = (avg_fourth_loss / len(train_data))/0.0025
    # #Full Loss
    # avg_train_loss = (avg_loss / len(train_data))/0.0025    
    # if (avg_decode_loss + avg_first_loss + avg_second_loss + avg_third_loss + avg_fourth_loss) < 2:
    #     break

    print(f"Epoch {epoch + 1}, Total Loss: {avg_train_loss}, DLoss: {avg_decode_loss}, A1: {avg_first_loss}, A2: {avg_second_loss}, A3: {avg_third_loss}, A4: {avg_fourth_loss}")


  first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum().item()
  second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum().item()
  third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum().item()
  fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum().item()


Epoch 1, Validation Loss Weighted: 7.4315505027771, | Validation Unweighted Avg: 0.19917109608650208
Epoch 1, Total Loss: 1.2242632088336078, DLoss: 1.2075781120495364, A1: 0.0037088934363881973, A2: 0.004093928370540115, A3: 0.0043763829797337, A4: 0.00450589331925254
Epoch 2, Total Loss: 0.7962757179682906, DLoss: 0.7811783217570999, A1: 0.003111842741385441, A2: 0.003644602153111588, A3: 0.004061055342158811, A4: 0.004279895141636106
Epoch 3, Total Loss: 0.2777337351136587, DLoss: 0.2618944889527153, A1: 0.0034190229944546114, A2: 0.003819341579748487, A3: 0.004188875748183239, A4: 0.004412004679695449
Epoch 4, Total Loss: 0.05249040217392824, DLoss: 0.03674559223846617, A1: 0.003481004252733493, A2: 0.003763037458570166, A3: 0.004125462237491526, A4: 0.004375305977141993
Epoch 5, Total Loss: 0.023803234650668772, DLoss: 0.008460839130566455, A1: 0.003413925662806088, A2: 0.003647099786692045, A3: 0.004003371396737004, A4: 0.004277998592111875
Epoch 6, Total Loss: 0.0314867511137642

KeyboardInterrupt: 

In [None]:
#save model
#torch.save(model.state_dict(), 'model.pth')
#load model
#model.load_state_dict(torch.load('model_1500.pth'))

In [11]:
for name, num in all_gradients[-1]:
    print(name +':', round(num, 2))

transition_fun: 0.0
decoder.fc.weight: 0.06
decoder.fc.bias: 0.0
decoder.deconv1.weight: 0.22
decoder.deconv1.bias: 0.01
decoder.deconv2.weight: 0.24
decoder.deconv2.bias: 0.05
decoder.final_conv.weight: 0.23
decoder.final_conv.bias: 0.12
encoder.cnn_encoder.conv1.weight: 0.04
encoder.cnn_encoder.conv1.bias: 0.14
encoder.cnn_encoder.bn1.weight: 0.01
encoder.cnn_encoder.bn1.bias: 0.01
encoder.cnn_encoder.conv2.weight: 0.09
encoder.cnn_encoder.conv2.bias: 0.01
encoder.cnn_encoder.bn2.weight: 0.01
encoder.cnn_encoder.bn2.bias: 0.0
encoder.cnn_encoder.conv3.weight: 0.08
encoder.cnn_encoder.conv3.bias: 0.0
encoder.cnn_encoder.bn3.weight: 0.0
encoder.cnn_encoder.bn3.bias: 0.0
encoder.cnn_encoder.residual_conv.weight: 0.01
encoder.cnn_encoder.residual_conv.bias: 0.0
encoder.linear.weight: 0.13
encoder.linear.bias: 0.0


<h1>View Predictions

In [None]:
image_world.action_viewer(model,start_state=train_data[0][0],actions = [0,1,2,3],shrink=4)

True Original Start State:
 tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])
Max True Original Start State: (7, 8)
Decoded Next States
Action: 0
Next State:
 tensor([[ 0., -0., -0.,  0.,  0., -0.,  0., -0.,  

In [None]:
image_world.view_greedy_path(model,start_state=train_data[0],shrink=4)

True Path
7 8
8 8
8 9
8 10
9 10
True Original Start State:
 tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])
Max True Original Start State: (7, 8)
Decoded Next States
Action: 3
Next State:
 tensor([[ 0., -0.,

In [None]:
action_df = image_world.get_action_df(model,valid_data)
action_df.var()

Action 1    1.322873
Action 2    0.630054
Action 3    1.450892
Action 4    1.317476
dtype: float64

In [None]:
action_df.head(10)

Unnamed: 0,Action 1,Action 2,Action 3,Action 4
0,3,1,3,1
1,2,3,1,3
2,3,2,0,3
3,0,1,3,2
4,3,3,3,3
5,1,2,1,3
6,0,2,1,3
7,3,1,3,1
8,2,3,1,3
9,2,3,1,3


In [None]:
unique_rows = action_df.drop_duplicates()
len(action_df), len(unique_rows)

(220, 16)