# MathQA Decoder

Notes/Todo:
- Make sure to set params back to normal (see pic)
- Use one global optimizer
- Create class for whole model
- Loss should be concatenated at each layer and returned at end
- Best course of action would be to get layer 2 working with teacher forcing and see if that can be trained before putting all together
- Keep running list of all true embeddings updated every layer
- Are duplicate equations allowed on seperate layers? (dont think so but probably should double check)
- implement code to remove item from batch once it has finished

#### Imports

In [1]:
from enum import Enum
import math
import os
import re
from copy import deepcopy
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModel
import pickle
from scipy.optimize import linear_sum_assignment
if get_ipython().__class__.__name__ == 'ZMQInteractiveShell':
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm
import unittest

# remove
import time

#### Constants

In [2]:
K = 6
MAX_LAYERS = 8
MAX_TOKENS = 392
EMBEDDING_SIZE = 768

DATA_PATH = './dataset/'
SET_NAMES = ['train', 'validation', 'test']
ENCODER_MODEL = 'distilroberta-base' # A more optimized version of roberta obtaining 95% of its performance
DEVICE = 'cuda:0'
NUM_MASK = '<num>'
WORKING_DIR = 'TEMP/'

OBJ_DIR = 'pickle/'


class Op(Enum):
    ADD = '+'
    SUB = '-'
    MULT = '*'
    DIV = '/'
    POW = '^'
    
class Const(Enum):
    CONST_NEG_1 = 'const_neg_1' # I added this
    CONST_0_25 = 'const_0_25'
    CONST_0_2778 = 'const_0_2778'
    CONST_0_33 = 'const_0_33'
    CONST_0_3937 = 'const_0_3937'
    CONST_1 = 'const_1'
    CONST_1_6 = 'const_1_6'
    CONST_2 = 'const_2'
    CONST_3 = 'const_3'
    CONST_PI = 'const_pi'
    CONST_3_6 = 'const_3_6'
    CONST_4 = 'const_4'
    CONST_5 = 'const_5'
    CONST_6 = 'const_6'
    CONST_10 = 'const_10'
    CONST_12 = 'const_12'
    CONST_26 = 'const_26'
    CONST_52 = 'const_52'
    CONST_60 = 'const_60'
    CONST_100 = 'const_100'
    CONST_180 = 'const_180'
    CONST_360 = 'const_360'
    CONST_1000 = 'const_1000'
    CONST_3600 = 'const_3600'

values = [-1, 0.25, 0.2778, 0.33, 0.3937, 1, 1.6, 2, 3, math.pi, 3.6, 4, 5, 6, 10, 12, 26, 52, 60, 100, 180, 360, 1000, 3600]
const2val = {k:v for k,v in zip(Const._value2member_map_.keys(), values)}    

op2id = {k:v for k,v in zip(Op._value2member_map_.keys(), range(len(Op._value2member_map_)))}
op2id['None'] = 5
id2op = {v:k for k,v in op2id.items()}
const2id = {k:v for k,v in zip(Const._value2member_map_.keys(), range(len(Const._value2member_map_)))}
id2const = np.array(list(const2id.keys()))

torch.autograd.set_detect_anomaly(True)
torch.set_printoptions(sci_mode=False)

class Util():
    def load_obj(self, path):
        with open(path, 'rb') as f:
            o = pickle.load(f)
        return o
    
    def save_obj(self, path, o):
        with open(path, 'wb') as f:
            pickle.dump(o, f)
            
    def load_data(self):
        return {name:pd.read_csv(f'{DATA_PATH}{name}.csv') for name in SET_NAMES}
    
    def set_seed(self, seed):
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
util = Util()

## Creating the Model

