# Mathematical Reasoning with Transformers

## First we import all the necessary libraries

In [10]:
from src.utils import translate
from src.defaults import _C as cfg
import argparse
import os
import pickle
from src.model import Seq2SeqTransformer
import torch
from config.defaults import get_cfg_defaults

mydevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Import configuration from `defaults.py`

In [11]:
cfg = get_cfg_defaults()

cwd = os.getcwd()
cfg.DATASET_DIR = cwd + "/data"
cfg.MODEL_SAVE_PATH = cwd + "/results/"

## Import vocab files

In [12]:
with open(cfg.MODEL_SAVE_PATH + "src_vocab.pickle", "rb") as infile:
    src_vocab = pickle.load(infile)
    
with open(cfg.MODEL_SAVE_PATH + "tgt_vocab.pickle", "rb") as infile:
    tgt_vocab = pickle.load(infile)

cfg.SRC_VOCAB_SIZE = len(src_vocab)
cfg.TGT_VOCAB_SIZE = len(tgt_vocab)

## Instantiate transformer model and load trained model weights

In [13]:
transformer = Seq2SeqTransformer(cfg.NUM_ENCODER_LAYERS, cfg.NUM_DECODER_LAYERS,
                                 cfg.EMB_SIZE, cfg.SRC_VOCAB_SIZE, cfg.TGT_VOCAB_SIZE, cfg.NHEAD,
                                 cfg.FFN_HID_DIM)


transformer.load_state_dict(torch.load(cfg.MODEL_SAVE_PATH + cfg.TASK + ".pth", map_location=mydevice))

<All keys matched successfully>

## Demo using questions

In [15]:
q1 = "What is the ten thousands digit of 62795675?"
a1 = '9'
q2 = "What is the hundred thousands digit of 82923295?"
a2 = '9'
q3 = "What is the tens digit of 70750657?"
a3 = '5'
qs = [q1, q2 ,q3]
a_s = [a1, a2, a3]

print("Final predictions")
for q,a in zip(qs,a_s):
    print("Question: " + q)
    print(translate(transformer, q, src_vocab, tgt_vocab, len(a), mydevice))

Final predictions
Question: What is the ten thousands digit of 62795675?
9
Question: What is the hundred thousands digit of 82923295?
9
Question: What is the tens digit of 70750657?
5


In [16]:
q = 'What is the tens digit of 402934?'
print("Question: " + q)
print(translate(transformer, q, src_vocab, tgt_vocab, len(a), mydevice))

Question: What is the tens digit of 402934?
3
