In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import random
import matplotlib.pyplot as plt
import json

In [None]:
class TreeCell(nn.Module):
    """
    Insert Useful Comments
    """
    def __init__(self, input_size, hidden_size, num_children):
        super(TreeCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Gates = input, output, memory + one forget gate per child
        numGates = 3 + num_children
        
        self.gates = []
        for _ in range(numGates):
            # One linear layer to handle the value of the node
            valueLinear = nn.Linear(input_size, hidden_size, bias = True)
            childrenLinear = []
            # One per child of the node
            for _ in range(num_children):
                childrenLinear.append(nn.Linear(hidden_size, hidden_size, bias = False))
            self.gates.append((valueLinear, childrenLinear))
            
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 2 hidden states.
        :param cell_states: A list of 2 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        data_sums = []
        for gate in self.gates:
            data_sum = gate[0](input)
            for i in range(len(hidden_states)):
                data_sum += gate[1][i](hidden_states[i])
            data_sums.append(data_sum)
            
        # First gate is the input gate
        i = self.sigmoid(data_sums[0])
        # Next output gate
        o = self.sigmoid(data_sums[1])
        # Next memory gate
        m = self.tanh(data_sums[2])
        # All the rest are forget gates
        forget_data = 0
        for i in range(len(cell_states)):
            forget_data += data_sums[3 + i] * cell_states[i]
        
        # Put it all together!
        new_state = i * m + forget_data
        new_hidden = o * self.tanh(new_state)
                
        return new_hidden, new_state

In [None]:
class TrinaryCell(nn.Module):
    """
    LSTM Cell which takes in 3 hidden states and 3 cell states.
    """
    def __init__(self, input_size, hidden_size):
        """
        Initialize all the gates
        
        :param input_size: The length of the input vector.
        :param hidden_size: The length of the hidden state/output vector
        """
        super(TrinaryCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Key:
        #   I = Input
        #   L = Left
        #   M = Middle
        #   R = Right
        
        # Initialize all the gates
        self.inputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.inputGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.inputGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.inputGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.leftForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.leftForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.leftForgetGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.leftForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.middleForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.middleForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.middleForgetGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.middleForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.rightForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.rightForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.rightForgetGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.rightForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.outputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.outputGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.outputGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.outputGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.memoryGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.memoryGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.memoryGateM = nn.Linear(hidden_size, hidden_size, bias = False)
        self.memoryGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        # Functions we'll use later
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 3 hidden states.
        :param cell_states: A list of 3 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        hiddenL = hidden_states[0]
        hiddenM = hidden_states[1]
        hiddenR = hidden_states[2]
        
        stateL = cell_states[0]
        stateM = cell_states[1]
        stateR = cell_states[2]
        
        # Don't you love all this copy-pasting?
        i = self.sigmoid(self.inputGateI(input) + 
                         self.inputGateL(hiddenL) + 
                         self.inputGateM(hiddenM) + 
                         self.inputGateR(hiddenR))
        
        f_left = self.sigmoid(self.leftForgetGateI(input) + 
                         self.leftForgetGateL(hiddenL) + 
                         self.leftForgetGateM(hiddenM) + 
                         self.leftForgetGateR(hiddenR))
        
        f_middle = self.sigmoid(self.middleForgetGateI(input) + 
                         self.middleForgetGateL(hiddenL) + 
                         self.middleForgetGateM(hiddenM) + 
                         self.middleForgetGateR(hiddenR))
        
        f_right = self.sigmoid(self.rightForgetGateI(input) + 
                         self.rightForgetGateL(hiddenL) + 
                         self.rightForgetGateM(hiddenM) + 
                         self.rightForgetGateR(hiddenR))
        
        o = self.sigmoid(self.outputGateI(input) + 
                         self.outputGateL(hiddenL) + 
                         self.outputGateM(hiddenM) + 
                         self.outputGateR(hiddenR))
        
        c = self.tanh(self.memoryGateI(input) + 
                         self.memoryGateL(hiddenL) + 
                         self.memoryGateM(hiddenM) + 
                         self.memoryGateR(hiddenR))
        
        new_state = i * c + f_left * stateL + f_middle * stateM + f_right * stateR
        new_hidden = o * self.tanh(new_state)
        
        return new_hidden, new_state

In [None]:
class BinaryCell(nn.Module):
    """
    Literally the same as TrinaryCell but with 2 inputs
    """
    def __init__(self, input_size, hidden_size):
        super(BinaryCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Key:
        #   I = Input
        #   L = Left
        #   R = Right
        
        # Initialize all the gates
        self.inputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.inputGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.inputGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.leftForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.leftForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.leftForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.rightForgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.rightForgetGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.rightForgetGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.outputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.outputGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.outputGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.memoryGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.memoryGateL = nn.Linear(hidden_size, hidden_size, bias = False)
        self.memoryGateR = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 2 hidden states.
        :param cell_states: A list of 2 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        hiddenL = hidden_states[0]
        hiddenR = hidden_states[1]
        
        stateL = cell_states[0]
        stateR = cell_states[1]
        
        i = self.sigmoid(self.inputGateI(input) + 
                         self.inputGateL(hiddenL) + 
                         self.inputGateR(hiddenR))
        
        f_left = self.sigmoid(self.leftForgetGateI(input) + 
                         self.leftForgetGateL(hiddenL) + 
                         self.leftForgetGateR(hiddenR))
        
        f_right = self.sigmoid(self.rightForgetGateI(input) + 
                         self.rightForgetGateL(hiddenL) + 
                         self.rightForgetGateR(hiddenR))
        
        o = self.sigmoid(self.outputGateI(input) + 
                         self.outputGateL(hiddenL) + 
                         self.outputGateR(hiddenR))
        
        c = self.tanh(self.memoryGateI(input) + 
                         self.memoryGateL(hiddenL) + 
                         self.memoryGateR(hiddenR))
        
        new_state = i * c + f_left * stateL + f_right * stateR
        new_hidden = o * self.tanh(new_state)
        
        return new_hidden, new_state

In [None]:
class UnaryCell(nn.Module):
    """
    Literally the same as BinaryCell but with 1 inputs
    """
    def __init__(self, input_size, hidden_size):
        super(UnaryCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Key:
        #   I = Input
        
        # Initialize all the gates
        self.inputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.inputGate = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.forgetGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.forgetGate = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.outputGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.outputGate = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.memoryGateI = nn.Linear(input_size, hidden_size, bias = False)
        self.memoryGate = nn.Linear(hidden_size, hidden_size, bias = True)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates
        
        :param hidden_states: A list of 2 hidden states.
        :param cell_states: A list of 2 cell states.
        :return A tuple containing (new hidden state, new cell state)
        """
        
        hidden = hidden_states[0]
        state = cell_states[0]
        
        i = self.sigmoid(self.inputGateI(input) + 
                         self.inputGate(hidden))
        
        f = self.sigmoid(self.forgetGateI(input) + 
                         self.forgetGate(hidden))
        
        o = self.sigmoid(self.outputGateI(input) +  
                         self.outputGate(hidden))
        
        c = self.tanh(self.memoryGateI(input) +  
                      self.memoryGate(hidden))
        
        new_state = i * c + f * state
        new_hidden = o * self.tanh(new_state)
        
        return new_hidden, new_state

In [None]:
'''
Encoder

Takes in a Tree where each node has a value (vector?) and a list of children
Produces a vector of desired size with an encoding of the tree

Recursively: For each node go left *then go middle* then go right, pass the necessary values
into an lstm cell along with values from each child (0 if at leaf) Output the result of the lstm cell
at the root.

'''
class Encoder(nn.Module):
    """
    Takes in a tree where each node has a value vector and a list of children
    Produces a sequence encoding of the tree
    """

    def __init__(self, input_size, hidden_size, valid_num_children):
        """
        Initialize variables we'll need later.
        """
        super(Encoder, self).__init__()
        
        self.lstm_dict = {}
        # We'll always need 0 for leaf nodes
        self.lstm_dict[0] = TreeCell(input_size, hidden_size, 0)
        
        for size in valid_num_children:
            self.lstm_dict[size] = TreeCell(input_size, hidden_size, size)

        self.encoding = Variable(torch.FloatTensor(1, hidden_size))
        
    def forward(self, tree):
        """
        Starts off the entire encoding process
        
        :param tree: a tree where each node has a value vector and a list of children
        :return self.encoding, a matrix where each row represents the encoded output of a single node
        """
        embeddings = [p[0] for p in self.encode(tree)]
        return torch.cat(embeddings, 0)
        
    def encode(self, node):
        """
        Recursively a node and all its children as sequence vectors
        
        :param node: The root of the tree (or subtree)
        :return A tuple (new hidden vector, new cell state).  The new hidden vector is an endoding of node
        """
        
        # List of tuples: (h, c), each of which are size hidden_size
        descendents = []
        children = []
        
        for child in node.children:
            current_descendents = self.encode(child)
            descendents += current_descendents
            children.append(current_descendents[-1])
        
        if len(children) == 0:
            children = [(Variable(torch.zeros(hidden_size)), 
                         Variable(torch.zeros(hidden_size))),
                        (Variable(torch.zeros(hidden_size)), 
                         Variable(torch.zeros(hidden_size)))]

        # Vector of size input_size x len(children)
        inputH = [vec[0] for vec in children]
        inputC = [vec[1] for vec in children]
        value = Variable(node.value.unsqueeze(0))
        
        if len(children) in self.lstm_dict:
            newH, newC = self.lstm_dict[len(children)](value, inputH, inputH)
        else:
            print("WHAAAAAT?")
            raise ValueError("Beware.  Something has gone horribly wrong.  You may not have long to live.")
            
        # Add the new encoding to the end of our list
        descendents.append((newH, newC))
        return descendents
    

        

In [None]:
input_size = 4
hidden_size = 5

test_vec = torch.FloatTensor(input_size)


class Node:
    """
    Node class just made for testing
    """
    def __init__(self, value):
        self. value = value
        self.children = []
    

child_len = [2, 3, 2, 3]   
def makeNodes(children):
    """
    Loop through the passes-in array and build a tree where each node in the i^th layer has children[i] nodes

    """
    if len(children) == 0:
        return Node(test_vec) # Make them all the same vec
    else: 
        newNode = Node(test_vec)
        for i in range(children[0]):
            newNode.children.append(makeNodes(children[1:]))
        return newNode 

# kangaroo

In [None]:
jsonString = "{\"tag\":\"If\",\"contents\":[{\"tag\":\"GeFor\",\"contents\":[{\"tag\":\"Const\",\"contents\":5},{\"tag\":\"Const\",\"contents\":3}]},{\"tag\":\"Assign\",\"contents\":[\"X\",{\"tag\":\"Const\",\"contents\":1}]},{\"tag\":\"Assign\",\"contents\":[\"Y\",{\"tag\":\"Const\",\"contents\":2}]}]}"
jsonObj = json.loads(jsonString)

num_vars = 5
num_ints = 7
for_ops = {
    "Var": 0,
    "Const": 1,
    "Plus": 2,
    "Minus": 3,
    "EqualFor": 4,
    "LeFor": 5,
    "GeFor": 6,
    "Assign": 7,
    "If": 8,
    "Seq": 9,
    "For": 10
}
var_dict = {}

def vectorize(val):
    vector = torch.zeros(num_vars + num_ints + len(for_ops.keys()))
    if type(val) is int:
        vector[val] = 1
    elif type(val) is str:
        index = len(var_dict.keys())
        if val in var_dict:
            index = var_dict[val]
        else:
            var_dict[val] = index
        vector[index + num_ints] = 1
    else:
        index = for_ops[val]
        vector[num_ints + num_vars + index] = 1
    return vector
            
        

def makeTree(json):
    if type(json) is str:
        parentNode = Node(vectorize("Var"))
        childNode = Node(vectorize(json))
        parentNode.children.append(childNode)
        return parentNode 
    
    if type(json) is int:
        return Node(vectorize(json))

    tag = json["tag"]
    children = json["contents"]
    parentNode = Node(vectorize(tag))
    
    currNode = parentNode
    
    if type(children) is list:
        for child in children:
            newChild = makeTree(child)
            currNode.children.append(newChild)
            currNode = newChild
    else:
        parentNode.children.append(makeTree(children))
        
    return parentNode



def printTree(tree):
    print(tree.value)
    for child in tree.children:
        printTree(child)
    
    
tree = makeTree(jsonObj)

encoder = Encoder(num_vars + num_ints + len(for_ops.keys()), hidden_size, [1,2,3])
encoded_vec = encoder(tree)
print("ENCODEDVEC", encoded_vec)

In [None]:
'''
Decoder



'''
class Tree_to_Sequence_Model(nn.Module):
    """
      For the decoder this expects something like an lstm cell or a gru cell and not an lstm/gru.
      If you don't use teacher forcing, batches are forbidden/lead to bad behavior
    """
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size,
                 decoder_cell_state_shape=None, use_lstm=False, use_cuda=True):
        super(Tree_to_Sequence_Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        #nclass + 2 to include end of sequence and trash
        self.output_log_probs = nn.Linear(hidden_size, nclass+2)
        self.softmax = nn.Softmax()

        self.SOS_token = Variable(torch.LongTensor([[0]]))
        self.EOS_value = 1

        self.use_cuda = use_cuda

        if self.use_cuda:
            self.SOS_token = self.SOS_token.cuda()

        self.embedding = nn.Embedding(nclass, embedding_size)

        self.use_lstm = use_lstm
        #nclass + 1 is the trash category to avoid penalties after target's EOS token
        self.loss_func = nn.CrossEntropyLoss(ignore_index=nclass+1)

        if use_lstm:
            self.decoder_initial_cell_state = torch.zeros(decoder_initial_cell_state)

    def forward_train(self, input, target, use_teacher_forcing=False):
        # encoded features
        encoded_features = self.encoder(input) # [w, c]
        decoder_hidden = encoded_features[-1, :]

        target_length = target.size()
        decoder_input = self.embedding(self.SOS_token).transpose(0,1).repeat(1, batch_size, 1)
        loss = 0

        if self.use_lstm:
            decoder_cell_state = self.decoder_initial_cell_state

        for i in range(target_length):
            if self.use_lstm:
                decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
            else:
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            log_probs = self.output_log_probs(decoder_output)
            loss += self.loss_func(log_probs, target[i, :])

            if use_teacher_forcing:
                decoder_input = self.embedding(target[i, :].unsqueeze(1)).squeeze(1)
            else:
                _, topi = log_probs.data.topk(1)
                ni = topi[0, 0]

                if ni == self.EOS_value:
                    break

                decoder_input = self.embedding(Variable([ni]).unsqueeze(1)).squeeze(1)

                if self.use_cuda:
                    decoder_input = decoder_input.cuda()

        return loss

    """
      Inputs must be of batch size 1
    """
    def point_wise_prediction(self, input, maximum_length=20):
      # encoded features
      encoded_features = self.encoder(input).squeeze(1) # [w, c]
      decoder_hidden = encoded_features[-1, :]
      decoder_input = self.embedding(self.SOS_token).squeeze(0)
      output_so_far = []

      if self.use_lstm:
          decoder_cell_state = self.decoder_initial_cell_state

      for i in range(maximum_length):
          if self.use_lstm:
              decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
          else:
              decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

          log_probs = self.output_log_probs(decoder_output)

          _, topi = log_probs.data.topk(1)
          ni = topi[0, 0]

          if ni == self.EOS_value:
              break

          output_so_far.append(ni)
          decoder_input = self.embedding(Variable([ni]).unsqueeze(1)).squeeze(1)

          if self.use_cuda:
              decoder_input = decoder_input.cuda()

      return output_so_far


    def beam_search_prediction(self, input, maximum_length=20):
      pass



In [None]:
class Tree_to_Sequence_Attention_Model(Tree_to_Sequence_Model):
    def __init__(self, encoder, decoder, hidden_size, nclass, embedding_size,
                 alignment_size, decoder_cell_state_shape=None, use_lstm=False):
        super(Sequence_to_Sequence_Attention_Model, self).__init__(encoder, decoder, hidden_size, nclass,
                                                                   decoder_cell_state_shape=decoder_cell_state_shape,
                                                                   use_lstm=use_lstm)

        self.attention_hidden = nn.Linear(hidden_size, alignment_size)
        self.attention_context = nn.Linear(hidden_size, alignment_size, bias=False)
        self.tanh = nn.Tanh()
        self.attention_alignment_vector = nn.Linear(encoded_size, 1)
        self.hidden_size = hidden_size

    """
        input: The output of the encoder for the input should have dimensions, (seq_len x batch_size x input_size)
        target: The target should have dimensions, (seq_len x batch_size), and should be a LongTensor.
    """
    def forward_train(self, input, target, use_teacher_forcing=False):
        # Think about what the dimensions should be in your case. Some of this code assumes batches are present.
        encoded_features = self.encoder(input) # [w, c]

        attention_hidden_values = self.attention_hidden(encoded_features)

        decoder_hidden = encoded_features[0, hidden_size//2:] # This needs to be tweaked to corresponded to the root.
        target_length, batch_size = target.size()
        word_input = self.embedding(self.SOS_token)

        loss = 0

        for i in range(target_length):
          attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden).unsqueeze(1)  + attention_hidden_values).squeeze(2)
          attention_probs = self.softmax(attention_logits, 1) # B x W
          context_vec = (attention_probs.unsqueeze(2) * encoded_features).sum(1) # B x C
          decoder_input = torch.cat((word_input, context_vec), dim=1)

          if self.use_lstm:
            decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
          else:
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

          log_probs = self.output_log_probs(decoder_output)
          loss += self.loss_func(log_probs, target[i, :])

          if use_teacher_forcing:
            word_input = self.embedding(target[i, :].unsqueeze(1)).squeeze(1)
          else:
            _, topi = log_probs.data.topk(1)
            ni = topi[0, 0]

            if ni == self.EOS_value:
              break

            word_input = self.embedding(Variable([ni]).unsqueeze(1)).squeeze(1)

            if self.use_cuda:
              word_input = word_input.cuda()

        return loss


    """
      Inputs must be of batch size 1
    """
    def point_wise_prediction(self, input, maximum_length=20):
      # encoded features
      encoded_features = self.encoder(input).squeeze(1) # [w, c]
      attention_hidden_values = self.attention_hidden(encoded_features)

      decoder_hidden = encoded_features[0, hidden_size//2:].unsqueeze(0) # This needs to be tweaked to corresponded to the root.
      word_input = self.embedding(self.SOS_token).squeeze(0)
      output_so_far = []

      if self.use_lstm:
          decoder_cell_state = self.decoder_initial_cell_state

      for i in range(maximum_length):
          attention_logits = self.attention_alignment_vector(self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
          attention_probs = self.softmax(attention_logits, 0) # W
          context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0) # C
          decoder_input = torch.cat((word_input, context_vec.unsqueeze(0)), dim=1)

          if self.use_lstm:
              decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (decoder_hidden, decoder_cell_state))
          else:
              decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

          log_probs = self.output_log_probs(decoder_output)

          _, topi = log_probs.data.topk(1)
          ni = topi[0, 0]

          if ni == self.EOS_value:
              break

          output_so_far.append(ni)
          word_input = self.embedding(Variable([ni]).unsqueeze(1)).squeeze(1)

          if self.use_cuda:
              word_input = word_input.cuda()

      return output_so_far

    def beam_search_prediction(self, input, maximum_length=20, beam_width=5):
        
        encoded_features = self.encoder(input).squeeze(1)  # [w, c]
        attention_hidden_values = self.attention_hidden(encoded_features)

        decoder_hidden = encoded_features[0, hidden_size // 2:].unsqueeze(
            0)  # This needs to be tweaked to corresponded to the root.
        word_inputs = [[0, self.embedding(self.SOS_token).squeeze(0)] for x in range(max_beam_width)]
        output_so_far = []

        if self.use_lstm:
            decoder_cell_state = self.decoder_initial_cell_state

        for i in range(maximum_length):
            attention_logits = self.attention_alignment_vector(
                self.attention_context(decoder_hidden) + attention_hidden_values).squeeze(1)
            attention_probs = self.softmax(attention_logits, 0)  # W
            context_vec = (attention_probs.unsqueeze(1) * encoded_features).sum(0)  # C
            newWordInputs = []
            for i in range(beam_width):
                word_input = word_inputs[i][1]
                decoder_input = torch.cat((word_input, context_vec.unsqueeze(0)), dim=1)

                if self.use_lstm:
                    decoder_output, (decoder_hidden, decoder_cell_state) = self.decoder(decoder_input, (
                    decoder_hidden, decoder_cell_state))
                else:
                    decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

                log_probs = self.output_log_probs(decoder_output)

                topk = log_probs.data.topk(beam_width)
                newWordInputs.append([topk[i][1] + word_inputs[i][0], topk[i][0]] for i in range(beam_width))
                
            #Get the new top 5 words
            ni = sorted(newWordInputs, key=lambda word_pair: word_pair[0])[:5]

            word_inputs = [self.embedding(Variable([ni[i][1]]).unsqueeze(1)).squeeze(1) for i in range(beam_width)]


            if self.use_cuda:
                word_input = word_input.cuda()

        