In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import random
import matplotlib.pyplot as plt
import json

In [2]:
class TreeCell(nn.Module):
    """
    Insert Useful Comments
    """
    def __init__(self, input_size, hidden_size, num_children):
        super(TreeCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Gates = input, output, memory + one forget gate per child
        numGates = 3 + num_children
        
        self.gates = []
        for _ in range(numGates):
            # One linear layer to handle the value of the node
            valueLinear = nn.Linear(input_size, hidden_size, bias = True)
            childrenLinear = []
            # One per child of the node
            for _ in range(num_children):
                childrenLinear.append(nn.Linear(hidden_size, hidden_size, bias = False))
            self.gates.append((valueLinear, childrenLinear))
            
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 2 hidden states.
        :param cell_states: A list of 2 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        data_sums = []
        for gate in self.gates:
            data_sum = gate[0](input)
            for i in range(len(hidden_states)):
                data_sum += gate[1][i](hidden_states[i])
            data_sums.append(data_sum)
            
        # First gate is the input gate
        i = self.sigmoid(data_sums[0])
        # Next output gate
        o = self.sigmoid(data_sums[1])
        # Next memory gate
        m = self.tanh(data_sums[2])
        # All the rest are forget gates
        forget_data = 0
        for i in range(len(cell_states)):
            forget_data += data_sums[3 + i] * cell_states[i]
        
        # Put it all together!
        new_state = i * m + forget_data
        new_hidden = o * self.tanh(new_state)
                
        return new_hidden, new_state

In [None]:
'''
Encoder

Takes in a Tree where each node has a value (vector?) and a list of children
Produces a vector of desired size with an encoding of the tree

Recursively: For each node go left *then go middle* then go right, pass the necessary values
into an lstm cell along with values from each child (0 if at leaf) Output the result of the lstm cell
at the root.

'''
class Encoder(nn.Module):
    """
    Takes in a tree where each node has a value vector and a list of children
    Produces a sequence encoding of the tree
    """

    def __init__(self, input_size, hidden_size, valid_num_children):
        """
        Initialize variables we'll need later.
        """
        super(Encoder, self).__init__()
        
        self.lstm_dict = {}
        # We'll always need 0 for leaf nodes
        self.lstm_dict[0] = TreeCell(input_size, hidden_size, 0)
        
        for size in valid_num_children:
            self.lstm_dict[size] = TreeCell(input_size, hidden_size, size)

        self.encoding = Variable(torch.FloatTensor(1, hidden_size))
        
    def forward(self, tree):
        """
        Starts off the entire encoding process
        
        :param tree: a tree where each node has a value vector and a list of children
        :return self.encoding, a matrix where each row represents the encoded output of a single node
        """
        embeddings = [p[0] for p in self.encode(tree)]
        return torch.cat(embeddings, 0)
        
    def encode(self, node):
        """
        Recursively a node and all its children as sequence vectors
        
        :param node: The root of the tree (or subtree)
        :return A tuple (new hidden vector, new cell state).  The new hidden vector is an endoding of node
        """
        
        # List of tuples: (h, c), each of which are size hidden_size
        descendents = []
        children = []
        
        for child in node.children:
            current_descendents = self.encode(child)
            descendents += current_descendents
            children.append(current_descendents[-1])
        
        if len(children) == 0:
            children = [(Variable(torch.zeros(hidden_size)), 
                         Variable(torch.zeros(hidden_size))),
                        (Variable(torch.zeros(hidden_size)), 
                         Variable(torch.zeros(hidden_size)))]

        # Vector of size input_size x len(children)
        inputH = [vec[0] for vec in children]
        inputC = [vec[1] for vec in children]
        value = Variable(node.value.unsqueeze(0))
        
        if len(children) in self.lstm_dict:
            newH, newC = self.lstm_dict[len(children)](value, inputH, inputH)
        else:
            print("WHAAAAAT?")
            raise ValueError("Beware.  Something has gone horribly wrong.  You may not have long to live.")
            
        # Add the new encoding to the end of our list
        descendents.append((newH, newC))
        return descendents
    

        

In [7]:
input_size = 4
hidden_size = 5

test_vec = torch.FloatTensor(input_size)


class Node:
    """
    Node class just made for testing
    """
    def __init__(self, value):
        self. value = value
        self.children = []
    

child_len = [2, 3, 2, 3]   
def makeNodes(children):
    """
    Loop through the passes-in array and build a tree where each node in the i^th layer has children[i] nodes

    """
    if len(children) == 0:
        return Node(test_vec) # Make them all the same vec
    else: 
        newNode = Node(test_vec)
        for i in range(children[0]):
            newNode.children.append(makeNodes(children[1:]))
        return newNode 

# kangaroo

In [8]:
jsonString = "{\"tag\":\"If\",\"contents\":[{\"tag\":\"GeFor\",\"contents\":[{\"tag\":\"Const\",\"contents\":5},{\"tag\":\"Const\",\"contents\":3}]},{\"tag\":\"Assign\",\"contents\":[\"X\",{\"tag\":\"Const\",\"contents\":1}]},{\"tag\":\"Assign\",\"contents\":[\"Y\",{\"tag\":\"Const\",\"contents\":2}]}]}"
jsonObj = json.loads(jsonString)

num_vars = 5
num_ints = 7
for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}
var_dict = {}

