# MathQA Decoder

TODO:
- Maybe increase problem embedding by a factor of 3?
- Change masked num embeddings function so it usese predicted constants from constants.pickle
- Consider doing the same for the operators
- Finish forward propogation (test a pass)
- Write custom loss function
- test 1 full pass of the model

#### Imports

In [1]:
from enum import Enum
import os
import anytree
from anytree import RenderTree
from anytree.importer import DictImporter
import pandas as pd
from itertools import permutations
import seaborn as sns
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer, AutoModelForMaskedLM, DataCollatorForLanguageModeling
from sklearn.metrics import f1_score, accuracy_score
from datasets import Dataset
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from sklearn.utils.class_weight import compute_class_weight
import re
import pickle
from copy import deepcopy

#### Constants

In [3]:
K = 6
MAX_LAYERS = 8
MAX_TOKENS = 392
EMBEDDING_SIZE = 768

DATA_PATH = './dataset/'
SET_NAMES = ['train', 'validation', 'test']
ENCODER_MODEL = 'distilroberta-base' # A more optimized version of roberta obtaining 95% of its performance
DEVICE = 'cuda:0'
NUM_MASK = '<num>'
WORKING_DIR = 'TEMP/'

OBJ_DIR = 'pickle/'


class Op(Enum):
    ADD = '+'
    SUB = '-'
    MULT = '*'
    DIV = '/'
    POW = '^'
    
class Const(Enum):
    CONST_NEG_1 = 'const_neg_1' # I added this
    CONST_0_25 = 'const_0_25'
    CONST_0_2778 = 'const_0_2778'
    CONST_0_33 = 'const_0_33'
    CONST_0_3937 = 'const_0_3937'
    CONST_1 = 'const_1'
    CONST_1_6 = 'const_1_6'
    CONST_2 = 'const_2'
    CONST_3 = 'const_3'
    CONST_PI = 'const_pi'
    CONST_3_6 = 'const_3_6'
    CONST_4 = 'const_4'
    CONST_5 = 'const_5'
    CONST_6 = 'const_6'
    CONST_10 = 'const_10'
    CONST_12 = 'const_12'
    CONST_26 = 'const_26'
    CONST_52 = 'const_52'
    CONST_60 = 'const_60'
    CONST_100 = 'const_100'
    CONST_180 = 'const_180'
    CONST_360 = 'const_360'
    CONST_1000 = 'const_1000'
    CONST_3600 = 'const_3600'

values = [-1, 0.25, 0.2778, 0.33, 0.3937, 1, 1.6, 2, 3, math.pi, 3.6, 4, 5, 6, 10, 12, 26, 52, 60, 100, 180, 360, 1000, 3600]
const2val = {k:v for k,v in zip(Const._value2member_map_.keys(), values)}    

op2id = {k:v for k,v in zip(Op._value2member_map_.keys(), range(len(Op._value2member_map_)))}
op2id['None'] = 5
const2id = {k:v for k,v in zip(Const._value2member_map_.keys(), range(len(Const._value2member_map_)))}

NameError: name 'Enum' is not defined

## Loading the data

In [3]:
data = {name:pd.read_csv(f'{DATA_PATH}{name}.csv') for name in SET_NAMES}

## Embedding Preparation

Converting number embeddings to masked tensors (to ensure they are all homogeneous) and adding in constants.

In [4]:
# def create_masked_embeddings(name):
#     num_embed = embeddings[name]['num']
#     const_embed = embeddings[name]['num']
#     mapping, nums = embeddings[name]['num_mapping']
#     const = [[num for num in eval(x) if num in const2val] for x in data[name]['nums']]
#     max_nums = (np.bincount(mapping)+np.array(list(map(len, const)))).max()

#     result = () 
#     for idx in range(len(const)):
#         # Getting number embeddings
#         nums = num_embed[mapping==idx]

#         # Adding constant embeddings
#         c = tuple([const_embed[const2id[x]] for x in const[idx]])
#         if len(c) > 0:
#             c = torch.stack(c)
#             nums = torch.cat((nums, c), dim=0)
        
#           # non masked code
# #         if result is None:
# #             result = nums
# #         else:
# #             result = torch.cat((result, nums), dim=0)
        
