In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data

import numpy as np
import random

import json
from translating_trees import *
from for_prog_dataset import ForDataset
from convert_to_tree_and_matrices import TreeANCDataset
from functools import partial

In [None]:
cd ..

In [None]:
from neural_nets_library import training, visualize
from ANC import Controller

In [None]:
class TreeCell(nn.Module):
    """
    LSTM Cell which takes in arbitrary numbers of hidden and cell states (one per child).
    """
    def __init__(self, input_size, hidden_size, num_children):
        """
        Initialize the LSTM cell.
        
        :param input_size: length of input vector
        :param hidden_size: length of hidden vector (and cell state)
        :param num_children: number of children = number of hidden/cell states passed in
        """
        super(TreeCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Gates = input, output, memory + one forget gate per child
        numGates = 3 + num_children
        
        self.gates_value = torch.nn.ModuleList()
        self.gates_children = torch.nn.ModuleList()
        for _ in range(numGates):
            # One linear layer to handle the value of the node
            value_linear = nn.Linear(input_size, hidden_size, bias = True)
            children_linear = torch.nn.ModuleList()
            # One per child of the node
            for _ in range(num_children):
                children_linear.append(nn.Linear(hidden_size, hidden_size, bias = False))
            self.gates_value.append(value_linear)
            self.gates_children.append(children_linear)
            
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of num_children hidden states.
        :param cell_states: A list of num_children cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        data_sums = []

        for i in range(len(self.gates_value)):
            data_sum = self.gates_value[i](input)
            for j in range(len(hidden_states)):
                data_sum += self.gates_children[i][j](hidden_states[j])
            data_sums.append(data_sum)
        
        # First gate is the input gate
        input_val = self.sigmoid(data_sums[0])
        # Next output gate
        o = self.sigmoid(data_sums[1])
        # Next memory gate
        m = self.tanh(data_sums[2])
        # All the rest are forget gates
        forget_data = 0
        for i in range(len(cell_states)):
            forget_data += self.sigmoid(data_sums[3 + i]) * cell_states[i]

        # Put it all together!
        new_state = input_val * m + forget_data
        new_hidden = o * self.tanh(new_state)  
                
        return new_hidden, new_state
    
    def initialize_forget_bias(self, bias_value):
        for i in range(3, len(self.gates_value)):
            torch.nn.init.constant(self.gates_value[i].bias, bias_value) 

In [None]:
class TreeLSTM(nn.Module):
    '''
    TreeLSTM

    Takes in a tree where each node has a value and a list of children.
    Produces a tree of the same size where the value of each node is now encoded.

    '''

    def __init__(self, input_size, hidden_size, valid_num_children):
        """
        Initialize tree cells we'll need later.
        """
        super(TreeLSTM, self).__init__()
        
        self.valid_num_children = [0] + valid_num_children
        self.lstm_list = torch.nn.ModuleList()
        
        for size in self.valid_num_children:
            self.lstm_list.append(TreeCell(input_size, hidden_size, size))
        
    def forward(self, node):
        """
        Creates a tree where each node's value is the encoded version of the original value.
        
        :param tree: a tree where each node has a value vector and a list of children
        :return a tuple - (root of encoded tree, cell state)
        """
        
        # List of tuples: (node, cell state)
        children = []
        
        # Recursively encode children
        for child in node.children:
            encoded_child = self.forward(child)
            children.append(encoded_child)

        # Extract the TreeCell inputs
        inputH = [vec[0].value for vec in children]
        inputC = [vec[1] for vec in children]

        value = node.value

        found = False
        
        # Feed the inputs into the TreeCell with the appropriate number of children.        
        for i in range(len(self.valid_num_children)):
            if self.valid_num_children[i] == len(children):
                newH, newC = self.lstm_list[i](value, inputH, inputC)
                found = True
                break
                
        if not found:
            print("WHAAAAAT?")
            raise ValueError("Beware.  Something has gone horribly wrong.  You may not have long to live.")
        
        # Set our encoded vector as the root of the new tree
        rootNode = Node(newH)
        rootNode.children = [vec[0] for vec in children]
        return (rootNode, newC)
    
    def initialize_forget_bias(self, bias_value):
        for lstm in self.lstm_list:
            lstm.initialize_forget_bias(bias_value)

In [None]:
class SeqEncoder(nn.Module):
    # If you are using an end of sequence token that should be accounted for in input_size.
    def __init__(self, input_size, hidden_size, num_layers, attention=True, 
                 use_embedding=True, embedding_size=256):
        super(SeqEncoder, self).__init__()
        
        self.use_embedding = use_embedding
        
        if use_embedding:
            self.embedding = nn.Embedding(input_size, embedding_size)
            self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers)
        else:
            self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        
        self.attention = attention
        
    def initialize_forget_bias(self, bias_val):
        for names in self.lstm._all_weights:
            for name in filter(lambda n: "bias" in n,  names):
                bias = getattr(self.lstm, name)
                n = bias.size(0)
                start, end = n//4, n//2
                bias.data[start:end].fill_(bias_val)
    
    def forward(self, input):
        if self.use_embedding:
            input = self.embedding(input)
        outputs, (hiddens, cell_states) = self.lstm(input.unsqueeze(1))
        outputs, hiddens, cell_states = outputs.squeeze(1), hiddens.squeeze(1), cell_states.squeeze(1)
        
        if self.attention:
            return outputs, hiddens, cell_states
        else:
            return hiddens, cell_states

