In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import random
import matplotlib.pyplot as plt
import json

In [None]:
class TrinaryCell(nn.Module):
    """
    LSTM Cell which takes in 3 hidden states and 3 cell states.
    """
    def __init__(self, input_size, hidden_size):
        """
        Initialize all the gates
        
        :param input_size: The length of the input vector.
        :param hidden_size: The length of the hidden state/output vector
        """
        super(TrinaryCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Key:
        #   I = Input
        #   L = Left
        #   M = Middle
        #   R = Right
        
        # Initialize all the gates
        self.inputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.inputGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.inputGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.inputGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.leftForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.leftForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.leftForgetGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.leftForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.middleForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.middleForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.middleForgetGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.middleForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.rightForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.rightForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.rightForgetGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.rightForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.outputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.outputGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.outputGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.outputGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.memoryGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.memoryGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.memoryGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.memoryGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        # Functions we'll use later
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 3 hidden states.
        :param cell_states: A list of 3 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        hiddenL = hidden_states[0]
        hiddenM = hidden_states[1]
        hiddenR = hidden_states[2]
        
        stateL = cell_states[0]
        stateM = cell_states[1]
        stateR = cell_states[2]
        
        # Don't you love all this copy-pasting?
        i = self.sigmoid(self.inputGateI(input) + 
                         self.inputGateL(hiddenL) + 
                         self.inputGateM(hiddenM) + 
                         self.inputGateR(hiddenR))
        
        f_left = self.sigmoid(self.leftForgetGateI(input) + 
                         self.leftForgetGateL(hiddenL) + 
                         self.leftForgetGateM(hiddenM) + 
                         self.leftForgetGateR(hiddenR))
        
        f_middle = self.sigmoid(self.middleForgetGateI(input) + 
                         self.middleForgetGateL(hiddenL) + 
                         self.middleForgetGateM(hiddenM) + 
                         self.middleForgetGateR(hiddenR))
        
        f_right = self.sigmoid(self.rightForgetGateI(input) + 
                         self.rightForgetGateL(hiddenL) + 
                         self.rightForgetGateM(hiddenM) + 
                         self.rightForgetGateR(hiddenR))
        
        o = self.sigmoid(self.outputGateI(input) + 
                         self.outputGateL(hiddenL) + 
                         self.outputGateM(hiddenM) + 
                         self.outputGateR(hiddenR))
        
        c = self.tanh(self.memoryGateI(input) + 
                         self.memoryGateL(hiddenL) + 
                         self.memoryGateM(hiddenM) + 
                         self.memoryGateR(hiddenR))
        
        new_state = i * c + f_left * stateL + f_middle * stateM + f_right * stateR
        new_hidden = o * self.tanh(new_state)
        
        return new_hidden, new_state

In [None]:
class BinaryCell(nn.Module):
    """
    Literally the same as TrinaryCell but with 2 inputs
    """
    def __init__(self, input_size, hidden_size):
        super(BinaryCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Key:
        #   I = Input
        #   L = Left
        #   R = Right
        
        # Initialize all the gates
        self.inputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.inputGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.inputGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.leftForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.leftForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.leftForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.rightForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.rightForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.rightForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.outputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.outputGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.outputGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.memoryGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.memoryGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.memoryGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 2 hidden states.
        :param cell_states: A list of 2 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        hiddenL = hidden_states[0]
        hiddenR = hidden_states[1]
        
        stateL = cell_states[0]
        stateR = cell_states[1]
        
        i = self.sigmoid(self.inputGateI(input) + 
                         self.inputGateL(hiddenL) + 
                         self.inputGateR(hiddenR))
        
        f_left = self.sigmoid(self.leftForgetGateI(input) + 
                         self.leftForgetGateL(hiddenL) + 
                         self.leftForgetGateR(hiddenR))
        
        f_right = self.sigmoid(self.rightForgetGateI(input) + 
                         self.rightForgetGateL(hiddenL) + 
                         self.rightForgetGateR(hiddenR))
        
        o = self.sigmoid(self.outputGateI(input) + 
                         self.outputGateL(hiddenL) + 
                         self.outputGateR(hiddenR))
        
        c = self.tanh(self.memoryGateI(input) + 
                         self.memoryGateL(hiddenL) + 
                         self.memoryGateR(hiddenR))
        
        new_state = i * c + f_left * stateL + f_right * stateR
        new_hidden = o * self.tanh(new_state)
        
        return new_hidden, new_state

In [None]:
class UnaryCell(nn.Module):
    """
    Literally the same as BinaryCell but with 1 inputs
    """
    def __init__(self, input_size, hidden_size):
        super(UnaryCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Key:
        #   I = Input
        
        # Initialize all the gates
        self.inputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.inputGate = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.forgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.forgetGate = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.outputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.outputGate = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.memoryGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.memoryGate = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 2 hidden states.
        :param cell_states: A list of 2 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        hidden = hidden_states[0]
        state = cell_states[0]
        
        i = self.sigmoid(self.inputGateI(input) + 
                         self.inputGate(hidden))
        
        f = self.sigmoid(self.forgetGateI(input) + 
                         self.forgetGate(hidden))
        
        o = self.sigmoid(self.outputGateI(input) +  
                         self.outputGate(hidden))
        
        c = self.tanh(self.memoryGateI(input) +  
                      self.memoryGate(hidden))
        
        new_state = i * c + f * state
        new_hidden = o * self.tanh(new_state)
        
        return new_hidden, new_state

In [None]:
'''
Encoder

Takes in a Tree where each node has a value (vector?) and a list of children
Produces a vector of desired size with an encoding of the tree

Recursively: For each node go left *then go middle* then go right, pass the necessary values
into an lstm cell along with values from each child (0 if at leaf) Output the result of the lstm cell
at the root.

'''
class Encoder(nn.Module):
    """
    Takes in a tree where each node has a value vector and a list of children
    Produces a sequence encoding of the tree
    """

    def __init__(self, input_size, hidden_size):
        """
        Initialize variables we'll need later.
        """
        super(Encoder, self).__init__()
        
        self.unaryLstm = UnaryCell(input_size, hidden_size)
        self.binaryLstm = BinaryCell(input_size, hidden_size)
        self.trinaryLstm = TrinaryCell(input_size, hidden_size)
        self.encoding = Variable(torch.FloatTensor(1, hidden_size))
        
    def forward(self, tree):
        """
        Starts off the entire encoding process
        
        :param tree: a tree where each node has a value vector and a list of children
        :return self.encoding, a matrix where each row represents the encoded output of a single node
        """
        self.encode(tree)
        return self.encoding
        
    def encode(self, node):
        """
        Recursively a node and all its children as sequence vectors
        
        :param node: The root of the tree (or subtree)
        :return A tuple (new hidden vector, new cell state).  The new hidden vector is an endoding of node
        """
        
        # List of tuples: (h, c), each of which are size hidden_size
        children = [self.encode(child) for child in node.children]
        
        
        if len(children) == 0:
            children = [(Variable(torch.zeros(hidden_size)), 
                         Variable(torch.zeros(hidden_size))),
                        (Variable(torch.zeros(hidden_size)), 
                         Variable(torch.zeros(hidden_size)))]
            
        # Vector of size input_size x len(children)
        inputH = [vec[0] for vec in children]
        inputC = [vec[1] for vec in children]
        newH = None
        newC = None
        
        value = Variable(node.value.unsqueeze(0))
            
        if len(children) == 2:
            newH, newC =  self.binaryLstm(value, inputH, inputC)
        elif len(children) == 3:
            newH, newC = self.trinaryLstm(value, inputH, inputC)
        elif len(children) == 1:
            newH, newC = self.unaryLstm(value, inputH, inputC)
        else:
            print("WHAAAAAT?")
            raise NotImplementedError
            
        # Add the new encoding to the end of our list
        self.encoding = torch.cat((newH, self.encoding), 0)
        return (newH, newC)
    

        

In [None]:
input_size = 4
hidden_size = 5

test_vec = torch.FloatTensor(input_size)


class Node:
    """
    Node class just made for testing
    """
    def __init__(self, value):
        self. value = value
        self.children = []
    

child_len = [2, 3, 2, 3]   
def makeNodes(children):
    """
    Loop through the passes-in array and build a tree where each node in the i^th layer has children[i] nodes

    """
    if len(children) == 0:
        return Node(test_vec) # Make them all the same vec
    else: 
        newNode = Node(test_vec)
        for i in range(children[0]):
            newNode.children.append(makeNodes(children[1:]))
        return newNode 
    

        

# TODO:
#     - Comments
#     - Train func (after decoder exists)

# kangaroo

In [None]:
jsonString = "{\"tag\":\"If\",\"contents\":[{\"tag\":\"GeFor\",\"contents\":[{\"tag\":\"Const\",\"contents\":5},{\"tag\":\"Const\",\"contents\":3}]},{\"tag\":\"Assign\",\"contents\":[\"X\",{\"tag\":\"Const\",\"contents\":1}]},{\"tag\":\"Assign\",\"contents\":[\"Y\",{\"tag\":\"Const\",\"contents\":2}]}]}"
jsonObj = json.loads(jsonString)

num_vars = 5
num_ints = 7
for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}
var_dict = {}

def vectorize(val):
    vector = torch.zeros(num_vars + num_ints + len(for_ops.keys()))
    if type(val) is int:
        vector[val] = 1
    elif type(val) is str:
        index = len(var_dict.keys())
        if val in var_dict:
            index = var_dict[val]
        else:
            var_dict[val] = index
        vector[index + num_ints] = 1
    else:
        index = for_ops[val]
        vector[num_ints + num_vars + index] = 1
    return vector
            
        

def makeTree(json):
    if type(json) is str:
        parentNode = Node(vectorize("Var"))
        childNode = Node(vectorize(json))
        parentNode.children.append(childNode)
        return parentNode 
    
    if type(json) is int:
        return Node(vectorize(json))

    tag = json["tag"]
    children = json["contents"]
    parentNode = Node(vectorize(tag))
    
    currNode = parentNode
    
    if type(children) is list:
        for child in children:
            newChild = makeTree(child)
            currNode.children.append(newChild)
            currNode = newChild
    else:
        parentNode.children.append(makeTree(children))
        
    return parentNode



def printTree(tree):
    print(tree.value)
    for child in tree.children:
        printTree(child)
    
    
tree = makeTree(jsonObj)
# printTree(tree)

encoder = Encoder(num_vars + num_ints + len(for_ops.keys()), hidden_size)
encoded_vec = encoder(tree)
print("ENCODEDVEC", encoded_vec)



In [None]:
'''
Decoder



'''