In [1]:
cd ..

In [2]:
import torch.optim as optim
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import datetime

from neural_nets_library import training
from tree_to_sequence.tree_encoder import TreeEncoder
from tree_to_sequence.tree_decoder import TreeDecoder
from tree_to_sequence.program_datasets import *
from tree_to_sequence.translating_trees import *
from tree_to_sequence.tree_to_tree_attention import TreeToTreeAttention
from functools import partial

torch.manual_seed(0)

<torch._C.Generator at 0x7fcb34145d30>

In [3]:
num_vars = 10
num_ints = 11
one_hot = False
binarize_input = False
binarize_output = False
eos_token = False
long_base_case = True
input_as_seq = False
output_as_seq = False

In [4]:
for_lambda_dset = ForLambdaDataset("ANC/AdditionalForDatasets/ForWithLevels/Easy-arbitraryForList.json", binarize_input=binarize_input, 
                                   binarize_output=binarize_output, eos_token=eos_token, one_hot=one_hot, 
                                   long_base_case=long_base_case, input_as_seq=input_as_seq, 
                                   output_as_seq=output_as_seq, num_samples=10)
max_size = max([x[1].size() for x in for_lambda_dset])

In [5]:
print(for_lambda_dset[0][1].children[2].value)

tensor([30])


In [6]:
pretty_print_tree(for_lambda_dset[0][0])

   ┌21┐
   │  └12
 28┤
   └22┐
      └9


In [7]:
pretty_print_tree(for_lambda_dset[0][1])

   ┌30
 29┤
   ├22┐
   │  └9
   └21┐
      └12


In [8]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform_(param, -stdev, stdev)

In [9]:
embedding_size = 100 #... 256 is from the paper, but 100 is WAY faster
hidden_size = 3#256
num_layers = 1
alignment_size = 50
align_type = 1
encoder_input_size = num_vars + num_ints + len(for_ops)
annotation_method = pre_order
#Changed to randomize=true
encoder = TreeEncoder(encoder_input_size, hidden_size, num_layers, [1, 2, 3, 4, 5], attention=True, one_hot=one_hot, binary_tree_lstm_cell=True, annotation_method=annotation_method, randomize_hiddens=False)
nclass = num_vars + num_ints + len(lambda_ops)
plot_every = 100
max_num_children = 2 if binarize_output else 4

decoder = TreeDecoder(embedding_size, hidden_size, max_num_children, nclass=nclass)
program_model = TreeToTreeAttention(encoder, decoder, hidden_size, embedding_size, nclass=nclass, max_size=max_size,
                                    alignment_size=alignment_size, align_type=align_type)
    
reset_all_parameters_uniform(program_model, 0.1)
encoder.initialize_forget_bias(3)
decoder.initialize_forget_bias(3)

In [10]:
program_model = program_model.cuda()

In [11]:
optimizer = optim.Adam(program_model.parameters(), lr=0.005)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=500, factor=0.8)

In [12]:
# Counts the number of matches between the prediction and target.
def count_matches(prediction, target):
    matches = 0
    if int(prediction.value) == int(target.value):
        matches += 1
    for i in range(min(len(target.children), len(prediction.children))):
        matches += count_matches(prediction.children[i], target.children[i])
    return matches

# Program accuracy (1 if completely correct, 0 otherwise)
def program_accuracy(prediction, target):
    if prediction.size() == count_matches(prediction, target) and \
       prediction.size() == target.size():
        return 1
    else:
        return 0

# Calculate validation accuracy (this could either be program or token accuracy)
def validation_criterion(prediction, target):
    return program_accuracy(prediction, target)

In [13]:
program_model.update_max_size(max_size)

In [14]:
start = datetime.datetime.now()
best_model, train_plot_losses, validation_plot_losses, thing1, thing2 = training.train_model_tree_to_tree(program_model, for_lambda_dset, 
                                 optimizer, lr_scheduler=None, num_epochs=100, plot_every=200,
#                                  batch_size=100, print_every=200,
                                 batch_size=90, print_every=20, 
                                 validation_criterion=program_accuracy, use_cuda=True)

best_model, train_plot_losses, validation_plot_losses = training.yet_another_train_func(program_model, for_lambda_dset, 
                                 optimizer, lr_scheduler=None, num_epochs=5, plot_every=1,
                                 batch_size=90, print_every=20, validation_criterion=None,
                                 use_cuda=True)
