<a href="https://colab.research.google.com/github/mekty2012/CS470_SymbolicIntegration/blob/main/CS470_SymbolicIntegration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Import packages

In [None]:
import torch.nn as nn
import torch.optim
import torch.utils.data
import torch

import os
from argparse import Namespace
import random
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import time
from itertools import product

## 2. Seed randoms

In [None]:
manual_seed = "7777".__hash__() % (2 ** 32) #random.randint(1, 10000)
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
np.random.seed(manual_seed)

!mkdir results

## 3. Implement preprocessing / postprocessing

In [None]:
token_list = ["exp", "log", "sqrt", "sin", "cos", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh",
              "+", "-", "*", "/", "x", "-5", "-4", "-3", "-2", "-1", "1", "2", "3", "4", "5"]

def token_to_one_hot_encoding(token):
    ind = token_list.index(token)
    return [1 if i == ind else 0 for i in range(len(token_list))]

def logit_to_token(prob_list):
    return token_list[prob_list.index(max(prob_list))]

## 4. Implement models

### 1. ChildSumLSTM

Given inp $v$, hiddens $h_i$, cells $c_i$, computes

$h = \sum_{i=1}^C h_i$

$v_i = \sigma(I_i(v) + H_i(h))$

$v_o = \sigma(I_o(v) + H_o(h)$

$v_u = \tanh(I_u(v) + H_u(h)$

$v_{f}^i = \sigma(I_f(v) + h_i)$

$v_c = v_i + v_u + \sum_{i=1}^C v_{f}^i * c_i$

$v_h = v_o * \tanh(v_c)$

You can stack ChildSumLSTM by letting parameter `return_sequences = True`.

In [None]:
def postorder_checker(depth):
    if depth == 1:
        return [True]
    else:
        return postorder_checker(depth - 1) * 2 + [False]

class ChildSumTreeLSTM(nn.Module):
    def __init__(self, in_dim, cell_dim, return_sequences = False):
        super(ChildSumTreeLSTM, self).__init__()
        self.return_sequences = return_sequences
        self.sigmoid = nn.sigmoid()
        self.tanh = nn.tanh()

        self.in_dim = in_dim
        self.cell_dim = cell_dim
        
        self.input_Wf = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wf = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.input_Wi = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wi = nn.Linear(self.cell_dim, self.cell_dim, bias = False)
        self.input_Wo = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wo = nn.Linear(self.cell_dim, self.cell_dim, bias = False)
        self.input_Wu = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wu = nn.Linear(self.cell_dim, self.cell_dim, bias = False)
        
        
    def forward(self, inps):
        # Assumes that input is given as postorder traversal.
        # inps : (batch_size, node_num, input_dim)
        batch_size = inps.shape[0]
        if self.return_sequences:
            result_seq = None
        stack = []
        size = inps.shape[1] # Get number of nodes
        depth = size.bit_length()
        checker = postorder_checker(depth)
        for i, tf in enumerate(checker):
            inp = inps[:, i, :]
            if tf: # external node
                hiddens = torch.zeros(batch_size, 2, self.cell_dim])
                cells = torch.zeros(batch_size, 2, self.cell_dim])
                new_hidden = torch.zeros(batch_size, 1, self.cell_dim])
            else:
                assert len(stack) >= 2
                left_hidden, left_cell = stack.pop()
                right_hidden, right_cell = stack.pop()
                new_hidden = torch.sum(left_hidden, right_hidden)
                
            i_vec = self.sigmoid(self.input_Wi(inp) + self.hidden_Wi(new_hidden)) # (batch_size, cell_dim)
            o_vec = self.sigmoid(eslf.input_Wo(inp) + self.hidden_Wo(new_hidden)) # (batch_size, cell_dim)
            u_vec = self.tanh(eslf.input_Wu(inp) + self.hidden_Wu(new_hidden)) # (batch_size, cell_dim)

            flat_hidden = torch.cat([left_hidden, right_hidden], dim=1)
            input_f_vec = self.input_Wf(inp).repeat(1, 2)
            hidden_f_vec = self.hidden_Wf(flat_hidden).view(-1, 2, self.cell_dim)
            f_vec = self.sigmoid(input_f_vec + hidden_f_vec)
               
            c_vec = i_vec * u_vec + torch.sum(cells * f_vec, 1)
            h_vec = o_vec * self.tanh(c_vec)

            stack.append((h_vec, c_vec))
            if self.return_sequences:
                if return_seq is None:
                    return_seq = c_vec
                else:
                    return_seq = torch.cat([return_seq, c_vec], dim=1)
        assert len(stack) == 1
        if self.return_sequences:
            return return_seq
        else:
            return stack[0]


### 2. Binary Tree LSTM

Given inp $v$, hiddens $h_l, h_r$, cells $c_l, c_r$, computes