In [3]:
class Preprocess:
    
    def __init__(self, data, K, embedding_size, max_tokens, model_name, tokenizer_name):
        # Randomly initializing embeddings
        self.op = torch.randn(len(op2id),embedding_size)
        self.query = torch.randn(K,embedding_size)
        const = torch.randn(len(const2id),embedding_size)
        
        self.num, self.text = self.__get_nums_and_mask(data)
        self.const = self.__get_const(const)
        self.combined = self.__combine(self.num, self.const)
        self.labels = {name:self.__get_labels(data,name) for name in SET_NAMES}
        
        # Encoder
        self.encoder, self.tokenizer = self.__setup_model(model_name, tokenizer_name)
        self.tokenized = {name:self.__tokenize_data(self.tokenizer, self.text[name], max_tokens) for name in SET_NAMES}
        
    # -----------------------------------------------------------
    # util
    # -----------------------------------------------------------
    def __expand_tensor(self, curr, new):
        if curr is None:
            curr = new
        else:
            curr =  torch.cat((curr,new), dim=0)
        return curr
    
    def __flatten(self, arr):
        idx = np.concatenate([[i]*len(x) for i,x in enumerate(arr)])
        flattened = np.concatenate(arr) 
        return idx, flattened
    
    # -----------------------------------------------------------
    # tokenization
    # -----------------------------------------------------------
    def __setup_model(self, model_name, tokenizer_name):
        model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        tokenizer.add_special_tokens({'additional_special_tokens':[NUM_MASK]})
        model.resize_token_embeddings(len(tokenizer))
        return model, tokenizer
    
    def __tokenize_data(self, tokenizer, text, max_tokens):
        tokenization = lambda x: tokenizer(x, padding='max_length', max_length=max_tokens, truncation=True)

        tokenized = list(map(tokenization, text))
        input_ids = torch.stack([torch.tensor(x['input_ids']) for x in tokenized])
        attention_mask = torch.stack([torch.tensor(x['attention_mask']) for x in tokenized])

        return {'input_ids':input_ids.long(), 'attention_mask':attention_mask.int()}
    
    # -----------------------------------------------------------
    # literals
    # -----------------------------------------------------------
    # Gets the numbers listed in a problem
    # Once found, numbers are masked using a number mask
    def __get_nums_and_mask(self, data):
        nums = {name:[] for name in SET_NAMES}
        problems = {name:[] for name in SET_NAMES}
        num_idx = {name:[] for name in SET_NAMES}
        for name in SET_NAMES:
            for i,problem in enumerate(data[name].problem):
                num = re.compile('([+-]?((\d+(\.\d*)?)|(\.\d+)))') # normal num
                big = re.compile(r'(-?\d{1,3}(,\d{3})+(\.\d*)?)') # num with comma

                big_results = re.finditer(big, problem)
                problem = re.sub(big, NUM_MASK, problem)        
                num_results = re.finditer(num, problem)
                problem = re.sub(num, NUM_MASK, problem)

                # Getting the combined numbers in order of occurence
                combined = [x for x in num_results]
                combined.extend([x for x in big_results])
                combined = sorted(combined, key=lambda x: x.start(0))

                combined = [float(x.group(0).replace(',','')) for x in combined]

                nums[name].append(np.array(combined))
                problems[name].append(problem)
            num_idx[name], nums[name] = self.__flatten(np.array(nums[name], dtype=object))
            problems[name] = np.array(problems[name])
        return {name:{'idx':torch.tensor(num_idx[name]), 'literals':nums[name]} for name in SET_NAMES}, problems
    
    def __get_const(self, embed):
        const = {}
        for name in SET_NAMES:
            const_pred = util.load_obj(f'{OBJ_DIR}constants.pickle')['pred'][name]
            const_idx, const_id = np.where(const_pred==1)
            literals = id2const[const_id]
            const[name] = {'idx':torch.tensor(const_idx), 'literals':literals, 'embed_idx':const_id, 'embed':embed}
        return const
    
    def __combine(self, num, const):
        combined = {}
        for name in SET_NAMES:
            combined[name] = {'idx':torch.cat((num[name]['idx'],const[name]['idx'])), 
                              'literals':np.concatenate((num[name]['literals'],const[name]['literals']))}
        return combined
    
    # -----------------------------------------------------------
    # labels
    # -----------------------------------------------------------
    def __get_exp(self, data, name):
        reorder = lambda x: [x[1], x[0], x[2]]
        convert_to_arr = lambda d: [[reorder(x.split()), idx] for idx, arr in enumerate(d.split(' ; ')) for x in eval(arr) if x is not None]
        return data[name]['incremental'].map(convert_to_arr)

    def __process_num(self, num, idx, literals, pnum, prev_eq):
        if num in const2val:
            temp = np.where((literals[idx==pnum]==num))[0]
            if len(temp) > 0:
                return temp[0]
            else: # The constant predictor failed to predict the correct constants for this specific problem
                return None
        elif 'x' in num:
            return prev_eq[int(num[1:])-1]
        else:
            return np.where((literals[idx==pnum]==str(float(num))))[0][0]
        
    def __process_exp(self, idx, literals, exp, pnum):
        op_labels = None
        prev_eq = []
        num1_labels = []
        num2_labels = []
        layer_idx = []
        for (op, num1, num2), layer in exp:
            num1 = self.__process_num(num1, idx, literals, pnum, prev_eq)
            num2 = self.__process_num(num2, idx, literals, pnum, prev_eq)
            op = op2id[op]

            layer_idx.append(layer)
            num1_labels.append(num1)
            num2_labels.append(num2)
            prev_eq.append((op,num1,num2))

            temp = torch.zeros(len(op2id))
            temp[op] = 1
            op_labels = self.__expand_tensor(op_labels,temp[None,:])
            
        return op_labels, np.array(num1_labels, dtype=object), np.array(num2_labels, dtype=object), np.array(layer_idx)
    
    def __get_labels(self, data, name):
        results = []
        for i,e in enumerate(self.__get_exp(data, name)):
            results.append(self.__process_exp(self.combined[name]['idx'], self.combined[name]['literals'], e, i))
        return results

In [4]:
class MathQA(Dataset):
    def __init__(self, config, name):
        self.input_ids = config.tokenized[name]['input_ids']
        self.attention_mask = config.tokenized[name]['attention_mask']
        self.const_embed = config.const[name]['embed']
        self.embed_idx = torch.tensor(config.const[name]['embed_idx']).int()
        self.idx = config.combined[name]['idx'].int()
        self.const_idx = config.const[name]['idx'].int()
        self.true_labels = config.labels[name]
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, i):
        if i.step: # using step is not defined
            raise NotImplementedError
            
        input_ids = self.input_ids[i]
        attention_mask = self.attention_mask[i]
        true_labels = self.true_labels[i]
        
        if isinstance(i, slice):
            const_embed = self.const_embed[self.embed_idx[(self.const_idx>=i.start)&(self.const_idx<i.stop)]]
            idx = self.idx[(self.idx>=i.start)&(self.idx<i.stop)]
        else:
            const_embed = self.const_embed[self.embed_idx[self.const_idx==i]]
            idx = self.idx[self.idx==i]
        return input_ids, attention_mask, const_embed, idx, true_labels
    
