In [160]:
import matplotlib as plt
import numpy as np
from torch import device

from plot import get_parameter_norms, plot_dicts

%matplotlib inline
# %debug

In [161]:
import math
from argparse import ArgumentParser
from datetime import datetime
from itertools import permutations
import copy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.pyplot import xscale
from tqdm import tqdm
import seaborn as sns

import torch
from torch import device
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

In [162]:
class Block(nn.Module):
    """Causal transformer block
    """

    def __init__(self, dim, num_heads):
        super().__init__()
        self.ln_1 = nn.LayerNorm(dim)
        self.ln_2 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )

    def forward(self, x):
        attn_mask = torch.full(
            (len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
        )
        attn_mask = torch.triu(attn_mask, diagonal=1)
        attn_mask[torch.isnan(attn_mask)] = 0.0 # fixes all 'nan' on 'mps' device

        x = self.ln_1(x)
        a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x


class Decoder(nn.Module):
    """Causal Transformer decoder
    """

    def __init__(self, dim=128, num_layers=2, num_heads=4, num_tokens=97, seq_len=5):
        super().__init__()
        self.token_embeddings = nn.Embedding(num_tokens, dim)
        self.position_embeddings = nn.Embedding(seq_len, dim)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Block(dim=dim, num_heads=num_heads))

        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_tokens, bias=False)

    def forward(self, x):
        h = self.token_embeddings(x)
        positions = torch.arange(x.shape[0], device=x.device).unsqueeze(-1)
        h = h + self.position_embeddings(positions).expand_as(h)
        for layer in self.layers:
            h = layer(h)

        h = self.ln_f(h)
        logits = self.head(h)
        return logits



In [163]:
def get_plot_infix(args):
    # plot model architecture infix
    ff = datetime.now().strftime("%f")
    plot_infix = f"l{args.num_layers}_h{args.num_heads}_e{args.embedding}_{ff}"
    return plot_infix

In [164]:
# replace with read_args(sys.argv[1:]) in python
def read_args(args):
    parser = ArgumentParser(description="Grokfast")
    
    print(f"provided args: {args}")
    
    # architecture parameters
    parser.add_argument("--embedding", type=int, default=128)
    parser.add_argument("--num_layers", type=int, default=2)
    parser.add_argument("--num_heads", type=int, default=4)

    # run params
    parser.add_argument("--label", default="")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--p", type=int, default=97)
    parser.add_argument("--budget", type=int, default=3e5)
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--beta1", type=float, default=0.9)
    parser.add_argument("--beta2", type=float, default=0.98)
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--optimizer", default="Adam")

    # Grokfast
    parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none")
    parser.add_argument("--alpha", type=float, default=0.99)
    parser.add_argument("--window_size", type=int, default=100)
    parser.add_argument("--lamb", type=float, default=5.0)

    # Ablation studies
    parser.add_argument("--two_stage", action='store_true')
    parser.add_argument("--save_weights", action='store_true')

    args = parser.parse_args(args=args)
    
    args.plot_infix = get_plot_infix(args=args)
    
    filter_str = ('_' if args.label != '' else '') + args.filter
    window_size_str = f'_w{args.window_size}'
    alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
    lamb_str = f'_l{int(args.lamb)}'

    if args.filter == 'none':
        filter_suffix = ''
    elif args.filter == 'ma':
        filter_suffix = window_size_str + lamb_str
    elif args.filter == 'ema':
        filter_suffix = alpha_str + lamb_str
    else:
        raise ValueError(f"Unrecognized filter type {args.filter}")

    optim_suffix = ''
    if args.weight_decay != 0:
        optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
    if args.lr != 1e-3:
        optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'

    args.label = args.label + filter_str + filter_suffix + optim_suffix
    print(f'Experiment results saved under name: {args.label}')

    
    return args

In [165]:
simulated_args = ['--embedding', '64', '--num_heads', '2']
print(simulated_args)
args = read_args(simulated_args)

['--embedding', '64', '--num_heads', '2']
provided args: ['--embedding', '64', '--num_heads', '2']
Experiment results saved under name: none


In [67]:
args

Namespace(embedding=64, num_layers=2, num_heads=2, label='none', seed=0, p=97, budget=300000.0, batch_size=512, lr=0.001, beta1=0.9, beta2=0.98, weight_decay=0, optimizer='Adam', filter='none', alpha=0.99, window_size=100, lamb=5.0, two_stage=False, save_weights=False, plot_infix='l2_h2_e64_696935')