In [None]:
class TreeEncoder(nn.Module):
    """
    Takes in a tree where each node has a value vector and a list of children
    Produces a sequence encoding of the tree
    """
    def __init__(self, input_size, hidden_size, num_layers, valid_num_children, 
                 attention=True, use_embedding=True, embedding_size=256):
        super(TreeEncoder, self).__init__()
        
        self.lstm_list = torch.nn.ModuleList()
        self.use_embedding = use_embedding
        
        if use_embedding:
            self.embedding = nn.Embedding(input_size, embedding_size)
            self.lstm_list.append(TreeLSTM(embedding_size, hidden_size, valid_num_children))
        else:
            self.lstm_list.append(TreeLSTM(input_size, hidden_size, valid_num_children))
        
        # All TreeLSTMs have input of hidden_size except the first.
        for i in range(num_layers-1):
            self.lstm_list.append(TreeLSTM(hidden_size, hidden_size, valid_num_children))
        
        self.attention = attention

    def forward(self, tree):
        """
        Encodes nodes of a tree in the rows of a matrix.
        
        :param tree: a tree where each node has a value vector and a list of children
        :return a matrix where each row represents the encoded output of a single node and also
                the hidden/cell states of the root node.
        
        """
        if self.use_embedding:
            tree = map_tree(lambda node: self.embedding(node).squeeze(0), tree)
        
        hiddens = []
        cell_states = []
        
        for lstm in self.lstm_list:
            tree, cell_state = lstm(tree)
            hiddens.append(tree.value)
            cell_states.append(cell_state)
        
        
        hiddens = torch.stack(hiddens)
        cell_states = torch.stack(cell_states)
        
        if self.attention:
            annotations = torch.stack(tree_to_list(tree))                
            return annotations, hiddens, cell_states
        else:
            return hiddens, cell_states
        
    def initialize_forget_bias(self, bias_value):
        for lstm in self.lstm_list:
            lstm.initialize_forget_bias(bias_value)