class DataLoader:
    def __init__(self, dataset, batch_size, shuffle):
        self.curr = 0
        self.dataset = dataset
        self.batch_size = batch_size
        self.slices = [slice(x,x+batch_size) for x in range(0,len(dataset),batch_size)]
        self.num_batches = len(self.slices)
        
        if shuffle:
            np.random.shuffle(self.slices)
            
    def __len__(self):
        return self.num_batches
        
    def __iter__(self):
        return self
    
    def __next__(self):        
        if self.curr < self.num_batches:
            input_ids, attention_mask, const_embed, idx, true_labels = self.dataset[self.slices[self.curr]]
            self.curr += 1
            return input_ids, attention_mask, const_embed, idx%self.batch_size, true_labels
        raise StopIteration

In [5]:
class Encoder(torch.nn.Module):
    
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.device = config.device
        self.mask_id = config.tokenizer.encode(NUM_MASK, add_special_tokens=False)[0]
        self.encoder = config.encoder.to(config.device)
        
    def forward(self, input_ids, attention_mask, const_embed):
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        const_embed = const_embed.to(self.device)
        
        # Getting the hidden output
        x = self.encoder(input_ids, attention_mask).last_hidden_state
        
        # Combining constants and numbers
        embed = torch.cat((x[input_ids==self.mask_id],const_embed), dim=0)
        
        return x, embed

In [6]:
class Decoder(torch.nn.Module):
    
    def __init__(self, config): 
        super(Decoder, self).__init__()
        self.device = config.device
        self.embedding_size = config.embedding_size
        self.num_tokens = config.num_tokens
        self.K = config.K
        self.op_embed = config.op
        
        # standard transformer decoder
        transformer_decoder_layer = torch.nn.TransformerDecoderLayer(self.embedding_size, nhead=config.nhead, batch_first=True)
        self.transformer_decoder = torch.nn.TransformerDecoder(transformer_decoder_layer, num_layers=config.nlayer)
        
        # softmax for the operations
        self.op_softmax = torch.nn.Softmax(dim=2)
        
        # softmax for the numbers
        self.num_softmax = torch.nn.Softmax(dim=0)

        # decreasing dimensionality to match embedding size: <op, num, num, num*num> = 3072 -> 768
        self.exp_encoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size*4, self.embedding_size*2),
            torch.nn.SiLU(),
            torch.nn.Linear(self.embedding_size*2, self.embedding_size*2),
            torch.nn.SiLU(),
            torch.nn.Linear(self.embedding_size*2, self.embedding_size),
        )
        
        # converting back into new <op, num, num> for loss calculation
        self.op_decoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size, int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), self.embedding_size),
        )
        self.num1_decoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size, int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), self.embedding_size),
        )
        self.num2_decoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size, int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), self.embedding_size),
        )
        
    def __apply_to_nums(self, f, num, idx, batch_size):
        new_nums = torch.zeros(batch_size, self.K, num.shape[0]).to(self.device)
        for x in range(batch_size):
            new_nums[x,:,idx==x] = f(num[idx==x]).T
        return new_nums  
    
    # query: [batch_size, K, 768]
    # prob_embed: [batch_size, num_tokens, 768]
    # num_embed: [num_nums, 768]
    # op_embed: [num_ops, 768]
    def forward(self, query, prob_embed, num_embed, idx): 
        query = query.to(self.device)
        prob_embed = prob_embed.to(self.device)
        num_embed = num_embed.to(self.device)
        op_embed = self.op_embed.to(self.device)
        idx = idx.to(self.device)
        batch_size = query.shape[0]
        num_ops = op_embed.shape[0]
        num_nums = num_embed.shape[0]

        # ----------------------------
        # Step 1 - transformer decoder
        # ----------------------------
        query = self.transformer_decoder(query, prob_embed) # [batch_size, K, 768] -> [batch_size, K, 768]
        
        # -------------------------------------------------------------------------------
        # Step 2 - decoding the output into three embeddings of size 768 (op, num1, num2)
        # -------------------------------------------------------------------------------
#         query = self.exp_decoder(query) # [batch_size, K, 768] -> [batch_size, K, 768*3]
#         op,num1,num2 = torch.split(query, self.embedding_size, dim=2)
        op = self.op_decoder(query)
        num1 = self.num1_decoder(query)
        num2 = self.num2_decoder(query)
        
        # ----------------------------------------------------
        # Step 3 - creating embedding for the found expression  
        # ----------------------------------------------------
        query = self.exp_encoder(torch.cat((op,num1,num2,num1*num2), dim=2)) # [batch_size, K, 3072] -> [batch_size, K, 768]
        
        # ------------------------------------------------------------------------------
        # Step 4 - Taking a dot product between the predicted and stored true embeddings
        # ------------------------------------------------------------------------------
        # making sure params have correct dimension
        num_embed = num_embed[:,None,:].expand(-1,K,-1) # [num_nums, 768] -> [num_nums, K, 768]
        op_embed = op_embed[None,None,:,:].repeat(batch_size,self.K,1,1) # [num_ops, 768] -> [batch_size, K, num_ops, 768]

        # dot product - calculating the similarity between each op/num prediction and stored embeddings
        num1 = (num1[idx]*num_embed).sum(dim=2) # [number_of_nums, K]
        num2 = (num2[idx]*num_embed).sum(dim=2) # [number_of_nums, K]        
        op = op[:,:,None,:].expand(-1,-1,num_ops,-1) # [batch_size, K, 768] -> # [batch_size, K, num_ops, 768]
        op = (op*op_embed).sum(dim=3) # [batch_size, K, num_ops, 768] -> [batch_size, K, num_ops]
        
        # -----------------------------------------------------------------------------------------
        # Step 5 - Finding the softmax for the similarity between the predicted and true embeddings
        # -----------------------------------------------------------------------------------------
        op = self.op_softmax(op) # [batch_size, K, num_ops] (ie op[1,2] would be the operator prediction probabilities for problem2, query3)
        num1 = self.__apply_to_nums(self.num_softmax, num1, idx, batch_size) # [batch_size,K,num_nums]
        num2 = self.__apply_to_nums(self.num_softmax, num2, idx, batch_size) # [batch_size,K,num_nums]
        
        return query, op, num1, num2

