In [248]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import random
import matplotlib.pyplot as plt
import json

In [249]:
class TreeCell(nn.Module):
    """
    Insert Useful Comments
    """
    def __init__(self, input_size, hidden_size, num_children):
        super(TreeCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Gates = input, output, memory + one forget gate per child
        numGates = 3 + num_children
        
        self.gates_value = torch.nn.ModuleList()
        self.gates_children = torch.nn.ModuleList()
        for _ in range(numGates):
            # One linear layer to handle the value of the node
            value_linear = nn.Linear(input_size, hidden_size, bias = True)
            children_linear = torch.nn.ModuleList()
            # One per child of the node
            for _ in range(num_children):
                children_linear.append(nn.Linear(hidden_size, hidden_size, bias = False))
            self.gates_value.append(value_linear)
            self.gates_children.append(children_linear)
            
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 2 hidden states.
        :param cell_states: A list of 2 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        data_sums = []
        for i in range(len(self.gates_value)):
            data_sum = self.gates_value[i](input)
            for j in range(len(hidden_states)):
                data_sum += self.gates_children[i][j](hidden_states[j])
            data_sums.append(data_sum)
            
        # First gate is the input gate
        i = self.sigmoid(data_sums[0])
        # Next output gate
        o = self.sigmoid(data_sums[1])
        # Next memory gate
        m = self.tanh(data_sums[2])
        # All the rest are forget gates
        forget_data = 0
        for i in range(len(cell_states)):
            forget_data += data_sums[3 + i] * cell_states[i]
        
        # Put it all together!
        new_state = i * m + forget_data
        new_hidden = o * self.tanh(new_state)
                
        return new_hidden, new_state

In [250]:
class MultilayerTreeCell(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_children, num_layers, bias=True):
        super(MultilayerTreeCell, self).__init__()
        self.lstm_layers = nn.ModuleList()
        
        if isinstance(hidden_sizes, int):
            temp = [input_size]
            
            for _ in range(num_layers):
                temp.append(hidden_sizes)
            
            hidden_sizes = temp
            
        else:
            hidden_sizes = [input_size] + hidden_sizes
        self.tree_cell = TreeCell(input_size, hidden_sizes[0], num_children)
        for i in range(1, num_layers):
            curr_lstm = nn.LSTMCell(hidden_sizes[i], hidden_sizes[i+1], bias=bias)
            self.lstm_layers.append(curr_lstm)
    
    def forward(self, input, hiddens, cell_states):
        result_hiddens, result_cell_states = [], []
        curr_input, new_cell_state = self.tree_cell(input, hiddens[0], cell_states[0])
        result_hiddens.append(curr_input)
        result_cell_states.append(new_cell_state)
        for lstm_cell, curr_hidden, curr_cell_state in zip(self.lstm_layers, hiddens[1:], cell_states[1:]):
            curr_input, new_cell_state = lstm_cell(curr_input, (curr_hidden, curr_cell_state))
            result_hiddens.append(curr_input)
            result_cell_states.append(new_cell_state)
        
        return result_hiddens, result_cell_states

In [251]:
'''
Encoder

Takes in a Tree where each node has a value (vector?) and a list of children
Produces a vector of desired size with an encoding of the tree

Recursively: For each node go left *then go middle* then go right, pass the necessary values
into an lstm cell along with values from each child (0 if at leaf) Output the result of the lstm cell
at the root.

'''
class Encoder(nn.Module):
    """
    Takes in a tree where each node has a value vector and a list of children
    Produces a sequence encoding of the tree
    """

    def __init__(self, input_size, hidden_size, numLayers, valid_num_children):
        """
        Initialize variables we'll need later.
        """
        super(Encoder, self).__init__()
        
        test = [0] + valid_num_children
        temp = torch.IntTensor(test)
        self.valid_num_children = Variable(temp)
        self.numLayers = numLayers
        self.lstm_list = torch.nn.ModuleList()
        # We'll always need 0 for leaf nodes
        #self.lstm_list.append(MultilayerTreeCell(input_size, hidden_size, 0, numLayers))
        self.lstm_list.append(TreeCell(input_size, hidden_size, 0))
        
        for size in valid_num_children:
   #         self.lstm_list.append(MultilayerTreeCell(input_size, hidden_size, size, numLayers))
            self.lstm_list.append(TreeCell(input_size, hidden_size, size))

        self.encoding = Variable(torch.FloatTensor(1, hidden_size))
        
    def forward(self, tree):
        """
        Starts off the entire encoding process
        
        :param tree: a tree where each node has a value vector and a list of children
        :return self.encoding, a matrix where each row represents the encoded output of a single node
        """
        embeddings = [p[0] for p in self.encode(tree)]
        return torch.cat(embeddings, 0)
        
    def encode(self, node):
        """
        Recursively a node and all its children as sequence vectors
        
        :param node: The root of the tree (or subtree)
        :return A tuple (new hidden vector, new cell state).  The new hidden vector is an endoding of node
        """
        
        # List of tuples: (h, c), each of which are size hidden_size
        descendents = []
        children = []
        
        for child in node.children:
            current_descendents = self.encode(child)
            descendents += current_descendents
            children.append(current_descendents[-1])

        # Vector of size input_size x len(children)
        inputH = [vec[0] for vec in children]
        inputC = [vec[1] for vec in children]
        
        
        value = Variable(node.value.unsqueeze(0))
        
        found = False
        for i in range(len(self.lstm_list)):
            if self.valid_num_children[i].data[0] == len(children):
                newH, newC = self.lstm_list[i](value, inputH, inputH)
                found = True
                break
        if not found:
            print("WHAAAAAT?")
            raise ValueError("Beware.  Something has gone horribly wrong.  You may not have long to live.")
            
        # Add the new encoding to the end of our list
        descendents.append((newH, newC))
        return descendents
    

        

In [252]:
input_size = 4
hidden_size = 5

test_vec = torch.FloatTensor(input_size)


class Node:
    """
    Node class just made for testing
    """
    def __init__(self, value):
        self. value = value
        self.children = []
    

child_len = [2, 3, 2, 3]   
def makeNodes(children):
    """
    Loop through the passes-in array and build a tree where each node in the i^th layer has children[i] nodes

    """
    if len(children) == 0:
        return Node(test_vec) # Make them all the same vec
    else: 
        newNode = Node(test_vec)
        for i in range(children[0]):
            newNode.children.append(makeNodes(children[1:]))
        return newNode 


In [253]:
jsonString = "{\"tag\":\"If\",\"contents\":[{\"tag\":\"GeFor\",\"contents\":[{\"tag\":\"Const\",\"contents\":5},{\"tag\":\"Const\",\"contents\":3}]},{\"tag\":\"Assign\",\"contents\":[\"X\",{\"tag\":\"Const\",\"contents\":1}]},{\"tag\":\"Assign\",\"contents\":[\"Y\",{\"tag\":\"Const\",\"contents\":2}]}]}"
jsonObj = json.loads(jsonString)
for_list_json = json.load(open('/Users/ericweiner/Documents/neural_nets_research/ANC/arbitraryForList.json'))
num_vars = 5
num_ints = 7
for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}
var_dict = {}

