# MathQA Encoder

#### Imports

In [2]:
from enum import Enum
import os
import anytree
import pandas as pd
from itertools import permutations
import seaborn as sns
import math
import numpy as np
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer, AutoModelForMaskedLM, DataCollatorForLanguageModeling
from sklearn.metrics import f1_score, accuracy_score
from datasets import Dataset
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from sklearn.utils.class_weight import compute_class_weight
import re
import pickle

#### Constants

In [3]:
K=6
DATA_PATH = './dataset/'
SET_NAMES = ['train', 'validation', 'test']
ENCODER_MODEL = 'distilroberta-base' # A more optimized version of roberta obtaining 95% of its performance
MAX_TOKENS = 392
DEVICE = 'cuda:0'
NUM_MASK = '<num>'
WORKING_DIR = 'TEMP/'
FINAL_DIR = 'pickle/'
MODEL_DIR = 'models/'
BATCH_SIZE = 8

class Op(Enum):
    ADD = '+'
    SUB = '-'
    MULT = '*'
    DIV = '/'
    POW = '^'
    
class Const(Enum):
    CONST_NEG_1 = 'const_neg_1' # I added this
    CONST_0_25 = 'const_0_25'
    CONST_0_2778 = 'const_0_2778'
    CONST_0_33 = 'const_0_33'
    CONST_0_3937 = 'const_0_3937'
    CONST_1 = 'const_1'
    CONST_1_6 = 'const_1_6'
    CONST_2 = 'const_2'
    CONST_3 = 'const_3'
    CONST_PI = 'const_pi'
    CONST_3_6 = 'const_3_6'
    CONST_4 = 'const_4'
    CONST_5 = 'const_5'
    CONST_6 = 'const_6'
    CONST_10 = 'const_10'
    CONST_12 = 'const_12'
    CONST_26 = 'const_26'
    CONST_52 = 'const_52'
    CONST_60 = 'const_60'
    CONST_100 = 'const_100'
    CONST_180 = 'const_180'
    CONST_360 = 'const_360'
    CONST_1000 = 'const_1000'
    CONST_3600 = 'const_3600'

values = [-1, 0.25, 0.2778, 0.33, 0.3937, 1, 1.6, 2, 3, math.pi, 3.6, 4, 5, 6, 10, 12, 26, 52, 60, 100, 180, 360, 1000, 3600]
const2val = {k:v for k,v in zip(Const._value2member_map_.keys(), values)}    

op2id = {k:v for k,v in zip(Op._value2member_map_.keys(), range(len(Op._value2member_map_)))}
const2id = {k:v for k,v in zip(Const._value2member_map_.keys(), range(len(Const._value2member_map_)))}

## Loading the data

Reading csv into a dictionary of dataframes

In [4]:
data = {name:pd.read_csv(f'{DATA_PATH}{name}.csv') for name in SET_NAMES}

Converts operations for each problem into a multi label onehot encoded setup

In [5]:
def onehot_ops(data):
    labels = []
    for op_set in data.ops:
        op_set = eval(op_set)
        idx = [op2id[op] for op in op_set]
        onehot = np.zeros(len(op2id))
        onehot[idx] = 1
        labels.append(onehot)
    return np.array(labels)
        
#onehot_ops(data['train'])

Sort nums for each each problem in increasing order

In [6]:
def max_num(nums):
    get_float = lambda x: float(const2val[x]) if x in const2val else float(x)
    return max(map(get_float, nums))

def remove_const(data):
    nums = []
    for num_list in data.nums:
        nums.append(set([float(x) for x in eval(num_list) if x not in const2val]))
    return nums

# Gets the numbers listed in a problem
# Once found, numbers are masked using a number mask
def get_nums_from_problem(data, convert_to_float=False):
    nums = []
    problems = []
    for problem in data.problem:
        num = re.compile('([+-]?((\d+(\.\d*)?)|(\.\d+)))')
        big = re.compile(r'(-?\d{1,3}(,\d{3})+(\.\d*)?)')
        
        big_results = re.finditer(big, problem)
        problem = re.sub(big, NUM_MASK, problem)        
        num_results = re.finditer(num, problem)
        problem = re.sub(num, NUM_MASK, problem)
        
        # Getting the combined numbers in order of occurence
        combined = [x for x in num_results]
        combined.extend([x for x in big_results])
        combined = sorted(combined, key=lambda x: x.start(0))
        
        if convert_to_float:
            combined = [float(x.group(0).replace(',','')) for x in combined]
        else:
            combined = [x.group(0) for x in combined]
        
        nums.append(combined)
        problems.append(problem)
    return nums, problems

def sort_nums(data):
    nums_sorted = []
    nums_no_const_sorted = []
    for nums in data.nums_no_const:
        nums_no_const_sorted.append(sorted(list(eval(nums)), key=lambda x: float(x)))
    for nums in data.nums:
        num_list = list(eval(nums))
        maximum = max_num(num_list)
        get_float = lambda x: float(const2val[x])+maximum if x in const2val else float(x)
        nums_sorted.append(sorted(num_list, key=get_float))
    return nums_sorted, nums_no_const_sorted

#sort_nums(data['train'])

Here I do some testing to see if the numbers from the equation can be found in the problem description using simple regexes. This actually works extremely well, having no examples where the expected numbers is not a subset of the obtained numbers. This does not include constants. Constants are values which should not occur in the problem description (like pi or the 2 in r^2 for example)

In [6]:
expected = remove_const(data['train'])
obtained,_ = get_nums_from_problem(data['train'], convert_to_float=True)
obtained = [set(x) for x in obtained]

idx = 0
for x, y in zip(expected, obtained):
    if not (x <= y):
        print('------------------')
        print(data['train']['problem'][idx])
        print(f'Expected: {x}')
        print(f'Obtained: {y}')
        print('------------------')
    idx += 1

In [7]:
data['train']['category'].value_counts()

category
general        7187
physics        4885
gain           3520
geometry       1410
other          1069
probability     144
Name: count, dtype: int64

## Encoder

In this step, we use Roberta to get contextualized embeddings for each math problem

First, the problem texts must be tokenized into input ids. A number mask token is used for each number in the problem, as they should not affect the problem itself.

In [8]:
model_path = f'{MODEL_DIR}{ENCODER_MODEL}-encoder-mathqa'
#model = AutoModelForMaskedLM.from_pretrained(ENCODER_MODEL) # Used for fine tuning only
model = AutoModel.from_pretrained(model_path, output_hidden_states=True) # Fine tuned model used for getting the contextualized embeddings

Some weights of RobertaModel were not initialized from the model checkpoint at models/distilroberta-base-encoder-mathqa and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
tokenizer = AutoTokenizer.from_pretrained(ENCODER_MODEL)

# Adding a new token to the model, for masking out numbers.
tokenizer.add_special_tokens({'additional_special_tokens':[NUM_MASK]})
model.resize_token_embeddings(len(tokenizer))

