In [4]:

from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn

from multigrok.data import ArithmeticDataset
from multigrok.transformer import Transformer

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

In [5]:
operators = ['+']
p = 59
optimization_steps = 100000
log_freq = math.ceil(optimization_steps / 1000)
batch_size = -1                 # -1 -> entire dataset, 0 < batch_size < 1 -> fraction of dataset, batch_size > 1 -> batch_size
n_layers = 2
n_heads = 8
d_model = 256
dropout = 0.0
non_linearity = 'relu'          # 'relu' or 'gelu'
training_data_fraction = 0.8

halve_abelian = False
only_input_tokens = False

embedding_lr = 1e-3
decoder_lr = 1e-3
embedding_weight_decay = 0.0
decoder_weight_decay = 0.0
eps = 1e-8

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

In [6]:
operators = list(operators)

torch.set_default_dtype(dtype)
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

dataset = ArithmeticDataset(operators, p=p, halve_abelian=halve_abelian, only_input_tokens=only_input_tokens)
n_train, n_val = int(training_data_fraction * len(dataset)), len(dataset) - int(training_data_fraction * len(dataset)) 
train, val = torch.utils.data.random_split(dataset, [n_train, n_val], torch.Generator().manual_seed(seed))
if batch_size == -1:
    bs = len(train)
elif 0 < batch_size < 1:
    bs = int(batch_size * len(train))
elif batch_size > 1 and batch_size <= len(train):
    bs = int(batch_size)
else:
    raise Exception(f"Invalid batch_size config {batch_size}.")
train_loader = torch.utils.data.DataLoader(train, batch_size=bs, shuffle=True)

model = Transformer(
    n_layers=n_layers,
    n_heads=n_heads,
    d_model=d_model,
    dropout=dropout,
    max_context_len=4, # TODO: make this a configurable parameter?
    vocab_len=dataset.ntokens,
    non_linearity=non_linearity,
    weight_noise=0 # TODO: make this configurable?
).to(device)

optimizer = torch.optim.AdamW(
    [{
        "params": model.embedding.parameters(),
        "lr": embedding_lr,
        "weight_decay": embedding_weight_decay,
        "eps": eps
    },
    {
        "params": list(model.decoder.parameters()) + list(model.linear.parameters()),
        "lr": decoder_lr,
        "weight_decay": decoder_weight_decay,
        "eps": eps
    }]
)

loss_fn = nn.CrossEntropyLoss()

# # prepare for logging
# ex.info['log_steps'] = []
# ex.info['total'] = {
#     'train': {
#         'loss': [],
#         'accuracy': []
#     },
#     'val': {
#         'loss': [],
#         'accuracy': []
#     }
# }
# for op in dataset.operators:
#     ex.info[op] = {
#     'train': {
#         'loss': [],
#         'accuracy': []
#     },
#     'val': {
#         'loss': [],
#         'accuracy': []
#     }
# }

pos = dataset.sequence_length - 1
steps = 0
with tqdm(total=optimization_steps) as pbar:
    for equation, answer in islice(cycle(train), optimization_steps):

        if steps % log_freq == 0:
            eval_loss_fn = nn.CrossEntropyLoss(reduction='sum')
            with torch.no_grad():

                # compute train metrics
                train_evaluation_loader = torch.utils.data.DataLoader(train, batch_size=min(500, len(train)), shuffle=False)
                ops_losses = defaultdict(float)
                ops_accuracies = defaultdict(int)
                ops_totals = defaultdict(int)
                for e_e, e_a in train_evaluation_loader:
                    e_e = e_e.to(device)
                    e_a = e_a.to(device)
                    logits, _, _ = model(e_e, pos=pos)
                    for i in range(e_e.shape[0]):
                        if only_input_tokens:
                            op = dataset.operators[0]
                        else:
                            op = dataset.operation_from_token(e_e[i][1])
                        ops_losses[op] += eval_loss_fn(logits[i:i+1], e_a[i:i+1]).item()
                        predicted_token = torch.argmax(logits[i:i+1]).item()
                        ops_accuracies[op] += int(predicted_token == e_a[i:i+1].item())
                        ops_totals[op] += 1
