# MathQA Decoder

#### Imports

In [2]:
from enum import Enum
import math
import os
import re
from copy import deepcopy
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModel
import pickle
from scipy.optimize import linear_sum_assignment
if get_ipython().__class__.__name__ == 'ZMQInteractiveShell':
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm
import unittest

# remove
import time

#### Constants

In [3]:
K = 6
MAX_LAYERS = 8
MAX_TOKENS = 392
EMBEDDING_SIZE = 768

DATA_PATH = './dataset/'
SET_NAMES = ['train', 'validation', 'test']
ENCODER_MODEL = 'distilroberta-base' # A more optimized version of roberta obtaining 95% of its performance
DEVICE = 'cuda:0'
NUM_MASK = '<num>'
WORKING_DIR = 'TEMP/'

OBJ_DIR = 'pickle/'


class Op(Enum):
    ADD = '+'
    SUB = '-'
    MULT = '*'
    DIV = '/'
    POW = '^'
    
class Const(Enum):
    CONST_NEG_1 = 'const_neg_1' # I added this
    CONST_0_25 = 'const_0_25'
    CONST_0_2778 = 'const_0_2778'
    CONST_0_33 = 'const_0_33'
    CONST_0_3937 = 'const_0_3937'
    CONST_1 = 'const_1'
    CONST_1_6 = 'const_1_6'
    CONST_2 = 'const_2'
    CONST_3 = 'const_3'
    CONST_PI = 'const_pi'
    CONST_3_6 = 'const_3_6'
    CONST_4 = 'const_4'
    CONST_5 = 'const_5'
    CONST_6 = 'const_6'
    CONST_10 = 'const_10'
    CONST_12 = 'const_12'
    CONST_26 = 'const_26'
    CONST_52 = 'const_52'
    CONST_60 = 'const_60'
    CONST_100 = 'const_100'
    CONST_180 = 'const_180'
    CONST_360 = 'const_360'
    CONST_1000 = 'const_1000'
    CONST_3600 = 'const_3600'

values = [-1, 0.25, 0.2778, 0.33, 0.3937, 1, 1.6, 2, 3, math.pi, 3.6, 4, 5, 6, 10, 12, 26, 52, 60, 100, 180, 360, 1000, 3600]
const2val = {k:v for k,v in zip(Const._value2member_map_.keys(), values)}    

op2id = {k:v for k,v in zip(Op._value2member_map_.keys(), range(len(Op._value2member_map_)))}
op2id['None'] = 5
id2op = {v:k for k,v in op2id.items()}
const2id = {k:v for k,v in zip(Const._value2member_map_.keys(), range(len(Const._value2member_map_)))}
id2const = np.array(list(const2id.keys()))

torch.autograd.set_detect_anomaly(True)
torch.set_printoptions(sci_mode=False)

class Util():
    def load_obj(self, path):
        with open(path, 'rb') as f:
            o = pickle.load(f)
        return o
    
    def save_obj(self, path, o):
        with open(path, 'wb') as f:
            pickle.dump(o, f)
            
    def load_data(self):
        return {name:pd.read_csv(f'{DATA_PATH}{name}.csv') for name in SET_NAMES}
    
    def set_seed(self, seed):
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
util = Util()

## Creating the Model

