In [1]:
from functools import partial # There are going to be some things we want to initialize lazily to economize on resources and reuse constructor calls.
import torch
# everything will use the same tokenizer
from transformers import AutoTokenizer
mistral = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(mistral, padding_side = "right")
tokenizer.pad_token = tokenizer.eos_token

In [2]:
## Everything will use the same dataset and dataloaders
repo = "wikimedia/wikipedia"
import datasets
ds = datasets.load_dataset(repo, "20231101.simple")
def quick_estimate_tokens(ds, field="text", chars_per_token=2.7):
    tally = 0
    max_len = 0
    lengths = {}
    for row in ds:
        l = len(row[field])
        tally += l
        lengths[l] = lengths.get(l, 0) + 1
        if l > max_len:
            max_len = l

    print(f'{int(tally):_}')
    print(f'Max length: {max_len}, estimated tokens: {int(max_len / chars_per_token):_}')
    lengths = list(lengths.items())
    lengths.sort(reverse=True)
    return int(tally/chars_per_token), lengths

total, length = quick_estimate_tokens(ds['train'], field="text")
ds = ds["train"].train_test_split(test_size=0.1)

max_tokens = 512
def batch_tokenize(batch):
    return {"input_ids": tokenizer(batch["text"], padding="max_length", truncation=True, max_length=max_tokens).input_ids}

tokenized = ds.map(batch_tokenize, batched=True, batch_size=1000)

from torch.utils.data import DataLoader

tokenized.set_format(type='torch', columns=['input_ids'])

267_477_061
Max length: 236695, estimated tokens: 87_664


Map:   0%|          | 0/217608 [00:00<?, ? examples/s]

Map:   0%|          | 0/24179 [00:00<?, ? examples/s]

In [3]:
%load_ext autoreload
%autoreload 2

batch_size = 32
train_loader = DataLoader(tokenized["train"], batch_size=batch_size, shuffle = True)
eval_loader = DataLoader(tokenized["test"], batch_size=32, shuffle = False)

In [4]:
%load_ext autoreload
%autoreload 2


import sys

from zoology_mixers.based import Based

from mixers import MixerModel, EmbeddingVectorizer, EmbeddingAndPositionalVectorizer, AttentionMixer, GatedStateMixer, no_activation, LinearAttentionMixer
import torch
from functools import partial
model_dim, layers, heads = 256, 3, 4
max_seq_len = 512 ## !!! Should we factor this out of being a required argument?  Or is it, even, now?

test_model = MixerModel(
    model_size = model_dim,
    num_layers = layers,
    max_seq_len = 512,
    vectorizer = EmbeddingVectorizer,
    #seq_mixer = (LinearAttentionMixer, {"num_heads": heads, "apply_rope": True, "feature_map": LinearAttentionMixer.taylor_expansion}),
    seq_mixer = (LinearAttentionMixer, {"num_heads": heads, "apply_rope": True, "feature_map": LinearAttentionMixer.relu}),
    #seq_mixer = (Based, {"num_key_value_heads": heads, "feature_dim": model_dim // heads, "num_heads": heads, "feature_name": "taylor_exp", "apply_rotary": True, "train_view": "quadratic"}), # This is their revised version
    #seq_mixer = (Based, {"num_key_value_heads": heads, "feature_dim": model_dim // heads, "num_heads": heads, "feature_name": "taylor_exp"}), Based from based.py also runs out of memory
    #seq_mixer = LinAttnWrapper, # also runs out of memory
    tokenizer = tokenizer,
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Failed to import the causal dot product kernel... 
note: tying weights


In [6]:
from train import Trainer, SimpleTestCallback, ResidualGatingWarmupCallback, get_warmup_schedule, TimedStoppingCallback, PerplexityStoppingCallback
test_trainer = Trainer(
    test_model,
    train_loader,
    eval_loader = eval_loader,
    device = "cuda",
    tokenizer = tokenizer,
    log_every = 250,
    eval_every = 10_000,
    schedule = get_warmup_schedule(),
    autocast_dtype = torch.bfloat16,
    gradient_accumulation_batch_size = 1,
    #callbacks = [TimedStoppingCallback(600)]
)
test_trainer.train(2)
# It explodes during the cumsum step...is this a memory-intensive step?  That would explain why they had a draft of an alternate implementation.
# Wait is this new? https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
# Yeah this looks like it has been refactored from top to bottom.  So...archive the older stuff?
# Wow the quadratic view is abysmally slow.
# Okay so the relu version here looks like it's going to have decent performance per epoch, but it's going slow.  So it might need the fast implementation; unfortunately it took like 

Training for 2 epochs starting from epoch 1; 6801 steps per epoch.
Beginning epoch 1
{'mode': 'train', 'epoch': 1, 'step': 250, 'steps': 250, 'seconds': 131.8725323677063, 'total_seconds': 131.8725323677063, 'loss': 4.569998227968812, 'ppl': 96.5439453125}
{'mode': 'train', 'epoch': 1, 'step': 500, 'steps': 250, 'seconds': 130.15543961524963, 'total_seconds': 262.02797198295593, 'loss': 4.411649387985468, 'ppl': 82.40525817871094}
{'mode': 'train', 'epoch': 1, 'step': 750, 'steps': 250, 'seconds': 130.3861813545227, 'total_seconds': 392.41415333747864, 'loss': 4.229723438613116, 'ppl': 68.69823455810547}
{'mode': 'train', 'epoch': 1, 'step': 1000, 'steps': 250, 'seconds': 129.63615489006042, 'total_seconds': 522.0503082275391, 'loss': 4.067895561374724, 'ppl': 58.43385314941406}
{'mode': 'train', 'epoch': 1, 'step': 1250, 'steps': 250, 'seconds': 131.967280626297, 'total_seconds': 654.0175888538361, 'loss': 3.8627132078930737, 'ppl': 47.594303131103516}
{'mode': 'train', 'epoch': 1, 's

In [7]:
3520/60

58.666666666666664