In [None]:
'''
Decoder
'''
class Tree_to_Sequence_Model(nn.Module):
    """
      For the decoder this expects something like an lstm cell or a gru cell and not an lstm/gru.
      Batch size is not supported at all. More precisely the encoder expects an input that does not
      appear in batches and most also output non-batched tensors.
    """
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size):
        super(Tree_to_Sequence_Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
        # nclass + 2 to include end of sequence and trash
        self.output_log_odds = nn.Linear(hidden_size, nclass+2)
        self.softmax = nn.Softmax(dim=0)
        self.log_softmax = nn.LogSoftmax(dim=0)

        self.register_buffer('SOS_token', torch.LongTensor([[nclass+2]]))
        self.EOS_value = nclass + 1

        # nclass + 3 to include start of sequence, end of sequence, and trash.
        # n + 2 - start of sequence, end of sequence - n + 1, trash - n.
        # The first n correspond to the alphabet in order.
        self.embedding = nn.Embedding(nclass+3, embedding_size)

        # nclass is the trash category to avoid penalties after target's EOS token
        self.loss_func = nn.CrossEntropyLoss(ignore_index=nclass)

    """
        input: The output of the encoder for the input should be a pair. The first part
               should correspond to the hidden state of the root. The second part
               should correspond to the cell state of the root. They both should be
               [num_layers, hidden_size].
        target: The target should have dimension, seq_len, and should be a LongTensor.
    """
    def forward_train(self, input, target, teacher_forcing=True):
        # root hidden state/cell state
        decoder_hiddens, decoder_cell_states = self.encoder(input) # num_layers x hidden_size
        decoder_hiddens = decoder_hiddens.unsqueeze(1)
        decoder_cell_states = decoder_cell_states.unsqueeze(1)
                                                            
        target_length, = target.size()
        SOS_token = Variable(self.SOS_token)
        decoder_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        loss = 0

        for i in range(target_length):
            decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states)) # num_layers x 1 x hidden_size
            decoder_hidden = decoder_hiddens[-1] # 1 x hidden_size
            log_odds = self.output_log_odds(decoder_hidden)

            loss += self.loss_func(log_odds, target[i])

            if teacher_forcing:
                next_input = target[i].unsqueeze(1)
            else:
                _, next_input = log_odds.topk(1)

            decoder_input = self.embedding(next_input).squeeze(1) # 1 x embedding_size
                
        return loss

    """
        This is just an alias for point_wise_prediction, so that training code that assumes the presence
        of a forward_train and forward_prediction works.
    """
    def forward_prediction(self, input, maximum_length=20):
        return self.point_wise_prediction(input, maximum_length)
    
    def point_wise_prediction(self, input, maximum_length=20):
        decoder_hiddens, decoder_cell_states = self.encoder(input)
        decoder_hiddens = decoder_hiddens.unsqueeze(1)
        decoder_cell_states = decoder_cell_states.unsqueeze(1)
        
        SOS_token = Variable(self.SOS_token)
        decoder_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        output_so_far = []

        for _ in range(maximum_length):
            decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
            decoder_hidden = decoder_hiddens[-1]
            log_odds = self.output_log_odds(decoder_hidden)

            _, next_input = log_odds.topk(1)
            output_so_far.append(int(next_input))
            
            if int(next_input) == self.EOS_value:
                break
                
            decoder_input = self.embedding(next_input).squeeze(1) # 1 x embedding size

        return output_so_far

    def beam_search_prediction(self, input, maximum_length=20, beam_width=5):
        decoder_hiddens, decoder_cell_states = self.encoder(input)
        decoder_hiddens = decoder_hiddens.unsqueeze(1)
        decoder_cell_states = decoder_cell_states.unsqueeze(1)
        
        SOS_token = Variable(self.SOS_token)
        decoder_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        word_inputs = []

        for _ in range(beam_width):
            word_inputs.append((0, [], True, [decoder_input, decoder_hiddens, decoder_cell_states]))

        for _ in range(maximum_length):
            new_word_inputs = []

            for i in range(beam_width):
                if not word_inputs[i][2]:
                    new_word_inputs.append(word_inputs[i])
                    continue

                decoder_input, decoder_hiddens, decoder_cell_states = word_inputs[i][3]
                decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
                decoder_hidden = decoder_hiddens[-1]
                log_odds = self.output_log_odds(decoder_hidden).squeeze(0) # nclasses
                log_probs = self.log_softmax(log_odds)

                log_value, next_input = log_probs.topk(beam_width) # beam_width, beam_width
                decoder_input = self.embedding(next_input.unsqueeze(1)) # beam_width x 1 x embedding size

                new_word_inputs.extend((word_inputs[i][0] + float(log_value[k]), word_inputs[i][1] + [int(next_input[k])],
                                        int(next_input[k]) != self.EOS_value, [decoder_input[k], decoder_hiddens, decoder_cell_states])
                                        for k in range(beam_width))
                    
            word_inputs = sorted(new_word_inputs, key=lambda word_input: word_input[0])[-beam_width:]
        return word_inputs[-1][1]