def vectorize(val):
    vector = torch.zeros(num_vars + num_ints + len(for_ops.keys()))
    if type(val) is int:
        vector[val] = 1
    elif type(val) is str:
        index = len(var_dict.keys())
        if val in var_dict:
            index = var_dict[val]
        else:
            var_dict[val] = index
        vector[index + num_ints] = 1
    else:
        index = for_ops[val]
        vector[num_ints + num_vars + index] = 1
    return vector
            
        

def makeTree(json):
    if type(json) is str:
        parentNode = Node(vectorize("Var"))
        childNode = Node(vectorize(json))
        parentNode.children.append(childNode)
        return parentNode 
    
    if type(json) is int:
        return Node(vectorize(json))

    tag = json["tag"]
    children = json["contents"]
    parentNode = Node(vectorize(tag))
    
    currNode = parentNode
    
    if type(children) is list:
        for child in children:
            newChild = makeTree(child)
            currNode.children.append(newChild)
            currNode = newChild
    else:
        parentNode.children.append(makeTree(children))
        
    return parentNode



def printTree(tree):
    print(tree.value)
    for child in tree.children:
        printTree(child)
    
    
tree = makeTree(jsonObj)

encoder = Encoder(num_vars + num_ints + len(for_ops.keys()), hidden_size, [1,2,3])
encoded_vec = encoder(tree)
print("ENCODEDVEC", encoded_vec)

ENCODEDVEC Variable containing:
-0.1848 -0.1053 -0.0355  0.1507 -0.1044
-0.1122  0.0519  0.0023  0.0359 -0.0757
-0.0019 -0.0054 -0.0002  0.0056  0.0031
-0.1563 -0.0816  0.1100  0.0366 -0.1411
 0.0171 -0.0734 -0.0908  0.1049 -0.0942
-0.1143 -0.0007 -0.0205  0.1564 -0.0322
-0.0015  0.0000  0.0013  0.0238  0.0020
 0.0143 -0.0174  0.0287 -0.0242 -0.1128
-0.0374  0.0431  0.0504  0.1268 -0.0165
 0.0061  0.0125 -0.0981  0.0004 -0.0423
 0.0003 -0.0011  0.0067  0.0001  0.0032
 0.0378 -0.0081  0.0136 -0.0122 -0.0980
 0.0046  0.0006 -0.0003  0.0001 -0.0142
-0.1192  0.0708  0.0547  0.0690 -0.0552
-0.0811 -0.0181  0.0213  0.1985  0.0102
-0.0010  0.0004  0.0006  0.0300  0.0009
[torch.FloatTensor of size 16x5]



