In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data

import numpy as np
import random
import matplotlib.pyplot as plt

import json
from translating_trees import *
from for_prog_dataset import ForDataset
from functools import partial

In [2]:
cd ..

/Users/ericweiner/Documents/neural_nets_research


In [3]:
from neural_nets_library import training

In [4]:
for_lambda_dset = ForDataset('ANC/Easy-arbitraryForList.json')

In [5]:
class TreeCell(nn.Module):
    """
    LSTM Cell which takes in arbitrary numbers of hidden and cell states (one per child).
    """
    def __init__(self, input_size, hidden_size, num_children):
        """
        Initialize the LSTM cell.
        
        :param input_size: length of input vector
        :param hidden_size: length of hidden vector (and cell state)
        :param num_children: number of children = number of hidden/cell states passed in
        """
        super(TreeCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Gates = input, output, memory + one forget gate per child
        numGates = 3 + num_children
        
        self.gates_value = torch.nn.ModuleList()
        self.gates_children = torch.nn.ModuleList()
        for _ in range(numGates):
            # One linear layer to handle the value of the node
            value_linear = nn.Linear(input_size, hidden_size, bias = True)
            children_linear = torch.nn.ModuleList()
            # One per child of the node
            for _ in range(num_children):
                children_linear.append(nn.Linear(hidden_size, hidden_size, bias = False))
            self.gates_value.append(value_linear)
            self.gates_children.append(children_linear)
            
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.reset_parameters()
        
    def reset_parameters(self):
        stdev = 0.1
        
        for gate_value in self.gates_value:
            nn.init.uniform(gate_value.weight, -stdev, stdev)
            nn.init.uniform(gate_value.bias, -stdev, stdev)
        
        for gate_child in self.gates_children:
            for gate_in_child in gate_child:
                nn.init.uniform(gate_in_child.weight, -stdev, stdev)            
    
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of num_children hidden states.
        :param cell_states: A list of num_children cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        

        data_sums = []

        for i in range(len(self.gates_value)):
            data_sum = self.gates_value[i](input)
            for j in range(len(hidden_states)):
                data_sum += self.gates_children[i][j](hidden_states[j])
            data_sums.append(data_sum)
        
        # First gate is the input gate
        i = self.sigmoid(data_sums[0])
        # Next output gate
        o = self.sigmoid(data_sums[1])
        # Next memory gate
        m = self.tanh(data_sums[2])
        # All the rest are forget gates
        forget_data = 0
        for i in range(len(cell_states)):
            forget_data += data_sums[3 + i] * cell_states[i]

        # Put it all together!
        new_state = i * m + forget_data
        new_hidden = o * self.tanh(new_state)
        
                
        return new_hidden, new_state

In [6]:
class TreeLSTM(nn.Module):
    '''
    TreeLSTM

    Takes in a tree where each node has a value and a list of children.
    Produces a tree of the same size where the value of each node is now encoded.

    '''

    def __init__(self, input_size, hidden_size, valid_num_children):
        """
        Initialize tree cells we'll need later.
        """
        super(TreeLSTM, self).__init__()
        
        self.valid_num_children = [0] + valid_num_children
        self.lstm_list = torch.nn.ModuleList()
        
        for size in self.valid_num_children:
            self.lstm_list.append(TreeCell(input_size, hidden_size, size))
        
    def forward(self, node):
        """
        Creates a tree where each node's value is the encoded version of the original value.
        
        :param tree: a tree where each node has a value vector and a list of children
        :return a tuple - (root of encoded tree, cell state)
        """
        
        # List of tuples: (node, cell state)
        children = []
        
        # Recursively encode children
        for child in node.children:
            encoded_child = self.forward(child)
            children.append(encoded_child)

        # Extract the TreeCell inputs
        inputH = [vec[0].value for vec in children]
        inputC = [vec[1] for vec in children]

        value = node.value

        found = False
        
        # Feed the inputs into the TreeCell with the appropriate number of children.        
        for i in range(len(self.valid_num_children)):
            if self.valid_num_children[i] == len(children):
                newH, newC = self.lstm_list[i](value, inputH, inputC)
                found = True
                break
                
        if not found:
            print("WHAAAAAT?")
            raise ValueError("Beware.  Something has gone horribly wrong.  You may not have long to live.")
        
        # Set our encoded vector as the root of the new tree
        rootNode = Node(newH)
        rootNode.children = [vec[0] for vec in children]
        return (rootNode, newC)

In [7]:
class Encoder(nn.Module):
    """
    Takes in a tree where each node has a value vector and a list of children
    Produces a sequence encoding of the tree
    """

    def __init__(self, input_size, hidden_size, num_layers, valid_num_children, attention=True):
        """
        Initialize the TreeLSTMs we'll need later (one per layer)
        """
        super(Encoder, self).__init__()
        
        self.lstm_list = torch.nn.ModuleList()
        
        # All TreeLSTMs have input of hidden_size except the first.
        self.lstm_list.append(TreeLSTM(input_size, hidden_size, valid_num_children))
        for i in range(num_layers-1):
            self.lstm_list.append(TreeLSTM(hidden_size, hidden_size, valid_num_children))
        
        self.attention = attention

    def forward(self, tree):
        """
        Encodes nodes of a tree in the rows of a matrix.
        
        :param tree: a tree where each node has a value vector and a list of children
        :return a matrix where each row represents the encoded output of a single node and also
                the hidden/cell states of the root node.
        
        """
        hiddens = []
        cell_states = []
        
        for lstm in self.lstm_list:
            tree, cell_state = lstm(tree)
            hiddens.append(tree.value)
            cell_states.append(cell_state)
        
        
        hiddens = torch.stack(hiddens)
        cell_states = torch.stack(cell_states)
        
        if self.attention:
            return torch.stack(tree_to_list(tree)), hiddens, cell_states
        else:
            return hiddens, cell_states

In [8]:
'''
Decoder
'''
class Tree_to_Sequence_Model(nn.Module):
    """
      For the decoder this expects something like an lstm cell or a gru cell and not an lstm/gru.
      Batch size is not supported at all. More precisely the encoder expects an input that does not
      appear in batches, while the decoder must work with batches, but will only be used with a batch size
      of 1.
    """
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size,
                 use_lstm=False):
        super(Tree_to_Sequence_Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
        # nclass + 2 to include end of sequence and trash
        self.output_log_odds = nn.Linear(hidden_size, nclass+2)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=0)
        self.log_softmax = nn.LogSoftmax(dim=0)

        self.register_buffer('SOS_token', torch.LongTensor([[nclass+2]]))
        self.EOS_value = nclass + 1
        self.i = 0

        # nclass + 3 to include start of sequence, end of sequence, and trash.
        # n + 2 - start of sequence, end of sequence - n + 1, trash - n.
        # The first n correspond to the alphabet in order.
        self.embedding = nn.Embedding(nclass+3, embedding_size)

        self.use_lstm = use_lstm
        # nclass is the trash category to avoid penalties after target's EOS token
        self.loss_func = nn.CrossEntropyLoss(ignore_index=nclass)
        self.reset_parameters()

    def reset_parameters(self):
        stdev = 0.1
        
        nn.init.uniform(self.output_log_odds.weight, -stdev, stdev)
        nn.init.uniform(self.output_log_odds.bias, -stdev, stdev)
        nn.init.uniform(self.embedding.weight, -stdev, stdev)
    
    """
        input: The output of the encoder for the input should be a pair. The first part
               should correspond to the hidden state of the root. The second part
               should correspond to the cell state of the root. They both should be
               [num_layers, hidden_size].
        target: The target should have dimension, seq_len, and should be a LongTensor.
    """
    def forward_train(self, input, target, teacher_forcing=True):
        # root hidden state/cell state
        decoder_hiddens, decoder_cell_states = self.encoder(input) # num_layers x hidden_size
        decoder_hiddens = decoder_hiddens.unsqueeze(1)
        decoder_cell_states = decoder_cell_states.unsqueeze(1)
                                                            
        num_layers, _, _ = decoder_hiddens.size()

        target_length, = target.size()
        SOS_token = Variable(self.SOS_token)
        decoder_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        loss = 0

        for i in range(target_length):
            if self.use_lstm:               
                decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states)) # num_layers x 1 x hidden_size
            else:
                decoder_hiddens = self.decoder(decoder_input, decoder_hiddens)

            decoder_hidden = decoder_hiddens[-1] # 1 x hidden_size
            log_odds = self.output_log_odds(decoder_hidden)

            loss += self.loss_func(log_odds, target[i])

            if teacher_forcing:
                next_input = target[i].unsqueeze(1)
            else:
                _, next_input = log_odds.topk(1)

            decoder_input = self.embedding(next_input).squeeze(1) # 1 x embedding_size
                
        return loss

    """
        This is just an alias for point_wise_prediction, so that training code that assumes the presence
        of a forward_train and forward_prediction works.
    """
    def forward_prediction(self, input, maximum_length=20):
        return self.point_wise_prediction(input, maximum_length)
    
    def point_wise_prediction(self, input, maximum_length=20):
        decoder_hiddens, decoder_cell_states = self.encoder(input)
        decoder_hiddens = decoder_hiddens.unsqueeze(1)
        decoder_cell_states = decoder_cell_states.unsqueeze(1)
        
        num_layers, _, _ = decoder_hiddens.size()
        SOS_token = Variable(self.SOS_token)

        decoder_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        output_so_far = []

        for _ in range(maximum_length):
            if self.use_lstm:
                decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
            else:
                decoder_hiddens = self.decoder(decoder_input, decoder_hiddens)

            decoder_hidden = decoder_hiddens[-1]
            log_odds = self.output_log_odds(decoder_hidden)

            _, next_input = log_odds.topk(1)
            output_so_far.append(int(next_input))
            
            if int(next_input) == self.EOS_value:
                break
                
            decoder_input = self.embedding(next_input).squeeze(1) # 1 x embedding size

        return output_so_far

    def beam_search_prediction(self, input, maximum_length=7, beam_width=5):
        decoder_hiddens, decoder_cell_states = self.encoder(input)
        decoder_hiddens = decoder_hiddens.unsqueeze(1)
        decoder_cell_states = decoder_cell_states.unsqueeze(1)
        
        num_layers, _, _ = decoder_hiddens.size()

        SOS_token = Variable(self.SOS_token)
        decoder_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        word_inputs = []

        for _ in range(beam_width):
            if self.use_lstm:
                word_inputs.append((0, [], True, [decoder_input, decoder_hiddens, decoder_cell_states]))
            else:
                word_inputs.append((0, [], True, [decoder_input, decoder_hiddens]))

        for _ in range(maximum_length):
            new_word_inputs = []

            for i in range(beam_width):
                if not word_inputs[i][2]:
                    new_word_inputs.append(word_inputs[i])
                    continue

                if self.use_lstm:
                    decoder_input, decoder_hiddens, decoder_cell_states = word_inputs[i][3]
                    decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
                else:
                    decoder_input, decoder_hiddens = word_inputs[i][3]
                    decoder_hiddens = self.decoder(decoder_input, decoder_hiddens)

                decoder_hidden = decoder_hiddens[-1]
                log_odds = self.output_log_odds(decoder_hidden).squeeze(0) # nclasses
                log_probs = self.log_softmax(log_odds)

                log_value, next_input = log_probs.topk(beam_width) # beam_width, beam_width
                decoder_input = self.embedding(next_input.unsqueeze(1)) # beam_width x 1 x embedding size

                if self.use_lstm:
                    new_word_inputs.extend((word_inputs[i][0] + float(log_value[k]), word_inputs[i][1] + [int(next_input[k])],
                                           int(next_input[k]) != self.EOS_value, [decoder_input[k], decoder_hiddens, decoder_cell_states])
                                           for k in range(beam_width))
                else:
                    new_word_inputs.extend((word_inputs[i][0] + float(log_value[k]), word_inputs[i][1] + [int(next_input[k])],
                                           int(next_input[k]) != self.EOS_value, [decoder_input[k], decoder_hiddens])
                                           for k in range(beam_width))
                    
        word_inputs = sorted(new_word_inputs, key=lambda word_input: word_input[0])[-beam_width:]
        return word_inputs[-1][1]