In [4]:
class Preprocess:
    
    def __init__(self, data, K, embedding_size, max_tokens, model_name, tokenizer_name):
        # Randomly initializing embeddings
        self.op = torch.randn(len(op2id),embedding_size)
        self.query = torch.randn(K,embedding_size)
        const = torch.randn(len(const2id),embedding_size)
        
        self.num, self.text = self.__get_nums_and_mask(data)
        self.const = self.__get_const(const)
        self.combined = self.__combine(self.num, self.const)
        self.labels = {name:self.__get_label(data,name) for name in SET_NAMES}
        
        # Encoder
        self.encoder, self.tokenizer = self.__setup_model(model_name, tokenizer_name)
        self.tokenized = {name:self.__tokenize_data(self.tokenizer, self.text[name], max_tokens) for name in SET_NAMES}
        
    # -----------------------------------------------------------
    # util
    # -----------------------------------------------------------
    def __expand_tensor(self, curr, new):
        if curr is None:
            curr = new
        else:
            curr =  torch.cat((curr,new), dim=0)
        return curr
    
    def __flatten(self, arr):
        idx = np.concatenate([[i]*len(x) for i,x in enumerate(arr)])
        flattened = np.concatenate(arr) 
        return idx, flattened
    
    # -----------------------------------------------------------
    # tokenization
    # -----------------------------------------------------------
    def __setup_model(self, model_name, tokenizer_name):
        model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        tokenizer.add_special_tokens({'additional_special_tokens':[NUM_MASK]})
        model.resize_token_embeddings(len(tokenizer))
        return model, tokenizer
    
    def __tokenize_data(self, tokenizer, text, max_tokens):
        tokenization = lambda x: tokenizer(x, padding='max_length', max_length=max_tokens, truncation=True)

        tokenized = list(map(tokenization, text))
        input_ids = torch.stack([torch.tensor(x['input_ids']) for x in tokenized])
        attention_mask = torch.stack([torch.tensor(x['attention_mask']) for x in tokenized])

        return {'input_ids':input_ids.long(), 'attention_mask':attention_mask.int()}
    
    # -----------------------------------------------------------
    # literals
    # -----------------------------------------------------------
    # Gets the numbers listed in a problem
    # Once found, numbers are masked using a number mask
    def __get_nums_and_mask(self, data):
        nums = {name:[] for name in SET_NAMES}
        problems = {name:[] for name in SET_NAMES}
        num_idx = {name:[] for name in SET_NAMES}
        for name in SET_NAMES:
            for i,problem in enumerate(data[name].problem):
                num = re.compile('([+-]?((\d+(\.\d*)?)|(\.\d+)))') # normal num
                big = re.compile(r'(-?\d{1,3}(,\d{3})+(\.\d*)?)') # num with comma

                big_results = re.finditer(big, problem)
                problem = re.sub(big, NUM_MASK, problem)        
                num_results = re.finditer(num, problem)
                problem = re.sub(num, NUM_MASK, problem)

                # Getting the combined numbers in order of occurence
                combined = [x for x in num_results]
                combined.extend([x for x in big_results])
                combined = sorted(combined, key=lambda x: x.start(0))

                combined = [float(x.group(0).replace(',','')) for x in combined]

                nums[name].append(np.array(combined))
                problems[name].append(problem)
            num_idx[name], nums[name] = self.__flatten(np.array(nums[name], dtype=object))
            problems[name] = np.array(problems[name])
        return {name:{'idx':torch.tensor(num_idx[name]), 'literals':nums[name]} for name in SET_NAMES}, problems
    
    def __get_const(self, embed):
        const = {}
        for name in SET_NAMES:
            const_pred = util.load_obj(f'{OBJ_DIR}constants.pickle')['pred'][name]
            const_idx, const_id = np.where(const_pred==1)
            literals = id2const[const_id]
            const[name] = {'idx':torch.tensor(const_idx), 'literals':literals, 'embed_idx':const_id, 'embed':embed}
        return const
    
    def __combine(self, num, const):
        combined = {}
        for name in SET_NAMES:
            combined[name] = {'idx':torch.cat((num[name]['idx'],const[name]['idx'])), 
                              'literals':np.concatenate((num[name]['literals'],const[name]['literals']))}
        return combined
    
    # -----------------------------------------------------------
    # labels
    # -----------------------------------------------------------
    def __get_init_exp(self, data, name):
        reorder = lambda x: [x[1], x[0], x[2]]
        convert_to_arr = lambda d: [reorder(x.split()) for idx, arr in enumerate(d.split(' ; ')) for x in eval(arr) if x is not None and idx == 0]
        return data[name]['incremental'].map(convert_to_arr)

    def __get_num_label(self, name, num, i):
        idx = self.combined[name]['idx']
        literals = self.combined[name]['literals'][idx==i]
        label = torch.zeros(literals.shape)
        if num in const2val:
            location = np.where(literals==num)[0]
            if len(location) == 0:
                return torch.tensor(-1)
            else:
                location = location[0]
        else:
            location = np.where(literals==str(float(num)))[0][0]
        return torch.tensor(location)

    def __pad_op(self, op_label):
        temp = torch.zeros(op_label.shape[1])
        temp[op2id['None']] = 1
        return torch.cat((op_label, temp[None,:].repeat(K-op_label.shape[0],1)), dim=0)[None,:,:]

    def __pad_num(self, num_label):
        return torch.cat((num_label, torch.tensor(-1).repeat(K-num_label.shape[0])), dim=0)

    def __get_true_label(self, name, e, i):
        op, num1, num2 = e
        op_label = torch.zeros(len(op2id))
        op_label[op2id[op]] = 1
        num1_label = self.__get_num_label(name, num1, i)
        num2_label = self.__get_num_label(name, num2, i)
        return op_label, num1_label, num2_label

    def __get_label(self, data, name):
        exp = self.__get_init_exp(data, name)
        idx = []
        true_op_label = None
        true_num1_label = []
        true_num2_label = []
        for i,e in enumerate(exp):
            get_true = lambda x: self.__get_true_label(name, x, i)
            op_label, num1_label, num2_label = [torch.stack(item) for item in list(zip(*map(get_true, e)))]
            true_op_label = self.__expand_tensor(true_op_label, self.__pad_op(op_label))
            true_num1_label.append(self.__pad_num(num1_label))
            true_num2_label.append(self.__pad_num(num2_label))
        return true_op_label, torch.stack(true_num1_label), torch.stack(true_num2_label)

