In [None]:
cd ..

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import numpy as np
import random
import pickle
import cv2

from neural_nets_library import training
from tree_to_sequence.tree_decoder import TreeDecoder
from tree_to_sequence.tree_to_tree_attention import TreeToTreeAttention
from tree_to_sequence.program_datasets import *
from functools import partial
from math_expressions.translating_math_trees import math_tokens
from tree_to_sequence.translating_trees import pretty_print_tree

In [None]:
# torch.cuda.set_device(1)

In [None]:
use_cuda = False
image_width = 40
image_height = 32
one_hot = False
binarize_output = False
eos_token = False
long_base_case = True
output_as_seq = False
num_samples = 100

In [None]:
def find_sketchy_list(tree):
    if type(tree) == list:
        print("something is weird here")
        print(tree)
        return True
    else:
        child_is_weird = True in [find_sketchy_list(child) for child in tree.children]
        if child_is_weird:
            print(tree.value)
        return False
        
find_sketchy_list(tree)

In [None]:
print(train_dset[0][1])
pretty_print_tree(train_dset[0][1])
print(train_dset[0][1].size())
i = 0
for (img, tree) in train_dset:
    try:
        size = tree.size()
        i += 1
    except:
        print("I is ", i)
        print("tree?", tree)
        pretty_print_tree(tree)
        
        


In [None]:

with open('math_expressions/data_new/train_data.pkl', 'rb') as f:
    train_dset = pickle.load(f)
with open('math_expressions/data_new/test_data.pkl', 'rb') as f:
    test_data = pickle.load(f)

    
def process_img(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img[:image_height, :image_width]
    img = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0)
    return img
    
train_dset = [(process_img(x), encode_program(y, math_tokens)) for x,y in train_dset]

print("img shape", train_dset[0][0].shape)

# test_dset = ForLambdaDataset("ANC/AdditionalForDatasets/ForWithLevels/Easy-arbitraryForList.json", binarize_input=True, 
#                                    binarize_output=binarize_output, eos_token=eos_token, one_hot=one_hot, 
#                                    long_base_case=long_base_case, input_as_seq=True, 
#                                    output_as_seq=output_as_seq, num_samples=30)



max_size = 100
# max_size = max([x[1].size() for x in train_dset])
# print("max size", max_size)

In [None]:
def vectorize(val, ops): 
    """
        Based on the value, num_variables, num_ints, and the possible ops, the index corresponding
        to the value is found. value should not correspond to the eos_token. Instead vectorization
        should occur prior to adding eos_tokens. Nodes with value None are simply returned as None.
    """
    num_ints = 10
    num_vars = 26
    alphabet = "abcdefghijklmnopqrstuvwzyz"


    if type(val) is int:
        index = val % num_ints
    elif val not in ops:
        index = alphabet.index(val) + num_ints
    else:
        index = num_ints + num_vars + ops[val]

    return torch.LongTensor([index])




def encode_program(program, ops):
    if isinstance(program, Node):
        return map_tree(lambda node: vectorize(node, ops), program)
    else:
        program = map(lambda node: vectorize(node, ops), program)
        return torch.LongTensor(list(program))

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, nchannels, nhidden, num_layers, attention=True):
        super(ImageEncoder, self).__init__()
        self.core = nn.Sequential(CNN_Sequence_Extractor(nchannels), nn.LSTM(512, nhidden, num_layers, bidirectional=True))
        self.register_buffer('reverse_indices', torch.LongTensor(range(1, num_layers*2, 2)))
        
        self.attention = attention

    def forward(self, input, widths=None):
        output, (all_hiddens, all_cell_state) = self.core(input)#, widths=widths) #TODO: Figure out what widths once did

        if widths is not None:
              output = nn.utils.rnn.pad_packed_sequence(output)

        forward_hiddens = all_hiddens.index_select(0, self.reverse_indices - 1)
        reverse_hiddens = all_hiddens.index_select(0, self.reverse_indices) #TODO: does this need a gradient
        hiddens = torch.cat([forward_hiddens, reverse_hiddens], 2)
        
        forward_cell_state = all_cell_state.index_select(0, self.reverse_indices - 1)
        reverse_cell_state = all_cell_state.index_select(0, self.reverse_indices) #TODO: does this need a gradient
        cell_state = torch.cat([forward_cell_state, reverse_cell_state], 2)
        
        
        
        if self.attention:
              return output.squeeze(1), hiddens.squeeze(0), cell_state.squeeze(0)
        else:
              return reverse_hiddens
            
            