$v_i = \sigma(I_i(v) + H_i^l(h_l) + H_i^r(h_r)$

$v_u = \sigma(I_u(v) + H_u^l(h_l) + H_u^r(h_r)$

$v_o = \tanh(I_o(v) + H_o^l(h_l) + H_o^r(h_r)$

$v_{f}^l = \sigma(I_{lf}^l(v) + H_{lf}^l(h_l) + H_{lf}^l(h_r)$

$v_{f}^r = \sigma(I_{rf}^l(v) + H_{rf}^l(h_l) + H_{rf}^l(h_r)$

$v_c = v_i + v_u + v_f^l * c_l + v_f^r * c_r$

$v_h = v_o * \tanh(v_c)$

You can stack BinaryTreeLSTM by letting parameter `return_sequences = True`.

In [None]:
class BinaryTreeLSTM(nn.Module):
    def __init__(self, in_dim, cell_dim, return_sequences = False):
        super(BinaryTreeLSTM, self).__init__()
        self.in_dim = in_dim
        self.cell_dim = cell_dim
        self.return_sequences = return_sequences

        self.sigmoid = nn.sigmoid()
        self.tanh = nn.tanh()

        self.input_Wi = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wo = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wu = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wlf = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wrf = nn.Linear(self.in_dim, self.cell_dim)

        self.left_Wi = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.left_Wo = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.left_Wu = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.left_Wlf = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.left_Wrf = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        
        self.right_Wi = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.right_Wo = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.right_Wu = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.right_Wlf = nn.Linear(self.in_dim, self.cell_dim, bias=False)
        self.right_Wrf = nn.Linear(self.in_dim, self.cell_dim, bias=False)

    def forward(self, inps):
        batch_size = inps.shape[0]
        if self.return_sequences:
            result_seq = None
        stack = []
        size = inps.shape[1] # Get number of nodes
        depth = size.bit_length()
        checker = postorder_checker(depth)
        for i, tf in enumerate(checker):
            if tf:
                left_hidden = torch.zeros(batch_size, self.cell_dim)
                right_hidden = torch.zeros(batch_size, self.cell_dim)
                left_cell = torch.zeros(batch_size, self.cell_dim)
                right_cell = torch.zeros(batch_size, self.cell_dim)
            else:
                right_hidden, right_cell = stack.pop()
                left_hidden, left_cell = stack.pop()
            
            i_vec = self.sigmoid(self.input_Wi(inp) + self.left_Wi(left_hidden) + self.right_Wi(right_hidden))
            o_vec = self.sigmoid(self.input_Wo(inp) + self.left_Wo(left_hidden) + self.right_Wo(right_hidden))
            u_vec = self.tanh(self.input_Wu(inp) + self.left_Wu(left_hidden) + self.right_Wu(right_hidden))
            left_f_vec = self.sigmoid(self.input_Wlf(inp) + self.left_Wlf(left_hidden) + self.right_Wlf(right_hidden))
            right_f_vec = self.sigmoid(self.input_Wrf(inp) + self.left_Wrf(left_hidden) + self.right_Wrf(right_hidden))

            c_vec = i_vec * u_vec + left_f_vec * left_cell + right_f_vec * right_cell
            h_vec = o_vec * self.tanh(c_vec)

            stack.append((h_vec, c_vec))
            if self.return_sequences:
                if result_seq is None:
                    result_seq = c_vec
                else:
                    result_seq = torch.cat(result_seq, c_vec)
        
        assert len(stack) == 1
        if self.return_sequences:
            return result_seq
        else:
            return stack[0]

### 3. Recursive Neural Network

Given input $v_i$, left $v_l$, right $v_r$, computes

$v = act(L_i(v_i) + L_l(v_l) + L_r(v_r))$

You can stack RecursiveNN by letting parameter `return_sequences = True`.

In [None]:
class RecursiveNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, return_sequences = False):
        super(RecuriveNN, self).__init__()
        self.dim = hidden_dim
        self.activation = nn.ReLU()
        self.return_sequences = return_sequences

        self.inp_linear = nn.Linear(self.in_dim, self.hidden_dim)
        self.left_linear = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
        self.right_linear = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
    
    def forward(self, inps):
        batch_size = inps.shape[0]
        if self.return_sequences:
            result_seq = None
        stack = []
        size = inps.shape[1] # Get number of nodes
        depth = size.bit_length()
        checker = postorder_checker(depth)
        for i, tf in enumerate(checker):
            inp = inps[:, i, :]
            if tf:
                right = stack.pop()
                left = stack.pop()
                res = self.activation(self.inp_linear(inp) + self.left_linear(left) + self.right_linear(right))
                stack.append(res)
                if self.return_sequences:
                    if result_seq is None:
                        result_seq = res
                    else:
                        result_seq = torch.cat(result_seq, res)
            else:
                res = stack.append(self.activation(self.inp_linear(inp)))
                if self.return_sequences:
                    if result_seq is None:
                        result_seq = res
                    else:
                        result_seq = torch.cat(result_seq, res)
        assert len(stack) == 1
        if self.return_sequences:
            return return_seq
        else:
            return stack[0]

### 4. Compositional Semantics

Compose the opeartors.

Each unary ops corresponds to $n \times n$ matrix, binary ops corresponds to $2n \times n$ matrix, terminal corresponds to length $n$ vector.

You can not stack this module.

In [None]:
class CompositionalSemantics(nn.Module):
    def __init__(self, hidden_dim):
        super(CompositionalSemantics, self).__init__()
        self.hidden_dim = hidden_dim
        self.unary_ops = nn.ModuleDict({
            "exp" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "log" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sqrt" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sin" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "cos" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "tan" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "asin" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "acos" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "atan" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sinh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "cosh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "tanh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "asinh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "acosh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "atanh" : nn.Linear(self.hidden_dim, self.hidden_dim)
        })
        
        self.binary_ops = nn.ModuleDict({
            "+" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "-" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "*" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "/" : nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        })
        self.terminals = nn.ModuleDict({
            "x" : torch.rand(self.hidden_dim, requires_grad = True),
            "-5" : torch.rand(self.hidden_dim, requires_grad = True),
            "-4" : torch.rand(self.hidden_dim, requires_grad = True),
            "-3" : torch.rand(self.hidden_dim, requires_grad = True),
            "-2" : torch.rand(self.hidden_dim, requires_grad = True),
            "-1" : torch.rand(self.hidden_dim, requires_grad = True),
            "1" : torch.rand(self.hidden_dim, requires_grad = True),
            "2" : torch.rand(self.hidden_dim, requires_grad = True),
            "3" : torch.rand(self.hidden_dim, requires_grad = True),
            "4" : torch.rand(self.hidden_dim, requires_grad = True),
            "5" : torch.rand(self.hidden_dim, requires_grad = True),
        })

    def forward(self, inp):
        """Assumes input is given as postorder traversal"""
        stack = []
        for i, curr in enumerate(inp):
            if curr in self.unary_ops:
                param = stack.pop()
                func = self.unary_ops[curr]
                stack.append(func(param))
            elif curr in self.binary_ops:
                right_param = stack.pop()
                left_param = stack.pop()
                func = self.binary_ops[curr]
                stack.append(func(torch.cat([left_param, right_param], dim=1)))
            elif curr in self.terminals:
                stack.append(self.terminals[curr])

### 5. Code2Seq

Sample the paths.

In [None]:
# Code2Seq will use RNN/LSTM/... structures, so we don't need module. We need to define preprocess function.
def code2seq_sample(tree, num):
    results = []
    terminals = list_terminal(tree)
    if num == -1:
        for i in range(len(terminals)):
            for j in range(i + 1, len(terminals)):
                left = terminals[i]
                right = terminals[j]
                results.append(terminal_to_path(tree, left, right))
        # Sample all paths
    else:
        l = []
        for i in range(len(terminals)):
            for j in range(i + 1, len(terminals)):
                l.append((i, j))
        indices = random.choices(l, k=num)
        for index in indices:
            left = terminals[i]
            right = terminals[j]
            results.append(terminal_to_path(tree, left, right))
        # Sample n paths
    return results

def terminal_to_path(tree, first, second):
    root = tree
    for i in range(len(first)):
        if first[i] == second[i]:
            if first[i]:
                root = root[1]
            else:
                root = root[2]
        else:
            break
    # i is current index of first/second
    left_index = []
    left_root = root

    for j in range(i, len(first)):
        if first[j]:
            left_root = left_root[1]
        else:
            left_root = left_root[2]
        left_index.append(left_root[0])
    
    right_index = []
    right_root = root
    
    for j in range(i, len(second)):
        if first[j]:
            right_root = right_root[1]
        else:
            right_root = right_root[2]
        right_index.append(right_root[0])

    return list(left_index.reverse()) + [root[0]] + right_index

# Returns list of terminals. The terminal's position is encoded as True(left), False(right). If unary, use True.
def list_terminal(tree):
    if len(tree) == 1:
        return [[]]
    elif len(tree) == 2:
        l = list_terminal(tree[1])
        return [t + [True] for t in l]
    elif len(tree) == 3:
        l = list_terminal(tree[1])
        l2 = list_terminal(tree[2])
        return [t + [True] for t in l] + [t + [False] for t in l2]
    else:
        raise ValueError("Tree is expected to be at most binary.")

### 6. TBCNN

Due to our tree structure (at most binary, windows size 2), TBCNN's implementation is identical to recursive NN. So we do not implement TBCNN.

## 5. Implement Training