In [None]:
cd .. 

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import matplotlib.pyplot as plt

from neural_nets_library import training
from tree_to_sequence.tree_encoder import TreeEncoder
from tree_to_sequence.tree_decoder import TreeDecoder
from tree_to_sequence.grammar_tree_decoder import GrammarTreeDecoder
from tree_to_sequence.program_datasets import *
from tree_to_sequence.translating_trees import ( parent_to_category_FOR, category_to_child_FOR,
                                                 category_to_child_LAMBDA, 
                                                 parent_to_category_LAMBDA, ForGrammar, For,
                                                 LambdaGrammar, Lambda)
from tree_to_sequence.tree_to_tree_attention import TreeToTreeAttention
from functools import partial

In [None]:
torch.cuda.set_device(1)

In [None]:
use_cuda = False
num_vars = 10
num_ints = 11
one_hot = False
binarize_input = True
binarize_output = False
eos_token = False
long_base_case = True
input_as_seq = False
output_as_seq = False

In [None]:
# for_lambda_dset = ForLambdaDataset("ANC/validation_For.json", binarize_input=binarize_input, 
#                                    binarize_output=binarize_output, eos_token=eos_token, one_hot=one_hot, 
#                                    num_ints=num_ints, num_vars=num_vars,
#                                    long_base_case=long_base_case, input_as_seq=input_as_seq, 
#                                    output_as_seq=output_as_seq)

js_cs_dset = JsCoffeeDataset("ANC/MainProgramDatasets/CoffeeJavascript/test_CS.json", "ANC/MainProgramDatasets/CoffeeJavascript/test_JS.json", 
                                binarize_input=binarize_input, 
                               binarize_output=binarize_output, eos_token=eos_token, one_hot=one_hot, 
                               num_ints=num_ints, num_vars=num_vars,
                               long_base_case=long_base_case, input_as_seq=input_as_seq, 
                               output_as_seq=output_as_seq)


max_size = max([x[1].size() for x in js_cs_dset])

In [None]:
tokens = set()

def list_tokens(node):
    if not node.value is None:
        tokens.add(int(node.value))
    for child in node.children:
        list_tokens(child)

# for prog in js_cs_dset[:100]:
#     list_tokens(prog[0])
    
# print(sorted(tokens))
# print(encoder_input_size)




def find_none(tree):
    if tree.value is None:
        return True
    for child in tree.children:
        if find_none(child):
            return True
    return False
    
    
def check_all():
    i = 0
    # Check grammar is right
    for prog in js_cs_dset:
        
        if find_none(prog[0]):
            print(i)
            pretty_print_tree(prog[0])
            
            return
        i += 1
        
# check_all()
# print("done")

javascript_path = "ANC/MainProgramDatasets/CoffeeJavascript/test_JS.json"
javascript_progs = [make_tree_javascript(prog, long_base_case=long_base_case) for prog in json.load(open(javascript_path))[:2]]

pretty_print_tree(javascript_progs[0])

pretty_print_tree(js_cs_dset[0][0])



In [None]:
print(len(js_cs_dset))

In [None]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform_(param, -stdev, stdev)

In [None]:
embedding_size = 256 #... 256 is from the paper, but 100 is WAY faster
hidden_size = 256
num_layers = 1
alignment_size = 50
align_type = 1
eos_bonus = 1 if eos_token else 0
encoder_input_size = num_vars + num_ints + len(javascript_ops) + eos_bonus
encoder = TreeEncoder(encoder_input_size, hidden_size, num_layers, [1, 2, 3, 4, 5], attention=True, one_hot=one_hot)
nclass = num_vars + num_ints + len(coffee_ops)
plot_every = 100
num_categories = len(CoffeeGrammar)
num_possible_parents = len(Coffee)

decoder = gr(embedding_size, hidden_size, num_categories, num_possible_parents, 
                             partial(parent_to_category_coffee, num_vars, num_ints), 
                             partial(category_to_child_coffee, num_vars, num_ints), 
                             share_linear=True, share_lstm_cell=True, num_ints_vars=num_ints + num_vars)