def tokenize_data(data):
    tokenization = lambda x: tokenizer(x, padding='max_length', max_length=MAX_TOKENS, truncation=True)
    _,problem = get_nums_from_problem(data)
    
    tokenized = list(map(tokenization, problem))
    input_ids = torch.stack([torch.tensor(x['input_ids']) for x in tokenized])
    attention_mask = torch.stack([torch.tensor(x['attention_mask']) for x in tokenized])
    
    return {'input_ids':input_ids.long(), 'attention_mask':attention_mask.int()}

tokenized = {name:tokenize_data(data[name]) for name in SET_NAMES}

You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embeding dimension will be 50266. This might induce some performance reduction as *Tensor Cores* will not be available. For more details  about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc


In [10]:
print(f"Number of problems that exceed {MAX_TOKENS} tokens: {np.sum(np.array((tokenized['train']['input_ids'][:,-1]!=1)))}") # 1 is the padding token

Number of problems that exceed 392 tokens: 0


Next, the encoder model is finetuned on MathQA, using masked language modeling, similar to how bert does its trainined. This allows the model to create better contextualized representations for each math problem. Hyperparameters courtesy of https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/language_modeling.ipynb#scrollTo=QRTpmyCc3l_T

In [11]:
# args = TrainingArguments(
#     f'{WORKING_DIR}{model_path}',
#     evaluation_strategy='epoch',
#     learning_rate=2e-5,
#     weight_decay=0.01,
#     per_device_train_batch_size = 8,
#     per_device_eval_batch_size = 8,
# )

# train = Dataset.from_dict(tokenized['train'])
# val = Dataset.from_dict(tokenized['validation'])
# train.set_format('torch')
# val.set_format('torch')

# trainer = Trainer(
#     model=model,
#     args=args,
#     train_dataset=train,
#     eval_dataset=val,
#     data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) # using the masked probability from BERT
# )

In [12]:
#trainer.train()

In [13]:
#trainer.save_model(model_path)

This function gets the number index for each masked number token in the tokenized problems

In [14]:
def get_masked_idx(tokenized):
    mask_id = tokenizer.encode(NUM_MASK, add_special_tokens=False)[0]
    ids = tokenized['input_ids']
    
    return [np.where(id==mask_id)[0] for id in ids]

masked_idx = {name:get_masked_idx(tokenized[name]) for name in data.keys()}

This function gets the problem indices that each constant is used in

In [15]:
def get_const_problems(data):
    const2idx = {const:[] for const in Const._value2member_map_.keys()}
    for idx, num_list in enumerate(data.nums):
        for x in eval(num_list):
            if x in const2val:
                const2idx[x].append(idx)
    return {k:np.array(v) for k,v in const2idx.items()}

const2idx = {name:get_const_problems(data[name]) for name in SET_NAMES}

Here the contextualized embeddings are obtained using the fientuned roberta model for the problem, problem numbers, and constants. The contextualized embeddings are just the sum of the last four hidden layers outputted from bert.

In [16]:
# Batches a non homogeneous array given a number of splits
def non_homogeneous_split(arr, num_per_batch):
    return [arr[idx:idx+num_per_batch] for idx in range(0,len(arr),num_per_batch)]