In [7]:
class Loss:
    def __init__(self, config):
        self.K = config.K
        self.embedding_size = config.embedding_size
        self.device = config.device
        
    # --------------------------------
    # loss/accuracy calculation
    # --------------------------------
    def __lmatch(self, true_op, true_num1, true_num2, op, num1, num2):
        true_op = torch.clone(true_op)
        true_op[:,:,op2id['None']] = 0 # setting all none operators to zero so ignored during calculation

        # rows: true exp, cols: pred exp (ie 0,1 would be the cost for true exp 0 with pred exp 1)
        # [8,K*K]
        op_mat = (true_op.repeat_interleave(K,dim=1)*op.repeat(1,K,1)).sum(dim=-1)
        num1_mat = (true_num1.repeat_interleave(K,dim=1)*num1.repeat(1,K,1)).sum(dim=-1)
        num2_mat = (true_num2.repeat_interleave(K,dim=1)*num2.repeat(1,K,1)).sum(dim=-1)

        return -(op_mat+num1_mat+num2_mat).reshape(op_mat.shape[0],K,K)
    
    # consider trying to write this in pytorch if have time
    def __hungarian_algorithm(self, true_op, true_num1, true_num2, op, num1, num2, batch_size):
        m = self.__lmatch(true_op, true_num1, true_num2, op, num1, num2)
        permutations = torch.stack(tuple([torch.tensor(linear_sum_assignment(m.detach().cpu().numpy()[x])[1]) for x in range(batch_size)]))
        batch_idx = torch.arange(batch_size).unsqueeze(-1)
        return op[batch_idx, permutations], num1[batch_idx, permutations], num2[batch_idx, permutations]
    
    # expects the optimal permutation from the hungarian algorithm
    def __final_loss(self, true_op, true_num1, true_num2, op, num1, num2):        
        log_op = torch.log((true_op*op).sum(-1))
        
        log_num1 = ((true_num1*num1).sum(-1))
        log_num1[log_num1.nonzero(as_tuple=True)] = torch.log(log_num1[log_num1.nonzero(as_tuple=True)])
        
        log_num2 = ((true_num2*num2).sum(-1))
        log_num2[log_num2.nonzero(as_tuple=True)] = torch.log(log_num2[log_num2.nonzero(as_tuple=True)])
 
        return (-log_op-log_num1-log_num2).sum(-1).mean()

    # expects the optimal permutation from the hungarian algorithm
    # gets the number of correct examples for a batch
    def __correct(self, true_op, true_num1, true_num2, op, num1, num2):
        invalid = true_op.argmax(-1)==op2id['None']
        op_correct = true_op.argmax(-1)==op.argmax(-1)
        num1_correct = num1.argmax(-1)==true_num1.argmax(-1)
        num2_correct = num2.argmax(-1)==true_num2.argmax(-1)
        num1_correct[invalid]=True
        num2_correct[invalid]=True
        
        correct = ((op_correct&num1_correct&num2_correct).sum(1)==self.K).sum()
        num_ops = true_op.shape[0]*true_op.shape[1]
        op_correct = op_correct.sum()
        num_nums = len(num1_correct[~invalid])
        num1_correct = num1_correct[~invalid].sum()
        num2_correct = num2_correct[~invalid].sum()
        return correct, op_correct, num_ops, num1_correct, num2_correct, num_nums
    
    # --------------------------------
    # true label construction
    # --------------------------------   
    def __expand_tensor(self, curr, new):
        if curr is None:
            curr = new
        else:
            curr =  torch.cat((curr,new), dim=0)
        return curr

    def __prepare_labels(self, curr_layer, found_nums_idx, pnum, true_labels, num_nums):
        true_op, true_num1, true_num2, layer_idx = true_labels[pnum]

        # Creating num1 labels [batch_size, K, num_nums]
        num1_indices = [found_nums_idx[pnum][x] for x in true_num1[layer_idx==curr_layer] if x in found_nums_idx[pnum]]          
        num1_labels = torch.zeros(len(num1_indices), num_nums)
        num1_labels[torch.arange(len(num1_indices)),num1_indices]=1
        num1_labels = torch.cat((num1_labels, torch.zeros(num_nums)[None,:].repeat(self.K-len(num1_indices),1)), dim=0) 

        # Creating num2 labels [batch_size, K, num_nums]
        num2_indices = [found_nums_idx[pnum][x] for x in true_num2[layer_idx==curr_layer] if x in found_nums_idx[pnum]]
        num2_labels = torch.zeros(len(num2_indices), num_nums)
        num2_labels[torch.arange(len(num2_indices)),num2_indices]=1
        num2_labels = torch.cat((num2_labels, torch.zeros(num_nums)[None,:].repeat(self.K-len(num2_indices),1)), dim=0)
        
        # Creating op labels [batch_size, K, num_ops]
        op_labels = true_op[layer_idx==curr_layer]
        pad = torch.zeros(len(op2id))
        pad[op2id['None']] = 1
        op_labels = torch.cat((op_labels, pad[None,:].repeat(K-len(op_labels),1)), dim=0)

        return op_labels.to(self.device), num1_labels.to(self.device), num2_labels.to(self.device)

    def prepare_all_labels(self, curr_layer, found_nums_idx, true_labels, num_nums, batch_size):
        true_op = None
        true_num1 = None
        true_num2 = None
        # Aggregating labels
        for pnum in range(batch_size):
            op_labels, num1_labels, num2_labels = self.__prepare_labels(curr_layer, found_nums_idx, pnum, true_labels, num_nums)
            true_op = self.__expand_tensor(true_op, op_labels[None,:,:])
            true_num1 = self.__expand_tensor(true_num1, num1_labels[None,:,:])
            true_num2 = self.__expand_tensor(true_num2, num2_labels[None,:,:])

        # Randomizing equation order to help with generalization
        batch_idx = torch.arange(batch_size).unsqueeze(-1)
        perm = torch.rand(batch_size,K).argsort(dim=1)
        true_op = true_op[batch_idx,perm,:]
        true_num1 = true_num1[batch_idx,perm,:]
        true_num2 = true_num2[batch_idx,perm,:]

        return true_op, true_num1, true_num2
    
    # --------------------------------
    # update for next layer
    # --------------------------------       
    # Given the predictions, update the running number dicts and get a new idx/embeddings list
    def prepare_next(self, query, op, num1, num2, found_nums_list, found_nums_idx, idx, num_embed, batch_size, decoder=None):
        new_ops = op.argmax(-1).flatten()
        new_num1 = found_nums_list[num1.argmax(-1).cpu().flatten()]
        new_num2 = found_nums_list[num2.argmax(-1).cpu().flatten()]
        indices_added = []
        new_idx = []
        found_nums_list = deepcopy(found_nums_list)
        found_nums_idx = deepcopy(found_nums_idx)

        found_nums_list = list(found_nums_list)

        # updating the number lists/embeddings/indices
        j = 0
        for i, new_op, new_num1, new_num2 in zip(torch.arange(batch_size).repeat_interleave(K), new_ops, new_num1, new_num2):
            #eq = (new_op.item(), found_nums_list[new_num1], found_nums_list[new_num2])
            eq = (new_op.item(), new_num1, new_num2)
            i = i.item()
            if eq not in found_nums_idx[i] and eq[0]!=op2id['None']:
                new_idx.append(i)
                indices_added.append(j)
                found_nums_list.append(eq)
                found_nums_idx[i][eq] = len(found_nums_list)-1
            j += 1

        found_nums_list = np.array(found_nums_list, dtype=object)
        new_idx = torch.tensor(new_idx)

        # Pass the decoder if in training mode so true embeddings can be calculated
        if decoder is not None:
            temp = [(o,found_nums_idx[pnum][x1],found_nums_idx[pnum][x2]) for pnum in range(batch_size) for o,x1,x2 in found_nums_list[-len(new_idx):][new_idx==pnum]]
            temp = list(map(list,zip(*temp)))
            o = decoder.op_embed[temp[0]].to(self.device)
            x1 = num_embed[temp[1]]
            x2 = num_embed[temp[2]]
            new_embed = torch.cat((o,x1,x2,x1*x2), dim=-1)
            new_embed = decoder.exp_encoder(new_embed)
        else:
            new_embed = query.reshape(K*batch_size,self.embedding_size)[indices_added]

        # Preparing the newly added nums to be concatenated with the problem embedding
        pad = torch.zeros(self.embedding_size)[None,:].to(self.device)
        pad_embed = lambda x: torch.cat((x, pad.repeat(self.K-len(x),1)), dim=0)
        to_cat = torch.stack([pad_embed(new_embed[new_idx==x]) for x in range(batch_size)])
        
        if decoder is not None:
            print(found_nums_list)

        return found_nums_list, found_nums_idx, torch.cat((idx,new_idx)).int(), torch.cat((num_embed,new_embed), dim=0), to_cat
    
    # --------------------------------
    # main function
    # --------------------------------  
    def calculate(self, true_op, true_num1, true_num2, op, num1, num2, batch_size):    
