## 1. Import packages

In [None]:
import torch.nn as nn
import torch.optim
import torch.utils.data
import torch

import sympy as sp
import os
from argparse import Namespace
import random
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import time
from itertools import product

from src.utils import AttrDict
from src.envs import build_env
from src.envs.sympy_utils import simplify
model_path = './fwd_bwd.pth'
params = AttrDict({

    # environment parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

    # model parameters
    'cpu': False,
    'emb_dim': 1024,
    'n_enc_layers': 6,
    'n_dec_layers': 6,
    'n_heads': 8,
    'dropout': 0,
    'attention_dropout': 0,
    'sinusoidal_embeddings': False,
    'share_inout_emb': True,
    'reload_model': model_path,

})
env = build_env(params)

## 2. Seed randoms

In [None]:
manual_seed = "7777".__hash__() % (2 ** 32) #random.randint(1, 10000)
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
np.random.seed(manual_seed)

!mkdir results

## 3. Implement preprocessing / postprocessing

In [None]:
x = env.local_dict['x']
F_infix = 'ln(cos(x + exp(x)) * sin(x**2 + 2) * exp(x) / x)'
F = sp.S(F_infix, locals=env.local_dict)
F

In [None]:
f = F.diff(x)
f

In [None]:
F_postfix = env.sympy_to_postfix(F)
f_postfix = env.sympy_to_postfix(f)
f_prefix = env.sympy_to_prefix(f)
print(f"F postfix : {len(F_postfix)}")
print(f"f postfix : {len(f_postfix)}")

## 4. Implement models

### 1. ChildSumLSTM

Given inp $v$, hiddens $h_i$, cells $c_i$, computes

$h = \sum_{i=1}^C h_i$

$v_i = \sigma(I_i(v) + H_i(h))$

