In [None]:
cd ..

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import math
import matplotlib.pyplot as plt

from neural_nets_library import training
from tree_to_sequence.tree_encoder import TreeEncoder
from tree_to_sequence.tree_decoder import TreeDecoder
from tree_to_sequence.program_datasets import IdentityTreeToTreeDataset
from tree_to_sequence.program_datasets import ForLambdaDataset
from tree_to_sequence.program_datasets import Const0
from tree_to_sequence.translating_trees import pretty_print_tree


from tree_to_sequence.tree_to_tree import TreeToTree

In [None]:
num_vars = 10
num_ints = 11

for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}

for_ops = {"<" + k.upper() + ">": v for k,v in for_ops.items()}

lambda_calculus_ops = {
                "<VARIABLE>": 0,
                "<ABSTRACTION>": 1,
                "<NUMBER>": 2,
                "<BOOLEAN>": 3,
                "<NIL>": 4,
                "<IF>": 5,
                "<CONS>": 6,
                "<MATCH>": 7,
                "<UNARYOPER>": 8,
                "<BINARYOPER>": 9,
                "<LET>": 10,
                "<LETREC>": 11,
                "<TRUE>": 12,
                "<FALSE>": 13,
                "<TINT>": 14,
                "<TBOOL>": 15,
                "<TINTLIST>": 16,
                "<TFUN>": 17,
                "<ARGUMENT>": 18,
                "<NEG>": 19,
                "<NOT>": 20,
                "<PLUS>": 21,
                "<MINUS>": 22,
                "<TIMES>": 23,
                "<DIVIDE>": 24,
                "<AND>": 25,
                "<OR>": 26,
                "<EQUAL>": 27,
                "<LESS>": 28,
                "<APPLICATION>": 29,
                "<HEAD>": 30,
                "<TAIL>": 31
            }

lambda_ops = {
            "<VAR>": 0,
            "<CONST>": 1,
            "<PLUS>": 2,
            "<MINUS>": 3,
            "<EQUAL>": 4,
            "<LE>": 5,
            "<GE>": 6,
            "<IF>": 7,
            "<LET>": 8,
            "<UNIT>": 9,
            "<LETREC>": 10,
            "<APP>": 11,
        }

In [None]:
use_embedding = True
binarize = True
eos_tokens = True
long_base_case = True

In [None]:
# identity_dset = IdentityTreeToTreeDataset("ANC/Easy-arbitraryForListWithOutput.json", 
#                                       binarize=binarize, is_lambda_calculus=False, 
#                                       num_ints=num_ints, num_vars=num_vars,
#                                       use_embedding=use_embedding, cuda=False)
# print(len(identity_dset))

for_lambda_dset = ForLambdaDataset("ANC/Easy-arbitraryForList.json", binarize=binarize, 
                                   eos_tokens=eos_tokens, use_embedding=use_embedding, 
                                   long_base_case=long_base_case, input_as_seq=False, output_as_seq=False)
print(len(for_lambda_dset))

# Calculate max-size so we're sure it's good enough
max_size = max([pair[1].size() for pair in for_lambda_dset])
print(max_size)




In [None]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform(param, -stdev, stdev)

In [None]:
embedding_size = 100#256
hidden_size = 100#256
num_layers = 1
alignment_size = 50
align_type = 1
encoder_input_size = num_vars + num_ints + len(for_ops) + 1
encoder = TreeEncoder(encoder_input_size, hidden_size, num_layers, [1, 2], attention=True, use_embedding=use_embedding, dropout=0.5)
nclass = 35#num_vars + num_ints + len(lambda_ops)
plot_every = 100
max_num_children = 2

decoder = TreeDecoder(embedding_size, hidden_size, max_num_children, nclass=nclass, binarized=binarize)
program_model = TreeToTree(encoder, decoder, hidden_size, embedding_size, nclass=nclass,
                           alignment_size=alignment_size, align_type=align_type,
                          max_size=max_size)

reset_all_parameters_uniform(program_model, 0.1)
encoder.initialize_forget_bias(3)
decoder.initialize_forget_bias(3)


In [None]:
#program_model = program_model.cuda()

In [None]:
optimizer = optim.Adam(program_model.parameters(), lr=0.01)
# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=1000, factor=0.9999)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=500, factor=0.8)

In [None]:
# Token accuracy of the program
def count_matches(prediction, target):
    matches = 0
    if int(prediction.value) == int(target.value):
        matches = matches + 1
    for i in range(min(len(target.children), len(prediction.children))):
        matches = matches + count_matches(prediction.children[i], target.children[i])
    size_diff = math.fabs(len(target.children) - len(prediction.children))
    return matches - size_diff