In [9]:
'''
Decoder
'''
class Tree_to_Sequence_Model(nn.Module):
    """
      For the decoder this expects something like an lstm cell or a gru cell and not an lstm/gru.
      If you don't use teacher forcing, batches are forbidden/lead to bad behavior
    """
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size,
                 use_lstm=False, use_cuda=True):
        super(Tree_to_Sequence_Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
        # nclass + 2 to include end of sequence and trash
        self.output_log_probs = nn.Linear(hidden_size, nclass+2)
        self.softmax = nn.Softmax(dim=0)

        self.SOS_token = Variable(torch.LongTensor([[0]]))
        self.EOS_value = 1

        if use_cuda:
            self.SOS_token = self.SOS_token.cuda()

        self.embedding = nn.Embedding(nclass, embedding_size)

        self.use_lstm = use_lstm
        # nclass + 1 is the trash category to avoid penalties after target's EOS token
        self.loss_func = nn.CrossEntropyLoss(ignore_index=nclass+1)

        if use_lstm:
            self.decoder_initial_cell_state = torch.zeros(1, hidden_size)

    def forward_train(self, input, target, use_teacher_forcing=False):
        # encoded features
        encoded_features = self.encoder(input) # [w, c]
        decoder_hidden = encoded_features[-1, :]

        target_length = target.size()[0]
        decoder_input = self.embedding(self.SOS_token).transpose(0,1).repeat(1, batch_size, 1)
        loss = 0

        if self.use_lstm:
            decoder_cell_state = self.decoder_initial_cell_state

        for i in range(target_length):
            if self.use_lstm:
                decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
            else:
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            log_probs = self.output_log_probs(decoder_output)
            loss += self.loss_func(log_probs, target[i, :])

            if use_teacher_forcing:
                decoder_input = self.embedding(target[i, :].unsqueeze(1)).squeeze(1)
            else:
                _, topi = log_probs.data.topk(1)
                ni = topi[0, 0]

                if ni == self.EOS_value:
                    break

                decoder_input = self.embedding(Variable([ni]).unsqueeze(1)).squeeze(1)

        return loss

#     """
#       Inputs must be of batch size 1
#     """
#     def point_wise_prediction(self, input, maximum_length=20):
#       # encoded features
#       encoded_features = self.encoder(input).squeeze(1) # [w, c]
#       decoder_hidden = encoded_features[-1, :]
#       decoder_input = self.embedding(self.SOS_token).squeeze(0)
#       output_so_far = []

#       if self.use_lstm:
#           decoder_cell_state = self.decoder_initial_cell_state

#       for i in range(maximum_length):
#           if self.use_lstm:
#               decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
#           else:
#               decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

#           log_probs = self.output_log_probs(decoder_output)

#           _, topi = log_probs.data.topk(1)
#           ni = topi[0, 0]

#           if ni == self.EOS_value:
#               break

#           output_so_far.append(ni)
#           decoder_input = self.embedding(Variable([ni]).unsqueeze(1)).squeeze(1)

#           if self.use_cuda:
#               decoder_input = decoder_input.cuda()

#       return output_so_far


#     def beam_search_prediction(self, input, maximum_length=20):
#       pass



In [10]:
class Tree_to_Sequence_Attention_Model(Tree_to_Sequence_Model):
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size,
                 alignment_size, decoder_cell_state_shape=None, use_lstm=False, use_cuda=True):
        super(Tree_to_Sequence_Attention_Model, self).__init__(encoder, decoder, hidden_size, nclass, embedding_size,
                                                               use_lstm=use_lstm, use_cuda=use_cuda)
        self.attention_hidden = nn.Linear(hidden_size, alignment_size)
        self.attention_context = nn.Linear(hidden_size, alignment_size, bias=False)
        self.tanh = nn.Tanh()
        self.attention_alignment_vector = nn.Linear(alignment_size, 1)

    """
        input: The output of the encoder for the input should have dimensions, (seq_len x batch_size x input_size)
        target: The target should have dimensions, seq_len, and should be a LongTensor.
    """
    def forward_train(self, input, target, use_teacher_forcing=False):
        encoded_features = self.encoder(input) # sequence x encoder hidden
        attention_hidden_values = self.attention_hidden(encoded_features)

        decoder_hidden = encoded_features[-1, :].unsqueeze(0)
        target_length = target.size()[0] # sequence x decoder hidden
        word_input = self.embedding(self.SOS_token).squeeze(0) # batch x decoder hidden

        loss = 0
        
        if self.use_lstm:
            decoder_cell_state = self.decoder_initial_cell_state

        for i in range(target_length):
            self.attention_context(decoder_hidden)
            self.attention_context(decoder_hidden) + attention_hidden_values
            self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values)
            
            attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
            attention_probs = self.softmax(attention_logits) # w
            context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0).unsqueeze(0) # decoder hidden 
            decoder_input = torch.cat((word_input, context_vec), dim=1)

            if self.use_lstm:
                print('i')
                print(decoder_input)
                
                print('h')
                print(decoder_hidden)
                
                print('c')
                print(decoder_cell_state)
                
                decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
            else:
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            log_probs = self.output_log_probs(decoder_output)
            loss += self.loss_func(log_probs, target[i])

            if use_teacher_forcing:
                next_input = target[i].unsqueeze(0)
            else:
                _, next_input = log_probs.topk(1)

                if next_input[0, 0] == self.EOS_value:
                    break

            word_input = self.embedding(next_input).squeeze(0) # batch x decoder hidden

        return loss


    """
      Inputs must be of batch size 1
    """
    def point_wise_prediction(self, input, maximum_length=20):
        # encoded features
        encoded_features = self.encoder(input) # sequence x hidden
        print(encoded_features)
        attention_hidden_values = self.attention_hidden(encoded_features)

        decoder_hidden = encoded_features[-1, :].unsqueeze(0)
        word_input = self.embedding(self.SOS_token).squeeze(0) # batch x hidden
        output_so_far = []

        if self.use_lstm:
            decoder_cell_state = self.decoder_initial_cell_state

        for i in range(maximum_length):
            attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
            attention_probs = self.softmax(attention_logits) # w
            context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0).unsqueeze(0) # hidden 
            decoder_input = torch.cat((word_input, context_vec), dim=1)

            if self.use_lstm:
                decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
            else:
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            log_probs = self.output_log_probs(decoder_output)
            _, next_input = log_probs.topk(1)
            
            output_so_far.append(next_input[0][0])
            
            if next_input[0, 0] == self.EOS_value:
                break

            word_input = self.embedding(next_input).squeeze(0) # batch x hidden

        return output_so_far

    def beam_search_prediction(self, input, maximum_length=20, beam_width=5):
        # encoded features
        encoded_features = self.encoder(input) # sequence x hidden
        attention_hidden_values = self.attention_hidden(encoded_features)

        decoder_hidden = encoded_features[-1, :].unsqueeze(0)
        word_inputs = [(0, [self.SOS_token], True) for _ in range(max_beam_width)]

        if self.use_lstm:
            decoder_cell_state = self.decoder_initial_cell_state

        for i in range(maximum_length):
            attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
            attention_probs = self.softmax(attention_logits) # w
            context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0).unsqueeze(0) # hidden 
            
            newWordInputs = []
            
            for i in range(beam_width):                
                if not word_inputs[i][2]:
                    newWordInputs.append(word_inputs[i])
                    continue
                    
                word_input = self.embedding(word_inputs[i][1][-1]).squeeze(0)
                decoder_input = torch.cat((word_input, context_vec), dim=1)

                if self.use_lstm:
                    decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
                else:
                    decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

                log_probs = self.output_log_probs(decoder_output).squeeze(0)
                value, index = self.softmax(log_probs).topk(beam_width)
                log_value = value.log()
                
                newWordInputs.extend((word_inputs[i][0] + log_value[i], word_inputs[i][1].append(index[i]), 
                                      index[i,0] == self.EOS_Value) for i in range(beam_width))
                
            #Get the new words in the beam.
            word_inputs = sorted(newWordInputs, key=lambda word_pair: word_pair[0])[-beam_width:]

        outputs = [word_inputs[i][1][0][0] for i in range(beam_width)]
        return outputs