In [None]:
class Tree_to_Sequence_Attention_Model(Tree_to_Sequence_Model):
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size,
                 alignment_size=50, align_type=1):
        super(Tree_to_Sequence_Attention_Model, self).__init__(encoder, decoder, hidden_size, nclass, embedding_size)
        
        self.attention_presoftmax = nn.Linear(2 * hidden_size, hidden_size)
        self.tanh = nn.Tanh()
        
        if align_type == 0:
            self.attention_hidden = nn.Linear(hidden_size, alignment_size)
            self.attention_context = nn.Linear(hidden_size, alignment_size, bias=False)
            self.attention_alignment_vector = nn.Linear(alignment_size, 1)
        elif align_type == 1:
            self.attention_hidden = nn.Linear(hidden_size, hidden_size)
            
        self.align_type = align_type
        self.register_buffer('et', torch.zeros(1, hidden_size))
        
    """
        input: The output of the encoder for the tree should have be a triple. The first 
               part of the triple should be the annotations and have dimensions, 
               number_of_nodes x hidden_size. The second triple of the pair should be the hidden 
               representations of the root and should have dimensions, num_layers x hidden_size.
               The third part should correspond to the cell states of the root and should
               have dimensions, num_layers x hidden_size.
        target: The target should have dimensions, seq_len, and should be a LongTensor.
    """
    def forward_train(self, input, target, teacher_forcing=True):
        annotations, decoder_hiddens, decoder_cell_states = self.encoder(input)
        # align_size: 0 number_of_nodes x alignment_size or align_size: 1-2 bengio number_of_nodes x hidden_size
        if self.align_type <= 1:
            attention_hidden_values = self.attention_hidden(annotations)
        else:
            attention_hidden_values = annotations
        
        decoder_hiddens = decoder_hiddens.unsqueeze(1) # num_layers x 1 x hidden_size
        decoder_cell_states = decoder_cell_states.unsqueeze(1) # num_layers x 1 x hidden_size

        target_length, = target.size()
        SOS_token = Variable(self.SOS_token)

        word_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        et = Variable(self.et)
        loss = 0

        for i in range(target_length):
            decoder_input = torch.cat((word_input, et), dim=1) # 1 x embedding_size + hidden_size
            decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
            decoder_hidden = decoder_hiddens[-1]
            
            attention_logits = self.attention_logits(attention_hidden_values, decoder_hidden)
            attention_probs = self.softmax(attention_logits) # number_of_nodes x 1
            context_vec = (attention_probs * annotations).sum(0).unsqueeze(0) # 1 x hidden_size
            et = self.tanh(self.attention_presoftmax(torch.cat((decoder_hidden, context_vec), dim=1)))
            log_odds = self.output_log_odds(et)
            loss += self.loss_func(log_odds, target[i])

            if teacher_forcing:
                next_input = target[i].unsqueeze(1)
            else:
                _, next_input = log_odds.topk(1)

            word_input = self.embedding(next_input).squeeze(1) # 1 x embedding size
        return loss        

    """
        This is just an alias for point_wise_prediction, so that training code that assumes the presence
        of a forward_train and forward_prediction works.
    """
    def forward_prediction(self, input, maximum_length=150):
        return self.point_wise_prediction(input, maximum_length)
    
    def point_wise_prediction(self, input, maximum_length=150):
        annotations, decoder_hiddens, decoder_cell_states = self.encoder(input)
        
        # align_size: 0 number_of_nodes x alignment_size or align_size: 1-2 bengio number_of_nodes x hidden_size
        if self.align_type <= 1:
            attention_hidden_values = self.attention_hidden(annotations)
        else:
            attention_hidden_values = annotations
        
        decoder_hiddens = decoder_hiddens.unsqueeze(1) # num_layers x 1 x hidden_size
        decoder_cell_states = decoder_cell_states.unsqueeze(1) # num_layers x 1 x hidden_size
        SOS_token = Variable(self.SOS_token)
        
        word_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        et = Variable(self.et)
        output_so_far = []
        
        for i in range(maximum_length):
            decoder_input = torch.cat((word_input, et), dim=1) # 1 x embedding_size + hidden_size
            decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
            decoder_hidden = decoder_hiddens[-1]
            
            attention_logits = self.attention_logits(attention_hidden_values, decoder_hidden)
            attention_probs = self.softmax(attention_logits) # number_of_nodes x 1
            context_vec = (attention_probs * annotations).sum(0).unsqueeze(0) # 1 x hidden_size
            et = self.tanh(self.attention_presoftmax(torch.cat((decoder_hidden, context_vec), dim=1)))
            log_odds = self.output_log_odds(et)
            _, next_input = log_odds.topk(1)

            output_so_far.append(int(next_input))
            
            if int(next_input) == self.EOS_value:
                break
                
            word_input = self.embedding(next_input).squeeze(1) # 1 x embedding size

        return output_so_far

    def beam_search_prediction(self, input, maximum_length=20, beam_width=5):
        annotations, decoder_hiddens, decoder_cell_states = self.encoder(input)
        # align_size: 0 number_of_nodes x alignment_size or align_size: 1-2 bengio number_of_nodes x hidden_size
        if self.align_type <= 1:
            attention_hidden_values = self.attention_hidden(annotations)
        else:
            attention_hidden_values = annotations
        
        decoder_hiddens = decoder_hiddens.unsqueeze(1) # num_layers x 1 x hidden_size
        decoder_cell_states = decoder_cell_states.unsqueeze(1) # num_layers x 1 x hidden_size
        SOS_token = Variable(self.SOS_token)
        
        word_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        et = Variable(self.et)
        
        decoder_input = torch.cat((word_input, et), dim=1)
        word_inputs = []

        for _ in range(beam_width):
            word_inputs.append((0, [], True, [decoder_input, decoder_hiddens, decoder_cell_states]))

        for _ in range(maximum_length):
            new_word_inputs = []

            for i in range(beam_width):
                if not word_inputs[i][2]:
                    new_word_inputs.append(word_inputs[i])
                    continue

                decoder_input, decoder_hiddens, decoder_cell_states = word_inputs[i][3]
                decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
                decoder_hidden = decoder_hiddens[-1]
            
                attention_logits = self.attention_logits(attention_hidden_values, decoder_hidden)
                attention_probs = self.softmax(attention_logits) # number_of_nodes x 1
                context_vec = (attention_probs * annotations).sum(0).unsqueeze(0) # 1 x hidden_size
                et = self.tanh(self.attention_presoftmax(torch.cat((decoder_hidden, context_vec), dim=1))) # 1 x hidden_size
                log_odds = self.output_log_odds(et).squeeze(0) # nclasses
                log_probs = self.log_softmax(log_odds)

                log_value, next_input = log_probs.topk(beam_width) # beam_width, beam_width
                word_input = self.embedding(next_input.unsqueeze(1)) # beam_width x 1 x embedding size
                decoder_input = torch.cat((word_input, et.unsqueeze(0).repeat(beam_width, 1, 1)), dim=2)

                new_word_inputs.extend((word_inputs[i][0] + float(log_value[k]), word_inputs[i][1] + [int(next_input[k])],
                                        int(next_input[k]) != self.EOS_value, [word_input[k], decoder_hiddens, decoder_cell_states])
                                        for k in range(beam_width))
            word_inputs = sorted(new_word_inputs, key=lambda word_input: word_input[0])[-beam_width:]
        return word_inputs[-1][1]
    
    def attention_logits(self, attention_hidden_values, decoder_hidden):
        if self.align_type == 0:
            return self.attention_alignment_vector(self.tanh(self.attention_context(decoder_hidden) + attention_hidden_values))
        else:
            return (decoder_hidden * attention_hidden_values).sum(1).unsqueeze(1)