def vectorize(val):
    vector = torch.zeros(num_vars + num_ints + len(for_ops.keys()))
    if type(val) is int:
        vector[val] = 1
    elif type(val) is str:
        index = len(var_dict.keys())
        if val in var_dict:
            index = var_dict[val]
        else:
            var_dict[val] = index
        vector[index + num_ints] = 1
    else:
        index = for_ops[val]
        vector[num_ints + num_vars + index] = 1
    return vector
            
        

def makeTree(json):
    if type(json) is str:
        parentNode = Node(vectorize("Var"))
        childNode = Node(vectorize(json))
        parentNode.children.append(childNode)
        return parentNode 
    
    if type(json) is int:
        return Node(vectorize(json))

    tag = json["tag"]
    children = json["contents"]
    parentNode = Node(vectorize(tag))
    
    currNode = parentNode
    
    if type(children) is list:
        for child in children:
            newChild = makeTree(child)
            currNode.children.append(newChild)
            currNode = newChild
    else:
        parentNode.children.append(makeTree(children))
        
    return parentNode



def printTree(tree):
    print(tree.value)
    for child in tree.children:
        printTree(child)
    
    
tree = makeTree(jsonObj)
numLayers = 3
encoder = Encoder(num_vars + num_ints + len(for_ops.keys()), hidden_size, numLayers,[1,2,3])
encoded_vec = encoder(tree)
print("ENCODEDVEC", encoded_vec)


# kangaroo