In [9]:
class Tree_to_Sequence_Attention_Model(Tree_to_Sequence_Model):
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size,
                 alignment_size, decoder_cell_state_shape=None, use_lstm=False, use_bengio=False):
        super(Tree_to_Sequence_Attention_Model, self).__init__(encoder, decoder, hidden_size, nclass, embedding_size,
                                                               use_lstm=use_lstm)
        
        if use_bengio:
            self.attention_hidden = nn.Linear(hidden_size, alignment_size)
            self.attention_context = nn.Linear(hidden_size, alignment_size, bias=False)
            self.attention_alignment_vector = nn.Linear(alignment_size, 1)
        else:
            self.attention_hidden = nn.Linear(hidden_size, hidden_size)
        self.use_bengio = use_bengio
        self.reset_attention_parameters()
            
    def reset_attention_parameters(self):
        stdev = 0.1
        
        nn.init.uniform(self.attention_hidden.weight, -stdev, stdev)
        nn.init.uniform(self.attention_hidden.bias, -stdev, stdev)
        
        if self.use_bengio:
            nn.init.uniform(self.attention_context.weight, -stdev, stdev)
            nn.init.uniform(self.attention_alignment_vector.weight, -stdev, stdev)
            nn.init.uniform(self.attention_alignment_vector.bias, -stdev, stdev)
        
    """
        input: The output of the encoder for the tree should have be a triple. The first 
               part of the triple should be the annotations and have dimensions, 
               number_of_nodes x hidden_size. The second triple of the pair should be the hidden 
               representations of the root and should have dimensions, num_layers x hidden_size.
               The third part should correspond to the cell states of the root and should
               have dimensions, num_layers x hidden_size.
        target: The target should have dimensions, seq_len, and should be a LongTensor.
    """
    def forward_train(self, input, target, teacher_forcing=True):
        annotations, decoder_hiddens, decoder_cell_states = self.encoder(input)
        # w/ bengio number_of_nodes x alignment_size or w/o bengio number_of_nodes x hidden_size 
        attention_hidden_values = self.attention_hidden(annotations)
        
        decoder_hiddens = decoder_hiddens.unsqueeze(1) # num_layers x 1 x hidden_size
        decoder_cell_states = decoder_cell_states.unsqueeze(1) # num_layers x 1 x hidden_size

        target_length, = target.size()
        num_layers, _, _ = decoder_hiddens.size()
        SOS_token = Variable(self.SOS_token)

        word_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        loss = 0

        for i in range(target_length):
            if self.use_bengio:
                attention_logits = self.attention_alignment_vector(self.tanh(self.attention_context(decoder_hiddens[0]) + attention_hidden_values))
            else:
                attention_logits = (decoder_hiddens[0] * attention_hidden_values).sum(1).unsqueeze(1)
            attention_probs = self.softmax(attention_logits) # number_of_nodes x 1
            context_vec = (attention_probs * annotations).sum(0).unsqueeze(0) # 1 x hidden_size
            decoder_input = torch.cat((word_input, context_vec), dim=1) # 1 x embedding_size + hidden_size


            if self.use_lstm:
                decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
            else:
                decoder_hiddens = self.decoder(decoder_input, decoder_hiddens)

            decoder_hidden = decoder_hiddens[-1]
            log_odds = self.output_log_odds(decoder_hidden)
            loss += self.loss_func(log_odds, target[i])

            if teacher_forcing:
                next_input = target[i].unsqueeze(1)
            else:
                _, next_input = log_odds.topk(1)

            word_input = self.embedding(next_input).squeeze(1) # 1 x embedding size
        return loss
        

    """
        This is just an alias for point_wise_prediction, so that training code that assumes the presence
        of a forward_train and forward_prediction works.
    """
    def forward_prediction(self, input, maximum_length=20):
        return self.point_wise_prediction(input, maximum_length)
    
    def point_wise_prediction(self, input, maximum_length=20):
        annotations, decoder_hiddens, decoder_cell_states = self.encoder(input)
        attention_hidden_values = self.attention_hidden(annotations) # number_of_nodes x alignment_size
        
        decoder_hiddens = decoder_hiddens.unsqueeze(1) # num_layers x 1 x hidden_size
        decoder_cell_states = decoder_cell_states.unsqueeze(1) # num_layers x 1 x hidden_size
        
        num_layers, _, _ = decoder_hiddens.size()
        SOS_token = Variable(self.SOS_token)
        
        word_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size
        output_so_far = []
        
        for i in range(maximum_length):
            if self.use_bengio:
                attention_logits = self.attention_alignment_vector(self.tanh(self.attention_context(decoder_hiddens[0]) + attention_hidden_values))
            else:
                attention_logits = (decoder_hiddens[0] * attention_hidden_values).sum(1).unsqueeze(1)
            attention_probs = self.softmax(attention_logits) # number_of_nodes x 1
            context_vec = (attention_probs * annotations).sum(0).unsqueeze(0) # 1 x hidden_size
            decoder_input = torch.cat((word_input, context_vec), dim=1) # 1 x embedding_size + hidden_size

            if self.use_lstm:
                decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
            else:
                decoder_hiddens = self.decoder(decoder_input, decoder_hiddens)

            decoder_hidden = decoder_hiddens[-1]
            log_odds = self.output_log_odds(decoder_hidden)
            _, next_input = log_odds.topk(1)

            output_so_far.append(int(next_input))
            
            if int(next_input) == self.EOS_value:
                break
                
            word_input = self.embedding(next_input).squeeze(1) # 1 x embedding size

        return output_so_far

    def beam_search_prediction(self, input, maximum_length=20, beam_width=5):
        annotations, decoder_hiddens, decoder_cell_states = self.encoder(input)
        attention_hidden_values = self.attention_hidden(annotations) # w/ bengio number_of_nodes x alignment_size or w/o bengio number_of_nodes x hidden_size
        
        decoder_hiddens = decoder_hiddens.unsqueeze(1) # num_layers x 1 x hidden_size
        decoder_cell_states = decoder_cell_states.unsqueeze(1) # num_layers x 1 x hidden_size
        
        num_layers, _, _ = decoder_hiddens.size()
        SOS_token = Variable(self.SOS_token)
        
        word_input = self.embedding(SOS_token).squeeze(0) # 1 x embedding_size

        word_inputs = []

        for _ in range(beam_width):
            if self.use_lstm:
                word_inputs.append((0, [], True, [word_input, decoder_hiddens, decoder_cell_states]))
            else:
                word_inputs.append((0, [], True, [word_input, decoder_hiddens]))

        for _ in range(maximum_length):
            new_word_inputs = []

            for i in range(beam_width):
                if not word_inputs[i][2]:
                    new_word_inputs.append(word_inputs[i])
                    continue

                if self.use_lstm:
                    word_input, decoder_hiddens, decoder_cell_states = word_inputs[i][3]
                else:
                    word_input, decoder_hiddens = word_inputs[i][3]

                if self.use_bengio:
                    attention_logits = self.attention_alignment_vector(self.tanh(self.attention_context(decoder_hiddens[0]) + attention_hidden_values))
                else:
                    attention_logits = (decoder_hiddens[0] * attention_hidden_values).sum(1).unsqueeze(1)
                attention_probs = self.softmax(attention_logits) # number_of_nodes x 1
                context_vec = (attention_probs * annotations).sum(0).unsqueeze(0) # 1 x hidden_size
                decoder_input = torch.cat((word_input, context_vec), dim=1) # 1 x embedding_size + hidden_size

                if self.use_lstm:
                    decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, (decoder_hiddens, decoder_cell_states))
                else:
                    decoder_hiddens = self.decoder(decoder_input, decoder_hiddens)

                decoder_hidden = decoder_hiddens[-1]
                log_odds = self.output_log_odds(decoder_hidden).squeeze(0) # nclasses
                log_probs = self.log_softmax(log_odds)

                log_value, next_input = log_probs.topk(beam_width) # beam_width, beam_width
                word_input = self.embedding(next_input.unsqueeze(1)) # beam_width x 1 x embedding size

                if self.use_lstm:
                    new_word_inputs.extend((word_inputs[i][0] + float(log_value[k]), word_inputs[i][1] + [int(next_input[k])],
                                           int(next_input[k]) != self.EOS_value, [word_input[k], decoder_hiddens, decoder_cell_states])
                                           for k in range(beam_width))
                else:
                    new_word_inputs.extend((word_inputs[i][0] + float(log_value[k]), word_inputs[i][1] + [int(next_input[k])],
                                           int(next_input[k]) != self.EOS_value, [word_input[k], decoder_hiddens])
                                           for k in range(beam_width))
            word_inputs = sorted(new_word_inputs, key=lambda word_input: word_input[0])[-beam_width:]

        return word_inputs[-1][1]

