In [None]:
cd ..

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import math
import matplotlib.pyplot as plt

from neural_nets_library import training
from tree_to_sequence.tree_encoder import TreeEncoder
from tree_to_sequence.tree_decoder import TreeDecoder
from tree_to_sequence.grammar_tree_decoder import GrammarTreeDecoder
from tree_to_sequence.program_datasets import IdentityTreeToTreeDataset
from tree_to_sequence.program_datasets import ForLambdaDataset
from tree_to_sequence.program_datasets import Const0
from tree_to_sequence.translating_trees import pretty_print_tree
from tree_to_sequence.translating_trees import parent_to_category_FOR
from tree_to_sequence.translating_trees import category_to_child_FOR
from tree_to_sequence.translating_trees import category_to_child_LAMBDA
from tree_to_sequence.translating_trees import parent_to_category_LAMBDA
from tree_to_sequence.translating_trees import ForGrammar
from tree_to_sequence.translating_trees import For
from tree_to_sequence.translating_trees import LambdaGrammar
from tree_to_sequence.translating_trees import Lambda



from tree_to_sequence.tree_to_tree import TreeToTree

In [None]:
num_vars = 10
num_ints = 11

for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}

for_ops = {"<" + k.upper() + ">": v for k,v in for_ops.items()}

lambda_calculus_ops = {
                "<VARIABLE>": 0,
                "<ABSTRACTION>": 1,
                "<NUMBER>": 2,
                "<BOOLEAN>": 3,
                "<NIL>": 4,
                "<IF>": 5,
                "<CONS>": 6,
                "<MATCH>": 7,
                "<UNARYOPER>": 8,
                "<BINARYOPER>": 9,
                "<LET>": 10,
                "<LETREC>": 11,
                "<TRUE>": 12,
                "<FALSE>": 13,
                "<TINT>": 14,
                "<TBOOL>": 15,
                "<TINTLIST>": 16,
                "<TFUN>": 17,
                "<ARGUMENT>": 18,
                "<NEG>": 19,
                "<NOT>": 20,
                "<PLUS>": 21,
                "<MINUS>": 22,
                "<TIMES>": 23,
                "<DIVIDE>": 24,
                "<AND>": 25,
                "<OR>": 26,
                "<EQUAL>": 27,
                "<LESS>": 28,
                "<APPLICATION>": 29,
                "<HEAD>": 30,
                "<TAIL>": 31
            }

lambda_ops = {
            "<VAR>": 0,
            "<CONST>": 1,
            "<PLUS>": 2,
            "<MINUS>": 3,
            "<EQUAL>": 4,
            "<LE>": 5,
            "<GE>": 6,
            "<IF>": 7,
            "<LET>": 8,
            "<UNIT>": 9,
            "<LETREC>": 10,
            "<APP>": 11,
        }

In [None]:
use_embedding = True
binarize = True
eos_tokens = True
long_base_case = True

In [None]:
# identity_dset = IdentityTreeToTreeDataset("ANC/Easy-arbitraryForListWithOutput.json", 
#                                       binarize=binarize, is_lambda_calculus=False, 
#                                       num_ints=num_ints, num_vars=num_vars,
#                                       use_embedding=use_embedding, cuda=False)
# print(len(identity_dset))
for_lambda_dset = ForLambdaDataset("ANC/VeryHard-arbitraryForList.json", binarize=binarize, 
                                   eos_tokens=eos_tokens, use_embedding=use_embedding, 
                                   long_base_case=long_base_case, input_as_seq=False, output_as_seq=False)
print(len(for_lambda_dset))
# Calculate max-size so we're sure it's good enough
max_size = max([pair[1].size() for pair in for_lambda_dset])
print(max_size)

In [None]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform(param, -stdev, stdev)

In [None]:
embedding_size = 256 #... 256 is from the paper, but 100 is WAY faster
hidden_size = 256
num_layers = 1
alignment_size = 50
align_type = 1
encoder_input_size = num_vars + num_ints + len(for_ops) + 1
encoder = TreeEncoder(encoder_input_size, hidden_size, num_layers, [1, 2], attention=True, use_embedding=use_embedding, dropout=0.5)
nclass = num_vars + num_ints + len(lambda_ops)
plot_every = 100
max_num_children = 2

