In [None]:
cd ..

In [None]:
import torch.optim as optim
import torch.nn as nn
import datetime

from tree_to_sequence.program_datasets import *
from tree_to_sequence.translating_trees import map_tree

In [None]:
import collections

import torch
from torch.autograd import Variable


class Fold(object):

    class Node(object):
        def __init__(self, op, step, index, *args):
            self.op = op
            self.step = step
            self.index = index
            self.args = args
            self.split_idx = -1
            self.batch = True

        def split(self, num):
            """Split resulting node, if function returns multiple values."""
            nodes = []
            for idx in range(num):
                nodes.append(Fold.Node(
                    self.op, self.step, self.index, *self.args))
                nodes[-1].split_idx = idx
            return tuple(nodes)

        def nobatch(self):
            self.batch = False
            return self

        def get(self, values):
            return values[self.step][self.op].get(self.index, self.split_idx)

        def __repr__(self):
            return "[%d:%d]%s" % (
                self.step, self.index, self.op)

    class ComputedResult(object):
        def __init__(self, batch_size, batched_result):
            self.batch_size = batch_size
            self.result = batched_result
            if isinstance(self.result, tuple):
                self.result = list(self.result)

        def try_get_batched(self, nodes):
            all_are_nodes = all(isinstance(n, Fold.Node) for n in nodes)
            num_nodes_is_equal = len(nodes) == self.batch_size
            if not all_are_nodes or not num_nodes_is_equal:
                return None

            valid_node_sequence = all(
                nodes[i].index < nodes[i + 1].index  # Indices are ordered
                and nodes[i].split_idx == nodes[i + 1].split_idx  # Same split index
                and nodes[i].step == nodes[i + 1].step  # Same step
                and nodes[i].op == nodes[i + 1].op  # Same op
                for i in range(len(nodes) - 1))
            if not valid_node_sequence:
                return None

            if nodes[0].split_idx == -1 and not isinstance(self.result, tuple):
                return self.result
            elif nodes[0].split_idx >= 0 and not isinstance(self.result[nodes[0].split_idx], tuple):
                return self.result[nodes[0].split_idx]
            else:
                # This result was already chunked.
                return None

        def get(self, index, split_idx=-1):
            if split_idx == -1:
                if not isinstance(self.result, tuple):
                    self.result = torch.chunk(self.result, self.batch_size)
                return self.result[index]
            else:
                if not isinstance(self.result[split_idx], tuple):
                    self.result[split_idx] = torch.chunk(self.result[split_idx], self.batch_size)
                return self.result[split_idx][index]

    def __init__(self, volatile=False, cuda=False):
        self.steps = collections.defaultdict(
            lambda: collections.defaultdict(list))
        self.cached_nodes = collections.defaultdict(dict)
        self.total_nodes = 0
        self.volatile = volatile
        self._cuda = cuda

    def cuda(self):
        self._cuda = True
        return self

    def add(self, op, *args):
        """Add op to the fold."""
        self.total_nodes += 1
        if not all([isinstance(arg, (
                Fold.Node, int, torch._C._TensorBase, Variable)) for arg in args]):
            
            raise ValueError(
                "All args should be Tensor, Variable, int or Node, got: %s" % str(args))
        if args not in self.cached_nodes[op]:
            step = max([0] + [arg.step + 1 for arg in args
                              if isinstance(arg, Fold.Node)])
            node = Fold.Node(op, step, len(self.steps[step][op]), *args)
            self.steps[step][op].append(args)
            self.cached_nodes[op][args] = node
        return self.cached_nodes[op][args]

    def _batch_args(self, arg_lists, values):
        res = []
        for arg in arg_lists:
            r = []
            if all(isinstance(arg_item, Fold.Node) for arg_item in arg): # Check if everything's a node
                assert all(arg[0].batch == arg_item.batch
                           for arg_item in arg[1:])
                if arg[0].batch:
                    batched_arg = values[arg[0].step][arg[0].op].try_get_batched(arg)
                    if batched_arg is not None:
                        res.append(batched_arg)
                    else:
                        res.append(
                            torch.cat([arg_item.get(values)
                                       for arg_item in arg], 0))
                else:
                    for arg_item in arg[1:]:
                        if arg_item != arg[0]:
                            raise ValueError("Can not use more then one of nobatch argument, got: %s." % str(arg_item))
                    res.append(arg[0].get(values))
            elif all(isinstance(arg_item, int) for arg_item in arg): # Check if everything's an int
                if self._cuda:
                    var = Variable(
                        torch.cuda.LongTensor(arg), volatile=self.volatile)
                else:
                    var = Variable(
                        torch.LongTensor(arg), volatile=self.volatile)
                res.append(var)
            else: # Check for a mix of tensors and nodes
                for arg_item in arg:
                    if isinstance(arg_item, Fold.Node):
                        assert arg_item.batch
                        r.append(arg_item.get(values))
                    elif isinstance(arg_item, (torch._C._TensorBase, Variable)):
                        r.append(arg_item)
                    else:
                        raise ValueError(
                            'Not allowed to mix Fold.Node/Tensor with int')
                res.append(torch.cat(r, 0))
        return res

    def apply(self, nn, nodes):
        """Apply current fold to given neural module."""
        values = {} # dict of dicts where steps are keys
        for step in sorted(self.steps.keys()):
            values[step] = {}
            for op in self.steps[step]:
                func = getattr(nn, op) # get the function to call
                try:
                    batched_args = self._batch_args(
                        zip(*self.steps[step][op]), values) # Make a batch out of the calls with the same step and op
                except Exception:
                    x = self.steps[step][op][0]
                    print("Error while executing node %s[%d] with args: %s" % (
                        op, step, self.steps[step][op][0]))
                    raise
                if batched_args:
                    arg_size = batched_args[0].size()[0] # Check the size of each element in the batch
                else:
                    arg_size = 1
                res = func(*batched_args) # Call the func, get the result
                values[step][op] = Fold.ComputedResult(arg_size, res) # sstore the result in the values dict
        try:
            return self._batch_args(nodes, values)
        except Exception:
            print("Retrieving %s" % nodes)
            for lst in nodes:
                if isinstance(lst[0], Fold.Node):
                    print(', '.join([str(x.get(values).size()) for x in lst]))
            raise

    def __str__(self):
        result = ''
        for step in sorted(self.steps.keys()):
            result += '%d step:\n' % step
            for op in self.steps[step]:
                first_el = ''
                for arg in self.steps[step][op][0]:
                    if first_el: first_el += ', '
                    if isinstance(arg, (torch.tensor._TensorBase, Variable)):
                        first_el += str(arg.size())
                    else:
                        first_el += str(arg)
                result += '\t%s = %d x (%s)\n' % (
                    op, len(self.steps[step][op]), first_el)
        return result

    def __repr__(self):
        return str(self)