# Batches the const2idx dictionary
def batch_const2idx(const2idx, name):
    batched_const2idx = [{const:[] for const in Const._value2member_map_.keys()} for x in range(num_splits)]
    split_size = math.ceil(len(data[name])/num_splits)
    for k,v in const2idx[name].items():
        for batch_num, batch_idx in zip(v//96, v%96):
            batched_const2idx[batch_num][k].append(batch_idx)
    return [{k:np.array(v) for k,v in x.items()} for x in batched_const2idx]

# Putting model on gpu
model.to(DEVICE)

def get_embeddings(name): 
    # batching ids and masks
    num_per_batch = BATCH_SIZE
    num_splits = math.ceil(len(tokenized[name]['input_ids'])/num_per_batch)
    batched_ids = torch.split(tokenized[name]['input_ids'], num_per_batch)
    batched_masks = torch.split(tokenized[name]['attention_mask'], num_per_batch)
    batched_idx = non_homogeneous_split(masked_idx[name], num_per_batch)
    batched_literals = non_homogeneous_split(get_nums_from_problem(data[name], convert_to_float=True)[0], num_per_batch)
    
    all_cls = None

    for batch_num in range(num_splits):   
        embeddings = {}
        
        # Getting first batch and putting on gpu
        ids = batched_ids[batch_num].to(DEVICE)
        mask = batched_masks[batch_num].to(DEVICE)
        idx = batched_idx[batch_num]
        literals = torch.tensor(np.concatenate(batched_literals[batch_num])).to(DEVICE)

        # Getting the raw hidden layer output
        with torch.no_grad():
            output = model(ids, mask)

        # [batch_size * tokens * 13 * 768]
        output = torch.stack(output[2], dim=0).permute(1,2,0,3)

        # Summing the last 4 hidden layers from roberta to be used as the contextualized embeddings
        #output = torch.sum(output[:,:,-4:,:], dim=2)
        output = output[:,:,-1,:]
        
        # Getting the num embeddings at the index of each masked number
        num_embeddings = ()
        num_idx = []
        for x in range(len(idx)):
            embed = output[x,idx[x]].to('cpu')
            num_embeddings += (embed,)
            num_idx.extend([x]*embed.shape[0])
        num_embeddings = torch.cat(num_embeddings, dim = 0)
        
        # Adding to dictionary
        embeddings['problem'] = output
        embeddings['mask'] = mask
        embeddings['nums'] = num_embeddings
        embeddings['num_idx'] = torch.tensor(num_idx)
        embeddings['num_literals'] = literals
        
        # Getting the embedding of the cls token for sentence level representation (used for constant embedding later)
        problem_embeddings = output[:,0,:].to('cpu')
        if all_cls is None:
            all_cls = problem_embeddings
        else:
            all_cls = torch.cat((all_cls, problem_embeddings), dim=0)
            
        # Storing output to disk (too large to all fit in memory)
        if not os.path.exists(f'{FINAL_DIR}embeddings/{name}'):
            os.makedirs(f'{FINAL_DIR}embeddings/{name}')
        with open(f'{FINAL_DIR}embeddings/{name}/batch{batch_num}.pickle', 'wb') as f:
            pickle.dump(embeddings, f) 
        
        # Cleaning up for the next batch
        del ids
        del num_idx
        del mask
        del num_embeddings
        del idx
        del embeddings
        del output
        torch.cuda.empty_cache()
    return all_cls

In [17]:
cls = {name:get_embeddings(name) for name in SET_NAMES}

To get the constant embeddings, we take the average of all of the problem embeddings that the constant was used in. This should hopefully give the constants some more context during downstream training. The training data is only used for the constant embeddings, as you would not know what constants belong to the problem in the test/validation. For the constants, we use the predicted constants from the constant classifier file (even for training to maintain consistency). Also, adding randomly initialized operation embeddings

In [18]:
op_embeds = torch.nn.init.normal_(torch.empty((len(op2id)+1,768)), mean=0, std=1)
query = torch.nn.init.normal_(torch.empty((K,768)), mean=0, std=1)
const_embeds = torch.stack(tuple([torch.mean(cls['train'][const2idx['train'][k]], dim=0) for k in const2idx['train'].keys()]))
    
init = {}
init['op'] = op_embeds
init['query'] = query
init['const'] = const_embeds

with open(f'{FINAL_DIR}embeddings/init.pickle', 'wb') as f:
    pickle.dump(init, f) 

In [19]:
batch_size = BATCH_SIZE
with open(f'{FINAL_DIR}constants.pickle', 'rb') as f:
    const_pred = pickle.load(f)
id2const = {v:k for k, v in const2id.items()}

def get_real_const(name):
    arr = []
    for nums in data[name]['nums']:
        nums = eval(nums)
        arr.append(set([const2id[x] for x in nums if x in const2id]))
    return np.array(arr)

def get_const_embeddings(name):
    problem, const_idx = np.where(const_pred['pred'][name]==1)
    batch_num = 0
    real_const = get_real_const(name)
    for idx in range(0,len(data[name]),8):
        p_idx = problem[(problem>=idx)&(problem<idx+8)]     
        c_idx = const_idx[(problem>=idx)&(problem<idx+8)]
        
        with open(f'{FINAL_DIR}embeddings/{name}/batch{batch_num}.pickle', 'rb') as f:
             embeddings = pickle.load(f)
        
        num_literals = embeddings['num_literals'].to('cpu')
        num_idx = embeddings['num_idx'].to('cpu')
        num_embed = embeddings['nums'].to('cpu')
        literals = np.concatenate((num_literals, [id2const[int(x)] for x in c_idx]))
        idx = np.concatenate((num_idx, p_idx%8))
        embed = torch.cat((num_embed, const_embeds[c_idx]),dim=0)
        
        embeddings['num_literals'] = np.array(literals)
        embeddings['num_idx'] = torch.tensor(idx)
        embeddings['nums'] = embed
        
        with open(f'{FINAL_DIR}embeddings/{name}/batch{batch_num}.pickle', 'wb') as f:
            pickle.dump(embeddings, f) 
        
        batch_num += 1
        
{name:get_const_embeddings(name) for name in SET_NAMES}

{'train': None, 'validation': None, 'test': None}

In [21]:
batch_size = BATCH_SIZE
op2id['None'] = 5

def get_exp(name):
    reorder = lambda x: (x[1], x[0], x[2])
    convert_to_arr = lambda d: [reorder(tuple(x.split())) for idx, arr in enumerate(d.split(' ; ')) for x in eval(arr) if x is not None and idx == 0]
    return data[name]['incremental'].map(convert_to_arr)

def non_homogeneous_split(arr, num_per_batch):
     return [arr[idx:idx+num_per_batch] for idx in range(0,len(arr),num_per_batch)]

def get_true_num(num, i, embeddings):
    if num in const2val:
        return const_embeds[const2id[num]]
    else:
        num_embed = embeddings['nums']
        literals = embeddings['num_literals']
        idx = embeddings['num_idx']
        return num_embed[((idx==i)&(literals==str(float(num)))).bool()][0]
    
def get_true_label(num, i, embeddings):
    num_embed = embeddings['nums']
    literals = embeddings['num_literals']
    idx = embeddings['num_idx']
    label = torch.zeros(literals.shape)
    if num in const2val:
        location = np.where((idx==i)&(literals==num))[0]
        if len(location) > 0:
            location = location[0]
        else: # if not found, the constant was not predicted from the constant predictor and we just return an empty label
            return label
    else:
        location = np.where((idx==i)&(literals==str(float(num))))[0][0]
    label[location]=1
    return label

def get_true(e, i, embeddings):
    op, num1, num2 = e
    op_embed = op_embeds[op2id[op]]
    num1_embed = get_true_num(num1, i, embeddings)
    num2_embed = get_true_num(num2, i, embeddings)
    op_label = torch.zeros(len(op2id))
    op_label[op2id[op]] = 1
    num1_label = get_true_label(num1, i, embeddings)
    num2_label = get_true_label(num2, i, embeddings)
    return op_label, num1_label, num2_label, torch.cat((op_embed,num1_embed,num2_embed,num1_embed*num2_embed))

def expand_tensor(curr, new):
    if curr is None:
        curr = new
    else:
        curr =  torch.cat((curr,new), dim=0)
    return curr

def pad_op(op_label):
    temp = torch.zeros(op_label.shape[1])
    temp[op2id['None']] = 1
    return torch.cat((op_label, temp[None,:].repeat(K-op_label.shape[0],1)), dim=0)[None,:,:]

def pad_num(num_label):
    temp = torch.zeros(num_label.shape[1])
    return torch.cat((num_label, temp[None,:].repeat(K-num_label.shape[0],1)), dim=0)[None,:,:]

def get_true_batch(exp_batch, embeddings):
    idx = []
    true_embed = None
    true_op_label = None
    true_num1_label = None
    true_num2_label = None
    for i,e in enumerate(exp_batch):
        get_true_embed = lambda x: get_true(x, i, embeddings)
        op_label, num1_label, num2_label, embed = [torch.stack(item) for item in list(zip(*map(get_true_embed, e)))]
        true_embed = expand_tensor(true_embed, embed)
        
        true_op_label = expand_tensor(true_op_label, pad_op(op_label))
        
        true_num1_label = expand_tensor(true_num1_label, pad_num(num1_label))
        true_num2_label = expand_tensor(true_num2_label, pad_num(num2_label))
        
        idx.extend([i]*embed.shape[0])
    return torch.tensor(idx), true_op_label, true_num1_label, true_num2_label, true_embed

def create_masked_nums(embeddings, batch_size):
    num_nums, embedding_size = embeddings['nums'].shape
    idx = embeddings['num_idx']
    masked_nums = None
    for i in range(batch_size):
        #temp = torch.zeros((num_nums, embedding_size))
        temp = torch.full((num_nums, embedding_size), -math.inf)
        temp[idx==i] = embeddings['nums'][idx==i]
        masked_nums = expand_tensor(masked_nums, temp[None,:,:])
    return masked_nums

def get_all_true(name):
    exp = non_homogeneous_split(get_exp(name).tolist(), batch_size)
    directory = f'{FINAL_DIR}embeddings/{name}'
    files = os.listdir(directory)
    files.sort(key=lambda f: int(re.sub('\D', '', f)))
    for batch_num, f in enumerate(files):
        fname = os.path.join(directory, f)
        with open(fname, 'rb') as f:
            embeddings = pickle.load(f)
            
        idx, true_op_label, true_num1_label, true_num2_label, true_embed = get_true_batch(exp[batch_num], embeddings)
        batch_idx = torch.arange(true_op_label.shape[0]).unsqueeze(-1)
        perm = torch.stack(tuple([torch.randperm(K) for x in range(true_op_label.shape[0])])) # Ensure model does not just learn that all None operations are at the end
        embeddings['true_op'] = true_op_label[batch_idx,perm,:]
        embeddings['true_num1'] = true_num1_label[batch_idx,perm,:]
        embeddings['true_num2'] = true_num2_label[batch_idx,perm,:]
#         embeddings['true_op'] = true_op_label
#         embeddings['true_num1'] = true_num1_labeltrue_op = 
#         embeddings['true_num2'] = true_num2_label
        #embeddings['true_idx'] = idx
        #embeddings['true_embed'] = true_embed
        #embeddings['masked_nums'] = create_masked_nums(embeddings, batch_size)

        with open(fname, 'wb') as f:
            pickle.dump(embeddings, f)

{name:get_all_true(name) for name in SET_NAMES}

{'train': None, 'validation': None, 'test': None}

In [None]:
class MathQADecoder(torch.nn.Module):
    
    def __init__(self, config): 
        super(MathQADecoder, self).__init__()
        self.device = config.device
        self.embedding_size = config.embedding_size
        self.num_tokens = config.num_tokens
        self.K = config.K
        self.path = config.path
        self.query = config.query
        self.op = config.op
        self.grad_norm_clip = config.grad_norm_clip
        
        # Getting all the independent expressions
        self.decoder_layer = DecoderLayer(config)
        
        # Optimizer
        self.opt = config.opt(self.decoder_layer.parameters(), lr=config.lr)

    def __expand_tensor(self, curr, new):
        if curr is None:
            curr = new
        else:
            curr =  torch.cat((curr,new), dim=0)
        return curr
    
    def lmatch(self, embeddings, op, num1, num2):
        true_ops = torch.clone(embeddings['true_op']).to(self.device)
        true_num1 = embeddings['true_num1'].to(self.device)
        true_num2 = embeddings['true_num2'].to(self.device)
        #true_ops[:,:,op2id['None']] = 0 # setting all none operators to zero so ignored during calculation

        # rows: true exp, cols: pred exp (ie 0,1 would be the cost for true exp 0 with pred exp 1)
        # [8,K*K]
        op_mat = (true_ops.repeat_interleave(K,dim=1)*op.repeat(1,K,1)).sum(dim=-1)
        num1_mat = (true_num1.repeat_interleave(K,dim=1)*num1.repeat(1,K,1)).sum(dim=-1)
        num2_mat = (true_num2.repeat_interleave(K,dim=1)*num2.repeat(1,K,1)).sum(dim=-1)

        return -(op_mat+num1_mat+num2_mat).reshape(op_mat.shape[0],K,K)
    
    # consider trying to write this in pytorch if have time
    def hungarian_algorithm(self, embeddings, op, num1, num2, batch_size):
        m = self.lmatch(embeddings, op, num1, num2)
        permutations = torch.stack(tuple([torch.tensor(linear_sum_assignment(m.detach().cpu().numpy()[x])[1]) for x in range(batch_size)]))
        batch_idx = torch.arange(batch_size).unsqueeze(-1)
        return op[batch_idx, permutations], num1[batch_idx, permutations], num2[batch_idx, permutations]
    
    # expects the optimal permutation from the hungarian algorithm
    def final_loss(self, embeddings, op, num1, num2):
        true_op = embeddings['true_op'].to(self.device)
        true_num1 = embeddings['true_num1'].to(self.device)
        true_num2 = embeddings['true_num2'].to(self.device)
        
        log_op = torch.log((true_op*op).sum(-1))
        
        log_num1 = ((true_num1*num1).sum(-1))
        log_num1[log_num1.nonzero(as_tuple=True)] = torch.log(log_num1[log_num1.nonzero(as_tuple=True)])
        
        log_num2 = ((true_num2*num2).sum(-1))
        log_num2[log_num2.nonzero(as_tuple=True)] = torch.log(log_num2[log_num2.nonzero(as_tuple=True)])
 
        return (-log_op-log_num1-log_num2).sum(-1).mean()

    # expects the optimal permutation from the hungarian algorithm
    # gets the number of correct examples for a batch
    def correct(self, embeddings, op, num1, num2):
        true_op = embeddings['true_op'].to(self.device)
        true_num1 = embeddings['true_num1'].to(self.device)
        true_num2 = embeddings['true_num2'].to(self.device)
        invalid = true_op.argmax(-1)==op2id['None']
        op_correct = true_op.argmax(-1)==op.argmax(-1)
        num1_correct = num1.argmax(-1)==true_num1.argmax(-1)
        num2_correct = num2.argmax(-1)==true_num2.argmax(-1)
        num1_correct[invalid]=True
        num2_correct[invalid]=True
        
        correct = ((op_correct&num1_correct&num2_correct).sum(1)==self.K).sum()
        num_ops = true_op.shape[0]*true_op.shape[1]
        op_correct = op_correct.sum()
        num_nums = len(num1_correct[~invalid])
        num1_correct = num1_correct[~invalid].sum()
        num2_correct = num2_correct[~invalid].sum()
        return correct, op_correct, num_ops, num1_correct, num2_correct, num_nums
    
    def loss(self, embeddings, op, num1, num2, batch_size):
        op, num1, num2 = self.hungarian_algorithm(embeddings, op, num1, num2, batch_size) # getting optimal permutation
        
        loss = self.final_loss(embeddings, op, num1, num2)
        correct, op_correct, num_ops, num1_correct, num2_correct, num_nums = self.correct(embeddings, op, num1, num2)
        
        return loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums
    
    def train(self, epoch, subset=None, shuffle=True):
        total_loss = 0
        total_correct = 0
        total_op_correct = 0
        total_num1_correct = 0
        total_num2_correct = 0
        total_nums = 0
        total_ops = 0
        total_examples = 0
        
        directory = f'{self.path}/train'
        files = np.array(os.listdir(directory))
        if shuffle:
            np.random.shuffle(files)
        if subset is not None:
            files = files[0:subset]
        with tqdm(range(len(files))) as bar:
            for i in bar:
                bar.set_description(f'Epoch {epoch}')
                
                fname = os.path.join(directory, files[i])
                with open(fname, 'rb') as f:
                    embeddings = pickle.load(f)

                batch_size = embeddings['problem'].shape[0]
                total_examples += batch_size

                # Forward pass
                x = self.query[None,:,:].repeat(batch_size,1,1) # [batch_size, K+1, 768]
                x, op, num1, num2 = self.forward(x, embeddings) # x: [batch_size, K, 768]

                # Getting the loss
                loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums = self.loss(embeddings, 
                                                                                                     op, 
                                                                                                     num1, 
                                                                                                     num2,
                                                                                                     batch_size)
                
                # Backpropagating the loss
                self.opt.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.decoder_layer.parameters(), self.grad_norm_clip)
                self.opt.step()

                total_loss += loss.detach().item()
                total_correct += correct.detach().item()
                total_op_correct += op_correct.detach().item()
                total_num1_correct += num1_correct.detach().item()
                total_num2_correct += num2_correct.detach().item()
                total_ops += num_ops
                total_nums += num_nums
                
                bar.set_postfix(loss=loss.detach().item())
                
        return total_loss/len(files), total_correct/total_examples, total_op_correct/total_ops, total_num1_correct/total_nums, total_num2_correct/total_nums
    
    def val(self, subset=1, shuffle=True):
        total_loss = 0
        total_correct = 0
        total_examples = 0
        total_op_correct = 0
        total_num1_correct = 0
        total_num2_correct = 0
        total_nums = 0
        total_ops = 0
        total_examples = 0
        
        directory = f'{self.path}/validation'
        files = np.array(os.listdir(directory))
        if shuffle:
            np.random.shuffle(files)
        if subset is not None:
            files = files[0:subset]
        with torch.no_grad():
            with tqdm(range(len(files))) as bar:
                for i in bar:
                    bar.set_description(f'Validation')

                    fname = os.path.join(directory, files[i])
                    with open(fname, 'rb') as f:
                        embeddings = pickle.load(f)

                    batch_size = embeddings['problem'].shape[0]
                    total_examples += batch_size

                    # Forward pass
                    x = self.query[None,:,:].repeat(batch_size,1,1) # [batch_size, K, 768]
                    x, op, num1, num2 = self.forward(x, embeddings) # x: [batch_size, K, 768]

                    # Getting the loss
                    loss, correct, op_correct, num_ops, num1_correct, num2_correct, num_nums = self.loss(embeddings, 
                                                                                                         op, 
                                                                                                         num1, 
                                                                                                         num2,
                                                                                                         batch_size)
                    
                    total_loss += loss.item()
                    total_correct += correct.item()
                    total_op_correct += op_correct.item()
                    total_num1_correct += num1_correct.item()
                    total_num2_correct += num2_correct.item()
                    total_ops += num_ops
                    total_nums += num_nums
                    
        return total_loss/len(files), total_correct/total_examples, total_op_correct/total_ops, total_num1_correct/total_nums, total_num2_correct/total_nums
    
    def fit(self, epochs, tsubset=None, vsubset=None, shuffle=True):
        for epoch in range(epochs):
            train_loss, train_acc, train_op_acc, train_num1_acc, train_num2_acc = self.train(epoch+1, tsubset, shuffle)
            val_loss, val_acc, val_op_acc, val_num1_acc, val_num2_acc = self.val(vsubset, shuffle)
            
            df = pd.DataFrame({'Train':[train_loss, train_acc, train_op_acc, train_num1_acc, train_num2_acc], 
                               'Validation':[val_loss, val_acc, val_op_acc, val_num1_acc, val_num2_acc]},
                       index = ['Loss','Accuracy','Op Accuracy','Num1 Accuracy','Num2 Accuracy'])
            
            display(df)
            
    def get_eq(self, embeddings, op, num1, num2):
        literals = embeddings['num_literals']
        op = op.cpu()
        num1 = num1.cpu()
        num2 = num2.cpu()
        op_lit = np.array([id2op[x.item()] for x in op.argmax(-1).flatten()])
        num1_lit = literals[num1.argmax(-1)].flatten()
        num2_lit = literals[num2.argmax(-1)].flatten()
        return np.array([f'{n1} {o} {n2}' if o!='None' else 'None' for o,n1,n2 in zip(op_lit,num1_lit,num2_lit)]).reshape(8,K)
            
    def forward(self, x, embeddings):
        return self.decoder_layer(x, embeddings)