program_model = TreeToTreeAttention(encoder, decoder, hidden_size, embedding_size, nclass=nclass, root_value=nclass,
                                    alignment_size=alignment_size, align_type=align_type)
    
reset_all_parameters_uniform(program_model, 0.1)
encoder.initialize_forget_bias(3)
decoder.initialize_forget_bias(3)

In [None]:
if use_cuda:
    program_model = program_model.cuda()

In [None]:
optimizer = optim.Adam(program_model.parameters(), lr=0.005)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=500, factor=0.8)

In [None]:
# Counts the number of matches between the prediction and target.
def count_matches(prediction, target):
    matches = 0
    if int(prediction.value) == int(target.value):
        matches += 1
    for i in range(min(len(target.children), len(prediction.children))):
        matches += count_matches(prediction.children[i], target.children[i])
    return matches

# Program accuracy (1 if completely correct, 0 otherwise)
def program_accuracy(prediction, target):
    if prediction.size() == count_matches(prediction, target) and \
       prediction.size() == target.size():
        return 1
    else:
        return 0

# Calculate validation accuracy (this could either be program or token accuracy)
def validation_criterion(prediction, target):
    return program_accuracy(prediction, target)

In [None]:
best_model, train_plot_losses, validation_plot_losses = training.train_model_tree_to_tree(program_model, js_cs_dset, 
                                 optimizer, lr_scheduler=None, num_epochs=5, plot_every=plot_every,
                                 batch_size=100, print_every=200, validation_criterion=validation_criterion,
                                 use_cuda=use_cuda)
    
    

In [None]:
%debug

In [None]:
# train_plot_losses_old = train_plot_losses_old + train_plot_losses_new
# validation_plot_losses_old = validation_plot_losses_old + validation_plot_losses_new

import csv

torch.save(program_model, "grammar-3-vars-share-everything-model")
with open("grammar-3-vars-share-everything-train.txt", "w") as output:
    writer = csv.writer(output, lineterminator='\n')
    for val in train_plot_losses:
        writer.writerow([val]) 
with open("grammar-3-vars-share-everything-validation.txt", "w") as output:
    writer = csv.writer(output, lineterminator='\n')
    for val in validation_plot_losses:
        writer.writerow([val]) 

In [None]:
print("hi")

In [None]:
plt.plot([x * plot_every for x in range(len(train_plot_losses))], train_plot_losses)
plt.show()

plt.plot([x * plot_every for x in range(len(validation_plot_losses))], validation_plot_losses)
plt.show()

In [None]:
n = num_ints + num_vars
print(num_vars)
print(num_ints)

def check_all():
    i = 0
    # Check grammar is right
    for prog in for_lambda_dset:
        correct = check_valid(prog[1], -1, 0)
        if correct is False:
            print(i)
            pretty_print_tree(prog[1])
            return
        i += 1

# Check whether a node is syntactically valid, given its parent and index
# Then recursively do it for all the node's children
def check_valid(node, parent, child_index):
    try:
        category = parent_to_category_LAMBDA(num_vars, num_ints, parent)[child_index]
    except:
        print("AAA", parent, child_index, node.value)
    possible_outputs = category_to_child_LAMBDA(num_vars, num_ints, category)
    if not int(node.value) in possible_outputs:
        print("parent", parent, "child_index", child_index)
        print("ERROR", int(node.value), category)
        return False
    for i in range(len(node.children)):
        if not check_valid(node.children[i], int(node.value), i):
            return False
    return True

# Check all the programs in a dataset for syntactic accuracy
# (this is a debugging function used to double check the accuracy of your grammar)
def check_all():
    i = 0
    # Check grammar is right
    for prog in for_lambda_dset:
        correct = check_valid(prog[1], -1, 0)
        if correct is False:
            print(i)
            pretty_print_tree(prog[1])
            return
        i += 1
        
        
        
check_all() #kangaroo
print("all good")