ENCODEDVEC Variable containing:
 2.2863e-02  6.5662e-03  1.1763e-02  2.8647e-02 -5.4916e-02
 4.6569e-02  7.1402e-03 -1.0319e-02  3.8583e-02 -4.6100e-02
 3.2813e-03 -1.0021e-03  2.1564e-05 -2.7582e-03  7.4565e-03
-1.8096e-02  2.1923e-02 -8.4839e-02  1.5602e-01  7.5357e-02
 5.2191e-02  2.9534e-02  4.0366e-03  3.6444e-02  1.3136e-02
-1.4336e-02 -1.9640e-02  5.6897e-02  6.3062e-02 -3.2052e-03
-9.9474e-04  2.1714e-03 -6.6700e-04 -3.6863e-03  5.0313e-04
 7.9014e-02  5.2478e-02 -1.0640e-01  2.1987e-02 -5.6953e-02
-5.9615e-03  3.0944e-03  3.8092e-02  5.4920e-02 -4.1653e-02
-1.2267e-02  4.1403e-02  1.2829e-02  6.8995e-02 -5.5225e-02
-7.4902e-04 -4.8250e-03 -4.1882e-06 -4.7838e-03  9.7149e-03
 8.6297e-02  4.4594e-02 -9.6374e-02  1.1192e-02 -4.9746e-02
-1.6625e-03 -3.9217e-04  5.3452e-03 -3.5707e-04 -1.6696e-03
 5.3520e-02  5.7091e-02  9.6964e-02  1.7505e-01  6.0116e-02
 1.2860e-01  1.1356e-01  5.5452e-02  1.0455e-01 -1.8374e-03
 1.3199e-02 -1.5690e-02  4.5331e-04 -6.2073e-03  1.1502e-05