In [5]:
class MathQA(Dataset):
    def __init__(self, config, name):
        self.input_ids = config.tokenized[name]['input_ids']
        self.attention_mask = config.tokenized[name]['attention_mask']
        self.const_embed = config.const[name]['embed']
        self.embed_idx = torch.tensor(config.const[name]['embed_idx']).int()
        self.idx = config.combined[name]['idx'].int()
        self.const_idx = config.const[name]['idx'].int()
        self.true_labels = config.labels[name]
        self.literals = config.combined[name]['literals']
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, i):
        if i.step: # using step is not defined
            raise NotImplementedError
            
        input_ids = self.input_ids[i]
        attention_mask = self.attention_mask[i]
        true_op = self.true_labels[0][i]
        true_num1 = self.true_labels[1][i]
        true_num2 = self.true_labels[2][i]
        
        if isinstance(i, slice):
            const_embed = self.const_embed[self.embed_idx[(self.const_idx>=i.start)&(self.const_idx<i.stop)]]
            idx = self.idx[(self.idx>=i.start)&(self.idx<i.stop)]
            literals = self.literals[(self.idx>=i.start)&(self.idx<i.stop)]
        else:
            const_embed = self.const_embed[self.embed_idx[self.const_idx==i]]
            idx = self.idx[self.idx==i]
            literals = self.literals[self.idx==i]
        return input_ids, attention_mask, const_embed, idx, (true_op, true_num1, true_num2), literals
    
class DataLoader:
    def __init__(self, dataset, batch_size, shuffle):
        self.curr = 0
        self.dataset = dataset
        self.batch_size = batch_size
        self.slices = [slice(x,x+batch_size) for x in range(0,len(dataset),batch_size)]
        self.num_batches = len(self.slices)
        
        if shuffle:
            np.random.shuffle(self.slices)
            
    def __len__(self):
        return self.num_batches
        
    def __iter__(self):
        return self
    
    def __next__(self):        
        if self.curr < self.num_batches:
            input_ids, attention_mask, const_embed, idx, true_labels, literals = self.dataset[self.slices[self.curr]]
            self.curr += 1
            return input_ids, attention_mask, const_embed, idx%self.batch_size, true_labels, literals
        raise StopIteration

In [6]:
class Encoder(torch.nn.Module):
    
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.device = config.device
        self.mask_id = config.tokenizer.encode(NUM_MASK, add_special_tokens=False)[0]
        self.encoder = config.encoder.to(config.device)
        
    def forward(self, input_ids, attention_mask, const_embed):
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        const_embed = const_embed.to(self.device)
        
        # Getting the hidden output
        x = self.encoder(input_ids, attention_mask).last_hidden_state
        
        # Combining constants and numbers
        embed = torch.cat((x[input_ids==self.mask_id],const_embed), dim=0)
        
        return x, embed