$v_o = \sigma(I_o(v) + H_o(h)$

$v_u = \tanh(I_u(v) + H_u(h)$

$v_{f}^i = \sigma(I_f(v) + h_i)$

$v_c = v_i + v_u + \sum_{i=1}^C v_{f}^i * c_i$

$v_h = v_o * \tanh(v_c)$

You can stack ChildSumLSTM by letting parameter `return_sequence = True`.

In [None]:
def postorder_checker(depth):
    if depth == 1:
        return [True]
    else:
        return postorder_checker(depth - 1) * 2 + [False]

class ChildSumTreeLSTM(nn.Module):
    def __init__(self, in_dim, cell_dim, return_sequence = False):
        super(ChildSumTreeLSTM, self).__init__()
        self.return_sequence = return_sequence
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

        self.in_dim = in_dim
        self.cell_dim = cell_dim
        
        self.input_Wf = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wf = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.input_Wi = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wi = nn.Linear(self.cell_dim, self.cell_dim, bias = False)
        self.input_Wo = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wo = nn.Linear(self.cell_dim, self.cell_dim, bias = False)
        self.input_Wu = nn.Linear(self.in_dim, self.cell_dim)
        self.hidden_Wu = nn.Linear(self.cell_dim, self.cell_dim, bias = False)
        
        
    def forward(self, inps):
        # Assumes that input is given as postorder traversal.
        # inps : (batch_size, node_num, input_dim)
        batch_size = inps.shape[0]
        if self.return_sequence:
            result_seq = None
        stack = []
        size = inps.shape[1] # Get number of nodes
        depth = size.bit_length()
        checker = postorder_checker(depth)
        for i, tf in enumerate(checker):
            inp = inps[:, i, :]
            if tf: # external node
                left_hidden = torch.zeros([batch_size, self.cell_dim])
                right_hidden = torch.zeros([batch_size, self.cell_dim])
                left_cell = torch.zeros([batch_size, self.cell_dim])
                right_cell = torch.zeros([batch_size, self.cell_dim])
                new_hidden = torch.zeros([batch_size, self.cell_dim])
            else:
                assert len(stack) >= 2
                left_hidden, left_cell = stack.pop()
                right_hidden, right_cell = stack.pop()
                new_hidden = left_hidden + right_hidden
                
            i_vec = self.sigmoid(self.input_Wi(inp) + self.hidden_Wi(new_hidden)) # (batch_size, cell_dim)
            o_vec = self.sigmoid(self.input_Wo(inp) + self.hidden_Wo(new_hidden)) # (batch_size, cell_dim)
            u_vec = self.tanh(self.input_Wu(inp) + self.hidden_Wu(new_hidden)) # (batch_size, cell_dim)

            flat_hidden = torch.cat([torch.unsqueeze(left_hidden, dim=1), torch.unsqueeze(right_hidden, dim=1)], dim=1) # (batch_size, 2, cell_dim)
            input_f_vec = torch.unsqueeze(self.input_Wf(inp), dim=1).repeat_interleave(2, dim=1)
            hidden_f_vec = self.hidden_Wf(flat_hidden).view(-1, 2, self.cell_dim)
            f_vec = self.sigmoid(input_f_vec + hidden_f_vec)
               
            c_vec = i_vec * u_vec + torch.sum(torch.cat([torch.unsqueeze(left_cell, dim=1), torch.unsqueeze(right_cell, dim=1)], dim=1) * f_vec, 1)
            h_vec = o_vec * self.tanh(c_vec)
            
            stack.append((h_vec, c_vec))
            if self.return_sequence:
                if result_seq is None:
                    result_seq = torch.unsqueeze(h_vec, dim=1)
                else:
                    result_seq = torch.cat([result_seq, torch.unsqueeze(h_vec, dim=1)], dim=1)
        assert len(stack) == 1
        if self.return_sequence:
            return result_seq
        else:
            return stack[0]


### 2. Binary Tree LSTM

Given inp $v$, hiddens $h_l, h_r$, cells $c_l, c_r$, computes

$v_i = \sigma(I_i(v) + H_i^l(h_l) + H_i^r(h_r)$

$v_u = \sigma(I_u(v) + H_u^l(h_l) + H_u^r(h_r)$

$v_o = \tanh(I_o(v) + H_o^l(h_l) + H_o^r(h_r)$

$v_{f}^l = \sigma(I_{lf}^l(v) + H_{lf}^l(h_l) + H_{lf}^l(h_r)$

$v_{f}^r = \sigma(I_{rf}^l(v) + H_{rf}^l(h_l) + H_{rf}^l(h_r)$

$v_c = v_i + v_u + v_f^l * c_l + v_f^r * c_r$

$v_h = v_o * \tanh(v_c)$

You can stack BinaryTreeLSTM by letting parameter `return_sequence = True`.

In [None]:
class BinaryTreeLSTM(nn.Module):
    def __init__(self, in_dim, cell_dim, return_sequence = False):
        super(BinaryTreeLSTM, self).__init__()
        self.in_dim = in_dim
        self.cell_dim = cell_dim
        self.return_sequence = return_sequence

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

        self.input_Wi = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wo = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wu = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wlf = nn.Linear(self.in_dim, self.cell_dim)
        self.input_Wrf = nn.Linear(self.in_dim, self.cell_dim)

        self.left_Wi = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.left_Wo = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.left_Wu = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.left_Wlf = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.left_Wrf = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        
        self.right_Wi = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.right_Wo = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.right_Wu = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.right_Wlf = nn.Linear(self.cell_dim, self.cell_dim, bias=False)
        self.right_Wrf = nn.Linear(self.cell_dim, self.cell_dim, bias=False)

    def forward(self, inps):
        batch_size = inps.shape[0]
        if self.return_sequence:
            result_seq = None
        stack = []
        size = inps.shape[1] # Get number of nodes
        depth = size.bit_length()
        checker = postorder_checker(depth)
        for i, tf in enumerate(checker):
            inp = inps[:, i, :]
            if tf:
                left_hidden = torch.zeros(batch_size, self.cell_dim)
                right_hidden = torch.zeros(batch_size, self.cell_dim)
                left_cell = torch.zeros(batch_size, self.cell_dim)
                right_cell = torch.zeros(batch_size, self.cell_dim)
            else:
                right_hidden, right_cell = stack.pop()
                left_hidden, left_cell = stack.pop()
            
            i_vec = self.sigmoid(self.input_Wi(inp) + self.left_Wi(left_hidden) + self.right_Wi(right_hidden))
            o_vec = self.sigmoid(self.input_Wo(inp) + self.left_Wo(left_hidden) + self.right_Wo(right_hidden))
            u_vec = self.tanh(self.input_Wu(inp) + self.left_Wu(left_hidden) + self.right_Wu(right_hidden))
            left_f_vec = self.sigmoid(self.input_Wlf(inp) + self.left_Wlf(left_hidden) + self.right_Wlf(right_hidden))
            right_f_vec = self.sigmoid(self.input_Wrf(inp) + self.left_Wrf(left_hidden) + self.right_Wrf(right_hidden))

            c_vec = i_vec * u_vec + left_f_vec * left_cell + right_f_vec * right_cell
            h_vec = o_vec * self.tanh(c_vec)

            stack.append((h_vec, c_vec))
            if self.return_sequence:
                if result_seq is None:
                    result_seq = torch.unsqueeze(h_vec, dim=1)
                else:
                    result_seq = torch.cat([result_seq, torch.unsqueeze(h_vec, dim=1)], dim=1)
        
        assert len(stack) == 1
        if self.return_sequence:
            return result_seq
        else:
            return stack[0]

### 3. Recursive Neural Network

Given input $v_i$, left $v_l$, right $v_r$, computes

$v = act(L_i(v_i) + L_l(v_l) + L_r(v_r))$

You can stack RecursiveNN by letting parameter `return_sequence = True`.

In [None]:
class RecursiveNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, return_sequence = False):
        super(RecursiveNN, self).__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.activation = nn.ReLU()
        self.return_sequence = return_sequence

        self.inp_linear = nn.Linear(self.in_dim, self.hidden_dim)
        self.left_linear = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
        self.right_linear = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
    
    def forward(self, inps):
        batch_size = inps.shape[0]
        if self.return_sequence:
            result_seq = None
        stack = []
        size = inps.shape[1] # Get number of nodes
        depth = size.bit_length()
        checker = postorder_checker(depth)
        for i, tf in enumerate(checker):
            inp = inps[:, i, :]
            if tf:
                res = self.activation(self.inp_linear(inp))
                stack.append(res)
                if self.return_sequence:
                    if result_seq is None:
                        result_seq = torch.unsqueeze(res, dim=1)
                    else:
                        result_seq = torch.cat([result_seq, torch.unsqueeze(res, dim=1)], dim=1)
            else:
                right = stack.pop()
                left = stack.pop()
                res = self.activation(self.inp_linear(inp) + self.left_linear(left) + self.right_linear(right))
                stack.append(res)
                if self.return_sequence:
                    if result_seq is None:
                        result_seq = torch.unsqueeze(res, dim=1)
                    else:
                        result_seq = torch.cat([result_seq, torch.unsqueeze(res, dim=1)], dim=1)
        assert len(stack) == 1
        if self.return_sequence:
            return result_seq
        else:
            return stack[0]

### 4. Compositional Semantics

Compose the opeartors.

Each unary ops corresponds to $n \times n$ matrix, binary ops corresponds to $2n \times n$ matrix, terminal corresponds to length $n$ vector.

You can not stack this module.

To stack, try rewriting unary/binary ops to stacked modules.

In [None]:
class CompositionalSemantics(nn.Module):
    def __init__(self, hidden_dim):
        super(CompositionalSemantics, self).__init__()
        self.hidden_dim = hidden_dim
        self.activation = nn.ReLU()
        self.unary_ops = nn.ModuleDict({
            "exp" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "log" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sqrt" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sin" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "cos" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "tan" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "asin" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "acos" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "atan" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "sinh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "cosh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "tanh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "asinh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "acosh" : nn.Linear(self.hidden_dim, self.hidden_dim), 
            "atanh" : nn.Linear(self.hidden_dim, self.hidden_dim)
        })
        
        self.binary_ops = nn.ModuleDict({
            "+" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "-" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "*" : nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            "/" : nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        })
        self.terminals = nn.ParameterDict({
            "x" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "-5" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "-4" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "-3" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "-2" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "-1" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "1" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "2" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "3" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "4" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True),
            "5" : nn.Parameter(torch.rand(self.hidden_dim), requires_grad = True)
        })

    def forward(self, inp):
        """Assumes inputs are given as postorder traversal"""
        res = []
        for i in range(len(inp)):
            stack = []
            for i, curr in enumerate(inp):
                if curr in self.unary_ops:
                    param = stack.pop()
                    func = self.unary_ops[curr]
                    stack.append(self.activation(func(param)))
                elif curr in self.binary_ops:
                    right_param = stack.pop()
                    left_param = stack.pop()
                    func = self.binary_ops[curr]
                    stack.append(self.activation(func(torch.cat([left_param, right_param], dim=1))))
                elif curr in self.terminals:
                    stack.append(self.terminals[curr])
            assert len(stack) == 1
            res.append(torch.unsqueeze(stack.pop(), 0))
        return torch.cat(res, dim=0)

### 5. Code2Seq

Sample the paths.

In [None]:
# Code2Seq will use RNN/LSTM/... structures, so we don't need module. We need to define preprocess function.
def code2seq_sample(tree, num):
    results = []
    terminals = list_terminal(tree)
    if num == -1:
        for i in range(len(terminals)):
            for j in range(i + 1, len(terminals)):
                left = terminals[i]
                right = terminals[j]
                results.append(terminal_to_path(tree, left, right))
        # Sample all paths
    else:
        l = []
        for i in range(len(terminals)):
            for j in range(i + 1, len(terminals)):
                l.append((i, j))
        indices = random.choices(l, k=num)
        for i in indices:
            left = terminals[i]
            right = terminals[j]
            results.append(terminal_to_path(tree, left, right))
        # Sample n paths
    return results

def terminal_to_path(tree, first, second):
    root = tree
    for i in range(len(first)):
        if first[i] == second[i]:
            if first[i]:
                root = root[1]
            else:
                root = root[2]
        else:
            break
    # i is current index of first/second
    left_index = []
    left_root = root

    for j in range(i, len(first)):
        if first[j]:
            left_root = left_root[1]
        else:
            left_root = left_root[2]
        left_index.append(left_root[0])
    
    right_index = []
    right_root = root
    
    for j in range(i, len(second)):
        if first[j]:
            right_root = right_root[1]
        else:
            right_root = right_root[2]
        right_index.append(right_root[0])

    return list(left_index.reverse()) + [root[0]] + right_index

# Returns list of terminals. The terminal's position is encoded as True(left), False(right). If unary, use True.
def list_terminal(tree):
    if len(tree) == 1:
        return [[]]
    elif len(tree) == 2:
        l = list_terminal(tree[1])
        return [t + [True] for t in l]
    elif len(tree) == 3:
        l = list_terminal(tree[1])
        l2 = list_terminal(tree[2])
        return [t + [True] for t in l] + [t + [False] for t in l2]
    else:
        raise ValueError("Tree is expected to be at most binary.")

class code2seq(nn.Module):
    def __init__(self, in_dim, lstm_dim, lstm_depth, attention_dim, attention_head, attention_depth, lstm_bidirectional = False):
        super(code2seq, self).__init__()
        self.in_dim = in_dim
        self.lstm_dim = lstm_dim
        self.lstm_depth = lstm_depth
        self.attention_dim = attention_dim
        self.attention_head = attention_head
        self.attention_depth = attention_depth
        self.lstm_bidirectional = lstm_bidirectional

        self.path_encoder = nn.LSTM(input_size = self.in_dim,
                                    hidden_size = self.lstm_dim,
                                    num_layers = self.lstm_depth,
                                    bidirectional = lstm_bidirectional,
                                    batch_first = True)
        self.transformer = nn.Transformer()
        self.decoder = nn.Linear(self.lstm_dim, self.in_dim)

    def forward(self, paths):
        # input : (batch_size, path_num, path_length, in_dim)
        paths = paths.reshape(paths.shape[0] * paths.shape[1], -1, self.in_dim)
        # reshape to (batch_size * path_num, path_length, in_dim)
        
        encoded = self.path_encoder(paths)[-1] # Get last output for encoding of path, having (batch_size * path_num, in_dim)
        encoded.reshape(paths.shape[0], paths.shape[1], self.in_dim) # to (batch_size, path_num, in_dim)
        return encoded

### 6. TBCNN

Due to our tree structure (at most binary, windows size 2), TBCNN's implementation is identical to recursive NN. So we do not implement TBCNN.

## 5. Define hyperparameters and parameters

In [None]:
# Shared

args = Namespace()
args.epoch = 20
args.lr = 0.001
args.optim = "adam"
args.in_dim = len(env.word2id)

# Child Sum LSTM
cslstm_args = Namespace()

cslstm_args.nn_type = "ReNN"
cslstm_args.cell_dim = 128
cslstm_args.depth = 4

# Binary Tree LSTM
btlstm_args = Namespace()

btlstm_args.nn_type = "ReNN"
btlstm_args.cell_dim = 128
btlstm_args.depth = 4
btlstm_args.lr = 0.001

# ReNN
renn_args = Namespace()

renn_args.nn_type = "ReNN"
renn_args.cell_dim = 128
renn_args.depth = 4

# Comp Sem
compsem_args = Namespace()

compsem_args.nn_type = "CompSem"
compsem_args.cell_dim = 100
# compsem is not stackable

# Code2Seq
code2seq_args = Namespace()

code2seq_args.nn_type = "Code2Seq"
code2seq_args.cell_dim = 100 # 
code2seq_args.path_depth = 5 # How many layers we will use to encode each paths
code2seq_args.path_type = "LSTM" # What layer to use when encoding path
code2seq_args.model_type = "attention"
code2seq_args.model_dim = 1024
code2seq_args.model_head = 6
code2seq_args.model_depth = 8

# Should use attention?


## 6. Implement Training

In [None]:
layers = []
for i in range(cslstm_args.depth):
    if i == 0:
        layers.append(ChildSumTreeLSTM(args.in_dim, cslstm_args.cell_dim, return_sequence = (i != cslstm_args.depth - 1)))
    else:
        layers.append(ChildSumTreeLSTM(cslstm_args.cell_dim, cslstm_args.cell_dim, return_sequence = (i != cslstm_args.depth - 1)))

cslstm = nn.Sequential(*layers)
cslstm_decoder = nn.Sequential()

layers = []
for i in range(btlstm_args.depth):
    if i == 0:
        layers.append(BinaryTreeLSTM(args.in_dim, btlstm_args.cell_dim, return_sequence = (i != cslstm_args.depth - 1)))
    else:
        layers.append(BinaryTreeLSTM(btlstm_args.cell_dim, btlstm_args.cell_dim, return_sequence = (i != cslstm_args.depth - 1)))

btlstm = nn.Sequential(*layers)

layers = []
for i in range(renn_args.depth):
    if i == 0:
        layers.append(RecursiveNN(args.in_dim, renn_args.cell_dim, return_sequence = (i != cslstm_args.depth - 1)))
    else:
        layers.append(RecursiveNN(renn_args.cell_dim, renn_args.cell_dim, return_sequence = (i != cslstm_args.depth - 1)))

renn = nn.Sequential(*layers)

compsem = CompositionalSemantics(compsem_args.cell_dim)

code_seq = code2seq(args.in_dim, code2seq_args.cell_dim, code2seq_args.path_depth, 
                    code2seq_args.model_dim, code2seq_args.model_head, code2seq_args.model_depth)

lstm_enocder = nn.LSTM(args.in_dim, 128, 4, batch_first=True)

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

count_parameters(lstm_enocder)

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers):
        super(Decoder, self).__init__()

        self.output_dim = output_dim
        self.emb_dim = emb_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers

        self.embedding = nn.Linear(self.output_dim, self.emb_dim)
        self.layers = nn.LSTM(self.emb_dim, self.hid_dim, self.n_layers)
        self.fc_out = nn.Linear(self.hid_dim, self.output_dim)

    def forward(self, hidden, cell):
        pass

In [None]:
from datetime import datetime
f_tree = env.sympy_to_tree(f)
bef = datetime.today()
samples = code2seq_sample(f_tree, -1)
samples = [[env.word2id[w] for t in sample] for sample in samples]
print((datetime.today() - bef).seconds)
output = code_seq(samples)
print((datetime.today() - bef).seconds)
output.sum().backward()
print((datetime.today() - bef).seconds)


In [None]:
# Testing the modules. I recommend not executing these.


words = [env.word2id[w] if w else -1 for w in f_postfix]
words = [[1 if i == j else 0 for i in range(len(env.word2id))] for j in words]
words = torch.Tensor(words).unsqueeze(dim=0)
words = words.repeat_interleave(16, dim=0)
words2 = [env.word2id[w] if w else -1 for w in f_prefix]
words2 = [[1 if i == j else 0 for i in range(len(env.word2id))] for j in words2]
words2 = torch.Tensor(words2).unsqueeze(dim=0)
words2 = words2.repeat_interleave(16, dim=0)
from datetime import datetime
begin = datetime.today()
output = cslstm(words)
print(output[0].shape)
print(output[1].shape)
print((datetime.today() - begin).seconds)
output[1].sum().backward()
print((datetime.today() - begin).seconds)
output = btlstm(words)
print(output[0].shape)
print(output[1].shape)
print((datetime.today() - begin).seconds)
output[1].sum().backward()
print((datetime.today() - begin).seconds)
output = renn(words)
print(output.shape)
print((datetime.today() - begin).seconds)
output.sum().backward()
print((datetime.today() - begin).seconds)

None # no output

In [None]:
lstm_encoder = nn.LSTM(args.in_dim, 128, 4, batch_first=True)
begin = datetime.today()
output = lstm_encoder(words2)
print(output[0].shape)
print((datetime.today() - begin).microseconds)
output[0].sum().backward()
print((datetime.today() - begin).microseconds)

In [None]:
def train(net, partition, preprocess, optimizer, criterion, model_args, args):
    trainloader = torch.utils.data.DataLoader(partition['train'], 
                                              batch_size=args.batch_size, 
                                              shuffle=True, num_workers=2)
    net.train()
    optimizer.zero_grad()

    correct = 0
    total = 0
    train_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs = inputs.view(-1, args.in_dim)
        inputs = preprocess(inputs)
        outputs = net(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = train_loss / len(trainloader)
    train_acc = 100 * correct / total
    return net, train_loss, train_acc

In [None]:
def validate(net, partition, preprocess, criterion, args):
    valloader = torch.utils.data.DataLoader(partition['val'], 
                                            batch_size=args.test_batch_size, 
                                            shuffle=False, num_workers=2)
    net.eval()

    correct = 0
    total = 0
    val_loss = 0 
    with torch.no_grad():
        for data in valloader:
            inputs, labels = data
            inputs = images.view(-1, args.in_dim)
            inputs = preprocess(inputs)
            outputs = net(images)

            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(valloader)
        val_acc = 100 * correct / total
    return val_loss, val_acc

In [None]:
def test(net, partition, args):
    testloader = torch.utils.data.DataLoader(partition['test'], 
                                             batch_size=args.test_batch_size, 
                                             shuffle=False, num_workers=2)
    net.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs = inputs.view(-1, args.in_dim)
            inputs = preprocess(inputs)

            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        test_acc = 100 * correct / total
    return test_acc

In [None]:
def experiment(partition, model_args, args):
    if model_args.name == "":
        net = TODO()