# decoder = TreeDecoder(embedding_size, hidden_size, max_num_children)
decoder = GrammarTreeDecoder(embedding_size, hidden_size, max_num_children, 
                             parent_to_category_LAMBDA, len(LambdaGrammar), category_to_child_LAMBDA,
                            num_vars, num_ints, binarized=binarize)

program_model = TreeToTree(encoder, decoder, hidden_size, embedding_size, nclass=nclass,
                           alignment_size=alignment_size, align_type=align_type, max_size=max_size)
    
reset_all_parameters_uniform(program_model, 0.1)
encoder.initialize_forget_bias(3)
decoder.initialize_forget_bias(3)

In [None]:
program_model = program_model.cuda()

In [None]:
optimizer = optim.Adam(program_model.parameters(), lr=0.005)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=500, factor=0.8)



In [None]:
# Token accuracy of the program
def count_matches(prediction, target):
    matches = 0
    if int(prediction.value) == int(target.value):
        matches = matches + 1
    for i in range(min(len(target.children), len(prediction.children))):
        matches = matches + count_matches(prediction.children[i], target.children[i])
    size_diff = math.fabs(len(target.children) - len(prediction.children))
    return matches - size_diff


# Program accuracy (1 if completely correct, 0 otherwise)
def all_matches(prediction, target):
    # Nodes must have the same value
    if not int(prediction.value) == int(target.value):
        return 0
    # ... and the same number of kids
    if not len(prediction.children) == len(target.children):
        return 0
    n = len(prediction.children)
    # ... and match perfectly for each kid
    for i in range(n):
        if all_matches(prediction.children[i], target.children[i]) == 0:
            return 0
    return 1

# Calc validation accuracy (this could either be program or token accuracy)
def validation_criterion(prediction, target):
    return all_matches(prediction, target)

In [None]:

best_model, train_plot_losses, validation_plot_losses = training.train_model_tree_to_tree(program_model, for_lambda_dset, 
                                 optimizer, lr_scheduler=None, num_epochs=1, plot_every=plot_every,
                                 batch_size=100, print_every=200, validation_criterion=validation_criterion,
                                 use_cuda=False) #TODO:  epochs=3



In [None]:
train_plot_losses_old = train_plot_losses
validation_plot_losses_old = validation_plot_losses

In [None]:
#%debug
plt.plot([x * plot_every for x in range(len(train_plot_losses_old +train_plot_losses))], train_plot_losses_old + train_plot_losses)
plt.show()

plt.plot([x * plot_every for x in range(len(validation_plot_losses_old + validation_plot_losses))], validation_plot_losses_old + validation_plot_losses)
plt.show()

In [None]:
n = num_ints + num_vars

# Check whether a node is syntactically valid, given its parent and index
# Then recursively do it for all the node's children
def check_valid(node, parent, child_index):
    category = parent_to_category_LAMBDA(parent, child_index, num_vars, num_ints)
    possible_outputs = category_to_child_LAMBDA(category, num_vars, num_ints)
    if not int(node.value) in possible_outputs:
        print("parent", parent, "child_index", child_index)
        print("ERROR", int(node.value), category)
        return False
    if (len(node.children) > 0):
        child1 = check_valid(node.children[0], int(node.value), 0)
        if not child1:
            return False
        child2 = check_valid(node.children[1], parent, child_index + 1)
        if not child2:
            return False
    return True

# Check all the programs in a dataset for syntactic accuracy
# (this is a debugging function used to double check the accuracy of your grammar)
def check_all():
    i = 0
    # Check grammar is right
    for prog in for_lambda_dset:
        correct = check_valid(prog[1], None, 0)
        if correct is False:
            print(i)
            pretty_print_tree(prog[1])
            return
        i += 1
        
        
        
check_all() #kangaroo

