## 1 Import packages

In [4]:
import torch.nn as nn
import torch.optim
import torch.utils.data
import torch
import os
from argparse import Namespace
import random
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import time
from itertools import product

## 2 Seed randoms

In [None]:
manual_seed = "7777".__hash__() % (2 ** 32) #random.randint(1, 10000)
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
np.random.seed(manual_seed)

!mkdir results

In [None]:
token_list = ["exp", "log", "sqrt", "sin", "cos", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh",
              "+", "-", "*", "/", "x", "-5", "-4", "-3", "-2", "-1", "1", "2", "3", "4", "5"]

def token_to_one_hot_encoding(token):
    ind = token_list.index(token)
    return [1 if i == ind else 0 for i in range(len(token_list))]

def logit_to_token(prob_list):
    return token_list[prob_list.index(max(prob_list))]

In [None]:
def preprocess_input(inp):
    """
    Given batch of input, you should transform it to torch.Tensor.
    Currently, the input is assumed to be tuple...?
    """

In [None]:
class ChildSumTreeLSTM(nn.Module):
    def __init__(self, in_dim, cell_dim):
        super(ChildSumTreeLSTM, self).__init__()
        self.sigmoid = nn.sigmoid()
        self.tanh = nn.tanh()

        self.in_dim = in_dim
        self.cell_dim = cell_dim
        
        self.input_Wf = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wf = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.input_Wi = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wi = nn.Linear(self.cell_dim, self.cell_dim, bias = False)
        self.input_Wo = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wo = nn.Linear(self.cell_dim, self.cell_dim, bias = False)
        self.input_Wu = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wu = nn.Linear(self.cell_dim, self.cell_dim, bias = False)

    def forward_internal(self, inp, hiddens = None, cells = None):
        # inp : (batch_size, in_dim)
        # hiddens : (batch_size, child_num, cell_dim)
        # cells : (batch_size, child_num, cell_dim)
        if hiddens is None:
            hiddens = torch.zeros([inp.shape[0], 1, self.cell_dim])
        if cells is None:
            cells = torch.zeros([inp.shape[0], 1, self.cell_dim])

        num_child = hiddens.shape[1]
        new_hidden = torch.sum(hiddens, 1) # (batch_size, cell_dim)
        
        i_vec = self.sigmoid(self.input_Wi(inp) + self.hidden_Wi(new_hidden)) # (batch_size, cell_dim)
        o_vec = self.sigmoid(eslf.input_Wo(inp) + self.hidden_Wo(new_hidden)) # (batch_size, cell_dim)
        u_vec = self.tanh(eslf.input_Wu(inp) + self.hidden_Wu(new_hidden)) # (batch_size, cell_dim)

        flat_hidden = hiddens.view(-1, self.cell_dim) # (batch_size * child_num, cell_dim)
        input_f_vec = self.input_Wf(inp).repeat(1, num_child) # (batch_size, child_num, cell_dim)
        hidden_f_vec = self.hidden_Wf(flat_hidden).view(-1, num_child, self.cell_dim) # (batch_size, child_num, cell_dim)
        f_vec = self.sigmoid(input_f_vec + hidden_f_vec) # (batch_size, child_num, cell_dim)

        c_vec = i_vec * u_vec + torch.sum(cells * f_vec, 1) # (batch_size, cell_dim)
        h_vec = o_vec * self.tanh(c_vec) # (batch_size, cell_dim)

        return c_vec, h_vec

    def forward_internal(self, inp):
        if inp.shape[1] == 3:
            left_cell, left_hidden = forward(inp[:, 1])
            right_cell, right_hidden = forward(inp[:, 2])
            left_hidden = left_hidden.view(-1, 1, self.cell_dim)
            right_hidden = right_hidden.view(-1, 1, self.cell_dim)
            left_cell = left_cell.view(-1, 1, self.cell_dim)
            right_cell = right_cell.view(-1, 1, self.cell_dim)
            return forward_internal(inp[:, 0], torch.cat([left_hidden, right_hidden], 1), torch.cat([left_cell, right_cell], 1))
        elif inp.shape[1] == 2:
            cell, hidden = forward(inp[:, 1])
            return forward_internal(inp[:, 0], hidden.view(-1, 1, self.cell_dim), cell.view(-1, 1, self.cell_dim))
        else:
            return forward_internal(inp[:, 0])

    def forward(self, inp):
        return self.forward_internal(inp)