#         print('--------------------------------------------')
#         print(true_op)
#         print(true_num1)
#         print(true_num2)
#         print('--------------------------------------------')
        op, num1, num2 = self.__hungarian_algorithm(true_op, true_num1, true_num2, op, num1, num2, batch_size) # getting optimal permutation
        
        loss = self.__final_loss(true_op, true_num1, true_num2, op, num1, num2)
        correct, op_correct, num_ops, num1_correct, num2_correct, num_nums = self.__correct(true_op, true_num1, true_num2, op, num1, num2)
        
        return loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums

In [24]:
data['test']['incremental'][0]

"['const_100 - 5', 'const_100 + 31.1', None, None, None, None] ; ['x2 * const_100', None, None, None, None, None] ; ['x3 / x1', None, None, None, None, None] ; ['x4 - const_100', None, None, None, None, None]"

In [28]:
config.combined['test']['literals'][config.combined['test']['idx']<4]

array(['5.0', '31.1', '14.0', '1000.0', '4.0', '28.0', '26.0', '24.0',
       'const_10', 'const_100', 'const_1', 'const_2', 'const_3',
       'const_4', 'const_10', 'const_100', 'const_1', 'const_2',
       'const_3', 'const_10', 'const_100', 'const_1000', 'const_neg_1',
       'const_0_25', 'const_1', 'const_2', 'const_3', 'const_4',
       'const_6', 'const_10', 'const_12'], dtype='<U32')