In [None]:
class Tree_to_Sequence_Attention_ANC_Model(Tree_to_Sequence_Attention_Model):
    def __init__(self, encoder, decoder, hidden_size, embedding_size, M, R,
                 alignment_size=50, align_type=1, N=11, t_max=10):
        # The 1 is for nclasses which is not used in this model.
        super(Tree_to_Sequence_Attention_ANC_Model, self).__init__(encoder, decoder, hidden_size, 1, embedding_size,
                                                                   alignment_size=alignment_size, align_type=align_type)
        # the initial registers all have value 0 with probability 1
        prob_dist = torch.zeros(R, M)
        prob_dist[:, 0] = 1
        
        self.register_buffer('initial_registers', prob_dist)
        
        self.M = M
        self.R = R
        self.N = N
        self.t_max = t_max
        
        self.initial_word_input = nn.Parameter(torch.Tensor(1, N + 3*R))
        self.output_log_odds = nn.Linear(hidden_size, N + 3*R)
        
    """
        input: The output of the encoder for the tree should have be a triple. The first 
               part of the triple should be the annotations and have dimensions, 
               number_of_nodes x hidden_size. The second triple of the pair should be the hidden 
               representations of the root and should have dimensions, num_layers x hidden_size.
               The third part should correspond to the cell states of the root and should
               have dimensions, num_layers x hidden_size.
    """
    def forward(self, input):
        annotations, decoder_hiddens, decoder_cell_states = self.encoder(input)
        # align_size: 0 number_of_nodes x alignment_size or align_size: 1-2 bengio number_of_nodes x hidden_size
        if self.align_type <= 1:
            attention_hidden_values = self.attention_hidden(annotations)
        else:
            attention_hidden_values = annotations
        
        decoder_hiddens = decoder_hiddens.unsqueeze(1) # num_layers x 1 x hidden_size
        decoder_cell_states = decoder_cell_states.unsqueeze(1) # num_layers x 1 x hidden_size

        word_input = self.initial_word_input # 1 x N + 3*R
        et = Variable(self.et)
        
        output_words = []

        for i in range(self.M):
            decoder_input = torch.cat((word_input, et), dim=1) # 1 x N + 3*R + hidden_size
            decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
            decoder_hidden = decoder_hiddens[-1]
            
            attention_logits = self.attention_logits(attention_hidden_values, decoder_hidden)
            attention_probs = self.softmax(attention_logits) # number_of_nodes x 1
            context_vec = (attention_probs * annotations).sum(0).unsqueeze(0) # 1 x hidden_size
            et = self.tanh(self.attention_presoftmax(torch.cat((decoder_hidden, context_vec), dim=1)))
            word_input = self.output_log_odds(et)
            
            output_words.append(word_input)
            
        controller_params = torch.stack(output_words, dim=2).squeeze(0) # N + 3*R x M
        instruction = controller_params[0:N]
        first_arg = controller_params[N:N+R]
        second_arg = controller_params[N+R:N+2*R]
        output = controller_params[N+2*R:N + 3*R]
        controller = Controller.Controller(first_arg=first_arg, second_arg=second_arg, output=output, 
                                           instruction=instruction, 
                                           initial_registers=Variable(self.initial_registers),
                                           multiplier = 1, correctness_weight=1, halting_weight=1, 
                                           confidence_weight=0, efficiency_weight=0, t_max=self.t_max)
        
        return controller
        
    """
        controller: The controller for an ANC.
        target: The target should be a list of triples, where the first element of any triple is
                the input matrix, the second element is the output matrix corresponding to the expected
                output based on the input and the third element is a mask that specifies the area
                of memory where the output is.
    """
    def compute_loss(self, controller, target):
        loss = 0
        input_memories = target[0]
        output_memories = target[1]
        output_masks = target[2]
        
        for i in range(len(input_memories)):
            loss += controller.forward_train(input_memories[i], (output_memories[i], output_masks[i]))
            
        return loss