class Unfold(object):
    """Replacement of Fold for debugging, where it does computation right away."""

    class Node(object):

        def __init__(self, tensor):
            self.tensor = tensor

        def __repr__(self):
            return str(self.tensor)

        def nobatch(self):
            return self

        def split(self, num):
            return [Unfold.Node(self.tensor[i]) for i in range(num)]

    def __init__(self, nn, volatile=False, cuda=False):
        self.nn = nn
        self.volatile = volatile
        self._cuda = cuda

    def cuda(self):
        self._cuda = True
        return self

    def _arg(self, arg):
        if isinstance(arg, Unfold.Node):
            return arg.tensor
        elif isinstance(arg, int):
            if self._cuda:
                return Variable(torch.cuda.LongTensor([arg]), volatile=self.volatile)
            else:
                return Variable(torch.LongTensor([arg]), volatile=self.volatile)
        else:
            return arg

    def add(self, op, *args):
        values = []
        for arg in args:
            values.append(self._arg(arg))
        res = getattr(self.nn, op)(*values)
        return Unfold.Node(res)

    def apply(self, nn, nodes):
        if nn != self.nn:
            raise ValueError("Expected that nn argument passed to constructor and passed to apply would match.")
        result = []
        for n in nodes:
            result.append(torch.cat([self._arg(a) for a in n]))
        return result

In [None]:
# class TreeLSTM(nn.Module):
#     def __init__(self, num_units):
#         super(TreeLSTM, self).__init__()
#         self.num_units = num_units
#         self.left = nn.Linear(num_units, 5 * num_units)
#         self.right = nn.Linear(num_units, 5 * num_units)