In [7]:
class Decoder(torch.nn.Module):
    
    def __init__(self, config): 
        super(Decoder, self).__init__()
        self.device = config.device
        self.embedding_size = config.embedding_size
        self.num_tokens = config.num_tokens
        self.K = config.K
        self.op_embed = config.op
        
        # standard transformer decoder
        transformer_decoder_layer = torch.nn.TransformerDecoderLayer(self.embedding_size, nhead=config.nhead, batch_first=True)
        self.transformer_decoder = torch.nn.TransformerDecoder(transformer_decoder_layer, num_layers=config.nlayer)
        
        # softmax for the operations
        self.op_softmax = torch.nn.Softmax(dim=2)
        
        # softmax for the numbers
        self.num_softmax = torch.nn.Softmax(dim=0)

        # decreasing dimensionality to match embedding size: <op, num, num, num*num> = 3072 -> 768
#         self.exp_encoder = torch.nn.Sequential(
#             torch.nn.Linear(self.embedding_size*4, self.embedding_size*2),
#             torch.nn.SiLU(),
#             torch.nn.Linear(self.embedding_size*2, self.embedding_size*2),
#             torch.nn.SiLU(),
#             torch.nn.Linear(self.embedding_size*2, self.embedding_size),
#         )
        
        # converting back into new <op, num, num> for loss calculation
        self.op_decoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size, int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), self.embedding_size),
        )
        self.num1_decoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size, int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), self.embedding_size),
        )
        self.num2_decoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size, int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), int(self.embedding_size*1.5)),
            torch.nn.SiLU(),
            torch.nn.Linear(int(self.embedding_size*1.5), self.embedding_size),
        )
        
    def __apply_to_nums(self, f, num, idx, batch_size):
        new_nums = torch.zeros(batch_size, self.K, num.shape[0]).to(self.device)
        for x in range(batch_size):
            new_nums[x,:,idx==x] = f(num[idx==x]).T
        return new_nums  
    
    # query: [batch_size, K, 768]
    # prob_embed: [batch_size, num_tokens, 768]
    # num_embed: [num_nums, 768]
    # op_embed: [num_ops, 768]
    def forward(self, query, prob_embed, num_embed, idx): 
        query = query.to(self.device)
        prob_embed = prob_embed.to(self.device)
        num_embed = num_embed.to(self.device)
        op_embed = self.op_embed.to(self.device)
        idx = idx.to(self.device)
        batch_size = query.shape[0]
        num_ops = op_embed.shape[0]
        num_nums = num_embed.shape[0]

        # ----------------------------
        # Step 1 - transformer decoder
        # ----------------------------
        query = self.transformer_decoder(query, prob_embed) # [batch_size, K, 768] -> [batch_size, K, 768]
        
        # -------------------------------------------------------------------------------
        # Step 2 - decoding the output into three embeddings of size 768 (op, num1, num2)
        # -------------------------------------------------------------------------------
        op = self.op_decoder(query)
        num1 = self.num1_decoder(query)
        num2 = self.num2_decoder(query)
        
        # ----------------------------------------------------
        # Step 3 - creating embedding for the found expression  
        # ----------------------------------------------------
        #query = self.exp_encoder(torch.cat((op,num1,num2,num1*num2), dim=2)) # [batch_size, K, 3072] -> [batch_size, K, 768]
        
        # ------------------------------------------------------------------------------
        # Step 4 - Taking a dot product between the predicted and stored true embeddings
        # ------------------------------------------------------------------------------
        # making sure params have correct dimension
        num_embed = num_embed[:,None,:].expand(-1,K,-1) # [num_nums, 768] -> [num_nums, K, 768]
        op_embed = op_embed[None,None,:,:].repeat(batch_size,self.K,1,1) # [num_ops, 768] -> [batch_size, K, num_ops, 768]

        # dot product - calculating the similarity between each op/num prediction and stored embeddings
        num1 = (num1[idx]*num_embed).sum(dim=2) # [number_of_nums, K]
        num2 = (num2[idx]*num_embed).sum(dim=2) # [number_of_nums, K]        
        op = op[:,:,None,:].expand(-1,-1,num_ops,-1) # [batch_size, K, 768] -> # [batch_size, K, num_ops, 768]
        op = (op*op_embed).sum(dim=3) # [batch_size, K, num_ops, 768] -> [batch_size, K, num_ops]
        
        # -----------------------------------------------------------------------------------------
        # Step 5 - Finding the softmax for the similarity between the predicted and true embeddings
        # -----------------------------------------------------------------------------------------
        op = self.op_softmax(op) # [batch_size, K, num_ops] (ie op[1,2] would be the operator prediction probabilities for problem2, query3)
        num1 = self.__apply_to_nums(self.num_softmax, num1, idx, batch_size) # [batch_size,K,num_nums]
        num2 = self.__apply_to_nums(self.num_softmax, num2, idx, batch_size) # [batch_size,K,num_nums]
        
        return query, op, num1, num2