In [None]:
class MultilayerLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_layers, bias=True):
        super(MultilayerLSTMCell, self).__init__()
        self.lstm_layers = nn.ModuleList()
        
        if isinstance(hidden_sizes, int):
            temp = []
            
            for _ in range(num_layers):
                temp.append(hidden_sizes)
            
            hidden_sizes = temp
            
        hidden_sizes = [input_size] + hidden_sizes
        
        for i in range(num_layers):
            curr_lstm = nn.LSTMCell(hidden_sizes[i], hidden_sizes[i+1], bias=bias)
            self.lstm_layers.append(curr_lstm)
    
    def initialize_forget_bias(self, bias_value):
        for lstm_cell in self.lstm_layers:
            n = lstm_cell.bias_ih.size(0)
            start, end = n//4, n//2
            b1 = lstm_cell.bias_ih
            nn.init.constant(lstm_cell.bias_ih[start:end], bias_value)
            nn.init.constant(lstm_cell.bias_hh[start:end], bias_value)
    
    def forward(self, input, past_states):
        hiddens, cell_states = past_states
        result_hiddens, result_cell_states = [], []
        curr_input = input
        
        for lstm_cell, curr_hidden, curr_cell_state in zip(self.lstm_layers, hiddens, cell_states):
            curr_input, new_cell_state = lstm_cell(curr_input, (curr_hidden, curr_cell_state))
            result_hiddens.append(curr_input)
            result_cell_states.append(new_cell_state)
        
        return torch.stack(result_hiddens), torch.stack(result_cell_states)