#         # Padding and creating mask
#         dim1, dim2 = nums.shape
#         mask = torch.full((dim1,dim2), True)
#         nums = F.pad(nums, (0,0,0,max_nums-dim1), 'constant', 0)
#         mask = F.pad(mask, (0,0,0,max_nums-dim1), 'constant', False)
        
#         # Creating masked tensor object
#         mt = torch.masked.masked_tensor(nums, mask)[None,:,:]
#         result = result + (mt,)
        
#     return torch.cat(result, dim=0)
# masked_num_embed = {name:create_masked_embeddings(name) for name in SET_NAMES}
# masked_num_embed['train'].shape

## Constructing the model

In [5]:
class MultiLayerDecoder(torch.nn.Module):
    
    def __init__(self, embedding_size, num_tokens, max_layers, K):
        super(MultiLayerDecoder, self).__init__()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.embedding_size = embedding_size
        self.num_tokens = num_tokens
        self.max_layers = max_layers
        self.K = K
        self.decoder_layer = MathQADecoder(embedding_size, num_tokens, K)
    
    # Repeatedly get new sets of equations until a maximum depth is reached or there are no more valid equations (all ops are None)
    def forward(self, x, embeddings):
        
        x = x.to(self.device)
        nums = embeddings['nums'].to(self.device)
        num_idx = embeddings['num_idx'].to(self.device)
        ops = embeddings['ops'].to(self.device)
        problems = embeddings['problem'].to(self.device)

        # while less than max layers and the input has more elements
        idx = 0
        temp = [] # REMOVE THIS
        while idx < self.max_layers and x.numel():
            prev_idx = num_idx
            
            # --------------------------------------------------
            # Step 1 - Encoding to embedding size if not already
            # --------------------------------------------------
            if x.shape[-1] == self.embedding_size*4:
                x = self.decoder_layer.exp_encoder(x) #3 [batch_size, K, 3072] -> [batch_size, K, 768] 

            # -----------------------
            # Step 2 - The main model
            # -----------------------
            prev_sizes = num_idx.bincount()
            x, op, num1, num2, nums, num_idx, problems = self.decoder_layer(x, nums, num_idx, ops, problems)
            new_sizes = num_idx.bincount()
            
            # ---------------------------------------------------------------------------------------------------------------------------------
            # Step 3 - Removing examples with no valid expressions (If the num nums for a problem is unchanged, it had no valid expressions)
            # ---------------------------------------------------------------------------------------------------------------------------------
            changed = prev_sizes!=new_sizes # [batch_size] (not_finished True, finished False)
            num_idx = num_idx[changed[num_idx]] # [num_nums_not_finished]
            nums = nums[num_idx]  # [num_nums_not_finished, 768] 
            problems = problems[changed] # [not_finished, num_tokens, 768]
            x = x[changed] # [not_finished, K, 768]

            # print(f'x: {x.shape}')
            # print(f'op: {op.shape}')
            # print(f'num1: {num1.shape}')
            # print(f'num2: {num2.shape}')
            # print(f'nums: {nums.shape}')
            # print(f'num_idx: {num_idx.shape}')
            # print(f'problems: {problems.shape}')
            # print()

            idx += 1
            temp.append((op,num1,num2,prev_idx))

        return temp

        #return op, num1, num2, prev_idx
        