In [8]:
class Loss:
    def __init__(self, config):
        self.K = config.K
        self.device = config.device
        
    # --------------------------------
    # loss/accuracy calculation
    # --------------------------------
    def __lmatch(self, true_op, true_num1, true_num2, op, num1, num2):
        true_op = torch.clone(true_op)
        true_op[:,:,op2id['None']] = 0 # setting all none operators to zero so ignored during calculation

        # rows: true exp, cols: pred exp (ie 0,1 would be the cost for true exp 0 with pred exp 1)
        # [8,K*K]
        op_mat = (true_op.repeat_interleave(K,dim=1)*op.repeat(1,K,1)).sum(dim=-1)
        num1_mat = (true_num1.repeat_interleave(K,dim=1)*num1.repeat(1,K,1)).sum(dim=-1)
        num2_mat = (true_num2.repeat_interleave(K,dim=1)*num2.repeat(1,K,1)).sum(dim=-1)

        return -(op_mat+num1_mat+num2_mat).reshape(op_mat.shape[0],K,K)
    
    # consider trying to write this in pytorch if have time
    def __hungarian_algorithm(self, true_op, true_num1, true_num2, op, num1, num2, batch_size):
        m = self.__lmatch(true_op, true_num1, true_num2, op, num1, num2)
        permutations = torch.stack(tuple([torch.tensor(linear_sum_assignment(m.detach().cpu().numpy()[x])[1]) for x in range(batch_size)]))
        batch_idx = torch.arange(batch_size).unsqueeze(-1)
        return op[batch_idx, permutations], num1[batch_idx, permutations], num2[batch_idx, permutations]
    
    # expects the optimal permutation from the hungarian algorithm
    def __final_loss(self, true_op, true_num1, true_num2, op, num1, num2):        
        log_op = torch.log((true_op*op).sum(-1))
        
        log_num1 = ((true_num1*num1).sum(-1))
        log_num1[log_num1.nonzero(as_tuple=True)] = torch.log(log_num1[log_num1.nonzero(as_tuple=True)])
        
        log_num2 = ((true_num2*num2).sum(-1))
        log_num2[log_num2.nonzero(as_tuple=True)] = torch.log(log_num2[log_num2.nonzero(as_tuple=True)])
 
        return (-log_op-log_num1-log_num2).sum(-1).mean()

    # expects the optimal permutation from the hungarian algorithm
    # gets the number of correct examples for a batch
    def __correct(self, true_op, true_num1, true_num2, op, num1, num2):
        invalid = true_op.argmax(-1)==op2id['None']
        op_correct = true_op.argmax(-1)==op.argmax(-1)
        num1_correct = num1.argmax(-1)==true_num1.argmax(-1)
        num2_correct = num2.argmax(-1)==true_num2.argmax(-1)
        num1_correct[invalid]=True
        num2_correct[invalid]=True
        
        correct = ((op_correct&num1_correct&num2_correct).sum(1)==self.K).sum()
        num_ops = true_op.shape[0]*true_op.shape[1]
        op_correct = op_correct.sum()
        num_nums = len(num1_correct[~invalid])
        num1_correct = num1_correct[~invalid].sum()
        num2_correct = num2_correct[~invalid].sum()
        return correct, op_correct, num_ops, num1_correct, num2_correct, num_nums
    
    def prepare_labels(self, idx, true_labels, shape):
        true_op, true_num1, true_num2 = true_labels
        batch_size = true_op.shape[0]

        temp = torch.zeros(shape)
        unmasked = true_num1!=-1
        true_idx = torch.cat([torch.where(idx==x)[0][true_num1[x][true_num1[x]!=-1]] for x in range(batch_size)])
        temp[unmasked,true_idx]=1
        true_num1 = temp

        temp = torch.zeros(shape)
        unmasked = true_num2!=-1
        true_idx = torch.cat([torch.where(idx==x)[0][true_num2[x][true_num2[x]!=-1]] for x in range(batch_size)])
        temp[unmasked,true_idx]=1
        true_num2 = temp

        # Randomizing equation order
        batch_idx = torch.arange(batch_size).unsqueeze(-1)
        perm = torch.rand(batch_size,self.K).argsort(dim=1)
        true_op = true_op[batch_idx,perm,:]
        true_num1 = true_num1[batch_idx,perm,:]
        true_num2 = true_num2[batch_idx,perm,:]

        return true_op.to(self.device), true_num1.to(self.device), true_num2.to(self.device)
    
    def calculate(self, idx, true_labels, op, num1, num2, batch_size):
        true_op, true_num1, true_num2 = self.prepare_labels(idx, true_labels, num1.shape)
        
        op, num1, num2 = self.__hungarian_algorithm(true_op, true_num1, true_num2, op, num1, num2, batch_size) # getting optimal permutation
        
        loss = self.__final_loss(true_op, true_num1, true_num2, op, num1, num2)
        correct, op_correct, num_ops, num1_correct, num2_correct, num_nums = self.__correct(true_op, true_num1, true_num2, op, num1, num2)
        
        return loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums

In [9]:
class FirstLayer(torch.nn.Module):
    def __init__(self, config):
        super(FirstLayer, self).__init__()
        self.device = config.device
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.loss = Loss(config)
        self.query = config.query
        
        self.opt = config.opt(self.parameters(), lr=config.lr)
    
    def forward(self, input_ids, attention_mask, const_embed, idx, true_labels):
        batch_size = input_ids.shape[0]
        
        # encoder
        prob_embed, num_embed = self.encoder(input_ids, attention_mask, const_embed)
                
        # decoder
        query = self.query[None,:,:].repeat(batch_size,1,1)
        query, op, num1, num2 = self.decoder(query, prob_embed, num_embed, idx)
        
        # loss
        loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums = self.loss.calculate(idx, true_labels, op, num1, num2, batch_size)
        
        return query, loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums, op, num1, num2

In [10]:
class FullModel(torch.nn.Module):
    def __init__(self, config):
        super(FullModel, self).__init__()
        self.device = config.device
        self.firstlayer = FirstLayer(config)
        self.grad_norm_clip = config.grad_norm_clip
        
    def __backpropagate(self, loss, layer):
        layer.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(layer.parameters(), self.grad_norm_clip)
        layer.opt.step()
        
    def train(self, dataloader, epoch, num_batches=None):
        total_loss = 0
        total_batches = num_batches if num_batches else len(dataloader)
        total_correct = 0
        total_examples = 0
        total_op_correct = 0
        total_num1_correct = 0
        total_num2_correct = 0
        total_nums = 0
        total_ops = 0
        with tqdm(dataloader, total=total_batches) as progress_bar:
            for item in progress_bar:
                progress_bar.set_description(f'Epoch {epoch}')
                input_ids, attention_mask, const_embed, idx, true_labels, _ = item
                
                # forward pass
                query, loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums,_,_,_ = self.forward(input_ids, attention_mask, const_embed, idx, true_labels)
                
                # backpropagating the loss
                self.__backpropagate(loss, self.firstlayer)
                
                total_loss += loss.detach().item()
                total_correct += correct.detach().item()
                total_examples += input_ids.shape[0]
                total_op_correct += op_correct.detach().item()
                total_num1_correct += num1_correct.detach().item()
                total_num2_correct += num2_correct.detach().item()
                total_ops += num_ops
                total_nums += num_nums
                
                progress_bar.set_postfix(loss=loss.detach().item())                
                progress_bar.update(1)
                
                if num_batches:
                    num_batches -= 1
                if num_batches == 0:
                    progress_bar.close()
                    break
        return total_loss/total_batches, total_correct/total_examples, total_op_correct/total_ops, total_num1_correct/total_nums, total_num2_correct/total_nums
    
    def val(self, dataloader, num_batches=None):
        total_loss = 0
        total_batches = num_batches if num_batches else len(dataloader)
        total_correct = 0
        total_examples = 0
        total_op_correct = 0
        total_num1_correct = 0
        total_num2_correct = 0
        total_nums = 0
        total_ops = 0
        with torch.no_grad():
            with tqdm(dataloader, total=total_batches) as progress_bar:
                for item in progress_bar:
                    progress_bar.set_description(f'Validation')
                    input_ids, attention_mask, const_embed, idx, true_labels, _ = item

                    # forward pass
                    query, loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums,_,_,_ = self.forward(input_ids, attention_mask, const_embed, idx, true_labels)

                    total_loss += loss.detach().item()
                    total_correct += correct.detach().item()
                    total_examples += input_ids.shape[0]
                    total_op_correct += op_correct.detach().item()
                    total_num1_correct += num1_correct.detach().item()
                    total_num2_correct += num2_correct.detach().item()
                    total_ops += num_ops
                    total_nums += num_nums
           
                    progress_bar.update(1)

                    if num_batches:
                        num_batches -= 1
                    if num_batches == 0:
                        progress_bar.close()
                        break
        return total_loss/total_batches, total_correct/total_examples, total_op_correct/total_ops, total_num1_correct/total_nums, total_num2_correct/total_nums
            
    def fit(self, epochs, num_batches=None):
        for epoch in range(epochs):
            best_loss = math.inf
            
            train_dataloader = DataLoader(MathQA(config, 'train'), batch_size=config.batch_size, shuffle=True)
            val_dataloader = DataLoader(MathQA(config, 'validation'), batch_size=config.batch_size, shuffle=True)
            train_loss, train_acc, train_op_acc, train_num1_acc, train_num2_acc = self.train(train_dataloader, epoch+1, num_batches)
            val_loss, val_acc, val_op_acc, val_num1_acc, val_num2_acc = self.val(val_dataloader, num_batches)
            
            df = pd.DataFrame({'Train':[train_loss, train_acc, train_op_acc, train_num1_acc, train_num2_acc], 
                               'Validation':[val_loss, val_acc, val_op_acc, val_num1_acc, val_num2_acc]},
                       index = ['Loss','Accuracy','Op Accuracy','Num1 Accuracy','Num2 Accuracy'])
            
            display(df)
            
            if val_loss < best_loss:
                best_loss = val_loss
                model_path = 'D:\jupyter_notebooks\Mathsage\Connor\models\mathsage_step1'
                torch.save(self.state_dict(), model_path)
            
    def forward(self, input_ids, attention_mask, const_embed, idx, true_labels):
        query, loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums, op, num1, num2 = self.firstlayer(input_ids, attention_mask, const_embed, idx, true_labels)
        return query, loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums, op, num1, num2

