In [None]:
cd ..

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import math
import matplotlib.pyplot as plt

from neural_nets_library import training
from tree_to_sequence.tree_encoder import TreeEncoder
from tree_to_sequence.tree_decoder import TreeDecoder
from tree_to_sequence.program_datasets import IdentityTreeToTreeDataset
from tree_to_sequence.program_datasets import ForLambdaDataset
from tree_to_sequence.program_datasets import Const0

from tree_to_sequence.tree_to_tree import TreeToTree

In [None]:
num_vars = 10
num_ints = 11

for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}

for_ops = {"<" + k.upper() + ">": v for k,v in for_ops.items()}

lambda_calculus_ops = {
                "<VARIABLE>": 0,
                "<ABSTRACTION>": 1,
                "<NUMBER>": 2,
                "<BOOLEAN>": 3,
                "<NIL>": 4,
                "<IF>": 5,
                "<CONS>": 6,
                "<MATCH>": 7,
                "<UNARYOPER>": 8,
                "<BINARYOPER>": 9,
                "<LET>": 10,
                "<LETREC>": 11,
                "<TRUE>": 12,
                "<FALSE>": 13,
                "<TINT>": 14,
                "<TBOOL>": 15,
                "<TINTLIST>": 16,
                "<TFUN>": 17,
                "<ARGUMENT>": 18,
                "<NEG>": 19,
                "<NOT>": 20,
                "<PLUS>": 21,
                "<MINUS>": 22,
                "<TIMES>": 23,
                "<DIVIDE>": 24,
                "<AND>": 25,
                "<OR>": 26,
                "<EQUAL>": 27,
                "<LESS>": 28,
                "<APPLICATION>": 29,
                "<HEAD>": 30,
                "<TAIL>": 31
            }

In [None]:
use_embedding = True
binarize = True #TODO - make true
input_eos_token = True #TODO: maybe False
long_base_case = True

In [None]:
identity_dset = IdentityTreeToTreeDataset("ANC/Easy-arbitraryForListWithOutput.json", 
                                      binarize=binarize, is_lambda_calculus=False, 
                                      num_ints=num_ints, num_vars=num_vars,
                                      use_embedding=use_embedding, cuda=False)
print(len(identity_dset))
# for_lambda_dset = ForLambdaDataset("ANC/Easy-arbitraryForList.json", binarize=binarize, 
#                                    input_eos_token=input_eos_token, use_embedding=use_embedding, 
#                                    long_base_case=long_base_case, input_as_seq=False)
# print(len(for_lambda_dset))

In [None]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform(param, -stdev, stdev)

In [None]:
embedding_size = 256
hidden_size = 256
num_layers = 1
alignment_size = 50
align_type = 1
encoder_input_size = num_vars + num_ints + len(for_ops) + 1 # TODO: plus 1?
encoder = TreeEncoder(encoder_input_size, hidden_size, num_layers, [1, 2], attention=True, use_embedding=use_embedding, dropout=0.5)
nclass = num_vars + num_ints + len(for_ops) # TODO: plus 1?
plot_every = 100
max_num_children = 2 #TODO - should be 2

decoder = TreeDecoder(embedding_size, hidden_size, max_num_children, nclass, align_type=align_type)
program_model = TreeToTree(encoder, decoder, hidden_size, embedding_size, 
                           alignment_size=alignment_size, align_type=align_type)
    
reset_all_parameters_uniform(program_model, 0.1)
encoder.initialize_forget_bias(3)
decoder.initialize_forget_bias(3)

In [None]:
#program_model = program_model.cuda()

In [None]:
optimizer = optim.Adam(program_model.parameters(), lr=0.005)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=1000, factor=0.9999)

In [None]:
def count_matches(prediction, target):
    matches = 0
    if int(prediction.value) == int(target.value):
        matches = matches + 1
    num_children = min(len(target.children), len(prediction.children))
    for i in range(num_children):
        matches = matches + count_matches(prediction.children[i], target.children[num_children - i - 1])
    size_diff = math.fabs(len(target.children) - len(prediction.children))
    return matches - size_diff

def validation_criterion(prediction, target):
    return 1.0 * count_matches(prediction, target) / target.size()

def all_matches(prediction, target):
    if not int(prediction.value) == int(target.value):
        return 0
    if not len(prediction.children) == len(target.children):
        return 0
    n = len(prediction.children)
    for i in range(n):
        # Compare in reverse order b/c we generate new nodes right-to-left
        if all_matches(prediction.children[i], target.children[n - i - 1]) == 0:
            return 0
    return 1

def validation_criterion(prediction, target):
    return all_matches(prediction, target)

In [None]:
# %%pixie_debugger

best_model, train_plot_losses, validation_plot_losses = training.train_model_tree_to_tree(program_model, identity_dset, 
                                 optimizer, lr_scheduler=lr_scheduler, num_epochs=1, plot_every=plot_every,
                                 batch_size=100, plateau_lr=True, print_every=200, validation_criterion=validation_criterion,
                                 use_cuda=False) #TODO: originally epochs=3





In [None]:
#%debug
plt.plot([x * plot_every for x in range(len(train_plot_losses))], train_plot_losses)
plt.show()

plt.plot([x * plot_every for x in range(len(validation_plot_losses))], validation_plot_losses)
plt.show()

In [None]:
# import pixiedust