#                 for op in dataset.operators:
#                     ex.info[op]['train']['loss'].append(ops_losses[op] / ops_totals[op])
#                     ex.info[op]['train']['accuracy'].append(ops_accuracies[op] / ops_totals[op])
#                 ex.info['total']['train']['loss'].append(sum(ops_losses.values()) / sum(ops_totals.values()))
#                 ex.info['total']['train']['accuracy'].append(sum(ops_accuracies.values()) / sum(ops_totals.values()))

                # compute test metrics
                val_evaluation_loader = torch.utils.data.DataLoader(val, batch_size=min(500, len(val)), shuffle=False)
                ops_losses = defaultdict(float)
                ops_accuracies = defaultdict(int)
                ops_totals = defaultdict(int)
                for e_e, e_a in val_evaluation_loader:
                    e_e = e_e.to(device)
                    e_a = e_a.to(device)
                    logits, _, _ = model(e_e, pos=pos)
                    for i in range(e_e.shape[0]):
                        if only_input_tokens:
                            op = dataset.operators[0]
                        else:
                            op = dataset.operation_from_token(e_e[i][1])
                        ops_losses[op] += eval_loss_fn(logits[i:i+1], e_a[i:i+1]).item()
                        predicted_token = torch.argmax(logits[i:i+1]).item()
                        ops_accuracies[op] += int(predicted_token == e_a[i:i+1].item())
                        ops_totals[op] += 1
#                 for op in dataset.operators:
#                     ex.info[op]['val']['loss'].append(ops_losses[op] / ops_totals[op])
#                     ex.info[op]['val']['accuracy'].append(ops_accuracies[op] / ops_totals[op])
#                 ex.info['total']['val']['loss'].append(sum(ops_losses.values()) / sum(ops_totals.values()))
#                 ex.info['total']['val']['accuracy'].append(sum(ops_accuracies.values()) / sum(ops_totals.values()))
#             pbar.set_description("{0:2.1f}% | {1:2.1f}%".format(ex.info['total']['train']['accuracy'][-1] * 100, ex.info['total']['val']['accuracy'][-1] * 100))

        equation = equation.to(device)
        answer = answer.to(device)
        logits, _, _ = model(equation, pos=pos)
        loss = loss_fn(logits, answer)
        loss.backward()
        optimizer.step()
        steps += 1
        pbar.update(1)


  0%|          | 0/100000 [00:00<?, ?it/s]

IndexError: too many indices for tensor of dimension 2

In [11]:
equation, answer = next(islice(cycle(train), optimization_steps))

In [None]:
torch.set_default_dtype(dtype)
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [None]:
dataset = ArithmeticDataset(operators, p=p, halve_abelian=halve_abelian, only_input_tokens=only_input_tokens)
n_train, n_val = int(training_data_fraction * len(dataset)), len(dataset) - int(training_data_fraction * len(dataset)) 
train, val = torch.utils.data.random_split(dataset, [n_train, n_val], torch.Generator().manual_seed(seed))
if batch_size == -1:
    bs = len(train)
elif 0 < batch_size < 1:
    bs = int(batch_size * len(train))
elif batch_size > 1 and batch_size <= len(train):
    bs = int(batch_size)
else:
    raise Exception(f"Invalid batch_size config {batch_size}.")
train_loader = torch.utils.data.DataLoader(train, batch_size=bs, shuffle=True)

In [None]:
blah, bleh = next(iter(train_loader))

In [None]:
bleh.shape

In [None]:
model = Transformer(
    n_layers=n_layers,
    n_heads=n_heads,
    d_model=d_model,
    dropout=dropout,
    max_context_len=4, # TODO: make this a configurable parameter?
    vocab_len=dataset.ntokens,
    non_linearity=non_linearity,
    weight_noise=0 # TODO: make this configurable?
).to(device)

In [None]:
eval_loss_fn = nn.CrossEntropyLoss(reduction='sum')

In [None]:
logits, _, _ = model(blah.to(device), pos=3)

In [None]:
eval_loss_fn(logits, bleh.to(device))

In [None]:
eval_loss_fn(logits[10:11], bleh[10:11].to(device))

In [None]:
torch.argmax(logits[10:11]).item()

In [None]:
logits[10:11]

In [None]:
int(5 == 4)

In [None]:
d = defaultdict(int)

In [None]:
d['+'] += 1

In [None]:
bleh[10:11].item()

In [None]:
"{0:2.1f}% | {1:2.1f}%".format(30.2823, 10.0123)

In [None]:
"{0:2.1f}% | {1:2.1f}%".format(0.1234 * 100, 0.1234 * 100)

In [None]:
bleh.shape