In [None]:
# pred_op2 = torch.tensor([[[.2,.01,.35,.09,.3,.05],
#                          [.4,.01,.02,.03,.04,.5],
#                          [.02,.08,.2,.25,.05,.4],
#                          [.04,.06,.07,.03,.02,.78],
#                          [.3,.03,.39,.01,.07,.2],
#                          [.03,.04,.02,.01,.06,.84]]])
# pred_num12 = torch.tensor([[[.1,.8,.03,.07],
#                            [.7,.2,.06,.04],
#                            [.2,.1,.65,.05],
#                            [.1,.3,.2,.4],
#                            [.45,.3,.05,.2],
#                            [0.1, 0.5, 0.17, 0.23]]])
# pred_num22 = torch.tensor([[[.09,.3,.11,.5],
#                            [.3,.4,.25,.05],
#                            [.1,.8,.04,.06],
#                            [.7,0.01,.19,.1],
#                            [.24,.26,.3,.2],
#                            [.2,.11,.09,.6]]])
# pred_op3 = torch.tensor([[[.2,.01,.35,.09,.3,.05],
#                          [.4,.01,.02,.03,.04,.5],
#                          [.02,.08,.2,.25,.05,.4],
#                          [.04,.06,.07,.03,.02,.78],
#                          [.3,.03,.39,.01,.07,.2],
#                          [.03,.04,.02,.01,.06,.84]]])
# pred_num13 = torch.tensor([[[.1,.8,.03,.07],
#                            [.7,.2,.06,.04],
#                            [.2,.1,.65,.05],
#                            [.1,.3,.2,.4],
#                            [.45,.3,.05,.2],
#                            [0.1, 0.5, 0.17, 0.23]]])
# pred_num23 = torch.tensor([[[.09,.3,.11,.5],
#                            [.3,.4,.25,.05],
#                            [.1,.8,.04,.06],
#                            [.7,0.01,.19,.1],
#                            [.24,.26,.3,.2],
#                            [.2,.11,.09,.6]]])
# pred_op = torch.tensor([[[.4,.01,.02,.03,.04,.5],
#                          [.2,.01,.35,.09,.3,.05],
#                          [.02,.08,.2,.25,.05,.4],
#                          [.04,.06,.07,.03,.02,.78],
#                          [.3,.03,.39,.01,.07,.2],
#                          [.03,.04,.02,.01,.06,.84]]])
# pred_num1 = torch.tensor([[[.7,.2,.06,.04],
#                            [.1,.8,.03,.07],
#                            [.2,.1,.65,.05],
#                            [.1,.3,.2,.4],
#                            [.45,.3,.05,.2],
#                            [0.1, 0.5, 0.17, 0.23]]])
# pred_num2 = torch.tensor([[[.3,.4,.25,.05],
#                            [.09,.3,.11,.5],
#                            [.1,.8,.04,.06],
#                            [.7,0.01,.19,.1],
#                            [.24,.26,.3,.2],
#                            [.2,.11,.09,.6]]])