In [42]:
temp = np.array(data['validation']['incremental'].map(lambda x: x.split(' ; ')).map(len))

In [43]:
(temp<=3).sum()

1616

In [8]:
class FirstLayer(torch.nn.Module):
    def __init__(self, config, layer):
        super(FirstLayer, self).__init__()
        self.device = config.device
        self.encoder = Encoder(config).to(self.device)
        self.decoder = Decoder(config).to(self.device)
        self.loss = Loss(config)
        self.query = config.query
        self.layer = layer
    
    def forward(self, input_ids, attention_mask, const_embed, idx, true_labels):
        batch_size = input_ids.shape[0]
        
        # encoder
        prob_embed, num_embed = self.encoder(input_ids, attention_mask, const_embed)
        
        # creating initial params
        query = self.query[None,:,:].repeat(batch_size,1,1)
        convert_to_dict = lambda x: dict(zip(np.arange(len(x)),x))
        temp = np.arange(len(idx))
        found_nums_idx = [convert_to_dict(temp[idx==x]) for x in range(batch_size)]
        found_nums_list = np.zeros(len(idx)).astype(int)
        for i,x in enumerate(torch.bincount(idx)):
            found_nums_list[idx==i] = np.arange(x)
        found_nums_list = np.array(found_nums_list, dtype=object)
                
        # decoder        
        query, op, num1, num2 = self.decoder(query, prob_embed, num_embed, idx)
        
        # loss
        true_op, true_num1, true_num2 = self.loss.prepare_all_labels(self.layer, found_nums_idx, true_labels, num1.shape[-1], batch_size)
        loss, correct, _, _, _, _, _ = self.loss.calculate(true_op, true_num1, true_num2, op, num1, num2, batch_size)
        
        # preparing for the next layer (if train mode, use the true values for teacher forcing)
        if self.training:
            found_nums_list, found_nums_idx, idx, num_embed, to_cat = self.loss.prepare_next(query, true_op, true_num1, 
                                                                                             true_num2, found_nums_list, 
                                                                                             found_nums_idx, idx, num_embed, 
                                                                                             batch_size, self.decoder)
        else:
            found_nums_list, found_nums_idx, idx, num_embed, to_cat = self.loss.prepare_next(query, op, num1, 
                                                                                             num2, found_nums_list, 
                                                                                             found_nums_idx, idx, num_embed, 
                                                                                             batch_size)
        
        return query, loss, correct, found_nums_list, found_nums_idx, idx, num_embed, torch.cat((prob_embed, to_cat), dim=1)

In [9]:
class Layer(torch.nn.Module):
    def __init__(self, config, layer):
        super(Layer, self).__init__()
        self.device = config.device
        self.decoder = Decoder(config).to(self.device)
        self.loss = Loss(config)
        self.layer = layer
        
    def forward(self, query, found_nums_list, found_nums_idx, idx, num_embed, prob_embed, true_labels):
        batch_size = query.shape[0]
        
        # forward pass
        query, op, num1, num2 = self.decoder(query, prob_embed, num_embed, idx)
        
        # loss
        true_op, true_num1, true_num2 = self.loss.prepare_all_labels(self.layer, found_nums_idx, true_labels, num1.shape[-1], batch_size)
        
        loss, correct, _, _, _, _, _ = self.loss.calculate(true_op, true_num1, true_num2, op, num1, num2, batch_size)
        
        # preparing for the next layer (if train mode, use the true values for teacher forcing)
        if self.training:
            found_nums_list, found_nums_idx, idx, num_embed, to_cat = self.loss.prepare_next(query, true_op, true_num1, 
                                                                                             true_num2, found_nums_list, 
                                                                                             found_nums_idx, idx, num_embed, 
                                                                                             batch_size, self.decoder)
        else:
            found_nums_list, found_nums_idx, idx, num_embed, to_cat = self.loss.prepare_next(query, op, num1, 
                                                                                             num2, found_nums_list, 
                                                                                             found_nums_idx, idx, num_embed, 
                                                                                             batch_size)
        
        return query, loss, correct, found_nums_list, found_nums_idx, idx, num_embed, torch.cat((prob_embed, to_cat), dim=1)

In [10]:
class FullModel(torch.nn.Module):
    def __init__(self, config):
        super(FullModel, self).__init__()
        self.device = config.device
        self.batch_size = config.batch_size
        self.K = config.K
        self.max_layers = config.max_layers
        self.firstlayer = FirstLayer(config, 0).to(self.device)
        self.layers = torch.nn.ModuleList([Layer(config, i+1).to(self.device) for i in range(config.batch_size-1)])
        
        # problem encoder - [batch_size,embedding_size,num_tokens+K] -> [batch_size,embedding_size,num_tokens]
        # used for mixing in information about newly found expressions
        self.prob_encoder = torch.nn.Sequential(
            torch.nn.Linear(config.num_tokens+self.K, config.num_tokens+int(self.K/2)),
            torch.nn.SiLU(),
            torch.nn.Linear(config.num_tokens+int(self.K/2), config.num_tokens+int(self.K/2)),
            torch.nn.SiLU(),
            torch.nn.Linear(config.num_tokens+int(self.K/2), config.num_tokens),
        )
        
        
    def forward(self, input_ids, attention_mask, const_embed, idx, true_labels):
        total_loss = None
        total_correct = torch.zeros(self.max_layers).int().to(self.device)
        
        query, loss, correct, found_nums_list, found_nums_idx, idx, num_embed, prob_embed = self.firstlayer(input_ids, attention_mask, const_embed, idx, true_labels)
        prob_embed = self.prob_encoder(prob_embed.permute(0,2,1)).permute(0,2,1)
        total_loss = loss
        total_correct[0] += correct
        
        query, loss, correct, found_nums_list, found_nums_idx, idx, num_embed, prob_embed = self.layers[0](query, found_nums_list, found_nums_idx, idx, num_embed, prob_embed, true_labels)
        prob_embed = self.prob_encoder(prob_embed.permute(0,2,1)).permute(0,2,1)
        total_loss += loss
        total_correct[1] += correct
        
        return found_nums_list, total_loss, total_correct