In [10]:
class MultilayerLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_layers, bias=True):
        super(MultilayerLSTMCell, self).__init__()
        self.lstm_layers = nn.ModuleList()
        
        if isinstance(hidden_sizes, int):
            temp = []
            
            for _ in range(num_layers):
                temp.append(hidden_sizes)
            
            hidden_sizes = temp
            
        hidden_sizes = [input_size] + hidden_sizes
        
        for i in range(num_layers):
            curr_lstm = nn.LSTMCell(hidden_sizes[i], hidden_sizes[i+1], bias=bias)
            self.lstm_layers.append(curr_lstm)
    
        self.reset_parameters()
    
    def reset_parameters(self):
        stdev = 0.1
        
        for lstm_cell in self.lstm_layers:
            nn.init.uniform(lstm_cell.weight_ih, -stdev, stdev)
            nn.init.uniform(lstm_cell.weight_hh, -stdev, stdev)
            nn.init.uniform(lstm_cell.bias_ih, -stdev, stdev)
            nn.init.uniform(lstm_cell.bias_hh, -stdev, stdev)
    
    def forward(self, input, past_states):
        hiddens, cell_states = past_states
        result_hiddens, result_cell_states = [], []
        curr_input = input
        
        for lstm_cell, curr_hidden, curr_cell_state in zip(self.lstm_layers, hiddens, cell_states):
            curr_input, new_cell_state = lstm_cell(curr_input, (curr_hidden, curr_cell_state))
            result_hiddens.append(curr_input)
            result_cell_states.append(new_cell_state)
        
        return torch.stack(result_hiddens), torch.stack(result_cell_states)