# true_op = torch.tensor([[[0,0,1,0,0,0],[0,0,1,0,0,0],[0,0,0,0,0,1],[0,0,0,0,0,1],[0,0,0,0,0,1],[0,0,0,0,0,1]]])
# true_num1 = torch.tensor([[[0,1,0,0],[1,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]])
# true_num2 = torch.tensor([[[0,0,0,1],[0,0,1,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]])

# true_op = torch.cat((true_op,true_op,true_op), dim=0)
# true_num1 = torch.cat((true_num1,true_num1,true_num1), dim=0)
# true_num2 = torch.cat((true_num2,true_num2,true_num2), dim=0)
# pred_op = torch.cat((pred_op,pred_op2,pred_op3), dim=0)
# pred_num1 = torch.cat((pred_num1,pred_num12,pred_num13), dim=0)
# pred_num2 = torch.cat((pred_num2,pred_num22,pred_num23), dim=0)

# def lmatch(true_ops, true_num1, true_num2, op, num1, num2):
#     true_ops = torch.clone(true_ops)
#     true_ops[:,:,op2id['None']] = 0 # setting all none operators to zero so ignored during calculation

#     # rows: true exp, cols: pred exp (ie 0,1 would be the cost for true exp 0 with pred exp 1)
#     # [8,K*K]
#     op_mat = (true_ops.repeat_interleave(K,dim=1)*op.repeat(1,K,1)).sum(dim=-1)
#     num1_mat = (true_num1.repeat_interleave(K,dim=1)*num1.repeat(1,K,1)).sum(dim=-1)
#     num2_mat = (true_num2.repeat_interleave(K,dim=1)*num2.repeat(1,K,1)).sum(dim=-1)