[torch.F

In [254]:
'''
Decoder
'''
class Tree_to_Sequence_Model(nn.Module):
    """
      For the decoder this expects something like an lstm cell or a gru cell and not an lstm/gru.
      If you don't use teacher forcing, batches are forbidden/lead to bad behavior
    """
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size,
                 use_lstm=False, use_cuda=True):
        super(Tree_to_Sequence_Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
        # nclass + 2 to include end of sequence and trash
        self.output_log_probs = nn.Linear(hidden_size, nclass+2)
        self.softmax = nn.Softmax(dim=0)

        self.SOS_token = Variable(torch.LongTensor([[0]]))
        self.EOS_value = 1

        if use_cuda:
            self.SOS_token = self.SOS_token.cuda()

        self.embedding = nn.Embedding(nclass, embedding_size)

        self.use_lstm = use_lstm
        # nclass + 1 is the trash category to avoid penalties after target's EOS token
        self.loss_func = nn.CrossEntropyLoss(ignore_index=nclass+1)

        if use_lstm:
            self.decoder_initial_cell_state = Variable(torch.zeros(1, hidden_size))

    def forward_train(self, input, target, use_teacher_forcing=False):
        # encoded features
        encoded_features = Variable(self.encoder(input)) # [w, c]
        decoder_hidden = encoded_features[-1, :]

        target_length = target.size()[0]
        decoder_input = self.embedding(self.SOS_token).transpose(0,1).repeat(1, batch_size, 1)
        loss = 0

        if self.use_lstm:
            decoder_cell_state = self.decoder_initial_cell_state

        for i in range(target_length):
            if self.use_lstm:
                decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
            else:
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            log_probs = self.output_log_probs(decoder_output)
            loss += self.loss_func(log_probs, target[i, :])

            if use_teacher_forcing:
                decoder_input = self.embedding(target[i, :].unsqueeze(1)).squeeze(1)
            else:
                _, topi = log_probs.data.topk(1)
                ni = topi[0, 0]

                if ni == self.EOS_value:
                    break

                decoder_input = self.embedding(Variable([ni]).unsqueeze(1)).squeeze(1)

        return loss

#     """
#       Inputs must be of batch size 1
#     """
#     def point_wise_prediction(self, input, maximum_length=20):
#       # encoded features
#       encoded_features = self.encoder(input).squeeze(1) # [w, c]
#       decoder_hidden = encoded_features[-1, :]
#       decoder_input = self.embedding(self.SOS_token).squeeze(0)
#       output_so_far = []

#       if self.use_lstm:
#           decoder_cell_state = self.decoder_initial_cell_state

#       for i in range(maximum_length):
#           if self.use_lstm:
#               decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
#           else:
#               decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

#           log_probs = self.output_log_probs(decoder_output)

#           _, topi = log_probs.data.topk(1)
#           ni = topi[0, 0]

#           if ni == self.EOS_value:
#               break

#           output_so_far.append(ni)
#           decoder_input = self.embedding(Variable([ni]).unsqueeze(1)).squeeze(1)

#           if self.use_cuda:
#               decoder_input = decoder_input.cuda()

#       return output_so_far


#     def beam_search_prediction(self, input, maximum_length=20):
#       pass



In [371]:
class Tree_to_Sequence_Attention_Model(Tree_to_Sequence_Model):
    def __init__(self, encoder, decoder, num_decoder_layers, hidden_size, nclass, embedding_size,
                 alignment_size, decoder_cell_state_shape=None, use_lstm=False, use_cuda=True):
        super(Tree_to_Sequence_Attention_Model, self).__init__(encoder, decoder, hidden_size, nclass, embedding_size,
                                                               use_lstm=use_lstm, use_cuda=use_cuda)
        self.attention_hidden = nn.Linear(hidden_size, alignment_size)
        self.attention_context = nn.Linear(hidden_size, alignment_size, bias=False)
        self.tanh = nn.Tanh()
        self.attention_alignment_vector = nn.Linear(alignment_size, 1)
        self.num_decoder_layers = num_decoder_layers

    """
        input: The output of the encoder for the input should have dimensions, (seq_len x batch_size x input_size)
        target: The target should have dimensions, seq_len, and should be a LongTensor.
    """
    def forward_train(self, input, target, use_teacher_forcing=False):
        encoded_features = self.encoder(input) # sequence x encoder hidden
        attention_hidden_values = self.attention_hidden(encoded_features)

        decoder_hiddens = [encoded_features[-1, :].unsqueeze(0) for _ in range(num_decoder_layers)]
        decoder_hidden = decoder_hiddens[-1]
        target_length = target.size()[0] # sequence x decoder hidden
        word_input = self.embedding(self.SOS_token).squeeze(0) # batch x decoder hidden

        loss = 0
        
        if self.use_lstm:
            decoder_cell_states = [self.decoder_initial_cell_state for x in range(num_decoder_layers)]

        for i in range(target_length):
            attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
            attention_probs = self.softmax(attention_logits) # w
            context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0).unsqueeze(0) # decoder hidden 
            decoder_input = torch.cat((word_input, context_vec), dim=1)

            if self.use_lstm:     
                decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, decoder_hiddens, decoder_cell_states)
            else:
                decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            decoder_hidden = decoder_hiddens[-1]
            log_probs = self.output_log_probs(decoder_hidden)
            loss += self.loss_func(log_probs, target[i])

            if use_teacher_forcing:
                next_input = target[i].unsqueeze(0)
            else:
                _, next_input = log_probs.topk(1)
                if next_input.data[0,0] == self.EOS_value:
                    break

            word_input = self.embedding(next_input).squeeze(0) # batch x decoder hidden

        return loss


    """
      Inputs must be of batch size 1
    """
    def point_wise_prediction(self, input, maximum_length=20):
        
        encoded_features = self.encoder(input) # sequence x encoder hidden
        attention_hidden_values = self.attention_hidden(encoded_features)
        decoder_hiddens = [encoded_features[-1, :].unsqueeze(0) for _ in range(num_decoder_layers)]
        decoder_hidden = decoder_hiddens[-1]
        word_input = self.embedding(self.SOS_token).squeeze(0) # batch x decoder hidden

        output_so_far = []

        if self.use_lstm:
            decoder_cell_states = [self.decoder_initial_cell_state for x in range(num_decoder_layers)]

            
        for i in range(maximum_length):
            attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
            attention_probs = self.softmax(attention_logits) # w
            context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0).unsqueeze(0) # decoder hidden 
            decoder_input = torch.cat((word_input, context_vec), dim=1)

            if self.use_lstm:     
                decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, decoder_hiddens, decoder_cell_states)
            else:
                decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            decoder_hidden = decoder_hiddens[-1]
            log_probs = self.output_log_probs(decoder_hidden)
            

            _, next_input = log_probs.topk(1)
            output_so_far.append(next_input[0][0])
            
            if next_input.data[0, 0] == self.EOS_value:
                break

            word_input = self.embedding(next_input).squeeze(0) # batch x hidden

        return output_so_far

    def beam_search_prediction(self, input, maximum_length=20, beam_width=5):
        # encoded features
        encoded_features = self.encoder(input) # sequence x hidden
        attention_hidden_values = self.attention_hidden(encoded_features)

        decoder_hiddens = [encoded_features[-1, :].unsqueeze(0) for _ in range(num_decoder_layers)]
        decoder_hidden = decoder_hiddens[-1]
        word_inputs = []
        for _ in range(beam_width):
            word_inputs.append((0, [self.SOS_token], True))

        if self.use_lstm:
            decoder_cell_states = [self.decoder_initial_cell_state for x in range(num_decoder_layers)]

        for i in range(maximum_length):
            attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
            attention_probs = self.softmax(attention_logits) # w
            context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0).unsqueeze(0) # hidden 
            
            attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
            attention_probs = self.softmax(attention_logits) # w
            context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0).unsqueeze(0) # decoder hidden 


            
            
            newWordInputs = []
            
            for j in range(beam_width):
                if not word_inputs[j][2]:
                    newWordInputs.append(word_inputs[j])
                    continue
                    
                word_input = self.embedding(word_inputs[j][1][-1]).squeeze(0)
                decoder_input = torch.cat((word_input, context_vec), dim=1)

                if self.use_lstm:
                    decoder_hiddens, decoder_cell_states = self.decoder(decoder_input, decoder_hiddens, decoder_cell_states)
 #               else:
 #                   decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                              
                
                
                log_probs = self.output_log_probs(decoder_hiddens[-1]).squeeze(0)
                value, index = self.softmax(log_probs).topk(beam_width)
                log_value = value.log()
                print("index i", index[i].unsqueeze(0))
                print("wordinputs beamsize*1 -1", word_inputs[-1][1])
                print("combined", word_inputs[-1][1] + [index[j].unsqueeze(0)])
                newWordInputs.extend((word_inputs[k][0] + log_value[k], word_inputs[k][1] + [index[k].unsqueeze(0)], 
                                      (index[k].data == self.EOS_value)[0]) for k in range(beam_width))
            #Get the new words in the beam.
            word_inputs = sorted(newWordInputs, key=lambda word_pair: word_pair[1][-1].data[0][0])[-beam_width:]
        outputs = [word_inputs[-i][1] for i in range(beam_width)]
        return outputs
    
    def train_with_validation(self, train_loader, validation_loader,
                                optimizer, lr_scheduler, num_fake_batches=10, num_epochs=20, use_teacher_forcing=False):
        """
        Trains a model while printing updates on loss and accuracy. Once training is complete,
        it is tested on the validation data set.
        """
        since = time.time()

        best_model = model
        best_acc = 0.0

        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)

            optimizer = lr_scheduler(optimizer, epoch)

            running_loss = 0.0

            current_batch = 0
            loss = Variable(torch.FloatTensor([0]))
            # Iterate over data.
            for in_tree, out_seq in train_loader:
                current_batch += 1

                # wrap them in Variable
                input_tree, expected_output_seq = Variable(in_tree), \
                                 Variable(out_seq)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                loss += self.forward_train(input, expected_output_seq, use_teacher_forcing)

                # backward
                if current_batch % num_fake_batches == 0:
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.data[0]


                if current_batch % 250 == 0:
                    curr_loss = running_loss / (current_batch * train_loader.batch_size)
                    time_elapsed = time.time() - since

                    print('Epoch Number: {}, Batch Number: {}, Loss: {:.4f}, Acc: {:.4f}'.format(
                        epoch, current_batch, curr_loss, curr_acc))
                    print('Time so far is {:.0f}m {:.0f}s'.format(
                        time_elapsed // 60, time_elapsed % 60))



            validation_acc = test_model(model, validation_loader)
            print('Epoch Number: {}, Validation Accuracy: {:.4f}'.format(epoch, validation_acc))
            print()

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))



