In [None]:
cd ..

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import numpy as np
import random
import pickle
import cv2

from neural_nets_library import training
from tree_to_sequence.tree_decoder_batch import TreeDecoderBatch
from tree_to_sequence.tree_to_tree_attention import TreeToTreeAttention
from tree_to_sequence.tree_to_tree_attention_batch import TreeToTreeAttentionBatch
from tree_to_sequence.program_datasets import *
from functools import partial
from math_expressions.translating_math_trees import math_tokens_short as math_tokens
from tree_to_sequence.translating_trees import pretty_print_tree

In [None]:
np.random.seed(3)

In [None]:
torch.cuda.set_device(1)

In [None]:
use_cuda = True
image_width = 40
image_height = 32
one_hot = False
binarize_output = True
eos_token = True
long_base_case = True
output_as_seq = False
num_samples = None
max_num_children = 2 if binarize_output else 3
batch_size = 5
normalize_input = True
num_layers = 16

In [None]:
def split_dataset(data, split):
    all_trees = []
    for img, tree in data:
        if not tree in all_trees:
            all_trees.append(tree)
    
    split_cutoff = int(len(all_trees) * split)
    first_split = all_trees[:split_cutoff]
    second_split = all_trees[split_cutoff:]
    
    first_data = [(img, tree) for img, tree in data if tree in first_split]
    second_data = [(img, tree) for img, tree in data if tree in second_split]
    
    return first_data, second_data


def make_dset(data):
    return MathExprDatasetBatched(program_pairs=data, 
                                 batch_size=batch_size,
                                 binarize_output=binarize_output,
                                 validation_set=False,
                                 max_children_output=max_num_children,
                                 eos_token=eos_token,
                                 normalize=normalize_input,
                                 num_samples=num_samples)

In [None]:
# Split into train/val/test sets
test_data = load_shuffled_data("math_expressions/test_data_short.pkl")
all_data = load_shuffled_data("math_expressions/train_data_short.pkl")

train_cutoff = 0.7
train_data, val_data = split_dataset(all_data, train_cutoff)

print("Train set size: ", len(train_data))
print("Val set size: ", len(val_data))
print("Test set size: ", len(test_data))

train_dset = make_dset(train_data)
val_dset = make_dset(val_data)
# test_dset = make_dset(test_data)

max_size = max([tree.size() for batch in train_dset for tree in batch[1]])
print("max size", max_size)
print(len(train_dset))

In [None]:
def display_normally(pic, title=None):
        if not title is None:
            plt.title(title)
        plt.imshow(np.repeat(np.int0(pic)[:,:,np.newaxis]*255, 3, axis=2))
        plt.show()

In [None]:
# # Print the dset
# for batched_img, batched_trees in train_dset:
#     for i in range(len(batched_img)):
#         img = batched_img[i]
#         tree = batched_trees[i]
#         display_normally(img[0])
#         pretty_print_tree(tree, math_tokens)
#         pretty_print_tree(tree)
 

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, nchannels, nhidden, num_layers, num_cnn_layers, attention=True):
        super(ImageEncoder, self).__init__()
        self.core = nn.Sequential(CNN_Sequence_Extractor(nchannels, num_cnn_layers), nn.LSTM(512, nhidden, num_layers, bidirectional=True))
        self.register_buffer('reverse_indices', torch.LongTensor(range(1, num_layers*2, 2)))
        
        self.attention = attention

    def forward(self, input, widths=None, training=True):
        output, (all_hiddens, all_cell_state) = self.core(input)#, widths=widths) #TODO: Figure out what widths once did

        if widths is not None:
              output = nn.utils.rnn.pad_packed_sequence(output)

        forward_hiddens = all_hiddens.index_select(0, self.reverse_indices - 1)
        reverse_hiddens = all_hiddens.index_select(0, self.reverse_indices) #TODO: does this need a gradient
        hiddens = torch.cat([forward_hiddens, reverse_hiddens], 2)
        
        forward_cell_state = all_cell_state.index_select(0, self.reverse_indices - 1)
        reverse_cell_state = all_cell_state.index_select(0, self.reverse_indices) #TODO: does this need a gradient
        cell_state = torch.cat([forward_cell_state, reverse_cell_state], 2)
        
        
        
        if self.attention:        
            if training:
                return output, hiddens.squeeze(0), cell_state.squeeze(0) # TODO: This is here b/c currently training is batched but testing isn't.  Someday we should fix this
            return output.squeeze(1), hiddens.squeeze(0), cell_state.squeeze(0)
        else:
              return reverse_hiddens
            
            
class CNN_Sequence_Extractor(nn.Module):
    def __init__(self, nchannels, num_layers, leakyRelu=False):
        super(CNN_Sequence_Extractor, self).__init__()