#     return -(op_mat+num1_mat+num2_mat).reshape(op_mat.shape[0],K,K)

# #print(lmatch(true_op, true_num1, true_num2, pred_op, pred_num1, pred_num2)[0])

# # consider trying to write this in pytorch if have time
# def hungarian_algorithm(true_ops, true_num1, true_num2, op, num1, num2, batch_size):
#     m = lmatch(true_ops, true_num1, true_num2, op, num1, num2)
#     permutations = torch.stack(tuple([torch.tensor(linear_sum_assignment(m.detach().cpu().numpy()[x])[1]) for x in range(batch_size)]))
#     batch_idx = torch.arange(batch_size).unsqueeze(-1)
#     return permutations
#     #return op[batch_idx, permutations], num1[batch_idx, permutations], num2[batch_idx, permutations]
    
# hungarian_algorithm(true_op, true_num1, true_num2, pred_op, pred_num1, pred_num2, 3)

# # def final_loss(true_op, true_num1, true_num2, op, num1, num2):
# #     log_op = torch.log((true_op*op).sum(-1))

# #     log_num1 = ((true_num1*num1).sum(-1))
# #     log_num1[log_num1.nonzero(as_tuple=True)] = torch.log(log_num1[log_num1.nonzero(as_tuple=True)])

# #     log_num2 = ((true_num2*num2).sum(-1))
# #     log_num2[log_num2.nonzero(as_tuple=True)] = torch.log(log_num2[log_num2.nonzero(as_tuple=True)])

# #     return (-log_op-log_num1-log_num2).sum(-1)

# # final_loss(true_op, true_num1, true_num2, pred_op, pred_num1, pred_num2)

In [None]:
class DecoderLayer(torch.nn.Module):
    
    def __init__(self, config): 
        super(DecoderLayer, self).__init__()
        self.device = config.device
        self.embedding_size = config.embedding_size
        self.num_tokens = config.num_tokens
        self.K = config.K
        self.op = config.op
        
        # standard transformer decoder
        # we choose K heads, K layers for K generated expressions             
        # (with the hope that each head/layer will get different information for each K expression)
        transformer_decoder_layer = torch.nn.TransformerDecoderLayer(self.embedding_size, nhead=config.nhead, batch_first=True)
        self.transformer_decoder = torch.nn.TransformerDecoder(transformer_decoder_layer, num_layers=config.nlayer)
        
        # softmax for the operations
        self.op_softmax = torch.nn.Softmax(dim=2)
        
        # softmax for the numbers
        self.num_softmax = torch.nn.Softmax(dim=0)
        
        # decreasing dimensionality to match embedding size: <op, num, num, num*num> = 3072 -> 768
        self.exp_encoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size*4, self.embedding_size*2),
            torch.nn.SiLU(),
            torch.nn.Linear(self.embedding_size*2, self.embedding_size*2),
            torch.nn.SiLU(),
            torch.nn.Linear(self.embedding_size*2, self.embedding_size),
        )
        
        # converting back into new <op, num, num> for loss calculation
        self.exp_decoder = torch.nn.Sequential(
            torch.nn.Linear(self.embedding_size, self.embedding_size*2),
            torch.nn.SiLU(),
            torch.nn.Linear(self.embedding_size*2, self.embedding_size*2),
            torch.nn.SiLU(),
            torch.nn.Linear(self.embedding_size*2, self.embedding_size*3),
        )
        
        # Attention Mask
        #self.tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(sz=self.K).to(self.device)
      
    def __apply_to_nums(self, f, num, embeddings):
        batch_size = embeddings['problem'].shape[0]
        num_idx = embeddings['num_idx']
        new_nums = torch.zeros(embeddings['true_num1'].shape).to(self.device)
        for x in range(batch_size):
            new_nums[x,:,num_idx==x] = f(num[num_idx==x]).T
        return new_nums
    
    # x: [batch_size, K, 768]
    # nums: [batch_size, num_nums, 768]
    # ops: [num_ops, 768]
    # problems: [batch_size, num_tokens, 768]
    def forward(self, x, embeddings):    
        nums = embeddings['nums'].to(self.device)
        num_idx = embeddings['num_idx'].to(self.device)
        ops = self.op
        problems = embeddings['problem'].to(self.device)
        batch_size = x.shape[0]
        num_ops = ops.shape[0]
        num_nums = nums.shape[0]

        # ----------------------------
        # Step 1 - transformer decoder
        # ----------------------------
        assert problems.shape == torch.Size([batch_size, self.num_tokens, self.embedding_size])
        assert x.shape == torch.Size([batch_size, self.K, self.embedding_size])
        x = self.transformer_decoder(x, problems) # [batch_size, K, 768] -> [batch_size, K, 768] (problems is [batch_size, num_tokens, 768])
        assert x.shape == torch.Size([batch_size, self.K, self.embedding_size])
        
        # -------------------------------------------------------------------------------
        # Step 2 - decoding the output into three embeddings of size 768 (op, num1, num2)
        # -------------------------------------------------------------------------------
        x = self.exp_decoder(x) # [batch_size, K, 768] -> [batch_size, K, 2304]
        assert x.shape == torch.Size([batch_size, self.K, self.embedding_size*3])
        operation, x1, x2 = torch.split(x, self.embedding_size, dim=2) # [batch_size, K, 2304] -> [batch_size, K, 768] for each
        assert operation.shape == torch.Size([batch_size, self.K, self.embedding_size]) and x1.shape == torch.Size([batch_size, self.K, self.embedding_size]) and x2.shape == torch.Size([batch_size, self.K, self.embedding_size])

        # -----------------------------------------------------------------------------------------------
        # Step 3 and 4 - Finding the softmax for the similarity between the predicted and true embeddings
        # -----------------------------------------------------------------------------------------------
        # making sure params have correct dimension
        nums_expanded = nums[:,None,:].expand(-1,K,-1) # [number_of_nums, 768] -> [number_of_nums, K, 768]
        assert nums_expanded.shape == torch.Size([num_nums, self.K, self.embedding_size])
        ops = ops[None,None,:,:].repeat(batch_size,self.K,1,1) # [num_ops, 768] -> [batch_size, K, num_ops, 768]
        assert ops.shape == torch.Size([batch_size, self.K, num_ops, self.embedding_size])

        # dot product - calculating the similarity between each op/num prediction and stored embeddings
        num1 = (x1[num_idx]*nums_expanded).sum(dim=2) # [number_of_nums, K]
        assert num1.shape == torch.Size([num_nums, self.K])
        num2 = (x2[num_idx]*nums_expanded).sum(dim=2) # [number_of_nums, K]        
        assert num2.shape == torch.Size([num_nums, self.K])
        op = operation[:,:,None,:].expand(-1,-1,num_ops,-1) # [batch_size, K, 768] -> # [batch_size, K, num_ops, 768]
        op = (op*ops).sum(dim=3) # [batch_size, K, num_ops, 768] -> [batch_size, K, num_ops]
        assert op.shape == torch.Size([batch_size, self.K, num_ops])
        
        # softmax
        op = self.op_softmax(op) # [batch_size, K, num_ops] (ie op[1,2] would be the operator prediction probabilities for problem2, query3)
        assert op.shape == torch.Size([batch_size, self.K, num_ops])
        assert op.sum(dim=2).sum()==batch_size*self.K  
        
        num1 = self.__apply_to_nums(self.num_softmax, num1, embeddings) # [batch_size,K,num_nums]
        assert num1.shape == torch.Size([batch_size, self.K, num_nums])
        assert np.isclose(num1[0].sum(1).sum().item(),self.K)
        
        num2 = self.__apply_to_nums(self.num_softmax, num2, embeddings) # [batch_size,K,num_nums]
        assert num2.shape == torch.Size([batch_size, self.K, num_nums])
        assert np.isclose(num2[0].sum(1).sum().item(),self.K)
        
        # ----------------------------------------------------
        # Step 5 - creating embedding for the found expression  
        # ----------------------------------------------------
        x = self.exp_encoder(torch.cat((operation,x1,x2,x1*x2), dim=2)) # [batch_size, K, 3072] -> [batch_size, K, 768]  
        assert x.shape == torch.Size([batch_size, self.K, self.embedding_size])
        
        # -----------------------
        # Returning final results
        # -----------------------
        return x, op, num1, num2