In [None]:
class BinaryTreeLSTM(nn.Module):
    def __init__(self, in_dim, cell_dim):
        super(BinaryTreeLSTM, self).__init__()
        self.in_dim = in_dim
        self.cell_dim = cell_dim

        self.sigmoid = nn.sigmoid()
        self.tanh = nn.tanh()

        self.input_Wi = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wo = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wu = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wlf = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wrf = nn.Linear(self.in_dim, self.cell_dim)

        self.left_Wi = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.left_Wo = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.left_Wu = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.left_Wlf = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.left_Wrf = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        
        self.right_Wi = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.right_Wo = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.right_Wu = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.right_Wlf = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.right_Wrf = nn.Linear(self.in_dim, self.cell_dim, bias=False)

    def forward_internal(self, inp, left_hidden=None, left_cell=None, right_hidden=None, right_cell=None):
        # inp : (batch_size, in_dim)
        # hiddens : (batch_size, child_num, cell_dim)
        # cells : (batch_size, child_num, cell_dim)
        if left_hidden is None:
            left_hidden = torch.zeros(inp.shape[0], self.in_dim, self.cell_dim)
        if right_hidden is None:
            right_hidden = torch.zeros(inp.shape[0], self.in_dim, self.cell_dim)
        if left_cell is None:
            left_cell = torch.zeros(inp.shape[0], self.in_dim, self.cell_dim)
        if right_cell is None:
            right_cell = torch.zeros(inp.shape[0], self.in_dim, self.cell_dim)
        
        i_vec = self.sigmoid(self.input_Wi(inp) + self.left_Wi(left_hidden) + self.right_Wi(right_hidden))
        o_vec = self.sigmoid(self.input_Wo(inp) + self.left_Wo(left_hidden) + self.right_Wo(right_hidden))
        u_vec = self.tanh(self.input_Wu(inp) + self.left_Wu(left_hidden) + self.right_Wu(right_hidden))
        left_f_vec = self.sigmoid(self.input_Wlf(inp) + self.left_Wlf(left_hidden) + self.right_Wlf(right_hidden))
        right_f_vec = self.sigmoid(self.input_Wrf(inp) + self.left_Wrf(left_hidden) + self.right_Wrf(right_hidden))

        c_vec = i_vec * u_vec + left_f_vec * left_cell + right_f_vec * right_cell
        h_vec = o_vec * self.tanh(c_vec)

        return c_vec, h_vec

    def forward(self, inp):
        if inp.shape[1] == 3:
            left_cell, left_hidden = forward(inp[:, 1])
            right_cell, right_hidden = forward(inp[:, 2])
            return forward_internal(inp[:, 0], left_hidden, left_cell, right_hidden, right_cell)
        elif inp.shape[1] == 2:
            cell, hidden = forward(inp[:, 1])
            return forward_internal(inp[:, 0], hidden, cell)
        else:
            return forward_internal(inp[:, 0])

In [None]:
class RecursiveNN(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super(RecuriveNN, self).__init__()
        self.dim = hidden_dim
        self.activation = nn.ReLU()

        self.inp_linear = nn.Linear(self.in_dim, self.hidden_dim)
        self.left_linear = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
        self.right_linear = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
    
    def forward(self, inp):
        if inp.shape[1] == 3:
            left = forward(inp[:, 1])
            right = forward(inp[:, 2])
            return self.activation(self.inp_linear(inp[:, 0]) + self.left_linear(left) + self.right_linear(right))
        elif inp.shape[1] == 2:
            left = forward(inp[:, 1])
            return self.activation(self.inp_linear(inp[:, 0]) + self.left_linear(left)
        else:
            return self.activation(self.inp_linear(inp[:, 0]))

In [None]:
class CompositionalSemantics(nn.Module):
    def __init__(self, hidden_dim):
        super(CompositionalSemantics, self).__init__()
        self.hidden_dim = hidden_dim
        self.unary_ops = {
            "exp" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "log" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sqrt" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sin" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "cos" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "tan" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "asin" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "acos" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "atan" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sinh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "cosh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "tanh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "asinh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "acosh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "atanh" : nn.Linear(self.hidden_dim, self.hidden_dim)
        }
        self.binary_ops = {
            "+" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "-" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "*" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "/" : nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        }
        self.terminals = {
            "x" : torch.rand(self.hidden_dim, requires_grad = True),
            "-5" : torch.rand(self.hidden_dim, requires_grad = True),
            "-4" : torch.rand(self.hidden_dim, requires_grad = True),
            "-3" : torch.rand(self.hidden_dim, requires_grad = True),
            "-2" : torch.rand(self.hidden_dim, requires_grad = True),
            "-1" : torch.rand(self.hidden_dim, requires_grad = True),
            "1" : torch.rand(self.hidden_dim, requires_grad = True),
            "2" : torch.rand(self.hidden_dim, requires_grad = True),
            "3" : torch.rand(self.hidden_dim, requires_grad = True),
            "4" : torch.rand(self.hidden_dim, requires_grad = True),
            "5" : torch.rand(self.hidden_dim, requires_grad = True),
        }
    
    def forward_internal(self, inp):
        if inp.shape[1] == 3:


    def forward(self, inp):
        """Assumes input is given as tree"""
        

In [None]:
class TBCNN(nn.Module):
    def __init__(self):
        super(TBCNN, self).__init__()
    
    def forward(self, inp):
        

In [None]:
# Code2Seq will use RNN/LSTM/... structures, so we don't need module. We need to define preprocess function.