#     def forward(self, left_in, right_in):
#         lstm_in = self.left(left_in[0])
#         lstm_in += self.right(right_in[0])
#         a, i, f1, f2, o = lstm_in.chunk(5, 1)
#         c = (a.tanh() * i.sigmoid() + f1.sigmoid() * left_in[1] +
#              f2.sigmoid() * right_in[1])
#         h = o.sigmoid() * c.tanh()
#         return h, c

In [None]:
# class Node:
#     """
#     Node class
#     """
#     def __init__(self, value):
#         self.id = value
#         self.left = None
#         self.right = None
#         self.label = 4 
        
#     def is_leaf(self):
#         return self.left is None and self.right is None

In [None]:
# class SPINN(nn.Module):

#     def __init__(self, n_classes, size, n_words):
#         super(SPINN, self).__init__()
#         self.size = size
#         self.tree_lstm = TreeLSTM(size)
#         self.embeddings = nn.Embedding(n_words, size)
#         self.out = nn.Linear(size, n_classes)

#     def leaf(self, word_id):
#         return self.embeddings(word_id), (torch.FloatTensor(word_id.size()[0], self.size))

#     def children(self, left_h, left_c, right_h, right_c):
#         return self.tree_lstm((left_h, left_c), (right_h, right_c))

#     def logits(self, encoding):
#         return self.out(encoding)

In [None]:
# def encode_tree_regular(model, tree):
#     def encode_node(node):
#         if node.is_leaf():
#             return model.leaf(torch.LongTensor([node.id]))
#         else:
#             left_h, left_c = encode_node(node.left)
#             right_h, right_c = encode_node(node.right)
#             return model.children(left_h, left_c, right_h, right_c)
#     encoding, _ = encode_node(tree)
#     return model.logits(encoding)


In [None]:
# def encode_tree_fold(fold, tree):
#     def encode_node(node):
#         if node.is_leaf():
#             return fold.add('leaf', node.id).split(2)
#         else:
#             left_h, left_c = encode_node(node.left)
#             right_h, right_c = encode_node(node.right)
#             return fold.add('children', left_h, left_c, right_h, right_c).split(2)
#     encoding, _ = encode_node(tree)
#     return fold.add('logits', encoding), "whatever"

In [None]:
# n1 = Node(1)
# n2 = Node(2)
# n3 = Node(3)
# n4 = Node(4)
# n5 = Node(5)
# n6 = Node(6)

# n1.left = n3
# n1.right = n5

# n2.left = n6
# n2.right = n4

# #n1 and n2 are separate trees
# n_classes = 4
# size = 10
# n_words = 7

# model = SPINN(n_classes, size, n_words)

In [None]:
# fold = Fold()

# all_logits, all_labels = [], []
# for tree in [n1, n2]:
#     all_logits.append(encode_tree_fold(fold, tree))
#     all_labels.append(tree.label)

# res = fold.apply(model, [all_logits, all_labels])
# loss = criterion(res[0], res[1])
# print(loss)

In [None]:
# def criterion(a, b):
#     return 3

In [None]:
# all_logits, all_labels = [], []
# for tree in [n1, n2]:
#     all_logits.append(encode_tree_regular(model, tree))
#     all_labels.append(tree.label)

# loss = criterion(torch.cat(all_logits, 0), torch.LongTensor(all_labels))

In [None]:
# import torch
# import torch.nn as nn

# from tree_to_sequence.tree_lstm import TreeLSTM
# from tree_to_sequence.translating_trees import map_tree, tree_to_list