In [11]:
# try:
#     data = util.load_data()
#     config = Config(data, reload=False)
#     train_dataloader = DataLoader(MathQA(config, 'train'), batch_size=config.batch_size, shuffle=False)
#     input_ids, attention_mask, const_embed, idx, true_labels = next(iter(train_dataloader))
#     util.set_seed(config.seed)
#     model = FullModel(config)
#     model.to(config.device)
#     found_nums_list, total_loss, total_correct = model(input_ids, attention_mask, const_embed, idx, true_labels)
# finally:
#     #del model
#     torch.cuda.empty_cache()

In [12]:
class Trainer:
    def __init__(self, config):
        self.device = config.device
        self.model = FullModel(config).to(self.device)
        self.grad_norm_clip = config.grad_norm_clip
        self.batch_size = config.batch_size
        self.max_layers = config.max_layers
        
        self.opt = config.opt(self.model.parameters(), lr=config.lr)
        
    def __backpropagate(self, loss):
        self.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clip)
        self.opt.step()
        
    def train(self, dataloader, epoch, num_batches=None, path=None):
        total_loss = 0
        total_batches = num_batches if num_batches else len(dataloader)
        total_correct = torch.zeros(self.max_layers).int()
        total_examples = 0
        self.model.train(True) # setting to train mode
        with tqdm(dataloader, total=total_batches) as progress_bar:
            for item in progress_bar:
                progress_bar.set_description(f'Epoch {epoch}')
                input_ids, attention_mask, const_embed, idx, true_labels = item
                
                # forward pass
                found_nums_list, loss, correct = self.model(input_ids, attention_mask, const_embed, idx, true_labels)
                
                # backpropagating the loss
                self.__backpropagate(loss)
                
                total_loss += loss.detach().item()
                total_correct += correct.detach().cpu()
                total_examples += self.batch_size
                
                progress_bar.set_postfix(loss=loss.detach().item())                
                progress_bar.update(1)
                
                if num_batches:
                    num_batches -= 1
                if num_batches == 0:
                    progress_bar.close()
                    break
        return total_loss/total_batches, total_correct/total_examples
    
    def val(self, dataloader, num_batches=None):
        total_loss = 0
        total_batches = num_batches if num_batches else len(dataloader)
        total_correct = torch.zeros(self.max_layers).int()
        total_examples = 0
        self.model.eval() # setting to eval mode
        with torch.no_grad():
            with tqdm(dataloader, total=total_batches) as progress_bar:
                for item in progress_bar:
                    progress_bar.set_description(f'Validation')
                    input_ids, attention_mask, const_embed, idx, true_labels = item

                    # forward pass
                    query, loss, correct = self.model(input_ids, attention_mask, const_embed, idx, true_labels)

                    total_loss += loss.detach().item()
                    total_correct += correct.detach().cpu()
                    total_examples += self.batch_size
           
                    progress_bar.update(1)

                    if num_batches:
                        num_batches -= 1
                    if num_batches == 0:
                        progress_bar.close()
                        break
        return total_loss/total_batches, total_correct/total_examples
            
    def fit(self, epochs, num_batches=None):
        for epoch in range(epochs):
            train_dataloader = DataLoader(MathQA(config, 'test'), batch_size=config.batch_size, shuffle=False)
            val_dataloader = DataLoader(MathQA(config, 'test'), batch_size=config.batch_size, shuffle=False)
            train_loss, train_acc = self.train(train_dataloader, epoch+1, num_batches)
            val_loss, val_acc = self.val(val_dataloader, num_batches)
            
            train_metrics = [train_loss]
            train_metrics.extend([acc for acc in train_acc.numpy()])
            val_metrics = [val_loss]
            val_metrics.extend([acc for acc in val_acc.numpy()])
            index = ['Loss']
            index.extend([f'Layer {x+1} Accuracy' for x in range(len(train_acc))])
            
            df = pd.DataFrame({'Train':train_metrics, 'Validation':val_metrics}, index = index)
            
            display(df)