In [11]:
num_vars = 10
num_ints = 11

for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}

for_ops = {"<" + k.upper() + ">": v for k,v in for_ops.items()}

lambda_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "If": 7,
    "Let": 8,
    "Unit": 9,
    "Letrec": 10,
    "App": 11
}

lambda_ops = {"<" + k.upper() + ">": v for k,v in lambda_ops.items()}

In [12]:
embedding_size = 256
hidden_size = 256
nclass = num_vars + num_ints + len(lambda_ops.keys())
num_layers = 1
attention = True
alignment_size = 50

encoder = Encoder(num_vars + num_ints + len(for_ops.keys()), hidden_size, num_layers, [1, 2], attention=attention)

if attention:
    decoder = MultilayerLSTMCell(embedding_size + hidden_size, hidden_size, num_layers)
    program_model = Tree_to_Sequence_Attention_Model(encoder, decoder, hidden_size, nclass, embedding_size, alignment_size, use_lstm=True, use_bengio=False)
else:
    decoder = MultilayerLSTMCell(embedding_size, hidden_size, num_layers)
    program_model = Tree_to_Sequence_Model(encoder, decoder, hidden_size, nclass, embedding_size, use_lstm=True)

In [None]:
program_model = program_model.cuda()

In [14]:
def program_accuracy(prediction, target):
    return 1 if list(target.data)[:-1] == prediction else 0

def token_accuracy(prediction, target):
    pass

optimizer = torch.optim.SGD(program_model.parameters(), lr=0.0005, momentum=0.9)

In [None]:
program_model, train_losses, validation_losses = \
    training.train_model_anc(program_model, for_lambda_dset, optimizer, 
                             lr_scheduler=partial(training.exp_lr_scheduler, init_lr=0.0005, lr_decay_epoch=1), 
                             num_epochs=10, validation_criterion=program_accuracy, batch_size=100, 
                             use_cuda=True)

Epoch 0/9
----------
LR is set to 0.0005
Epoch Number: 0, Batch Number: 200, Validation Metric: 0.0000
Epoch Number: 0, Batch Number: 200, Training Loss: 86.3152
Time so far is 0m 25s