#         # ORIGINAL MODEL SIZE
#         ks = [3, 3, 3, 3, 3, 3, 2]
#         ps = [1, 1, 1, 1, 1, 1, 0]
#         ss = [(2,2), (3,2), (2,1), (3,1), (2,1), (3,1), (2,1)]
#         nm = [64, 128, 256, 256, 512, 512, 512]

        assert(num_layers >= 4)

        # Size of the kernel (image filter) for each convolutional layer.
        ks = [3] * (num_layers - 1) + [2]
        
        # Amount of zero-padding for each convoutional layer.
        ps = [1] * (num_layers - 1) + [0]

        # The stride for each convolutional layer. The list elements are of the form (height stride, width stride).
        ss = [(2,2), (3,2)] + [(2,1) if i % 2 else (3,1) for i in range(num_layers - 2)]
        
        # Number of channels in each convolutional layer.
        nm = [64, 128, 245, 256] + [512] * (num_layers - 4)

        # Initializing the container for the modules that make up the neural network the neurel netowrk.
        cnn = nn.Sequential()

        # Represents a convolutional layer. The input paramter i signals that this is the ith convolutional layer. The user also has the option to set batchNormalization to True which will perform a batch normalization on the image after it has undergone a convoltuional pass. There is no output but this function adds the convolutional layer module created here to the sequential container, cnn.
        def convRelu(i, batchNormalization=False):
            nIn = nchannels if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('leaky_relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        batch_norm_on = True
        # Creating the 7 convolutional layers for the model.
        convRelu(0)
        convRelu(1)
        convRelu(2, batch_norm_on)
        convRelu(3)
        convRelu(4, batch_norm_on)
        convRelu(5)
        convRelu(6, batch_norm_on)

        self.cnn = cnn

    def forward(self, input, widths=None):
        output = self.cnn(input)
        _, _, h, _ = output.size()
        assert h == 1, "the height of conv must be 1"
        output = output.squeeze(2) # [b, c, w]
        output = output.permute(2, 0, 1) #[w, b, c]

        if widths is not None:
            sorted_widths, idx = widths.sort(descending=True)
            output = output.index_select(1, idx)
            output = nn.utils.pack_padded_sequence(output, sorted_widths / 4)

        return output

In [None]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform_(param, -stdev, stdev)

In [None]:
eos_bonus = 1 if eos_token else 0
nclass = len(math_tokens) + 26*2 + 10 # TODO: FIGURE THIS OUT
plot_every = 100
hidden_size = 256
embedding_size = 100
alignment_size = 50
n_channels = 1
num_layers = 1 # TODO: Later consider making this work for num_layers > 1
align_type = 1
num_cnn_layers = 16
    
encoder = ImageEncoder(n_channels, hidden_size, num_layers, num_cnn_layers, attention=True)
decoder = TreeDecoderBatch(embedding_size, hidden_size*2, max_num_children, nclass=nclass)
program_model = TreeToTreeAttentionBatch(encoder, decoder, hidden_size * 2, embedding_size, nclass=nclass, max_size=max_size,
                                    alignment_size=alignment_size, align_type=align_type, use_cuda=use_cuda)    
    
reset_all_parameters_uniform(program_model, 0.1)
decoder.initialize_forget_bias(3)


In [None]:
program_model.decoder.EOS_value = 101

In [None]:
if use_cuda:
    program_model = program_model.cuda()

In [None]:
optimizer = optim.Adam(program_model.parameters(), lr=0.001) #0.001
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=500, factor=0.8)

In [None]:
# Counts the number of matches between the prediction and target.
def count_matches(prediction, target):
    matches = 0
    if int(prediction.value) == int(target.value):
        matches += 1
    for i in range(min(len(target.children), len(prediction.children))):
        matches += count_matches(prediction.children[i], target.children[i])
    return matches

# Program accuracy (1 if completely correct, 0 otherwise)
def program_accuracy(prediction, target):
    target = target[0]
    if prediction.size() == count_matches(prediction, target) and \
       prediction.size() == target.size():
        return 1
    else:
        return 0

# Calculate validation accuracy (this could either be program or token accuracy)
def validation_criterion(prediction, target):
    return program_accuracy(prediction, target)

In [None]:
# program_model = torch.load("math_expressions/models/hopefully_fixed_model_sorted_correctly_model")


In [None]:
for input_tree, target_tree in train_dset[:20]:
    input_tree = input_tree.cuda()
    target_tree = [actual_tree.cuda() for actual_tree in target_tree]                    
                        
    program_model.eval()
    program_model.print_img_tree_example(input_tree, target_tree, math_tokens)
                

In [None]:
best_model, train_plot_losses, train_plot_accuracies, _, _ = training.train_model_tree_to_tree(
    program_model, 
    train_dset,                      
    optimizer, 
    lr_scheduler=None, 
    num_epochs=100, 
    plot_every=plot_every,                            
    batch_size=5, # 10+
    print_every=20, 
    validation_criterion=validation_criterion, 
    validation_dset=val_dset,
    save_folder ="math_expressions/models", 
    save_file="independent_val",                        
    use_cuda=use_cuda, 
    skip_output_cuda=False, 
    tokens=math_tokens,                     
    save_current_only=True, 
    input_tree_form=False)
    
    

In [None]:
plt.plot([x * plot_every for x in range(len(train_plot_losses))], train_plot_losses)
plt.title("Loss")
plt.show()

plt.plot([x * plot_every for x in range(len(train_plot_accuracies))], train_plot_accuracies)
plt.title("Accuracy")
plt.show()

In [None]:
def cudafy_pair(pair):
    img_cuda = pair[0].cuda()
    tree_cuda = [tree.cuda() for tree in pair[1]]
    return (img_cuda, tree_cuda)

In [None]:

val_dset_cuda = [cudafy_pair(pair) for pair in val_dset]

program_model.eval()
acc = training.test_model_tree_to_tree(program_model, val_dset_cuda, validation_criterion, use_cuda=False) 
print("accuracy", acc)

In [None]:
train_dset_cuda = [cudafy_pair(pair) for pair in train_dset]

program_model.eval()
acc = training.test_model_tree_to_tree(program_model, train_dset_cuda, validation_criterion, use_cuda=False) 
print("accuracy", acc)