# Program accuracy (1 if completely correct, 0 otherwise)
def all_matches(prediction, target):
    # Nodes must have the same value
    if not int(prediction.value) == int(target.value):
        return 0
    # ... and the same number of kids
    if not len(prediction.children) == len(target.children):
        return 0
    n = len(prediction.children)
    # ... and match perfectly for each kid
    for i in range(n):
        if all_matches(prediction.children[i], target.children[i]) == 0:
            return 0
    return 1

# Calc validation accuracy (this could either be program or token accuracy)
def validation_criterion(prediction, target):
    return all_matches(prediction, target)

In [None]:
# for param_group in optimizer.param_groups:
#     print(param_group["lr"])
print(nclass)
print(program_model.decoder)
print(decoder)

In [None]:
# %%pixie_debugger

best_model, train_plot_losses, validation_plot_losses = training.train_model_tree_to_tree(program_model, for_lambda_dset, 
                                 optimizer, lr_scheduler=lr_scheduler, num_epochs=1, plot_every=plot_every,
                                 batch_size=100, plateau_lr=True, print_every=200, validation_criterion=validation_criterion,
                                 use_cuda=False) #TODO: originally epochs=3





In [None]:
#%debug
plt.plot([x * plot_every for x in range(len(train_plot_losses))], train_plot_losses)
plt.show()

plt.plot([x * plot_every for x in range(len(validation_plot_losses))], validation_plot_losses)
plt.show()

In [None]:
# import pixiedust
train_plot_losses_old = train_plot_losses_old + train_plot_losses
validation_plot_losses_old = validation_plot_losses_old + validation_plot_losses

In [None]:
# from tree_to_sequence.translating_trees import pretty_print_attention_t2t

def pseudo_forward_train(input_tree, target_tree, teacher_forcing=True):
    """
    Generate predictions for an output tree given an input tree, then calculate the loss.
    """
    # Encode tree
    annotations, decoder_hiddens, decoder_cell_states = program_model.encoder(input_tree)

    if program_model.align_type <= 1:
        attention_hidden_values = program_model.attention_hidden(annotations)
    else:
        attention_hidden_values = annotations

    # number to accumulate loss
    loss = 0

    # Tuple: (hidden_state, cell_state, desired_output, parent_value, child_index)
    unexpanded = [(decoder_hiddens, decoder_cell_states, target_tree, None, 0)]

    # Any line relating to this list was put in just for debugging purposes
    all_attention_probs = [] 

    # while stack isn't empty:
    while (len(unexpanded)) > 0:
        # Pop last item
        decoder_hiddens, decoder_cell_states, targetNode, parent_val, child_index = unexpanded.pop()
        print("GENERATING", targetNode.value)
        # Use attention and past hidden state to generate scores
        attention_logits = program_model.attention_logits(attention_hidden_values, decoder_hiddens)
        attention_probs = program_model.softmax(attention_logits) # number_of_nodes x 1
        all_attention_probs.append(attention_probs) #TODO - take out!
        print("ATTENTION PROBS", attention_probs)
        context_vec = (attention_probs * annotations).sum(0).unsqueeze(0) # 1 x hidden_size
        et = program_model.tanh(program_model.attention_presoftmax(torch.cat((decoder_hiddens, context_vec), dim=1))) # 1 x hidden_size
        # Calculate loss
        
        log_odds = program_model.output_log_odds(et)
        print("LOG ODDS", log_odds)
        loss = loss + program_model.loss_func(log_odds, targetNode.value)
        
        
        # If we have an EOS, there are no children to generate
        if int(targetNode.value) == program_model.EOS_value:
            continue

        # Teacher forcing means we use the correct value (not the predicted value) of a node to generate its children
        if teacher_forcing:
            next_input = targetNode.value
        else:
            _, next_input = log_odds.topk(1)

        decoder_input = torch.cat((program_model.embedding(next_input), et), 1)

        for i, child in enumerate(targetNode.children):
            # Parent node of a node's children is that node
            parent = next_input
            new_child_index = i
            # ... unless you're the right child
            if program_model.binarized and i == 1:
                parent = parent_val
                new_child_index = child_index + 1

            # Get hidden state and cell state which will be used to generate this node's children
            child_hiddens, child_cell_states = program_model.decoder(new_child_index, decoder_input, decoder_hiddens, decoder_cell_states)
            unexpanded.append((child_hiddens, child_cell_states, child, parent, new_child_index))

    # Uncomment if you want to see where the attention is focusing as each node is generated
#         if program_model.i % 200 == 0:
#             pretty_print_attention_t2t(all_attention_probs, input_tree, target_tree)
    program_model.i += 1
    return loss

In [None]:
for prog in for_lambda_dset[:5]:
    prediction = program_model.forward_prediction(prog[0])
    if not all_matches(prediction, prog[1]):
        print("expected")
        pretty_print_tree(prog[1])
        print("got")
        pretty_print_tree(prediction)
        print("stepping through...")
        pseudo_forward_train(prog[0], prog[1])

In [None]:
# print(program_model.parameters)
program_model.decoder.forward