class MathQADecoder(torch.nn.Module):
    
    def __init__(self, embedding_size, num_tokens, K): 
        super(MathQADecoder, self).__init__()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.embedding_size = embedding_size
        self.num_tokens = num_tokens
        self.K = K
        
        # decreasing dimensionality to match embedding size: <op, num, num, num*num> = 3072 -> 768
        self.exp_encoder = torch.nn.Sequential(
            torch.nn.Linear(embedding_size*4, embedding_size),
            torch.nn.ReLU(),
        )
        
        # converting back into new <op, num, num> for loss calculation
        self.exp_decoder = torch.nn.Sequential(
            torch.nn.Linear(embedding_size, embedding_size*3),
            torch.nn.ReLU(),
        )
        
        # Mixing in expression information to the problem encoding
        self.prob_encoder = torch.nn.Sequential(
            torch.nn.Linear(num_tokens+K, num_tokens),
            torch.nn.ReLU(),
        )
        
        # standard transformer decoder
        # we choose K heads for K generated expressions 
        # (with the hope that each head will get different information for each K expression)
        self.transformer_decoder = torch.nn.TransformerDecoderLayer(embedding_size, nhead=K, batch_first=True)
        
        # softmax for the operations
        self.op_softmax = torch.nn.Softmax(dim=2)
        
        # softmax for the numbers
        self.num_softmax = torch.nn.Softmax(dim=0)
        
    def __apply_to_nums(self, f, nums, num_idx):
        return torch.cat(tuple([f(nums[num_idx==x]) for x in range(nums.shape[0])]), dim=0)
      
    # x: [batch_size, K, 768]
    # nums: [num_nums, 768]
    # num_idx: [num_nums,]
    # ops: [num_ops, 768]
    # problems: [batch_size, num_tokens, 768]
    def forward(self, x, nums, num_idx, ops, problems):    
        batch_size = x.shape[0]
        num_ops = ops.shape[0]
        num_nums = nums.shape[0]

        # ----------------------------
        # Step 1 - transformer decoder
        # ----------------------------
        x = self.transformer_decoder(x, problems) # [batch_size, K, 768] -> [batch_size, K, 768] (problems is [batch_size, num_tokens, 768])

        # -------------------------------------------------------------------------------
        # Step 2 - decoding the output into three embeddings of size 768 (op, num1, num2)
        # -------------------------------------------------------------------------------
        x = self.exp_decoder(x) # [batch_size, K, 768] -> [batch_size, K, 2304]
        operation, x1, x2 = torch.split(x, self.embedding_size, dim=2) # [batch_size, K, 2304] -> [batch_size, K, 768] for each

        # -----------------------------------------------------------------------------------------------
        # Step 3 and 4 - Finding the softmax for the similarity between the predicted and true embeddings
        # -----------------------------------------------------------------------------------------------
        # making sure params have correct dimension
        nums_expanded = nums[:,None,:].expand(-1,self.K,-1) # [number_of_nums, 768] -> [number_of_nums, K, 768]
        ops = ops[None,None,:,:].repeat(batch_size,self.K,1,1) # [num_ops, 768] -> [batch_size, K, num_ops, 768]

        # dot product - calculating the similarity between each op/num prediction and its true embedding
        num1 = (x1[num_idx]*nums_expanded).sum(dim=2) # [number_of_nums, K]
        num2 = (x2[num_idx]*nums_expanded).sum(dim=2) # [number_of_nums, K]        
        op = operation[:,:,None,:].expand(-1,-1,num_ops,-1) # [batch_size, K, 768] -> # [batch_size, K, num_ops, 768]
        op = (op*ops).sum(dim=3) # [batch_size, K, num_ops, 768] -> [batch_size, K, num_ops]

        # softmax
        op = self.op_softmax(op) # [batch_size, K, num_ops] (ie op[1,2] would be the operator prediction probabilities for problem2, query3)
        num1 = self.__apply_to_nums(self.num_softmax, num1, num_idx) # [num_nums, K]
        num2 = self.__apply_to_nums(self.num_softmax, num2, num_idx) # [num_nums, K]

        # ----------------------------------------------------
        # Step 5 - creating embedding for the found expression  
        # ----------------------------------------------------
        x = self.exp_encoder(torch.cat((operation,x1,x2,x1*x2), dim=2)) # [batch_size, K, 3072] -> [batch_size, K, 768]  

        # -------------------------------------------------------------------------------------------
        # Step 6 - finding valid, adding found expression to problem embeddings and number embeddings
        # -------------------------------------------------------------------------------------------
        # Getting predicted ops
        mask = torch.argmax(op, dim=2)!=op2id['None'] # [batch_size, K]

        # Getting valid embeddings along with a problem index
        valid = x[mask] # [num_valid, 768]
        problem_idx = torch.arange(mask.shape[0]).to(self.device).repeat_interleave(num_ops)[mask.flatten()] # [num_valid, 768]

        # Appending to num_embeddings
        nums = torch.cat((nums, valid), dim=0) # [num_nums+num_valid, 768]
        num_idx = torch.cat((num_idx, problem_idx), dim=0) # [num_nums+num_valid,]

        # Updating problem embeddings
        temp = torch.clone(x)
        temp[~mask] = 0 # [batch_size, K, 768]
        problems = torch.cat((problems,temp),dim=1) # [batch_size, num_tokens+K, 768]
        problems = self.prob_encoder(problems.permute(0,2,1)) # [batch_size, 768, num_tokens]
        problems = problems.permute(0,2,1) # [batch_size, num_tokens, 768]

        # -----------------------
        # Returning final results
        # -----------------------
        return x, op, num1, num2, nums, num_idx, problems
    
    """  
        # Needed for loss calculation
        # pretty sure this might be wrong now
    
        # getting correct op/nums idx
        op = torch.argmax(op, dim=2) # [batch_size, K, num_ops] -> [batch_size, K]
        num1 = torch.argmax(op, dim=2) # [batch_size, K, num_nums] -> [batch_size, K]
        num2 = torch.argmax(op, dim=2) # [batch_size, K, num_nums] -> [batch_size, K]
        
        # getting correct op/nums embeddings
        temp = torch.arange(temp.size(0)).repeat_interleave(K).reshape(batch_size, K) # [batch_size, K]
        op = ops[temp, op] # [batch_size, K] -> [batch_size, K, 768]
        num1 = nums[temp, op] # [batch_size, K] -> [batch_size, K, 768]
        num2 = nums[temp, op] # [batch_size, K] -> [batch_size, K, 768]
    
        return x
    """