In [372]:
class MultilayerLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_layers, bias=True):
        super(MultilayerLSTMCell, self).__init__()
        self.lstm_layers = nn.ModuleList()
        
        if isinstance(hidden_sizes, int):
            temp = [input_size]
            
            for _ in range(num_layers):
                temp.append(hidden_sizes)
            
            hidden_sizes = temp
            
        else:
            hidden_sizes = [input_size] + hidden_sizes
        for i in range(num_layers):
            curr_lstm = nn.LSTMCell(hidden_sizes[i], hidden_sizes[i+1], bias=bias)
            self.lstm_layers.append(curr_lstm)
    
    def forward(self, input, hiddens, cell_states):
        result_hiddens, result_cell_states = [], []
        curr_input = input
        
        for lstm_cell, curr_hidden, curr_cell_state in zip(self.lstm_layers, hiddens, cell_states):
            curr_input, new_cell_state = lstm_cell(curr_input, (curr_hidden, curr_cell_state))
            result_hiddens.append(curr_input)
            result_cell_states.append(new_cell_state)
        
        return result_hiddens, result_cell_states

In [373]:
embedding_size = 50
hidden_size = 5
nclass = 100
alignment_size = 30
num_decoder_layers = 3

decoder = MultilayerLSTMCell(embedding_size + hidden_size, 5, 3)
program_model = Tree_to_Sequence_Attention_Model(encoder, decoder, num_decoder_layers, hidden_size, nclass, embedding_size, alignment_size, use_lstm=True,
                                                 use_cuda=False)