In [None]:
eval_loss_fn = nn.CrossEntropyLoss(reduction='sum')
ops_losses = {}
ops_accuracies = {}
with torch.no_grad():
    train_evaluation_loader = torch.utils.data.DataLoader(train, batch_size=min(500, len(train)), shuffle=False)
    val_evaluation_loader = torch.utils.data.DataLoader(val, batch_size=min(500, len(val)), shuffle=False)
    for t_e_e, t_e_a in train_evaluation_loader:
        t_e_e = t_e_e.to(device)
        t_e_a = t_e_a.to(device)
        logits, _, _ = model(t_e_e, pos=pos)
        for i in range(t_e_e.shape[0]):
            if only_input_tokens:
                op = dataset.operations[0]
            else:
                op = dataset.operation_from_token(t_e_e[i][1])
            ops_losses[op] += eval_loss_fn(logits[i:i+1], t_e_a[i:i+1])
            predicted_token = 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

# from grok import data, transformer

from multigrok import data, transformer
# from x_transformers import TransformerWrapper, Decoder

In [None]:
dataset = data.ArithmeticDataset('+', p=59)

In [None]:
len(dataset)

In [None]:
dataset[10]

In [None]:
n = 3000
dataset.readable_equation(dataset[n][0]), dataset[n][1].item()

In [None]:
dataset.ntokens

In [None]:
train, val = torch.utils.data.random_split(dataset, [len(dataset) // 2, len(dataset) // 2], torch.Generator().manual_seed(0))


In [None]:
train[0]

In [None]:
n = 10
print(train[n])
dataset.readable_equation(train[n][0])

In [None]:
torch.optim.AdamW?

In [None]:
torch.utils.data.DataLoader?

In [None]:
model.forward?

In [None]:
nn.CrossEntropyLoss(reduction='sum')

In [None]:
p = 59
for a in range(1, p):
    a_inverse = a**(p-2) % p
    assert (a * a_inverse) % p == 1, f"{a} {a_inverse}"

In [None]:
nn.CrossEntropyLoss?

In [None]:
x = torch.randint(low=0, high=331, size=(500, 4))

In [None]:
x[499:500]

In [None]:
x[-1]

In [None]:
data.ArithmeticDataset.make_data('+')

In [None]:
tokenizer = data.ArithmeticTokenizer()

In [None]:
model = transformer.Transformer(n_layers=3, 
                n_heads=8, 
                d_model=256, 
                max_context_len=6,
                vocab_len=150)

In [None]:
# model = TransformerWrapper(
#     num_tokens=150,
#     max_seq_len=6,
#     attn_layers=Decoder(
#         dim=256,
#         heads=8,
#         depth=3
#     )
# )

In [None]:
# activations = {}
# def hook(module, input, output):
#     activations[module] = output

In [None]:
activations_inputs = {}
activations_outputs = {}
def create_hook_for_named_module(name):
    def hook(module, input, output):
        activations_inputs[name] = input
        activations_outputs[name] = output
    return hook

hooks = []
for name, module in model.named_modules():
    h = module.register_forward_hook(create_hook_for_named_module(name))
    hooks.append(h)

In [None]:
x = tokenizer.encode(["1 + 2", "2 + 3 + 4 ="])
print(x.shape)
y = model(x)[0]

In [None]:
for name, _ in model.named_modules():
    print(name)

In [None]:
activations_outputs['embedding']

In [None]:
activations_inputs['decoder.blocks.0.self_attn'][0].shape

In [None]:
activations_inputs['decoder']

In [None]:
(torch.rand_like(model.embedding.weight) * 0.02).dtype

In [None]:
activations_outputs

In [None]:
activation_inputs

In [None]:
activations

In [None]:
activations['embedding'].shape

In [None]:
for name, _ in model.named_modules():
    print(name)

In [None]:
activations['embedding'].dtype

In [None]:
model.decoder.blocks[0].self_attn.attn_heads[0]

In [None]:
model.decoder.blocks[0].self_attn.attn_heads[0].Wq.weight.dtype

In [None]:
for i in range(8):
    print(model.decoder.blocks[0].self_attn.attn_heads[i].Wq.weight.dtype)

In [None]:
for parameter in model.parameters():
    print(parameter.dtype)

In [None]:
activations[model.embedding].dtype

In [None]:
for name, module in model.named_modules():
    print(name)

In [None]:
model.decoder.blocks[1]

In [None]:
x = torch.randint(low=0, high=150, size=(10, 6))

In [None]:
y = model(x)

In [None]:
y[0].shape

In [None]:
model(x, pos=5)[0].shape

In [None]:
model(tokenizer.encode(["1 + 2 =", "2 + 3 ="]), pos=3)[0].shape

In [None]:
transformer.Transformer

In [None]:
model

In [None]:
from phasegrok import models

In [None]:
models.Transformer(vocab_len=150, embedding_dim=256, output_dim=150, depth=3, concat=True)

In [None]:
model(x)[0].shape