In [None]:
cd ..

In [None]:
import torch.optim as optim
import torch.nn as nn

from tree_to_sequence.program_datasets import *
from tree_to_sequence.tree_to_sequence import TreeToSequence
from tree_to_sequence.tree_to_sequence_attention import TreeToSequenceAttention
from tree_to_sequence.tree_encoder import TreeEncoder
from tree_to_sequence.sequence_encoder import SequenceEncoder
from tree_to_sequence.multilayer_lstm_cell import MultilayerLSTMCell
from neural_nets_library import training

In [None]:
num_vars = 10
num_ints = 11

In [None]:
input_as_seq = False
output_as_seq = True
one_hot = False
binarize = True
eos_token = True
eos_bonus = 1 if eos_token else 0
long_base_case = True

In [None]:
for_lambda_dset = ForLambdaDataset("ANC/VeryHard-arbitraryForList.json", binarize=binarize, 
                                   eos_token=eos_token, one_hot=one_hot, 
                                   long_base_case=long_base_case, input_as_seq=input_as_seq,
                                   output_as_seq=output_as_seq)

In [None]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform_(param, -stdev, stdev)

In [None]:
embedding_size = 256
hidden_size = 256
nclass = num_vars + num_ints + len(lambda_ops.keys())
num_layers = 1
attention = True
alignment_size = 50
align_type = 1
encoder_input_size = num_vars + num_ints + len(for_ops.keys())

if input_as_seq:
    encoder = SequenceEncoder(encoder_input_size, hidden_size, num_layers, attention=attention, one_hot=one_hot)
else:
    encoder = TreeEncoder(encoder_input_size, hidden_size, num_layers, [1, 2], attention=attention, one_hot=one_hot)
    
if attention:
    decoder = MultilayerLSTMCell(embedding_size + hidden_size, hidden_size, num_layers)
    program_model = TreeToSequenceAttention(encoder, decoder, hidden_size, nclass, embedding_size, alignment_size=alignment_size, align_type=align_type)
else:
    decoder = MultilayerLSTMCell(embedding_size, hidden_size, num_layers)
    program_model = TreeToSequence(encoder, decoder, hidden_size, nclass, embedding_size)
    
reset_all_parameters_uniform(program_model, 0.1)
encoder.initialize_forget_bias(3)
decoder.initialize_forget_bias(3)

In [None]:
program_model = program_model.cuda()

In [None]:
def program_accuracy(prediction, target):
    return 1 if list(target.data)[:-1] == prediction else 0

def token_accuracy(prediction, target):
    pass

optimizer = optim.Adam(program_model.parameters(), lr=0.005)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=100, factor=0.8)

In [None]:
print(nclass)

In [None]:
for _, target_prog in for_lambda_dset:
    print(target_prog[-1])
    break

In [None]:
model, train_loss, validation_loss = training.train_model_anc(program_model, for_lambda_dset, optimizer,
                                                              lr_scheduler=lr_scheduler, num_epochs=3, print_every=99999999,
                                                              validation_criterion=program_accuracy, batch_size=100,
                                                              use_cuda=True, plateau_lr=True, plot_every=100)

In [None]:
# train_plot_losses_old = train_plot_losses_old + train_plot_losses_new
# validation_plot_losses_old = validation_plot_losses_old + validation_plot_losses_new

import csv

torch.save(program_model, "t2s-normal-vars-model")
with open("t2s-normal-vars-train.txt", "w") as output:
    writer = csv.writer(output, lineterminator='\n')
    for val in train_plot_losses:
        writer.writerow([val]) 
with open("t2s-normal-vars-validation.txt", "w") as output:
    writer = csv.writer(output, lineterminator='\n')
    for val in validation_plot_losses:
        writer.writerow([val]) 