In [13]:
class Config:
    
    def __init__(self, data, reload=False):
        self.device = torch.device('cuda:0')
        self.model_name = r'D:\jupyter_notebooks\Mathsage\Connor\models\distilroberta-base-encoder-mathqa'
        self.tokenizer_name = 'distilroberta-base'
        self.seed = 3
        self.batch_size = 4
        
        # model params
        self.K = 6
        self.max_layers = 8
        self.nhead = 6
        self.nlayer = 6
        self.embedding_size = 768
        self.num_tokens = 392
        
        # training params
        self.lr = 1e-5
        self.opt = torch.optim.AdamW
        self.grad_norm_clip = 1.0
        
        # Loading preprocessed data
        if reload:
            p = Preprocess(data, self.K, self.embedding_size, self.num_tokens, self.model_name, self.tokenizer_name)
            util.save_obj(f'{OBJ_DIR}preprocess.pickle', p)
        else:
            p = util.load_obj(f'{OBJ_DIR}preprocess.pickle')
                
        self.num = p.num
        self.const = p.const
        self.combined = p.combined
        self.query = p.query
        self.op = p.op
        self.text = p.text
        self.labels = p.labels
        self.encoder = p.encoder
        self.tokenizer = p.tokenizer
        self.tokenized = p.tokenized

In [14]:
#config.labels['train'][2]

In [15]:
#data['test']['incremental'][4]

In [16]:
try:
    data = util.load_data()
    config = Config(data, reload=False)
    util.set_seed(config.seed)
    trainer = Trainer(config)
    trainer.fit(epochs = 10, num_batches=1)
finally:
    #del model
    torch.cuda.empty_cache()

  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (2, 1, (3, 0, 8)) (0, (3, 0, 8), 3) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,37.479973,14.423472
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (2, 1, (3, 0, 8)) (0, (3, 0, 8), 3) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,17.614445,13.864141
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (2, 1, (3, 0, 8)) (0, (3, 0, 8), 3) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,15.486795,12.273218
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (0, (3, 0, 8), 3) (2, 1, (3, 0, 8)) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,13.388052,10.754713
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (0, (3, 0, 8), 3) (2, 1, (3, 0, 8)) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,12.492888,10.181341
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (2, 1, (3, 0, 8)) (0, (3, 0, 8), 3) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,12.22501,9.416731
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (0, (3, 0, 8), 3) (2, 1, (3, 0, 8)) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,11.082148,9.079498
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (2, 1, (3, 0, 8)) (0, (3, 0, 8), 3) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,10.555177,8.65107
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (1, 3, 0)
 (0, 3, 1) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (0, (3, 0, 8), 3) (2, 1, (3, 0, 8)) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,9.779472,8.317472
Layer 1 Accuracy,0.0,0.0
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


  0%|          | 0/1 [00:00<?, ?it/s]

[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0)]
[0 1 0 1 2 0 0 1 2 3 3 4 5 6 7 8 1 2 3 4 5 6 2 3 4 5 6 7 8 9 10 (0, 3, 1)
 (1, 3, 0) (3, 0, 8) (0, 0, 1) (2, 2, 0) (2, (0, 3, 1), 3)
 (0, (3, 0, 8), 3) (2, 1, (3, 0, 8)) (0, (0, 0, 1), 1) (0, (2, 2, 0), 4)]


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation
Loss,10.034519,8.055159
Layer 1 Accuracy,0.0,0.25
Layer 2 Accuracy,0.0,0.0
Layer 3 Accuracy,0.0,0.0
Layer 4 Accuracy,0.0,0.0
Layer 5 Accuracy,0.0,0.0
Layer 6 Accuracy,0.0,0.0
Layer 7 Accuracy,0.0,0.0
Layer 8 Accuracy,0.0,0.0


In [22]:
config.combined['test']['literals'][config.combined['test']['idx']<4]

array(['5.0', '31.1', '14.0', '1000.0', '4.0', '28.0', '26.0', '24.0',
       'const_10', 'const_100', 'const_1', 'const_2', 'const_3',
       'const_4', 'const_10', 'const_100', 'const_1', 'const_2',
       'const_3', 'const_10', 'const_100', 'const_1000', 'const_neg_1',
       'const_0_25', 'const_1', 'const_2', 'const_3', 'const_4',
       'const_6', 'const_10', 'const_12'], dtype='<U32')

In [17]:
config.labels['test'][0]

(tensor([[0., 1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0.],
         [0., 1., 0., 0., 0., 0.]]),
 array([3, 3, (0, 3, 1), (2, (0, 3, 1), 3),
        (3, (2, (0, 3, 1), 3), (1, 3, 0))], dtype=object),
 array([0, 1, 3, (1, 3, 0), 3], dtype=object),
 array([0, 0, 1, 2, 3]))

In [18]:
config.labels['test'][1]

(tensor([[0., 0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0.],
         [0., 0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0.]]),
 array([0, 1, (3, 0, 8), (2, 1, (3, 0, 8)), (0, (3, 0, 8), 3), 1,
        (2, 1, (4, (0, (3, 0, 8), 3), 2)),
        (1, (2, 1, (4, (0, (3, 0, 8), 3), 2)), 1)], dtype=object),
 array([8, (3, 0, 8), 3, 2, 2, (4, (0, (3, 0, 8), 3), 2), 1,
        (2, (2, 1, (3, 0, 8)), 2)], dtype=object),
 array([0, 1, 1, 2, 2, 3, 4, 5]))

In [19]:
config.labels['test'][2]

(tensor([[1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0.]]),
 array([0, (0, 0, 1), (0, 0, 1)], dtype=object),
 array([1, 1, (0, (0, 0, 1), 1)], dtype=object),
 array([0, 1, 2]))

In [20]:
config.labels['test'][3]

(tensor([[0., 0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.]]),
 array([2, (2, 2, 0), (0, (2, 2, 0), 4), (0, (2, 2, 0), 4)], dtype=object),
 array([0, 4, 4, (0, (0, (2, 2, 0), 4), 4)], dtype=object),
 array([0, 1, 2, 3]))