In [None]:
class BinaryTreeLSTM(nn.Module):
    '''
    BinaryTreeLSTM

    Takes in a binary tree where each node has a value and a list of children.
    Produces a tree of the same size where the value of each node is now encoded.
    '''

    def __init__(self, input_size, hidden_size):
        """
        Initialize tree cell we'll need later.
        """
        super(BinaryTreeLSTM, self).__init__()

        self.tree_lstm = TreeCell(input_size, hidden_size, 2)            
        self.register_buffer('zero_buffer', torch.zeros(1, hidden_size))
        
    def encode_none_node(self):
        """
        :return annotations, hidden, cell
        """
        return self.zero_buffer.unsqueeze(1), self.zero_buffer, self.zero_buffer
    
    # TODO: Later make this stackable
    def encode_node_with_children(self, value, leftA, leftH, leftC, rightA, rightH, rightC):
        """
        :return annotations, hidden, cell
        """
        print("INPUTS")
        print("value", value.shape)
        print("annotations", leftA.shape, rightA.shape)
        print("hiddens", leftH.shape, rightH.shape)
        print("cell states", leftC.shape, rightC.shape)
        newH, newC = self.tree_lstm(value, [leftH, rightH], [leftC, rightC])
        newA = newH.unsqueeze(1)
        
        newA = torch.cat([newA, leftA.float(), rightA.float()])
        return newA, newH, newC


In [None]:
class TreeCell(nn.Module):
    """
    LSTM Cell which takes in arbitrary numbers of hidden and cell states (one per child).
    """
    def __init__(self, input_size, hidden_size, num_children):
        """
        Initialize the LSTM cell.

        :param input_size: length of input vector
        :param hidden_size: length of hidden vector (and cell state)
        :param num_children: number of children = number of hidden/cell states passed in
        """
        super(TreeCell, self).__init__()

        # Gates = input, output, memory + one forget gate per child
        numGates = 3 + num_children

        self.gates_value = nn.ModuleList()
        self.gates_children = nn.ModuleList()
        for _ in range(numGates):
            # One linear layer to handle the value of the node
            value_linear = nn.Linear(input_size, hidden_size, bias=True)
            children_linear = nn.ModuleList()
            # One per child of the node
            for _ in range(num_children):
                children_linear.append(nn.Linear(hidden_size, hidden_size, bias=False))
            self.gates_value.append(value_linear)
            self.gates_children.append(children_linear)

        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, hidden_states, cell_states):
        """
        Calculate a new hidden state and a new cell state from the LSTM gates

        :param hidden_states: A list of num_children hidden states.
        :param cell_states: A list of num_children cell states.
        :return A tuple containing (new hidden state, new cell state)
        """

        data_sums = []

        for i in range(len(self.gates_value)):
            data_sum = self.gates_value[i](input)
            for j in range(len(hidden_states)):
                data_sum += self.gates_children[i][j](hidden_states[j])
            data_sums.append(data_sum)

        # First gate is the input gate
        input_val = self.sigmoid(data_sums[0])
        # Next output gate
        o = self.sigmoid(data_sums[1])
        # Next memory gate
        m = self.tanh(data_sums[2])
        # All the rest are forget gates
        forget_data = 0
        for i in range(len(cell_states)):
            forget_data += self.sigmoid(data_sums[3 + i]) * cell_states[i]

        # Put it all together!
        new_state = input_val * m + forget_data
        new_hidden = o * self.tanh(new_state)

        return new_hidden, new_state

    def initialize_forget_bias(self, bias_value):
        for i in range(3, len(self.gates_value)):
            nn.init.constant_(self.gates_value[i].bias, bias_value)


In [None]:
def forward(fold, node):
    value = node.value
    if value is None:
        return fold.add('encode_none_node').split(3)
    
    # List of tuples: (node, cell state)
    children = []

    for child in node.children:
        encoded = forward(fold, child)
        children += list(encoded)
        
    while len(children) < 6:
        children += fold.add('encode_none_node').split(3)
        
    return  fold.add('encode_node_with_children', value, *children).split(3)

   

In [None]:
def unsqueeze(node):
    if node.value is not None:
        node.value = node.value.unsqueeze(0)
    for child in node.children:
        unsqueeze(child)
    return node

In [None]:
for_lambda_dset = ForLambdaDataset("ANC/AdditionalForDatasets/ForWithLevels/Easy-arbitraryForList.json", binarize_input=True, 
                                   binarize_output=True, eos_token=True, one_hot=True, 
                                   long_base_case=False, input_as_seq=False,
                                   output_as_seq=False)