class CNN_Sequence_Extractor(nn.Module):
    def __init__(self, nchannels, leakyRelu=False):
        super(CNN_Sequence_Extractor, self).__init__()

        # Size of the kernel (image filter) for each convolutional layer.
        ks = [3, 3, 3, 3, 3, 3, 2]
        # Amount of zero-padding for each convoutional layer.
        ps = [1, 1, 1, 1, 1, 1, 0]
        # The stride for each convolutional layer. The list elements are of the form (height stride, width stride).
        ss = [(2,2), (2,2), (1,1), (2,1), (1,1), (2,1), (1,1)]
        # Number of channels in each convolutional layer.
        nm = [64, 128, 256, 256, 512, 512, 512]

        # Initializing the container for the modules that make up the neural network the neurel netowrk.
        cnn = nn.Sequential()

        # Represents a convolutional layer. The input paramter i signals that this is the ith convolutional layer. The user also has the option to set batchNormalization to True which will perform a batch normalization on the image after it has undergone a convoltuional pass. There is no output but this function adds the convolutional layer module created here to the sequential container, cnn.
        def convRelu(i, batchNormalization=False):
            nIn = nchannels if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('leaky_relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        # Creating the 7 convolutional layers for the model.
        convRelu(0)
        convRelu(1)
        convRelu(2, True)
        convRelu(3)
        convRelu(4, True)
        convRelu(5)
        convRelu(6, True)

        self.cnn = cnn

    def forward(self, input, widths=None):
        output = self.cnn(input)
        _, _, h, _ = output.size()
        assert h == 1, "the height of conv must be 1"
        output = output.squeeze(2) # [b, c, w]
        output = output.permute(2, 0, 1) #[w, b, c]

        if widths is not None:
            sorted_widths, idx = widths.sort(descending=True)
            output = output.index_select(1, idx)
            output = nn.utils.pack_padded_sequence(output, sorted_widths / 4)

        return output

In [None]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform_(param, -stdev, stdev)

In [None]:
eos_bonus = 1 if eos_token else 0
nclass = len(math_tokens) # TODO: FIGURE THIS OUT
plot_every = 100
max_num_children = 4
hidden_size = 256
embedding_size = 25
alignment_size = 100
n_channels = 1
num_layers = 1 # TODO: Later consider making this work for num_layers > 1
max_size = 100 # TODO: FIGURE THIS OUT    
align_type = 1
    
encoder = ImageEncoder(n_channels, hidden_size, num_layers, attention=True)
decoder = TreeDecoder(embedding_size, hidden_size * 2, max_num_children, nclass=nclass)
program_model = TreeToTreeAttention(encoder, decoder, hidden_size * 2, embedding_size, nclass=nclass, max_size=max_size,
                                    alignment_size=alignment_size, align_type=align_type)
    
    
reset_all_parameters_uniform(program_model, 0.1)
decoder.initialize_forget_bias(3)


In [None]:
if use_cuda:
    program_model = program_model.cuda()

In [None]:
optimizer = optim.Adam(program_model.parameters(), lr=0.005)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=500, factor=0.8)

In [None]:
# Counts the number of matches between the prediction and target.
def count_matches(prediction, target):
    matches = 0
    if int(prediction.value) == int(target.value):
        matches += 1
    for i in range(min(len(target.children), len(prediction.children))):
        matches += count_matches(prediction.children[i], target.children[i])
    return matches

# Program accuracy (1 if completely correct, 0 otherwise)
def program_accuracy(prediction, target):
    if prediction.size() == count_matches(prediction, target) and \
       prediction.size() == target.size():
        return 1
    else:
        return 0

# Calculate validation accuracy (this could either be program or token accuracy)
def validation_criterion(prediction, target):
    return program_accuracy(prediction, target)

In [None]:
best_model, train_plot_losses, validation_plot_losses, _, _ = training.train_model_tree_to_tree(program_model, train_dset, 
                                 optimizer, lr_scheduler=None, num_epochs=5, plot_every=plot_every,
                                 batch_size=1, print_every=200, validation_criterion=validation_criterion,
                                 use_cuda=use_cuda)
    
    

In [None]:
plt.plot([x * plot_every for x in range(len(train_plot_losses))], train_plot_losses)
plt.show()

plt.plot([x * plot_every for x in range(len(validation_plot_losses))], validation_plot_losses)
plt.show()

In [None]:
r = torch.tensor(np.random.rand(1,100,256))
r.index_select(1, torch.tensor([1]))