end = datetime.datetime.now()
print("TIME", end - start)

Epoch 0/99
----------
Epoch 1/99
----------
Epoch Number: 1, Batch Number: 10, Training Loss: 26.6276
Time so far is 0m 0s
Epoch Number: 1, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 2/99
----------
Epoch 3/99
----------
Epoch Number: 3, Batch Number: 10, Training Loss: 26.6276
Time so far is 0m 1s
Epoch Number: 3, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 4/99
----------
Epoch 5/99
----------
Epoch Number: 5, Batch Number: 10, Training Loss: 26.6276
Time so far is 0m 1s
Epoch Number: 5, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 6/99
----------
Epoch 7/99
----------
Epoch Number: 7, Batch Number: 10, Training Loss: 26.6276
Time so far is 0m 2s
Epoch Number: 7, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 8/99
----------
Epoch 9/99
----------
Epoch Number: 9, Batch Number: 10, Training Loss: 26.5970
Time so far is 0m 3s
Epoch Number: 9, Batch Number: 10, Validation Metric: 0.0000
Example output:


Epoch Number: 81, Batch Number: 10, Training Loss: 26.0962
Time so far is 0m 25s
Epoch Number: 81, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 82/99
----------
Epoch 83/99
----------
Epoch Number: 83, Batch Number: 10, Training Loss: 26.0600
Time so far is 0m 26s
Epoch Number: 83, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 84/99
----------
Epoch 85/99
----------
Epoch Number: 85, Batch Number: 10, Training Loss: 26.0600
Time so far is 0m 26s
Epoch Number: 85, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 86/99
----------
Epoch 87/99
----------
Epoch Number: 87, Batch Number: 10, Training Loss: 26.0600
Time so far is 0m 27s
Epoch Number: 87, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 88/99
----------
Epoch 89/99
----------
Epoch Number: 89, Batch Number: 10, Training Loss: 26.0600
Time so far is 0m 28s
Epoch Number: 89, Batch Number: 10, Validation Metric: 0.0000
Example output:
Epoch 90/99
---------

TypeError: 'Node' object does not support indexing

In [None]:
plt.plot([x * plot_every for x in range(len(train_plot_losses))], train_plot_losses)
plt.show()

plt.plot([x * plot_every for x in range(len(validation_plot_losses))], validation_plot_losses)
plt.show()

In [None]:
train_plot_losses_old = train_plot_losses#_old + train_plot_losses_new
validation_plot_losses_old = validation_plot_losses#_old + validation_plot_losses_new

# import csv

# torch.save(program_model, "max-big-t2t-all-vars-model")
# with open("max-big-t2t-all-vars-train.txt", "w") as output:
#     writer = csv.writer(output, lineterminator='\n')
#     for val in train_plot_losses:
#         writer.writerow([val]) 
# with open("max-big-t2t-all-vars-validation.txt", "w") as output:
#     writer = csv.writer(output, lineterminator='\n')
#     for val in validation_plot_losses:
#         writer.writerow([val]) 

In [None]:
n = num_ints + num_vars

# Check whether a node is syntactically valid, given its parent and index
# Then recursively do it for all the node's children
def check_valid(node, parent, child_index):
    category = parent_to_category_LAMBDA(parent, child_index, num_vars, num_ints)
    possible_outputs = category_to_child_LAMBDA(category, num_vars, num_ints)
    if not int(node.value) in possible_outputs:
        print("parent", parent, "child_index", child_index)
        print("ERROR", int(node.value), category)
        return False
    if (len(node.children) > 0):
        child1 = check_valid(node.children[0], int(node.value), 0)
        if not child1:
            return False
        child2 = check_valid(node.children[1], parent, child_index + 1)
        if not child2:
            return False
    return True

# Check all the programs in a dataset for syntactic accuracy
# (this is a debugging function used to double check the accuracy of your grammar)
def check_all():
    i = 0
    # Check grammar is right
    for prog in for_lambda_dset:
        correct = check_valid(prog[1], None, 0)
        if correct is False:
            print(i)
            pretty_print_tree(prog[1])
            return
        i += 1
        
        
        
check_all() #kangaroo