start = datetime.datetime.now()
trees = [unsqueeze(tree[0]) for tree in for_lambda_dset]
trees = [tree[0] for tree in for_lambda_dset]
tree_lstm = BinaryTreeLSTM(trees[0].value.size()[-1], 3)
fold = Fold()
result = [forward(fold, tree) for tree in trees]
annotations = [x[0] for x in result]
hiddens = [x[1] for x in result]
print("INTERMEDIATE HIDDENS", hiddens)
cell_states = [x[2] for x in result]
x = fold.apply(tree_lstm, [annotations, hiddens, cell_states])
end = datetime.datetime.now()
print("done")
print("DIFF", end - start)

In [None]:
# 0:00:00.004358
        
# NEW
# 0:00:11.194712
#         0:00:10.217078
#                 0:00:13.408681
#                         0:00:15.319642
                                
# OLD
# 0:00:52.612726
#         0:00:50.966683
#                 0:00:52.092692

In [None]:
def decode(fold, decoder_hiddens, decoder_cell_states, targetNode, parent_val, child_index, annotations): # Assumes teacher forcing is true
    
    et = fold.add("calc_attention", decoder_hiddens, annotations)
    loss = fold.add("calc_loss", parent_val, child_index, et, targetNode.value)
    next_input = targetNode.value

    decoder_input = fold.add("get_next_decoder_input", next_input, et)

    for i, child in enumerate(targetNode.children):
        # Parent node of a node's children is that node
        parent = next_input
#         new_child_index = fold.add("identity", i).nobatch()
        new_child_index = i
#         new_lstm = fold.add('get_lstm', new_child_index)#.nobatch()
        child_hiddens, child_cell_states = fold.add("get_next_child_states", parent, 
                                                                new_child_index,
                                                                 decoder_input, 
                                                                 decoder_hiddens, 
                                                                 decoder_cell_states).split(2)
#         child_hiddens, child_cell_states = fold.add("get_next_child_states", parent, new_child_index, 
#                                                                  decoder_input, 
#                                                                  decoder_hiddens, 
#                                                                  decoder_cell_states).split(2)

        new_loss = decode(fold, child_hiddens, child_cell_states, child, parent, new_child_index, annotations)
        loss = fold.add("plus", loss, new_loss)
        
    return loss
    
    
    
    
    
    

In [None]:
class TreeToTreeAttention(nn.Module):
    def __init__(self, decoder, hidden_size, embedding_size, nclass,
                 root_value=-1, alignment_size=50, align_type=1, max_size=50): #TODO: Add encoder back!
        """
        Translates an encoded representation of one tree into another
        """
        super(TreeToTreeAttention, self).__init__()
        
        # Save useful values
        self.nclass = nclass
#         self.encoder = encoder
        self.decoder = decoder
        self.align_type = align_type
        self.max_size = max_size
        self.root_value = root_value
        
        # EOS is always the last token
        self.EOS_value = nclass
        
        # Useful functions
        self.softmax = nn.Softmax(0)
        self.tanh = nn.Tanh()
        
        # Set up attention
        if align_type == 0:
            self.attention_hidden = nn.Linear(hidden_size, alignment_size)
            self.attention_context = nn.Linear(hidden_size, alignment_size, bias=False)
            self.attention_alignment_vector = nn.Linear(alignment_size, 1)
        elif align_type == 1:
            self.attention_hidden = nn.Linear(hidden_size, hidden_size)
            
        self.attention_presoftmax = nn.Linear(2 * hidden_size, hidden_size)
        self.embedding = nn.Embedding(nclass + 1, embedding_size)  
    
    def calc_attention(self, decoder_hiddens, annotations):
        #TODO: Move this part to an outer func
        if self.align_type <= 1:
            attention_hidden_values = self.attention_hidden(annotations)
        else:
            attention_hidden_values = annotations
        
        # Use attention and past hidden state to generate scores
        decoder_hidden = decoder_hiddens[-1].unsqueeze(0)
        attention_logits = self.attention_logits(attention_hidden_values, decoder_hidden)
        attention_probs = self.softmax(attention_logits) # number_of_nodes x 1
        
        context_vec = (attention_probs * annotations).sum(1).unsqueeze(1) #1 x 1 x hidden_size