In [11]:
class Config:
    
    def __init__(self, data, reload=False):
        self.device = torch.device('cuda:0')
        self.model_name = r'D:\jupyter_notebooks\Mathsage\Connor\models\distilroberta-base-encoder-mathqa'
        self.tokenizer_name = 'distilroberta-base'
        self.seed = 3
        self.batch_size = 8
        
        # model params
        self.K = 6
        self.max_layers = 8
        self.nhead = 6
        self.nlayer = 6
        self.embedding_size = 768
        self.num_tokens = 392
        
        # training params
        self.lr = 2e-5
        self.opt = torch.optim.AdamW
        self.grad_norm_clip = 1.0
        
        # Loading preprocessed data
        if reload:
            p = Preprocess(data, self.K, self.embedding_size, self.num_tokens, self.model_name, self.tokenizer_name)
            util.save_obj(f'{OBJ_DIR}preprocess.pickle', p)
        else:
            p = util.load_obj(f'{OBJ_DIR}preprocess.pickle')
                
        self.num = p.num
        self.const = p.const
        self.combined = p.combined
        self.query = p.query
        self.op = p.op
        self.text = p.text
        self.labels = p.labels
        self.encoder = p.encoder
        self.tokenizer = p.tokenizer
        self.tokenized = p.tokenized