In [None]:
from tree_to_sequence.translating_trees import *
# pretty_print_tree(for_lambda_dset[1][1])
import json
progsjson = json.load(open("ANC/VeryHard-arbitraryForList.json"))
print(progsjson[3])
for_prog = make_tree(progsjson[3], long_base_case=long_base_case)
pretty_print_tree(for_prog)
lambda_prog = translate_from_for(for_prog)
pretty_print_tree(lambda_prog)
binarized_lambda = binarize_tree(lambda_prog, add_eos=True)
pretty_print_tree(binarized_lambda)

def check_weirdness(prog):
    for i in range(len(prog.children)):
        if prog.children[i].value[0] == "*":
            print("FOUND A STAR at index ", i)
        check_weirdness(prog.children[i])
        
            
check_weirdness(binarized_lambda)

In [None]:
def print_stuff(node):
    if int(node.value) == 29:
        print("got one!")
        first_child = node.children[0]
        second_child = first_child.children[1]
        third_child = second_child.children[1]
        fourth_child = third_child.children[1]
        print(int(first_child.value), int(second_child.value), int(third_child.value), int(fourth_child.value))
    for child in node.children:
        print_stuff(child)
        
        
print(int(Lambda.BLANK))
print(category_to_child_LAMBDA(LambdaGrammar.VARUNITBLANK, num_vars, num_ints))
print(num_vars)
print(num_ints)


# for prog in for_lambda_dset[10000:15000]:
#     l = prog[1]
#     print_stuff(l)
    
    #IDEA - blank counts as. a var
    # Letrec first child is 34,12/13,28,32 = 12, a1/a2, 7, 11 = BLANK, a1/a2, IF, APP
    # lec children: 33/21, 21/22/23/24/28/29/31,28/2930/32 = 12/0, 0,1,2,3,7,8,10,  7,8,10,12 = BLANK/VAR, VAR,CONST<PLUS<MINU<IFLETLETREC . .   if,let,letrec,blank

In [None]:
# torch.save(program_model, "/tree_to_sequence/identity_grammar_t2t")
# import importlib
# importlib.reload(ForLambdaDataset)

# Check how many vars actually used
def count_vars(prog):
    x = int(prog.value)
#     if x in range(num_ints, num_ints + num_vars):
    if x in var_counter:
        var_counter[x] = var_counter[x] + 1
    else:
        var_counter[x] = 1
    for child in prog.children:
        count_vars(child)


var_counter = {}
for prog in for_lambda_dset:
    count_vars(prog[0])

for key in var_counter.keys():
    print("VAR", key, "COUNT", var_counter[key])




In [None]:
var_counter = {}
for prog in for_lambda_dset:
    count_vars(prog[1])

for key in var_counter.keys():
    print("VAR", key, "COUNT", var_counter[key])




In [None]:
# torch.save(program_model, "saved_models/t2t_grammar")

print("VAR SHOULD BE: ", Lambda.CONST + num_ints + num_vars)

def check_weirdness(prog, parent):
    x = int(prog.value)
    if x in range(num_ints) and not parent == Lambda.CONST + num_ints + num_vars:
        return True
    weirdness = False
    for child in prog.children:
        weirdness = weirdness or check_weirdness(child, x)
    return weirdness

for prog in for_lambda_dset:
    if check_weirdness(prog[1], None):
        pretty_print_tree(prog[1])

In [None]:
# Check binarize
# Get a single program
# Print out the before and after
from pptree import *

class Employee:

    def __init__(self, fullname, function, head=None):
        self.fullname = fullname
        self.function = function
        self.team = []
        if head:
            head.team.append(self)

    def __str__(self):
        return self.function

jean = Employee("Jean Dupont", "CEO")
isabelle = Employee("Isabelle Leblanc", "Sales", jean)
enzo = Employee("Enzo Riviera", "Technology", jean)
lola = Employee("Lola Monet", "RH", jean)
kevin = Employee("Kevin Perez", "Developer", enzo)
lydia = Employee("Lydia Petit", "Tester", enzo)

print([employee.fullname for employee in jean.team])
print_tree(jean, "team", "fullname")