In [374]:
def tree_to_list(self, ast):
    pass

def list_to_tree(self, ls):
    pass
    
def translate_from_for(self, ls):
    if ls[0] == '<SEQ>':
        t1 = self.translate_from_for(ls[1])
        t2 = self.translate_from_for(ls[2])
        if t1[0] == '<LET>' and t1[-1] == '<UNIT>':
            t1[-1] = t2
            return t1
        else:
            return ['<LET>', 'blank', t1, t2]
    elif ls[0] == '<IF>':
        cmp = ls[1]
        t1 = self.translate_from_for(ls[2])
        t2 = self.translate_from_for(ls[3])
        return ['<IF>', cmp, t1, t2]
    elif ls[0] == '<FOR>':
        var = ls[1]
        init = self.translate_from_for(ls[2])
        cmp = self.translate_from_for(ls[3])
        inc = self.translate_from_for(ls[4])
        body = self.translate_from_for(ls[5])
        tb = ['<LET>', 'blank', body, ['<APP>', 'func', inc]]
        funcbody = ['<IF>', cmp, tb, '<UNIT>']
        translate = ['<LETREC>', 'func', var, funcbody, ['<APP>', 'func', init]]
        return translate
    elif ls[0] == '<ASSIGN>':
        return ['<LET>', ls[1], ls[2], '<UNIT>']
    else:
        return ls

In [375]:
"""
Programs have more variables than are allowed for in the code, need to check how many 
"""
for_prog_0 = makeTree(for_list_json[0])
prediction = program_model.beam_search_prediction(for_prog_0)
print(prediction)
#for_trees = [makeTree(for_prog_json) for for_prog_json in for_list_json]
#lambdaProgList = [translate_from_for(for_prog) for for_prog in for_list_json]
#loss = Variable(torch.FloatTensor([0]))
#for in_tree, out_seq, i in zip(input_trees, expected_output_seqs, range(input_trees)):
#    loss += program_model.forward_train(in_tree, out_seq)
#    loss.backwards()


index i Variable containing:
 62
[torch.LongTensor of size 1x1]

wordinputs beamsize*1 -1 [Variable containing:
 0
[torch.LongTensor of size 1x1]
]
combined [Variable containing:
 0
[torch.LongTensor of size 1x1]
, Variable containing:
 62
[torch.LongTensor of size 1x1]
]
index i Variable containing:
 44
[torch.LongTensor of size 1x1]

wordinputs beamsize*1 -1 [Variable containing:
 0
[torch.LongTensor of size 1x1]
]
combined [Variable containing:
 0
[torch.LongTensor of size 1x1]
, Variable containing:
 62
[torch.LongTensor of size 1x1]
]
index i Variable containing:
 44
[torch.LongTensor of size 1x1]

wordinputs beamsize*1 -1 [Variable containing:
 0
[torch.LongTensor of size 1x1]
]
combined [Variable containing:
 0
[torch.LongTensor of size 1x1]
, Variable containing:
 62
[torch.LongTensor of size 1x1]
]
index i Variable containing:
 44
[torch.LongTensor of size 1x1]

wordinputs beamsize*1 -1 [Variable containing:
 0
[torch.LongTensor of size 1x1]
]
combined [Variable containing:
 0