In [11]:
embedding_size = 50
hidden_size = 5
nclass = 100
alignment_size = 30

decoder = nn.LSTMCell(embedding_size + hidden_size, 5)
program_model = Tree_to_Sequence_Attention_Model(encoder, decoder, hidden_size, nclass, embedding_size, alignment_size, use_lstm=True,
                                                 use_cuda=False)

In [12]:
loss = program_model.forward_train(tree, Variable(torch.LongTensor([1,2,3,4,5,6,7])))

i
Variable containing:

Columns 0 to 9 
 0.2146 -0.4223  0.2160 -0.4057 -0.5137 -1.2001  0.0252  1.5605  0.9259  1.0601

Columns 10 to 19 
 0.9424 -0.0404  0.9510 -0.7113  1.3665 -0.5405  0.7011  1.1401  1.1159  0.3539

Columns 20 to 29 
-0.5552 -1.6007  0.2998 -0.2021  2.2639  0.8223  1.2482 -1.0407 -0.1205 -0.7811

Columns 30 to 39 
 0.8677  0.6044 -1.1734  0.6653  0.8041 -0.0141 -0.1219  0.5598  0.7022  1.0724

Columns 40 to 49 
-1.3424  0.0706  1.6218  0.7883 -0.1021  1.6576 -0.3728  0.6833  1.9102  1.0823

Columns 50 to 54 
-0.0448 -0.0084  0.0022  0.0558 -0.0479
[torch.FloatTensor of size 1x55]

h
Variable containing:
1.00000e-02 *
 -0.0959  0.0376  0.0615  2.9981  0.0860
[torch.FloatTensor of size 1x5]

c

 0  0  0  0  0
[torch.FloatTensor of size 1x5]



RuntimeError: mul() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:
 * (float other)
      didn't match because some of the arguments have invalid types: ([31;1mtorch.FloatTensor[0m)
 * (Variable other)
      didn't match because some of the arguments have invalid types: ([31;1mtorch.FloatTensor[0m)
