In [1]:
import os

In [2]:
import math

import time
import random
import pickle

from collections import deque

%matplotlib notebook
import matplotlib as mpl
plt = mpl.pyplot
# mpl.pylab.rcParams['figure.figsize'] = (12, 9)

import numpy as np
import torch as tch
F = tch.nn.functional

In [3]:
tch.cuda.is_available()

True

In [4]:
# pytorch auxiliar functions
def np2var(input):
    output = tch.autograd.Variable(tch.from_numpy(input), volatile=True)
    if tch.cuda.is_available():
        output = output.cuda()
    return output

def np2var_cpu(input):
    output = tch.autograd.Variable(tch.from_numpy(input), volatile=True)
    return output


# tanh2sigmoid = lambda x: (x + 1) / 2

def postprocess(tch_img):
    return (tch_img * 255).byte().data.cpu().numpy().squeeze()

# screen slicing00
top = 86
left = 20 + 100
height = 300
width = 500 - 200
gamescreen_slice = (slice(top,  top + height), slice(left, left + width))
game_center = (top + height // 2, left + width // 2) # y, x
game_center = game_center[::-1]

# average pooling for downscaling
s = downscale_size = 2
pool2d_downscale = tch.nn.AvgPool2d((s, s), stride=(s, s))
downscale_n = 1
downscale = downscale_size ** downscale_n
dummy_img = np.zeros((height // downscale, width // downscale, 1))

In [6]:
class CNN_encoder(tch.nn.Module):
    def __init__(self, n_convs=5, out_ch=128, act=F.selu):
        
        super(CNN_encoder, self).__init__()
        
        self.n_convs = n_convs
        self.kernel_size = (4, 4)
        self.out_ch = out_ch
        self.act = act
        
        self.conv_layers = tch.nn.ModuleList()
        self.conv_layers.append(tch.nn.Conv2d(1, 16, self.kernel_size, stride=(2, 2), padding=(2, 2)))
        self.conv_layers.append(tch.nn.Conv2d(16, 16, self.kernel_size, stride=(2, 2), padding=(1, 1)))
        self.conv_layers.append(tch.nn.Conv2d(16, 32, self.kernel_size, stride=(2, 2), padding=(2, 2)))
        self.conv_layers.append(tch.nn.Conv2d(32, 64, self.kernel_size, stride=(2, 2), padding=(1, 1)))
        self.conv_layers.append(tch.nn.Conv2d(64, 128, self.kernel_size, stride=(2, 2), padding=(1, 1)))
        self.conv_layers.append(tch.nn.Conv2d(128, 512, (5, 5), stride=(1, 1), padding=(0, 0)))
            
        self.last_conv_unrolled_size = 512
        self.fc_sizes = []
        self.fc_layers = tch.nn.ModuleList()
        in_size = self.last_conv_unrolled_size
        for out_size in self.fc_sizes:
            self.fc_layers.append(tch.nn.Linear(in_size, out_size))
            in_size = out_size
            
        self.conv_sizes = []
            
    def forward(self, input):
        
        self.conv_sizes = []

        output = input
        for layer in list(self.conv_layers):
            self.conv_sizes.append(tuple(output.size()))
            output = self.act(layer(output))
        self.conv_sizes.append(tuple(output.size()))
        
        output = output.view(-1, self.last_conv_unrolled_size)
        
        for layer in self.fc_layers:
            output = self.act(layer(output))
        
        return output

encoder = CNN_encoder()

for param in encoder.parameters():
    n = np.prod(list(param.size())[-3:]) # valid for 2D convolutions
    stdv = 2. / math.sqrt(n)
    param.data.uniform_(-stdv, stdv)

if tch.cuda.is_available():
    encoder = encoder.cuda()

dummy_img_T = dummy_img.transpose(2, 0, 1)
dummy_img_T = np2var(dummy_img_T).float().unsqueeze(0)
_ = encoder(dummy_img_T)

In [7]:
class CNN_decoder(tch.nn.Module):
    def __init__(self, encoder):
        
        super(CNN_decoder, self).__init__()
        
        self.act = encoder.act
        
        self.input_size = encoder.fc_sizes[-1] if len(encoder.fc_sizes) > 0 else None
        self.fc_sizes = (encoder.fc_sizes[-2::-1] + [encoder.last_conv_unrolled_size] 
                         if len(encoder.fc_sizes) > 0 else [])
        
        self.last_conv_shape = (-1, ) + encoder.conv_sizes[-1][1:]
        
        self.fc_layers = tch.nn.ModuleList()

        in_size = self.input_size
        for out_size in self.fc_sizes:
            self.fc_layers.append(tch.nn.Linear(in_size, out_size))
            in_size = out_size
            
            
        self.convtrans_layers = tch.nn.ModuleList()
        in_ch = encoder.conv_sizes[-1][1]
        for output_shape, encode_layer in zip(encoder.conv_sizes[-2::-1], tuple(encoder.conv_layers)[::-1]):
            out_ch = output_shape[1]
            self.convtrans_layers.append(tch.nn.ConvTranspose2d(in_ch, out_ch,
                                                                encode_layer.kernel_size,
                                                                encode_layer.stride,
                                                                encode_layer.padding))
            in_ch = out_ch
            
        
    def forward(self, input):
        
        output = input
        for layer in self.fc_layers:
            output = self.act(layer(output))
            
        output = output.view(*self.last_conv_shape)
        
        for layer in list(self.convtrans_layers)[:-1]:
            output = self.act(layer(output))   
        output = F.sigmoid(self.convtrans_layers[-1](output))
        
        return output
        

decoder = CNN_decoder(encoder)

for param in decoder.parameters():
    n = 1 if len(param.size()) == 1 else param.size()[0]
    n *= np.prod(list(param.size())[-2:]) # valid for 2D convolutions
    stdv = 2. / math.sqrt(n)
    param.data.uniform_(-stdv, stdv)

if tch.cuda.is_available():
    decoder = decoder.cuda()

In [8]:
class LSTM_predictor(tch.nn.Module):
    
    def __init__(self, size, hidden_size):
        
        super(LSTM_predictor, self).__init__()
        
        self.cell = tch.nn.LSTMCell(size, hidden_size)
        
        self.output_layer = tch.nn.Linear(hidden_size, size)
        
        self.act = encoder.act
        
    def forward(self, hidden, input):
               
        hidden = self.cell(input, hidden)
        
        output = self.act(self.output_layer(hidden[0]))
        
        return hidden, output
        
    def zero_hidden(self):
        
        hidden = (tch.autograd.Variable(tch.zeros(1, self.cell.hidden_size)),
                  tch.autograd.Variable(tch.zeros(1, self.cell.hidden_size)))
        
        if tch.cuda.is_available():
            hidden = tuple(h.cuda() for h in hidden)
            
        return hidden # h_0, c_0
    
predictor = LSTM_predictor(512, 512)

if tch.cuda.is_available():
    predictor = predictor.cuda()

plt.figure()
plt.plot(torch.cat(record).data.cpu().numpy())

In [27]:
load_ids = ['slither-AE-LSTM-final-e2e-sim',
#             'slither-AE-LSTM-final-e2e-d2d',
#             'slither-AE-LSTM-final-e2e-c2d',
#             'slither-AE-LSTM-final-e2e',
#             'slither-AE-LSTM-final-bl',
#             'slither-AE-LSTM-final-n',
            'slither-AE-LSTM-final-e2e-aim']

In [28]:
train_data = pickle.load(open("./final_figures/train_losses.pkl", 'rb'))
test_data = pickle.load(open("./final_figures/test_losses.pkl", 'rb'))
code_error_data = pickle.load(open("./final_figures/code_errors.pkl", 'rb'))

In [29]:
bce_criterion = tch.nn.BCELoss()

In [30]:
# test_data = {}

for load_id in load_ids:
    
    encoder_load = tch.load(load_id + 'encoder.pkl')
    decoder_load = tch.load(load_id + 'decoder.pkl')
    
    if not load_id == 'slither-AE-LSTM-final-bl':
        predictor_load = tch.load(load_id + 'predictor.pkl')
        
    if load_id == 'slither-AE-LSTM-final-bl':
        for module, module_load in zip((encoder, decoder), (encoder_load, decoder_load)):
            for param, param_load in zip(module.parameters(), module_load.parameters()):
                param.data = param_load.data        
    else:
        for module, module_load in zip((encoder, decoder, predictor), (encoder_load, decoder_load, predictor_load)):
            for param, param_load in zip(module.parameters(), module_load.parameters()):
                param.data = param_load.data

    total_loss = 0
    total_steps = 0

    for filename in os.listdir('./game_dataset_test_data/'):

        if not load_id == 'slither-AE-LSTM-final-bl':
            hidden = predictor.zero_hidden()

        with open('./game_dataset_test_data/' + filename, 'rb') as gamescreen_file:
            
            if load_id in ['slither-AE-LSTM-final-e2e', 'slither-AE-LSTM-final-n']:

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                # compute
                codes = encoder(gamescreens)
                pred_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, pred_code = predictor(hidden, code)
                    pred_code_list.append(pred_code)
                pred_codes = tch.cat(pred_code_list, dim=0)
                decoded_pred_gamescreens = decoder(pred_codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += bce_criterion(decoded_pred_gamescreens[:-1], gamescreens[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-bl',]:

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                # compute
                codes = encoder(gamescreens)
                decoded_pred_gamescreens = decoder(codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += bce_criterion(decoded_pred_gamescreens[:-1], gamescreens[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-e2e-c2d',]:

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                # compute
                codes = encoder(gamescreens)
                delta_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, delta_code = predictor(hidden, code)
                    delta_code_list.append(delta_code)
                delta_codes = tch.cat(delta_code_list, dim=0)
                decoded_pred_gamescreens = decoder(delta_codes + codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += bce_criterion(decoded_pred_gamescreens[:-1], gamescreens[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-e2e-d2d',]:
                
                init_code = np2var(np.zeros((1, 512), dtype=np.float32))

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                prev_pred_code = init_code
                # compute
                codes = encoder(gamescreens)
                pred_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, delta_code = predictor(hidden, code)
                    pred_code = code + delta_code
                    pred_code_list.append(pred_code)
                    prev_pred_code = pred_code
                pred_codes = tch.cat(pred_code_list, dim=0)
                decoded_pred_gamescreens = decoder(pred_codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += bce_criterion(decoded_pred_gamescreens[:-1], gamescreens[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-e2e-sim', 'slither-AE-LSTM-final-e2e-aim']:
                
                prev_pred_gamescreen = None
                
                gamescreen_list = pickle.load(gamescreen_file)
        
                for gamescreen, _ in gamescreen_list:

                    gamescreen = np2var(gamescreen).float() / 255
                    gamescreen = gamescreen.unsqueeze(0).unsqueeze(0)

                    if prev_pred_gamescreen is not None:
                        total_steps += 1
                        total_loss += bce_criterion(prev_pred_gamescreen, gamescreen).data[0]
                        
                    if prev_pred_gamescreen is None:
                        prev_pred_gamescreen = 0 * gamescreen

                    # compute
                    if load_id == 'slither-AE-LSTM-final-e2e-aim':
                        code = encoder(gamescreen + prev_pred_gamescreen)
                    elif load_id == 'slither-AE-LSTM-final-e2e-sim':
                        code = encoder(prev_pred_gamescreen - gamescreen)
                    hidden, pred_code = predictor(hidden, code)
                    decoded_pred_gamescreen = decoder(pred_code)
                    prev_pred_gamescreen = decoded_pred_gamescreen
                
            else:
                print('else???!!!!!')
                total_steps = 1

    test_data[load_id] = (total_loss / total_steps)

In [31]:
test_data

{'slither-AE-LSTM-final-bl': 0.3959071990795154,
 'slither-AE-LSTM-final-e2e': 0.39887651646925165,
 'slither-AE-LSTM-final-e2e-aim': 0.409412894215518,
 'slither-AE-LSTM-final-e2e-c2d': 0.3961973212188162,
 'slither-AE-LSTM-final-e2e-d2d': 0.41119116567964503,
 'slither-AE-LSTM-final-e2e-sim': 0.40508294184889276,
 'slither-AE-LSTM-final-n': 0.3973316244720789}

In [32]:
pickle.dump(test_data, open("./final_figures/test_losses.pkl", 'wb'))

In [33]:
# train_data = {}

for load_id in load_ids:
    
    encoder_load = tch.load(load_id + 'encoder.pkl')
    decoder_load = tch.load(load_id + 'decoder.pkl')
    
    if not load_id == 'slither-AE-LSTM-final-bl':
        predictor_load = tch.load(load_id + 'predictor.pkl')
        
    if load_id == 'slither-AE-LSTM-final-bl':
        for module, module_load in zip((encoder, decoder), (encoder_load, decoder_load)):
            for param, param_load in zip(module.parameters(), module_load.parameters()):
                param.data = param_load.data        
    else:
        for module, module_load in zip((encoder, decoder, predictor), (encoder_load, decoder_load, predictor_load)):
            for param, param_load in zip(module.parameters(), module_load.parameters()):
                param.data = param_load.data

    total_loss = 0
    total_steps = 0

    for filename in os.listdir('./game_dataset/'):

        if not load_id == 'slither-AE-LSTM-final-bl':
            hidden = predictor.zero_hidden()

        with open('./game_dataset/' + filename, 'rb') as gamescreen_file:
            
            if load_id in ['slither-AE-LSTM-final-e2e', 'slither-AE-LSTM-final-n']:

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                # compute
                codes = encoder(gamescreens)
                pred_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, pred_code = predictor(hidden, code)
                    pred_code_list.append(pred_code)
                pred_codes = tch.cat(pred_code_list, dim=0)
                decoded_pred_gamescreens = decoder(pred_codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += bce_criterion(decoded_pred_gamescreens[:-1], gamescreens[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-bl',]:

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                # compute
                codes = encoder(gamescreens)
                decoded_pred_gamescreens = decoder(codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += bce_criterion(decoded_pred_gamescreens[:-1], gamescreens[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-e2e-c2d',]:

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                # compute
                codes = encoder(gamescreens)
                delta_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, delta_code = predictor(hidden, code)
                    delta_code_list.append(delta_code)
                delta_codes = tch.cat(delta_code_list, dim=0)
                decoded_pred_gamescreens = decoder(delta_codes + codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += bce_criterion(decoded_pred_gamescreens[:-1], gamescreens[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-e2e-d2d',]:
                
                init_code = np2var(np.zeros((1, 512), dtype=np.float32))

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                prev_pred_code = init_code
                # compute
                codes = encoder(gamescreens)
                pred_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, delta_code = predictor(hidden, code)
                    pred_code = code + delta_code
                    pred_code_list.append(pred_code)
                    prev_pred_code = pred_code
                pred_codes = tch.cat(pred_code_list, dim=0)
                decoded_pred_gamescreens = decoder(pred_codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += bce_criterion(decoded_pred_gamescreens[:-1], gamescreens[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-e2e-sim', 'slither-AE-LSTM-final-e2e-aim']:
                
                prev_pred_gamescreen = None
                
                gamescreen_list = pickle.load(gamescreen_file)
        
                for gamescreen, _ in gamescreen_list:

                    gamescreen = np2var(gamescreen).float() / 255
                    gamescreen = gamescreen.unsqueeze(0).unsqueeze(0)

                    if prev_pred_gamescreen is not None:
                        total_steps += 1
                        total_loss += bce_criterion(prev_pred_gamescreen, gamescreen).data[0]
                        
                    if prev_pred_gamescreen is None:
                        prev_pred_gamescreen = 0 * gamescreen

                    # compute
                    if load_id == 'slither-AE-LSTM-final-e2e-aim':
                        code = encoder(gamescreen + prev_pred_gamescreen)
                    elif load_id == 'slither-AE-LSTM-final-e2e-sim':
                        code = encoder(prev_pred_gamescreen - gamescreen)
                    hidden, pred_code = predictor(hidden, code)
                    decoded_pred_gamescreen = decoder(pred_code)
                    prev_pred_gamescreen = decoded_pred_gamescreen
                
            else:
                print('else???!!!!!')
                total_steps = 1

    train_data[load_id] = (total_loss / total_steps)

In [34]:
train_data

{'slither-AE-LSTM-final-bl': 0.3941200897363571,
 'slither-AE-LSTM-final-e2e': 0.3936181576691321,
 'slither-AE-LSTM-final-e2e-aim': 0.4017716719374208,
 'slither-AE-LSTM-final-e2e-c2d': 0.3931514325426246,
 'slither-AE-LSTM-final-e2e-d2d': 0.41267399215583467,
 'slither-AE-LSTM-final-e2e-sim': 0.4008163459408871,
 'slither-AE-LSTM-final-n': 0.3962164528108782}

In [35]:
pickle.dump(train_data, open("./final_figures/train_losses.pkl", 'wb'))

In [36]:
criterion = tch.nn.MSELoss()

In [37]:
# code_error_data = {}

for load_id in load_ids:
    
    encoder_load = tch.load(load_id + 'encoder.pkl')
    decoder_load = tch.load(load_id + 'decoder.pkl')
    
    if not load_id == 'slither-AE-LSTM-final-bl':
        predictor_load = tch.load(load_id + 'predictor.pkl')
        
    if load_id == 'slither-AE-LSTM-final-bl':
        for module, module_load in zip((encoder, decoder), (encoder_load, decoder_load)):
            for param, param_load in zip(module.parameters(), module_load.parameters()):
                param.data = param_load.data        
    else:
        for module, module_load in zip((encoder, decoder, predictor), (encoder_load, decoder_load, predictor_load)):
            for param, param_load in zip(module.parameters(), module_load.parameters()):
                param.data = param_load.data

    total_loss = 0
    total_steps = 0

    for filename in os.listdir('./game_dataset_test_data/'):

        if not load_id == 'slither-AE-LSTM-final-bl':
            hidden = predictor.zero_hidden()

        with open('./game_dataset_test_data/' + filename, 'rb') as gamescreen_file:
            
            if load_id in ['slither-AE-LSTM-final-e2e', 'slither-AE-LSTM-final-n']:

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                # compute
                codes = encoder(gamescreens)
                pred_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, pred_code = predictor(hidden, code)
                    pred_code_list.append(pred_code)
                pred_codes = tch.cat(pred_code_list, dim=0)
#                 decoded_pred_gamescreens = decoder(pred_codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += criterion(pred_codes[:-1], codes[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-bl',]:

#                 gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

#                 gamescreens = np.stack(gamescreen_list)

#                 gamescreens = np2var(gamescreens).float() / 255
#                 gamescreens = gamescreens.unsqueeze(1)

#                 # compute
#                 codes = encoder(gamescreens)
#                 decoded_pred_gamescreens = decoder(codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += 0
                
            elif load_id in ['slither-AE-LSTM-final-e2e-c2d',]:

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                # compute
                codes = encoder(gamescreens)
                delta_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, delta_code = predictor(hidden, code)
                    delta_code_list.append(delta_code)
                delta_codes = tch.cat(delta_code_list, dim=0)
                pred_codes = delta_codes + codes
#                 decoded_pred_gamescreens = decoder(delta_codes + codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += criterion(pred_codes[:-1], codes[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-e2e-d2d',]:
                
                init_code = np2var(np.zeros((1, 512), dtype=np.float32))

                gamescreen_list, _ = zip(*pickle.load(gamescreen_file))

                gamescreens = np.stack(gamescreen_list)

                gamescreens = np2var(gamescreens).float() / 255
                gamescreens = gamescreens.unsqueeze(1)

                prev_pred_code = init_code
                # compute
                codes = encoder(gamescreens)
                pred_code_list = []
                for code in codes:
                    code = code.unsqueeze(0)
                    hidden, delta_code = predictor(hidden, code)
                    pred_code = code + delta_code
                    pred_code_list.append(pred_code)
                    prev_pred_code = pred_code
                pred_codes = tch.cat(pred_code_list, dim=0)
#                 decoded_pred_gamescreens = decoder(pred_codes)

                episode_steps = len(gamescreen_list) - 1
                total_steps += episode_steps
                total_loss += criterion(pred_codes[:-1], codes[1:]).data[0] * episode_steps
                
            elif load_id in ['slither-AE-LSTM-final-e2e-sim', 'slither-AE-LSTM-final-e2e-aim']:
                
                prev_pred_gamescreen = None
                prev_pred_code = None
                
                gamescreen_list = pickle.load(gamescreen_file)
        
                for gamescreen, _ in gamescreen_list:

                    gamescreen = np2var(gamescreen).float() / 255
                    gamescreen = gamescreen.unsqueeze(0).unsqueeze(0)
                        
                    if prev_pred_gamescreen is None:
                        prev_pred_gamescreen = 0 * gamescreen

                    # compute
                    if load_id == 'slither-AE-LSTM-final-e2e-aim':
                        code = encoder(gamescreen + prev_pred_gamescreen)
                    elif load_id == 'slither-AE-LSTM-final-e2e-sim':
                        code = encoder(prev_pred_gamescreen - gamescreen)
                            
                    if prev_pred_code is not None:
                        total_steps += 1
                        total_loss += criterion(prev_pred_code, code).data[0]
                        
                    hidden, pred_code = predictor(hidden, code)
                    decoded_pred_gamescreen = decoder(pred_code)
                    prev_pred_gamescreen = decoded_pred_gamescreen
                    prev_pred_code = pred_code
                
            else:
                print('else???!!!!!')
                total_steps = 1

    code_error_data[load_id] = (total_loss / total_steps)

In [38]:
code_error_data

{'slither-AE-LSTM-final-bl': 0.0,
 'slither-AE-LSTM-final-e2e': 2.3281322056942195,
 'slither-AE-LSTM-final-e2e-aim': 5.342667237906245,
 'slither-AE-LSTM-final-e2e-c2d': 1.0177571401219678,
 'slither-AE-LSTM-final-e2e-d2d': 1.9066890256437754,
 'slither-AE-LSTM-final-e2e-sim': 3.097211682314653,
 'slither-AE-LSTM-final-n': 0.11728683915653515}

In [39]:
pickle.dump(code_error_data, open("./final_figures/code_errors.pkl", 'wb'))