In [None]:
#     # Each set of equations starts with an sos token, and is padded with additional None tokens to reach K
#     def __get_true_embed(self, embeddings):
#         batch_size = embeddings['problem'].shape[0]
#         idx = embeddings['true_idx']
#         sos = embeddings['sos']
#         x = self.exp_encoder(embeddings['true_embed'])
#         x = torch.nn.functional.normalize(x, dim=-1)
#         pad = self.exp_encoder(torch.cat((self.padding,embeddings['ops'][op2id['None']],self.padding,self.padding)))
#         pad = torch.nn.functional.normalize(pad, dim=-1)
#         new_x = None
#         for batch_num in range(batch_size):
#             temp = x[idx==batch_num]
#             temp = torch.cat((sos[None,:], temp, pad[None,:].repeat(K-temp.shape[0],1)), dim=0)
#             new_x = self.__expand_tensor(new_x, temp[None,:,:])
#         return new_x
    
#     def __get_all_true(self):
#         true = {}
#         padding = torch.rand(self.embedding_size)
#         for name in SET_NAMES:
#             directory = f'{self.path}/{name}'
#             files = os.listdir(directory)
#             files.sort(key=lambda f: int(re.sub('\D', '', f)))
#             for batch_num, f in enumerate(files):
#                 fname = os.path.join(directory, f)
#                 with open(fname, 'rb') as f:
#                     embeddings = pickle.load(f)
#                 embeddings['true'] = self.__get_true_embed(embeddings, padding)
#                 with open(fname, 'wb') as f:
#                     pickle.dump(embeddings, f)

In [None]:
# class TestModel(unittest.TestCase):
    
#     def __init__(self, *args, **kwargs):
#         super(TestModel, self).__init__(*args, **kwargs)
#         self.data = util.load_data()
#         self.config = Config(data, reload=False)
    
#     def test_config(self):
#         # query
#         self.assertEqual(self.config.query.shape, torch.Size([self.config.K, self.config.embedding_size]))
#         self.assertEqual(self.config.query.unique(dim=0).shape, self.config.query.shape)

#         # op
#         self.assertEqual(self.config.op.shape, torch.Size([len(op2id), self.config.embedding_size]))
#         self.assertEqual(self.config.op.unique(dim=0).shape, self.config.op.shape)
#         for name in SET_NAMES:
#             # nums
#             self.assertEqual(self.config.num[name]['idx'].max(), len(self.data[name])-1)
#             self.assertEqual(self.config.num[name]['idx'].shape, self.config.num[name]['literals'].shape)
            
#             # const
#             self.assertEqual(self.config.const[name]['idx'].max(), len(self.data[name])-1)
#             self.assertEqual(self.config.const[name]['idx'].shape, self.config.const[name]['literals'].shape)
#             self.assertEqual(self.config.const[name]['embed_idx'].shape, self.config.const[name]['idx'].shape)
#             self.assertEqual(self.config.const[name]['embed_idx'].max(), len(self.config.const[name]['embed'])-1)
#             self.assertEqual(len(self.config.const[name]['embed']), config.const['train']['embed'].unique(dim=1).shape[0])
            
#             # combined
#             self.assertEqual(self.config.combined[name]['idx'].max(), len(self.data[name])-1)
#             self.assertEqual(self.config.combined[name]['idx'].shape, self.config.combined[name]['literals'].shape)
#             self.assertEqual(self.config.combined[name]['idx'].shape[0], self.config.const[name]['idx'].shape[0]+self.config.num[name]['idx'].shape[0])    
            
#             # Checking if all equation nums are in the problem
#             for i, eq_nums in enumerate([x-set(const2val.keys()) for x in data[name]['nums'].map(eval).tolist()]):
#                 eq_nums = set(np.array(list(eq_nums)).astype(float))
#                 literals = set(self.config.num[name]['literals'][self.config.num[name]['idx']==i])
#                 self.assertTrue(eq_nums<=literals)
                
#             # Checking if constant correctness is within acceptable threshold
#             num_correct = 0
#             thresh = .97
#             for i, actual in enumerate([x.intersection(set(const2val.keys())) for x in data[name]['nums'].map(eval).tolist()]):
#                 predicted = set(id2const[config.const[name]['embed_idx'][config.const[name]['idx']==i]])
#                 literals = set(self.config.const[name]['literals'][self.config.const[name]['idx']==i])
#                 if actual <= predicted and literals <= predicted:
#                     num_correct += 1
#             self.assertTrue(num_correct/len(data[name]) >= thresh)
            