In [None]:
num_vars = 10
num_ints = 11

for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}

for_ops = {"<" + k.upper() + ">": v for k,v in for_ops.items()}

lambda_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "If": 7,
    "Let": 8,
    "Unit": 9,
    "Letrec": 10,
    "App": 11
}

lambda_ops = {"<" + k.upper() + ">": v for k,v in lambda_ops.items()}

lambda_calculus_ops = {
                "<VARIABLE>": 0,
                "<ABSTRACTION>": 1,
                "<NUMBER>": 2,
                "<BOOLEAN>": 3,
                "<NIL>": 4,
                "<IF>": 5,
                "<CONS>": 6,
                "<MATCH>": 7,
                "<UNARYOPER>": 8,
                "<BINARYOPER>": 9,
                "<LET>": 10,
                "<LETREC>": 11,
                "<TRUE>": 12,
                "<FALSE>": 13,
                "<TINT>": 14,
                "<TBOOL>": 15,
                "<TINTLIST>": 16,
                "<TFUN>": 17,
                "<ARGUMENT>": 18,
                "<NEG>": 19,
                "<NOT>": 20,
                "<PLUS>": 21,
                "<MINUS>": 22,
                "<TIMES>": 23,
                "<DIVIDE>": 24,
                "<AND>": 25,
                "<OR>": 26,
                "<EQUAL>": 27,
                "<LESS>": 28,
                "<APPLICATION>": 29,
                "<HEAD>": 30,
                "<TAIL>": 31
            }

In [None]:
input_eos_token = False
input_as_seq = False
use_embedding = True
eos_bonus = 1 if input_eos_token and input_as_seq else 0
long_base_case = True
binarize = True

In [None]:
is_lambda_calculus = False

for_anc_dset = TreeANCDataset("ANC/Easy-arbitraryForListWithOutput.json", is_lambda_calculus, binarize=binarize, input_eos_token=input_eos_token, 
                              use_embedding=use_embedding, long_base_case=long_base_case, 
                              input_as_seq=input_as_seq, cuda=True)

In [None]:
def reset_all_parameters_uniform(model, stdev):
    for param in model.parameters():
        nn.init.uniform(param, -stdev, stdev)

In [None]:
embedding_size = 30
hidden_size = 30
num_layers = 1
alignment_size = 50
align_type = 1
M, R = 10, 3
N = 11
encoder_input_size = num_vars + num_ints + len(for_ops.keys()) + eos_bonus

if input_as_seq:
    encoder = SeqEncoder(encoder_input_size, hidden_size, num_layers, attention=True, use_embedding=use_embedding)
else:
    encoder = TreeEncoder(encoder_input_size, hidden_size, num_layers, [1, 2], attention=True, use_embedding=use_embedding)

decoder = MultilayerLSTMCell(N + 3*R + hidden_size, hidden_size, num_layers)
program_model = Tree_to_Sequence_Attention_ANC_Model(encoder, decoder, hidden_size, embedding_size, M, R, 
                                                     alignment_size=alignment_size, align_type=align_type)
    
reset_all_parameters_uniform(program_model, 0.1)
encoder.initialize_forget_bias(3)
decoder.initialize_forget_bias(3)

In [None]:
program_model = program_model.cuda()

In [None]:
optimizer = torch.optim.Adam(program_model.parameters(), lr=0.5)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=100, factor=0.8)

In [None]:
for prog, target in for_anc_dset:
    controller = program_model(prog)
    controller.cuda()
    loss = program_model.compute_loss(controller, target)
    loss.backward()
    
    for name, param in program_model.named_parameters():
        print(name)
        print(param.grad)
    optimizer.zero_grad()
    break

In [None]:
import importlib
importlib.reload(training)

In [None]:
_ = \
    training.train_model_tree_to_anc(program_model, for_anc_dset, optimizer, 
                                     lr_scheduler=lr_scheduler, 
                                     num_epochs=10, batch_size=1,
                                     cuda=True,
                                     plateau_lr=True,
                                     print_every=100)

In [None]:
program_model.encoder.embedding.weight.grad