In [6]:
# x: [batch_size, K, 768]
# nums: [num_nums, 768]
# num_idx: [num_nums,]
# ops: [num_ops, 768]
# problems: [batch_size, num_tokens, 768]
try:
    with open(f'{OBJ_DIR}embeddings/train/batch0.pickle', 'rb') as f:
        embeddings = pickle.load(f)
    x = torch.rand(8,6,768)
    model = MultiLayerDecoder(768, 392, 8, 6)
    model.to('cuda:0')
    result = model(x, embeddings)
finally:
    del x
    #del embeddings
    del model
    torch.cuda.empty_cache()

In [16]:
op,num1,num2,num_idx=result[0]
op.argmax(dim=2)
torch.cat([num1[num_idx==x].argmax(dim=-1) for x in range(8)])

tensor([1, 3, 4, 1, 4, 0, 3, 1, 1, 2, 4, 5, 3, 5, 5, 4, 0, 1, 5, 4, 4, 1, 3, 1,
        2], device='cuda:0')

In [8]:
embeddings['nums'][embeddings['num_idx']==0]

tensor([[ 0.3330,  1.9565, -1.2056,  ..., -0.8565,  0.8773,  1.1536],
        [ 0.6438,  1.8634, -0.7155,  ...,  0.6324,  1.4042,  1.3915],
        [ 0.2300,  1.3455, -0.7449,  ...,  2.3113,  1.2338,  2.6616]])

In [9]:
with open(f'{OBJ_DIR}embeddings/train/batch0.pickle', 'rb') as f:
    embeddings = pickle.load(f)

In [10]:
embeddings['mask'].shape

torch.Size([8, 392])

In [11]:
op.argmax(dim=-1)

tensor([[4, 0, 3, 3, 4, 4],
        [4, 3, 3, 3, 3, 0],
        [0, 4, 0, 4, 4, 4],
        [3, 0, 4, 3, 4, 4],
        [0, 4, 3, 3, 3, 4],
        [3, 4, 0, 3, 0, 0],
        [3, 3, 0, 4, 0, 4],
        [3, 4, 3, 0, 4, 3]], device='cuda:0')

In [12]:
op_ids = np.zeros((MAX_LAYERS, len(data['train']),K), dtype=int)
i = 0
for j, step in enumerate(data['train']['incremental'][i].split(' ; ')):
    for k, x in enumerate(eval(step)):
        if x is not None:
            x1, op, x2 = x.split()
            op_ids[i,j,k] = op2id[op]
        else:
            op_ids[i,j,k] = op2id['None']
op_ids

array([[[2, 2, 5, 5, 5, 5],
        [3, 5, 5, 5, 5, 5],
        [2, 5, 5, 5, 5, 5],
        ...,
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        ...,
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        ...,
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]],

       ...,

       [[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        ...,
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        ...,
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0

In [13]:
count = 0
for x in data['train']['incremental'].dropna():
    temp = len(x.split(' ; '))
    if temp > 8:
        count += 1
count

0