#             # Checking if all equations have same literals within threshold
#             num_correct = 0
#             thresh = .97
#             for i, actual in enumerate(data[name]['nums'].map(eval).tolist()):
#                 actual = set([x if x in const2val else str(float(x)) for x in actual])
#                 literals = set(self.config.combined[name]['literals'][self.config.combined[name]['idx']==i])
#                 if actual <= literals:
#                     num_correct += 1
#             self.assertTrue(num_correct/len(data[name]) >= thresh)
            
        
# unittest.main(argv=[''], verbosity=2, exit=False);

## Old prepare labels

In [None]:
def prepare_labels(self, idx, true_labels, shape):
    true_op, true_num1, true_num2 = true_labels
    batch_size = true_op.shape[0]

    temp = torch.zeros(shape)
    unmasked = true_num1!=-1
    true_idx = torch.cat([torch.where(idx==x)[0][true_num1[x][true_num1[x]!=-1]] for x in range(batch_size)])
    temp[unmasked,true_idx]=1
    true_num1 = temp

    temp = torch.zeros(shape)
    unmasked = true_num2!=-1
    true_idx = torch.cat([torch.where(idx==x)[0][true_num2[x][true_num2[x]!=-1]] for x in range(batch_size)])
    temp[unmasked,true_idx]=1
    true_num2 = temp

    # Randomizing equation order
    batch_idx = torch.arange(batch_size).unsqueeze(-1)
    perm = torch.rand(batch_size,self.K).argsort(dim=1)
    true_op = true_op[batch_idx,perm,:]
    true_num1 = true_num1[batch_idx,perm,:]
    true_num2 = true_num2[batch_idx,perm,:]

    return true_op.to(self.device), true_num1.to(self.device), true_num2.to(self.device)

## Old preprocessing code

In [None]:
#     def __get_init_exp(self, data, name):
#         reorder = lambda x: [x[1], x[0], x[2]]
#         convert_to_arr = lambda d: [reorder(x.split()) for idx, arr in enumerate(d.split(' ; ')) for x in eval(arr) if x is not None and idx == 0]
#         return data[name]['incremental'].map(convert_to_arr)

#     def __get_num_label(self, name, num, i):
#         idx = self.combined[name]['idx']
#         literals = self.combined[name]['literals'][idx==i]
#         label = torch.zeros(literals.shape)
#         if num in const2val:
#             location = np.where(literals==num)[0]
#             if len(location) == 0:
#                 return torch.tensor(-1)
#             else:
#                 location = location[0]
#         else:
#             location = np.where(literals==str(float(num)))[0][0]
#         return torch.tensor(location)

#     def __pad_op(self, op_label):
#         temp = torch.zeros(op_label.shape[1])
#         temp[op2id['None']] = 1
#         return torch.cat((op_label, temp[None,:].repeat(K-op_label.shape[0],1)), dim=0)[None,:,:]

#     def __pad_num(self, num_label):
#         return torch.cat((num_label, torch.tensor(-1).repeat(K-num_label.shape[0])), dim=0)

#     def __get_true_label(self, name, e, i):
#         op, num1, num2 = e
#         op_label = torch.zeros(len(op2id))
#         op_label[op2id[op]] = 1
#         num1_label = self.__get_num_label(name, num1, i)
#         num2_label = self.__get_num_label(name, num2, i)
#         return op_label, num1_label, num2_label

#     def __get_label(self, data, name):
#         exp = self.__get_init_exp(data, name)
#         idx = []
#         true_op_label = None
#         true_num1_label = []
#         true_num2_label = []
#         for i,e in enumerate(exp):
#             get_true = lambda x: self.__get_true_label(name, x, i)
#             op_label, num1_label, num2_label = [torch.stack(item) for item in list(zip(*map(get_true, e)))]
#             true_op_label = self.__expand_tensor(true_op_label, self.__pad_op(op_label))
#             true_num1_label.append(self.__pad_num(num1_label))
#             true_num2_label.append(self.__pad_num(num2_label))
# #         return true_op_label, torch.stack(true_num1_label), torch.stack(true_num2_label)

Getting the true labels for operations and numbers. Since our model will generate equations in any arbitrary order, we choose to store a tuple representation of the formula, which can be looked up later during loss calculation.

The value stored in embeddings['true'] is batch size examples of (true_ops, true_num1, true_num2, layer_idx)

In [7]:
num_per_batch = 8

def get_exp(name):
    reorder = lambda x: (x[1], x[0], x[2])
    convert_to_arr = lambda d: [[reorder(tuple(x.split())),idx] for idx, arr in enumerate(d.split(' ; ')) for x in eval(arr) if x is not None]
    return data[name]['incremental'].map(convert_to_arr)

def non_homogeneous_split(arr, num_per_batch):
    return [arr[idx:idx+num_per_batch] for idx in range(0,len(arr),num_per_batch)]

def process_num(num, literals, idx, prob, prev):
    if num in const2val:
        temp = np.where((literals[idx==prob]==num))[0]
        if len(temp) > 0:
            return temp[0]
        else: # The constant predictor failed to predict the correct constants for this specific problem
            return None
    elif 'x' in num:
        eq_idx = int(num[1:])-1
        return (prev[0][eq_idx], prev[1][eq_idx])
    else:
        return np.where((literals[idx==prob]==str(float(num))))[0][0]

def process_exp(num_idx, literals, exp, p):
    op_labels = None
    num1_labels = []
    num2_labels = []
    layer_idx = []
    for (op, num1, num2), layer in exp:
        num1_labels.append(process_num(num1, literals, num_idx, p, (num1_labels, num2_labels)))
        num2_labels.append(process_num(num2, literals, num_idx, p, (num1_labels, num2_labels)))
        
        temp = torch.zeros((1,6))
        temp[0,op2id[op]] = 1
        if op_labels is None:
            op_labels = temp
        else:
            op_labels = torch.cat((op_labels,temp),dim=0)
            
        layer_idx.append(layer)
    return op_labels, np.array(num1_labels, dtype=object), np.array(num2_labels, dtype=object), np.array(layer_idx)

def process_batch(num_idx, literals, expressions):
    results = []
    for p,exp in enumerate(expressions):
        results.append(process_exp(num_idx, literals, exp, p))
    return results
    
def get_true_labels(name):
    exp = non_homogeneous_split(get_exp(name).tolist(), num_per_batch)
    for batch_num, f in enumerate(files):
        print(np.array(process_batch(embeddings['num_idx'], embeddings['num_literals'], exp[batch_num]), dtype=object))
        break
            
{name:get_true_labels(name) for name in SET_NAMES}

FileNotFoundError: [WinError 3] The system cannot find the path specified: 'pickle/embeddings/train'

In [None]:
torch.rand(6,768)