#         print("INPUTS")
#         print("ANNOTATIONS", annotations.shape)
#         print("HHH", decoder_hiddens.shape)
#         print("CCC", context_vec.shape)
        et = self.tanh(self.attention_presoftmax(torch.cat((decoder_hiddens, context_vec), 
                                                       dim=2))) # 1 x hidden_size
        return et
    
    def calc_loss(self, parent, child_index, et, true_value): # this should be deocder specific
        return self.decoder.calculate_loss(parent, child_index, et, true_value)
    
    def get_next_decoder_input(self, next_input, et):
        print("ET", et.shape)
        print("embedding", self.embedding(next_input).shape)
        print("INPUT", next_input.shape)
        return torch.cat((self.embedding(next_input), et), 2)
        
    def get_next_child_states(self, parent, child_index, input, hidden_state, cell_state): # should be decoder specific
        return self.decoder.get_next_child_states(parent, child_index, input, hidden_state, cell_state)
    
    def plus(self, first, second):
        return first + second
    
    def attention_logits(self, attention_hidden_values, decoder_hidden):
        """
        Calculates the logits over the nodes in the input tree.
        """
        if self.align_type == 0:
            return self.attention_alignment_vector(self.tanh(self.attention_context(decoder_hidden) 
                                                             + attention_hidden_values))
        else:
            return (decoder_hidden * attention_hidden_values).sum(1).unsqueeze(1) 
        
    def decode(self, fold, decoder_hiddens, decoder_cell_states, targetNode, parent_val, child_index, annotations): # Assumes teacher forcing is true
    
#         print("ORIG")
#         print("HH", decoder_hiddens.shape)
#         print("CC", decoder_cell_states.shape)
#         print("AA", annotations.shape)
        et = fold.add("calc_attention", decoder_hiddens, annotations)
        loss = fold.add("calc_loss", parent_val, child_index, et, targetNode.value)
        next_input = targetNode.value

        decoder_input = fold.add("get_next_decoder_input", next_input, et)

        for i, child in enumerate(targetNode.children):
            # Parent node of a node's children is that node
            parent = next_input
    #         new_child_index = fold.add("identity", i).nobatch()
            new_child_index = i
    #         new_lstm = fold.add('get_lstm', new_child_index)#.nobatch()
            child_hiddens, child_cell_states = fold.add("get_next_child_states", parent, 
                                                                    new_child_index,
                                                                     decoder_input, 
                                                                     decoder_hiddens, 
                                                                     decoder_cell_states).split(2)
    #         child_hiddens, child_cell_states = fold.add("get_next_child_states", parent, new_child_index, 
    #                                                                  decoder_input, 
    #                                                                  decoder_hiddens, 
    #                                                                  decoder_cell_states).split(2)

            new_loss = decode(fold, child_hiddens, child_cell_states, child, parent, new_child_index, annotations)
            loss = fold.add("plus", loss, new_loss)

        return loss


In [None]:
class TreeDecoder(nn.Module):
    def __init__(self, embedding_size, hidden_size, max_num_children, nclass):
        """
        :param embedding_size: length of the encoded representation of a node
        :param hidden_size: hidden state size
        :param max_num_children: max. number of children a node can have
        :param nclass: number of different tokens which could be in a tree (not counting end
                       of tree token)
        """
        super(TreeDecoder, self).__init__()
                
#         self.loss_func = nn.CrossEntropyLoss() #TODO: make this work!!!
        
        # Linear layer to calculate log odds. The one is to account for the eos token.
        self.output_log_odds = nn.Linear(hidden_size, nclass + 1)        
        
        # Create a separate lstm for each child index
        self.lstm_list = nn.ModuleList()
        
        self.max_num_children = max_num_children
        
        for i in range(self.max_num_children):
            self.lstm_list.append(nn.LSTMCell(embedding_size + hidden_size, hidden_size))
            
    def loss_func(self, a, b):
        print("FIRST", a.shape)
        print("SECOND", b.shape)
        return torch.ones(a.size()[0], 1)
#         return a.float() - b.float() #TODO: do this correctly!
    
    def calculate_loss(self, parent, child_index, et, true_value):
        """
        Calculate cross entropy loss from et.
        
        :param parent: node's parent (dummy; used for compatibility with grammar decoder)
        :param child_index: index of generated child (dummy; used for compatibility with grammar 
                            decoder)
        :param et: vector incorporating info from the attention and hidden state of past node
        :param true_value: true value of the new node
        :returns: cross entropy loss
        """
        log_odds = self.output_log_odds(et)
        return self.loss_func(log_odds, true_value)
    
