# Math Q&A Transformer

In [None]:
import sys
import os

module_path = os.path.abspath('.')
if module_path not in sys.path:
    sys.path.append(module_path)

## Data

Summarize dataset structure for the `arithmetic` category.

In [None]:
from math_dataset import MathDataset

for data_type, categories in MathDataset.subcategories().items():
  print(f'Data Type: {data_type}')
  for subcat in categories['arithmetic']:
    print(f'  {subcat}')
  print()

Print some example questions and answer.

In [None]:
test_dataset = MathDataset('train-easy', 'arithmetic', 'add_or_sub')

for question, answer in zip(*test_dataset[5000:5005]):
  print(f'Question: {question}')
  print(f'Answer: {answer}')
  print()

## Setup for Transformer

Determine token set for arithmetic training easy. Takes ~30s

In [None]:
import torch
import utils

# load all arithmetic training easy subcategories
arthimetic_easy_subcats = MathDataset.subcategories()['train-easy']['arithmetic']
datasets = [
    MathDataset('train-easy', 'arithmetic', s)
    for s in arthimetic_easy_subcats
]

# put together all subcategories
arithmetic_easy = torch.utils.data.ConcatDataset(datasets)

# find all possible tokens
arithmetic_easy_tokens = utils.token_set(arithmetic_easy)

arithmetic_easy_tokens

Evaluate input-output sequence lengths. Takes ~1min

In [None]:
utils.plot_length_histogram(arithmetic_easy)

## Exploring Token Embeddings and Attention Masks

In [None]:
from transformer import Transformer, TokenEmbedding

transformer = Transformer(
    TokenEmbedding(arithmetic_easy_tokens, 6),  # 6 = d_model
    30,   # = max_output_length
    5,    # = n_encoder_layers
    5,    # = n_decoder_layers
    1,    # = n_heads
    1024  # = d_ff
)

In [None]:
sample_questions = [arithmetic_easy[i][0] for i in range(5)]
sample_questions

In [None]:
sample_indices = transformer.token_embedding.indices(sample_questions)
sample_indices

In [None]:
sample_unembedded = transformer.token_embedding.unembed(sample_indices)
sample_unembedded

In [None]:
sample_unembedded_special = transformer.token_embedding.unembed(
    sample_indices, include_special=True
)
sample_unembedded_special

In [None]:
sample_token_embeddings = transformer.token_embedding(sample_indices)
sample_token_embeddings

In [None]:
def print_full(matrix, digits='.02f'):
  for r in range(matrix.shape[0]):
    for c in range(matrix.shape[1]):
      print(f'{float(matrix[r, c]):{digits}}', end=' ')
    print()
  print()


mask = transformer.input_attention_mask(sample_indices)
for index_seq, mask in zip(sample_indices, mask):
  print('Token sequence indices and input mask matrix')
  print_full(torch.cat([index_seq[None], mask]), digits='g')

  print('Post-softmax')
  print_full(torch.nn.functional.softmax(mask, dim=1), digits='.4f')

In [None]:
sample_outputs = [arithmetic_easy[i][1] for i in range(5)]
sample_output_indices = transformer.token_embedding.indices(sample_outputs)

In [None]:
sample_outputs

In [None]:
sample_output_indices

In [None]:
mask = transformer.output_attention_mask(sample_output_indices)
for index_seq, mask in zip(sample_output_indices, mask):
  print('Token sequence indices and output self-attention mask matrix')
  print_full(torch.cat([index_seq[None], mask]), digits='g')

  print('Post-softmax')
  print_full(torch.nn.functional.softmax(mask, dim=1), digits='.4f')

In [None]:
mask = transformer.cross_attention_mask(sample_indices, sample_output_indices)
for index_seq, mask in zip(sample_indices, mask):
  print('Token sequence indices and output cross-attention mask matrix')
  print_full(torch.cat([index_seq[None], mask]), digits='g')

  print('Post-softmax')
  print_full(torch.nn.functional.softmax(mask, dim=1), digits='.4f')

## Transformer Training

In [None]:
from training import QATransformerTrainer, SavePeriodicallyCallback


model = Transformer(
    TokenEmbedding(arithmetic_easy_tokens, 512),
    30, 6, 6, 8, 2048, p_dropout=0.1
).cuda()

ar_easy_dl = torch.utils.data.DataLoader(
    arithmetic_easy, batch_size=128, shuffle=True,
)

optim = torch.optim.Adam(
    model.parameters(),
    lr=6e-4,
    betas=(.9, .995),
    eps=1e-9
)

cel = torch.nn.CrossEntropyLoss(
    ignore_index=model.token_embedding.pad_index, label_smoothing=.05
)

def loss_fn(prob, actual):
    pad_b, pad_n = torch.nonzero(actual == model.token_embedding.pad_index, as_tuple=True)
    prob[pad_b, :, pad_n] = -1000.
    return cel(prob, actual)

trainer = QATransformerTrainer('model3', model, ar_easy_dl, optim, loss_fn, 100)

# save every 900s = 15min
save_callback = SavePeriodicallyCallback(trainer, 900)

In [None]:
losses, accuracies = trainer.train(
    epochs=2,
    batch_callbacks=[save_callback],
    verbosity=1
)

In [None]:
from simple_dataset import SimpleDataset1

simple_dataset = SimpleDataset1(1000000)

In [None]:
from training import QATransformerTrainer, SavePeriodicallyCallback

model = Transformer(
    TokenEmbedding(SimpleDataset1.tokens(), 256),
    130, 6, 6, 8, 1024, p_dropout=0.1
).cuda()

simple_dl = torch.utils.data.DataLoader(
    simple_dataset, batch_size=256, shuffle=True,
)

optim = torch.optim.Adam(
    model.parameters(),
    lr=1e-5,
    betas=(.9, .995),
    eps=1e-9
)

cel = torch.nn.CrossEntropyLoss(
    ignore_index=model.token_embedding.pad_index, label_smoothing=.05
)

def loss_fn(prob, actual):
    pad_b, pad_n = torch.nonzero(actual == model.token_embedding.pad_index, as_tuple=True)
    prob[pad_b, :, pad_n] = -1000.
    return cel(prob, actual)

trainer = QATransformerTrainer('model3', model, simple_dl, optim, loss_fn, 100)

# save every 900s = 15min
save_callback = SavePeriodicallyCallback(trainer, 900)

In [None]:
losses, accuracies = trainer.train(
    epochs=2,
    batch_callbacks=[save_callback],
    verbosity=1
)

In [None]:
batch = next(iter(simple_dl))
model.eval()
with torch.no_grad():
    for q, a in zip(*batch):
        print(q)
        print(a)
        print(model(q))
        print()