In [12]:
# try:
#     data = util.load_data()
#     config = Config(data, reload=False)
#     util.set_seed(config.seed)
#     model = FullModel(config)
#     model.to(config.device)
#     model.fit(epochs = 5)
# finally:
#     #del model
#     torch.cuda.empty_cache()

In [13]:
try:
    data = util.load_data()
    config = Config(data, reload=False)
    util.set_seed(config.seed)
    model = FullModel(config)
    model.to(config.device)
    model.load_state_dict(torch.load('D:\jupyter_notebooks\Mathsage\Connor\models\mathsage_step1'))
    train_dataloader = DataLoader(MathQA(config, 'train'), batch_size=config.batch_size, shuffle=True)
    val_dataloader = DataLoader(MathQA(config, 'validation'), batch_size=config.batch_size, shuffle=True)
    test_dataloader = DataLoader(MathQA(config, 'test'), batch_size=config.batch_size, shuffle=True)
    train_loss, train_acc, train_op_acc, train_num1_acc, train_num2_acc = model.val(train_dataloader)
    val_loss, val_acc, val_op_acc, val_num1_acc, val_num2_acc = model.val(val_dataloader)
    test_loss, test_acc, test_op_acc, test_num1_acc, test_num2_acc = model.val(test_dataloader)
    df = pd.DataFrame({'Train':[train_loss, train_acc, train_op_acc, train_num1_acc, train_num2_acc], 
                        'Validation':[val_loss, val_acc, val_op_acc, val_num1_acc, val_num2_acc],
                        'Test':[test_loss, test_acc, test_op_acc, test_num1_acc, test_num2_acc]},
                       index = ['Loss','Accuracy','Op Accuracy','Num1 Accuracy','Num2 Accuracy'])  
    display(df)
finally:
    #del model
    torch.cuda.empty_cache()

  0%|          | 0/2277 [00:00<?, ?it/s]

  0%|          | 0/339 [00:00<?, ?it/s]

  0%|          | 0/225 [00:00<?, ?it/s]

Unnamed: 0,Train,Validation,Test
Loss,1.253186,4.139475,3.978741
Accuracy,0.782322,0.662731,0.660178
Op Accuracy,0.963995,0.923555,0.922506
Num1 Accuracy,0.959682,0.878148,0.887967
Num2 Accuracy,0.951088,0.876333,0.873444


In [15]:
val_dataloader = DataLoader(MathQA(config, 'validation'), batch_size=config.batch_size, shuffle=False)
input_ids, attention_mask, const_embed, idx, true_labels, literals = next(iter(val_dataloader))

In [75]:
def get_eq(op, num1, num2, literals):
    valid = op.argmax(-1)!=op2id['None']
    pred_op = op.argmax(-1)[valid]
    pred_num1 = num1.argmax(-1)[valid]
    pred_num2 = num2.argmax(-1)[valid]
    eq_idx = np.concatenate([[i]*x for i,x in enumerate(valid.sum(axis=-1))])
    eq = [f'{num1} {op} {num2}' for num1, op, num2 in zip(literals[pred_num1.cpu()], np.array(list(id2op.values()))[pred_op.cpu().numpy()], literals[pred_num2.cpu()])]
    return eq, eq_idx

def get_all_eq(name):
    dataloader = DataLoader(MathQA(config, name), batch_size=config.batch_size, shuffle=False)
    eq = []
    eq_idx = []
    with torch.no_grad():
        for i,batch in enumerate(dataloader):
            input_ids, attention_mask, const_embed, idx, true_labels, literals = batch
            query, loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums, op, num1, num2 = model(input_ids, attention_mask, const_embed, idx, true_labels)
            x1, x2 = get_eq(op, num1, num2, literals)
            eq.extend(x1)
            eq_idx.extend(list((x2+config.batch_size*i).astype(int)))
    return eq, eq_idx

subexpressions = {name:get_all_eq(name) for name in SET_NAMES}

In [78]:
util.save_obj(f'{OBJ_DIR}subexp.pickle',subexpressions)