In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn import Parameter

from torch.optim import Adam

from treeQN.treeqn_traj_simplest import TreeQN
import image_world
import random

import pandas as pd

In [2]:
train_data, valid_data = image_world.get_data(size = 20)

<h1> Initialize Model

In [3]:
a = torch.zeros(20,20)
b = torch.zeros(20,20)
b[0][0] = 1
F.mse_loss(a,b) #Target Loss should be at least less than this

tensor(0.0025)

In [4]:
input_shape = train_data[0][0].shape# minimum size #train_data[0][0].shape
num_actions = 4
tree_depth = 4
embedding_dim = 64
gamma = 1 
decode_dropout = 0
t1 =False#True is Einsum. False +dx 
model = TreeQN(input_shape=input_shape, num_actions=num_actions, tree_depth=tree_depth, embedding_dim=embedding_dim, gamma=gamma,decode_dropout=decode_dropout,t1=t1)
optimizer = Adam(model.parameters(), lr=1e-4)

Addition Transition


<h1>Optional: Pretrain Autoencoder</h1>
(Doesn't seem necessary, useful for testing if decoding Z is possible though)

In [5]:
#image_world.train_autoencoder(model,optimizer,train_data,valid_data,epochs=100,lambda_reg=0)

In [6]:
# #freeze encoder and decoder
# for param in model.encoder.parameters():
#     param.requires_grad = False
# for param in model.decoder.parameters():
#     param.requires_grad = False

In [7]:
# #test_autoencode ability
# def test_autoencoder(model,valid_data):
#     with torch.no_grad():
#         test_sample = random.choice(valid_data)
#         encoding = model.encoder(test_sample[0])
#         decoding = model.decoder(encoding)
#         print('Original:\n', test_sample[0].numpy()[0][0][4:-4, 4:-4])
#         print('Reconstructed:\n', np.round(decoding.numpy()[0][0][4:-4, 4:-4]))

<h1>Train Full Model

In [8]:
all_gradients = []
model.train()  

for epoch in range(3000):  # epochs
    avg_loss = 0
    avg_decode_loss, avg_first_loss, avg_second_loss, avg_third_loss, avg_fourth_loss = 0, 0, 0, 0, 0
    raw_gradients = []

    for t in random.sample(train_data, len(train_data)):  # sample through all data in random order each epoch
        # Get reconstruction loss to help ground abstract state
        decoded_values, transition_probabilities = model(t[0])
        decode_loss = F.mse_loss(decoded_values[0], t[0], reduction='sum')

        # Flatten transition probabilities to then weigh with loss of each predicted state at each layer
        first = transition_probabilities[0].view(-1,1,1,1)
        second = transition_probabilities[1].view(-1,1,1,1)
        third = transition_probabilities[2].view(-1,1,1,1)
        fourth = transition_probabilities[3].view(-1,1,1,1)

        #Weighted Transitions
        # first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum()
        # second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum()
        # third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum()
        # fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum()

        #Greedy Policy (Squeezing to eliminate batch and channel dimensions)
        first_loss = (F.mse_loss(decoded_values[1][first.argmax()].squeeze(0),t[1].squeeze(0).squeeze(0)))
        second_loss = (F.mse_loss(decoded_values[2][second.argmax()].squeeze(0),t[2].squeeze(0).squeeze(0)))
        third_loss = (F.mse_loss(decoded_values[3][third.argmax()].squeeze(0),t[3].squeeze(0).squeeze(0)))
        fourth_loss = (F.mse_loss(decoded_values[4][fourth.argmax()].squeeze(0),t[4].squeeze(0).squeeze(0)))

        total_loss = first_loss + second_loss  + third_loss  + fourth_loss + decode_loss

        # break if total loss is nan
        if torch.isnan(total_loss):
            raise ValueError("NAN LOSS")


        avg_decode_loss += decode_loss.item()
        avg_first_loss += first_loss.item()
        avg_second_loss += second_loss.item()
        avg_third_loss += third_loss.item()
        avg_fourth_loss += fourth_loss.item()
        avg_loss += total_loss.item()


        optimizer.zero_grad()
        total_loss.backward()
        # Monitor gradients before clipping and stepping
        all_gradients.append(image_world.store_gradients(model))
        optimizer.step()

    if epoch % 10 == 0: 
        #print just validation
        print(f"Epoch {epoch + 1}, Validation Loss Weighted: {image_world.validate(model, valid_data,weighted=True)}, | Validation Unweighted Avg: {image_world.validate(model, valid_data,weighted=False)/5}")

    #Individual Lossses
    avg_decode_loss = avg_decode_loss / len(train_data)
    avg_first_loss = avg_first_loss / len(train_data)
    avg_second_loss = avg_second_loss / len(train_data)
    avg_third_loss = avg_third_loss / len(train_data)
    avg_fourth_loss = avg_fourth_loss / len(train_data)
    #Full Loss
    avg_train_loss = avg_loss / len(train_data)

    print(f"Epoch {epoch + 1}, Total Loss: {avg_train_loss}, DLoss: {avg_decode_loss}, A1: {avg_first_loss}, A2: {avg_second_loss}, A3: {avg_third_loss}, A4: {avg_fourth_loss}")


  first_loss = (F.mse_loss(decoded_values[1], t[1], reduction='none') * first).sum().item()
  second_loss = (F.mse_loss(decoded_values[2], t[2], reduction='none') * second).sum().item()
  third_loss = (F.mse_loss(decoded_values[3], t[3], reduction='none') * third).sum().item()
  fourth_loss = (F.mse_loss(decoded_values[4], t[4], reduction='none') * fourth).sum().item()


Epoch 1, Validation Loss Weighted: 7.272695064544678, | Validation Unweighted Avg: 0.15447017550468445
Epoch 1, Total Loss: 1.021253763274713, DLoss: 1.0046975414861332, A1: 0.0034557934427125888, A2: 0.004056082857476378, A3: 0.004442845801399513, A4: 0.004601503271524879
Epoch 2, Total Loss: 0.3693698264150457, DLoss: 0.35293838005004957, A1: 0.0034333888557739555, A2: 0.003966094620144842, A3: 0.004424838287840513, A4: 0.004607125037265095
Epoch 3, Total Loss: 0.06195463539016518, DLoss: 0.04547949661077424, A1: 0.0035973065405745398, A2: 0.003925986087415368, A3: 0.004386147867295552, A4: 0.004565698431212116
Epoch 4, Total Loss: 0.029470828302543273, DLoss: 0.013267396898432211, A1: 0.0035939741668037394, A2: 0.0038410434531132606, A3: 0.00428268563806672, A4: 0.004485728054053404
Epoch 5, Total Loss: 0.023157147030261428, DLoss: 0.007292939272222363, A1: 0.003546910475812514, A2: 0.003745620743244548, A3: 0.004178395611234009, A4: 0.004393280828794972
Epoch 6, Total Loss: 0.02703

KeyboardInterrupt: 

In [15]:
#save model
#torch.save(model.state_dict(), 'model.pth')
#load model
#model.load_state_dict(torch.load('model.pth'))

In [9]:
for name, num in all_gradients[-1]:
    print(name +':', round(num, 2))

transition_fun: 0.0
decoder.fc.weight: 0.05
decoder.fc.bias: 0.01
decoder.deconv1.weight: 0.07
decoder.deconv1.bias: 0.02
decoder.deconv2.weight: 0.07
decoder.deconv2.bias: 0.03
decoder.final_conv.weight: 0.17
decoder.final_conv.bias: 0.04
encoder.cnn_encoder.conv1.weight: 0.01
encoder.cnn_encoder.conv1.bias: 0.01
encoder.cnn_encoder.bn1.weight: 0.0
encoder.cnn_encoder.bn1.bias: 0.0
encoder.cnn_encoder.conv2.weight: 0.02
encoder.cnn_encoder.conv2.bias: 0.04
encoder.cnn_encoder.bn2.weight: 0.01
encoder.cnn_encoder.bn2.bias: 0.01
encoder.cnn_encoder.conv3.weight: 0.04
encoder.cnn_encoder.conv3.bias: 0.02
encoder.cnn_encoder.bn3.weight: 0.0
encoder.cnn_encoder.bn3.bias: 0.01
encoder.cnn_encoder.residual_conv.weight: 0.02
encoder.cnn_encoder.residual_conv.bias: 0.01
encoder.linear.weight: 0.06
encoder.linear.bias: 0.01


<h1>View Predictions

In [16]:
image_world.action_viewer(model,start_state=train_data[0][0],actions = [0],shrink=4)

Original:
 tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])
Action: 0
Next State:
 tensor([[-0., -0.,  0., -0.,  0.,  0.,  0.,  0.,  0.,  0., -0.,  0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -

In [18]:
image_world.view_greedy_path(model,start_state=train_data[0],shrink=4)

True Path
12 13
11 13
11 12
11 11
11 10
Original:
 tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])
Action: 1
Next State:
 tensor([[ 0., -0.,  0.,  0., -0.,  0.,  0., -0., -0., -0., -0.,  0.],
        [-0.,  

In [12]:
action_df = image_world.get_action_df(model,valid_data)
action_df.var()

Action 1    0.527086
Action 2    1.263159
Action 3    0.943524
Action 4    1.067746
dtype: float64

In [14]:
action_df.head(10)

Unnamed: 0,Action 1,Action 2,Action 3,Action 4
0,0,0,1,0
1,1,0,0,1
2,0,1,3,1
3,0,1,0,2
4,1,0,1,0
5,0,0,1,0
6,1,0,2,2
7,0,1,3,1
8,0,1,0,2
9,0,0,1,0