In [68]:
args.p

97

In [166]:
def multiplication_mod_p_data(p, eq_token, op_token):
    """x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p
    """
    x = torch.arange(p)
    y = torch.arange(1, p)
    x, y = torch.cartesian_prod(x, y).T

    eq = torch.ones_like(x) * eq_token
    op = torch.ones_like(x) * op_token
    result = x * y % p

    # "All of our experiments used a small transformer trained on datasets of
    # equations of the form a◦b = c, where each of “a”, “◦”, “b”, “=”, and “c”
    # is a seperate token"
    return torch.stack([x, op, y, eq, result])

In [140]:
data = multiplication_mod_p_data(args.p, args.p, args.p)

In [79]:
n = 5891
data[:, n:n+10]

tensor([[61, 61, 61, 61, 61, 61, 61, 61, 61, 61],
        [97, 97, 97, 97, 97, 97, 97, 97, 97, 97],
        [36, 37, 38, 39, 40, 41, 42, 43, 44, 45],
        [97, 97, 97, 97, 97, 97, 97, 97, 97, 97],
        [62, 26, 87, 51, 15, 76, 40,  4, 65, 29]])

In [167]:
def build_model(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # tokens for <op> and <=>. It's not clear why <=> is needed at all since it
    # has no effect on the output, but we'll leave it in to best follow the
    # paper.
    eq_token = args.p
    op_token = args.p + 1

    # "We trained a standard decoder-only transformer (Vaswani et al., 2017)
    # with causal attention masking, and calculated loss and accuracy only on
    # the answer part of the equation. For all experiments we used a
    # transformer with 2 layers, width 128, and 4 attention heads"
    model = Decoder(
        dim=args.embedding,
        num_layers=args.num_layers,
        num_heads=args.num_heads,
        num_tokens=args.p + 2,
        seq_len=5
    ).to(device)
    print_model(model)
    return model

In [168]:
def print_model(model):
    nparams = sum([p.numel() for p in model.parameters() if p.requires_grad])
    print(model)
    print(f'Total number of parameters: {nparams}')
    return

In [169]:
class TransformerDataset(Dataset):
    def __init__(self, token_array, labels=None):
        """
        token_array: numpy array where each column is a sequence of token IDs for one example
        labels: optional array of labels for each example
        """
        self.data = torch.tensor(token_array.T)  # Transpose to make each row an example
        self.labels = None if labels is None else torch.tensor(labels)
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        if self.labels is not None:
            return self.data[idx], self.labels[idx]
        return self.data[idx]

In [170]:
def build_dataloader(token_array, labels=None):
    # Create dataset and dataloader
    dataset = TransformerDataset(token_array, labels)
    dataloader = DataLoader(
        dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4
    )
    return dataloader

In [171]:
eq_token = args.p
op_token = args.p + 1
data = multiplication_mod_p_data(p=args.p, eq_token=args.p, op_token=args.p)


In [146]:
dataset = TransformerDataset(data.clone().detach().numpy())

In [96]:
dataloader = build_dataloader(token_array=data.clone().detach().numpy())

In [172]:
class TransformerDataset(Dataset):
    def __init__(self, token_array, labels=None):
        self.data = torch.tensor(token_array.T)  # Transpose to make each row an example
        self.labels = None if labels is None else torch.tensor(labels)
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        if self.labels is not None:
            return self.data[idx], self.labels[idx]
        return self.data[idx]

def create_train_test_dataloaders(token_array, labels=None, train_ratio=0.8, batch_size=32):
    dataset = TransformerDataset(token_array, labels)
    
    # Calculate lengths for split
    train_size = int(train_ratio * len(dataset))
    test_size = len(dataset) - train_size
    
    # Split dataset
    train_dataset, test_dataset = random_split(
        dataset, [train_size, test_size], 
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader

In [173]:
train_loader, test_loader = create_train_test_dataloaders(token_array=data.clone().detach().numpy(), train_ratio=0.5)

In [127]:
len(dataset)

9312

In [113]:
# Usage in training loop
for batch in train_loader:
    if len(batch) == 2:
        inputs, targets = batch
    else:
        inputs = batch
    pass
    # Forward pass, loss calculation, etc.

In [174]:
class DataLogger:
    def __init__(self):
        # Initialize an empty dictionary to store logs
        self.logs = {}

    def log_data(self, category, key, value):
        """
        Log data under a specific category and key.

        :param category: The category or subset of data logs.
        :param key: The key within the category to store the value.
        :param value: The value to log.
        """
        if category not in self.logs:
            self.logs[category] = {}
        self.logs[category][key] = value

    def update_category(self, category, data_dict):
        """
        Add or update a whole sub-dictionary for a specific category.

        :param category: The category to update.
        :param data_dict: The dictionary containing data to add or update.
        """
        if category not in self.logs:
            self.logs[category] = {}
        self.logs[category].update(data_dict)
        
    def update_category_means(self, category, data_dict):
        # info compute means of category dict and update  category dict
        for key in data_dict:
            data_dict[key] = data_dict[key].mean()
        self.update_category(category, data_dict)
        
    def get_all_logs(self):
        return self.logs.copy()

    def get_logs(self, category=None):
        """
        Retrieve logs for a specific category or all logs if no category is specified.

        :param category: The category to retrieve logs for. If None, retrieve all logs.
        :return: The logs for the specified category or all logs.
        """
        if category:
            return self.logs.get(category, {})
        return self.logs

    def clear_logs(self, category=None):
        """
        Clear logs for a specific category or all logs if no category is specified.

        :param category: The category to clear logs for. If None, clear all logs.
        """
        if category:
            if category in self.logs:
                del self.logs[category]
        else:
            self.logs.clear()

# # Example usage
# logger = DataLogger()
# logger.log_data('temperature', 'sensor1', 25.5)
# logger.log_data('temperature', 'sensor2', 26.0)
# 
# # Update the 'temperature' category with a new sub-dictionary
# new_temperature_data = {'sensor3': 27.0, 'sensor4': 28.5}
# logger.update_category('temperature', new_temperature_data)
# 
# # Retrieve and print logs
# temperature_logs = logger.get_logs('temperature')
# print("Temperature Logs:", temperature_logs)



In [175]:
def train_one_epoch(model, loader, criterion, optimizer, scheduler, device):
    model.train()
    loss_sum = torch.zeros(1, device=device)
    all_inputs = torch.zeros(1, device=device)
    total_acc = torch.zeros(1, device=device)
    with torch.set_grad_enabled(True):
        for k, batch in enumerate(loader):
            inputs, targets = batch[:-1], batch[-1]
            inputs, targets = inputs.to(device), targets.to(device)
            logits = model(inputs)
            loss = criterion(logits[-1], targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            loss_sum += loss.item() * inputs.size(0)
            all_inputs += inputs.size(0)
            acc = (logits[-1].argmax(-1) == input[-1]).float().mean()
            total_acc += acc.item() * input.size[-1]
    return total_acc / len(all_inputs), loss_sum / len(all_inputs)

def validate_epoch(model, loader, criterion, device):
    model.eval()
    loss_sum = torch.zeros(1, device=device)
    all_inputs = torch.zeros(1, device=device)
    total_acc = torch.zeros(1, device=device)
    with torch.no_grad:
        for k, batch in enumerate(loader):
            inputs, targets = batch[:-1], batch[-1]
            inputs, targets = inputs.to_device(), targets.to_device()
            logits = model(inputs)
            loss = criterion(logits[-1], targets)
            loss_sum += loss.item() * inputs.size(0)
            all_inputs += inputs.size(0)
            acc = (logits[-1].argmax(-1)) == targets.float().mean()
            total_acc += acc.item() * input.size(0)
    return total_acc / len(all_inputs), loss_sum / len(all_inputs)

In [177]:
def set_optimizer(model, args):
    optimizer = getattr(torch.optim, args.optimizer)(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(args.beta1, args.beta2),
    )

    #  linear learning rate warmup over the first 10 updates
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda update: 1 if update > 10 else update / 10
    )
    return optimizer, scheduler

In [185]:
def main(args):
    # info data logging
    # Example usage
    logger = DataLogger()
    logger.log_data('temperature', 'sensor1', 25.5)
    logger.log_data('temperature', 'sensor2', 26.0)
    logger.log_data('humidity', 'sensor1', 45.0)
    
    # Retrieve and print logs
    temperature_logs = logger.get_logs('temperature')
    humidity_logs = logger.get_logs('humidity')
    
    print("Temperature Logs:", temperature_logs)
    print("Humidity Logs:", humidity_logs)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    eq_token = args.p
    op_token = args.p + 1
    data = multiplication_mod_p_data(p=args.p, eq_token=eq_token, op_token=op_token)
    train_loader, test_loader = create_train_test_dataloaders(token_array=data.clone().detach().numpy(), 
                                                              train_ratio=0.5, batch_size=args.batch_size)
    plot_infix = get_plot_infix(args=args)
    model = build_model(args=args)
    optimizer, scheduler = set_optimizer(model=model, args=args)
    steps_per_epoch = math.ceil(data.shape[1] / args.batch_size)
    plot_interval = 10
    
    
    for epoch in tqdm(range(int(args.budget // steps_per_epoch))):
        trn_acc, trn_loss = train_one_epoch(model=model,loader=train_loader, 
                                            criterion=torch.nn.CrossEntropyLoss(),
                                            optimizer=optimizer, scheduler=scheduler,device=device)
        if epoch % 2 == 0:
            vld_acc, vld_loss = validate_epoch(model=model,loader=test_loader,)
            # info log accuracy and loss for both train and validate
            logger.log_data('accuracy', 'train', trn_acc)
            logger.log_data('accuracy', 'valid', vld_acc)
            logger.log_data('accuracy', 'epoch', epoch)
            logger.log_data('loss', 'train', trn_loss)
            logger.log_data('loss', 'valid', vld_loss)
            logger.log_data('loss', 'epoch', epoch)
            norms = get_parameter_norms(model)
            logger.update_category_means('norms', norms)
            logger.log_data('norms', 'epoch', epoch)
        if epoch % plot_interval == 0:
            plot_dicts(logger.get_all_logs(), plot_infix=plot_infix)

In [186]:
main(args=args)

Temperature Logs: {'sensor1': 25.5, 'sensor2': 26.0}
Humidity Logs: {'sensor1': 45.0}
Decoder(
  (token_embeddings): Embedding(99, 64)
  (position_embeddings): Embedding(5, 64)
  (layers): ModuleList(
    (0-1): 2 x Block(
      (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
      )
    )
  )
  (ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=64, out_features=99, bias=False)
)
Total number of parameters: 113088


  0%|          | 0/15789 [00:00<?, ?it/s]


IndexError: index out of range in self

In [187]:
%debug

> [0;32m/Users/igor/miniforge3/envs/mini/lib/python3.10/site-packages/torch/nn/functional.py[0m(2237)[0;36membedding[0;34m()[0m
[0;32m   2235 [0;31m        [0;31m# remove once script supports set_grad_enabled[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   2236 [0;31m        [0m_no_grad_embedding_renorm_[0m[0;34m([0m[0mweight[0m[0;34m,[0m [0minput[0m[0;34m,[0m [0mmax_norm[0m[0;34m,[0m [0mnorm_type[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 2237 [0;31m    [0;32mreturn[0m [0mtorch[0m[0;34m.[0m[0membedding[0m[0;34m([0m[0mweight[0m[0;34m,[0m [0minput[0m[0;34m,[0m [0mpadding_idx[0m[0;34m,[0m [0mscale_grad_by_freq[0m[0;34m,[0m [0msparse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   2238 [0;31m[0;34m[0m[0m
[0m[0;32m   2239 [0;31m[0;34m[0m[0m
[0m
--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user


In [184]:
steps_per_epoch = math.ceil(data.shape[1] / args.batch_size)
int(range(args.budget) // steps_per_epoch)

TypeError: 'float' object cannot be interpreted as an integer

In [188]:
%debug

> [0;32m/Users/igor/miniforge3/envs/mini/lib/python3.10/site-packages/ipykernel/kernelbase.py[0m(1325)[0;36m_input_request[0;34m()[0m
[0;32m   1323 [0;31m                [0;31m# re-raise KeyboardInterrupt, to truncate traceback[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1324 [0;31m                [0mmsg[0m [0;34m=[0m [0;34m"Interrupted by user"[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1325 [0;31m                [0;32mraise[0m [0mKeyboardInterrupt[0m[0;34m([0m[0mmsg[0m[0;34m)[0m [0;32mfrom[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1326 [0;31m            [0;32mexcept[0m [0mException[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
'Interrupted by user'
--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user