#     def get_lstm(self, child_index):
#         return self.lstm_list[0]
#         return self.lstm_list[child_index] #TODO: ACTUALLY DO SOMETHING!
    
    def get_next_child_states(self, parent, child_index, input, hidden_state, cell_state):
        """
        Generate the hidden and cell states which will be used to generate the current node's 
        children
        
        :param parent: node's parent (dummy; used for compatibility with grammar decoder)
        :param child_index: index of generated child
        :param input: embedded reprentation of the node's parent
        :param hidden_state: hidden state generated by an lstm
        :param cell_state: cell state generated by an lstm
        """
#         lstm(input, (hidden_state, cell_state))
        hiddens = []
        cell_states = []
        for i in child_index:
            hidden, cell = self.lstm_list[i](input[i], (hidden_state[i], cell_state[i]))
            hiddens.append(hidden)
            cell_states.append(cell)
        return torch.cat(hiddens, dim=0).unsqueeze(1), torch.cat(cell_states, dim=0).unsqueeze(1)
    

In [None]:
use_cuda = True
num_vars = 10
num_ints = 11
input_as_seq = False
output_as_seq = False
one_hot = False
binarize_input = True
binarize_output = False
eos_token = True
long_base_case = True
input_size = num_vars + num_ints + len(for_ops.keys()) + 1

In [None]:
binarize_output=True
num_vars = 10
num_ints = 11
embedding_size = 256
hidden_size = 3
num_layers = 1
alignment_size = 50
align_type = 1
encoder_input_size = num_vars + num_ints + len(for_ops)
nclass = num_vars + num_ints + len(lambda_ops)
plot_every = 100
max_num_children = 2 if binarize_output else 4
max_size=50

decoder = TreeDecoder(embedding_size, hidden_size, max_num_children, nclass=nclass)
program_model = TreeToTreeAttention(decoder, hidden_size, embedding_size, nclass=nclass, max_size=max_size,
                                    alignment_size=alignment_size, align_type=align_type)

In [None]:
for_lambda_dset = ForLambdaDataset("ANC/AdditionalForDatasets/ForWithLevels/Easy-arbitraryForList.json", binarize_input=binarize_input, 
                                   binarize_output=binarize_output, eos_token=eos_token, one_hot=one_hot, 
                                   long_base_case=long_base_case, input_as_seq=input_as_seq,
                                   output_as_seq=output_as_seq)
start = datetime.datetime.now()
trees = [unsqueeze(tree[0]) for tree in for_lambda_dset]
embedding = nn.Embedding(input_size, embedding_size)
        
trees = [tree[0] for tree in for_lambda_dset]
if not one_hot:
    trees = [map_tree(lambda node: embedding(node).squeeze(0), tree) for tree in trees]
target_trees = [map_tree(lambda node: node.unsqueeze(0), tree[1]) for tree in for_lambda_dset]
print("VAL SHAPES", trees[0].value.shape, target_trees[0].value.shape)
fold = Fold()
result = [forward(fold, tree) for tree in trees]
annotations = [x[0] for x in result]
hiddens = [x[1] for x in result]
cell_states = [x[2] for x in result]
tree_lstm = BinaryTreeLSTM(trees[0].value.size()[-1], 3)
x = fold.apply(tree_lstm, [annotations, hiddens, cell_states])
print("early stuff worked!")
fold2 = Fold()
annotations_real = [vec.unsqueeze(0) for vec in x[0]]
hiddens_real = [vec.unsqueeze(0).unsqueeze(0) for vec in x[1]]
cell_states_real = [vec.unsqueeze(0).unsqueeze(0) for vec in x[2]]

losses = [program_model.decode(fold2, hidden, cell_state, tree, -1, 0, annotation) for hidden, cell_state, tree, annotation in zip(hiddens_real, cell_states_real, target_trees, annotations_real)]
y = fold2.apply(program_model, [losses])
end = datetime.datetime.now()
print("DONE")