In [34]:
import torch
import torch.nn as nn

from datasets import load_dataset, Dataset

from tqdm.auto import tqdm
from prodigyopt import Prodigy
import numpy as np

from IPython.display import HTML

from transformers import AutoTokenizer

import random

DATA_SIZE = 1000

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.3")
tokenizer.pad_token = tokenizer.eos_token

ds = load_dataset("roneneldan/TinyStories", split='train', streaming=True)


data = []
for i, ex in enumerate(ds):
    data.append(ex)
    if i > DATA_SIZE:
        break

ds = Dataset.from_list(data)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'mps' if torch.mps.is_available() else device 

torch.cuda.empty_cache()

In [35]:
class LinearTokenPredictor(nn.Module):
    """
    This model is a simple linear model that predicts the next token in a sequence.

    Originally three layers:
    1. An input embedding layer with dimension $d=256$
    2. A linear layer mapping dxT to dxT (with standard masking for the next tokens during training)
    3. An output embedding layer mapping vectors of dimension $d$ to the vocabulary.

    But I also added a layer norm between 2 and 3 to help a bit.
    """

    def __init__(
        self,
        tokenizer,
        vocab_size: int,
        context_size: int = 64,
        d: int = 256,
        device: str = "cuda",
    ):
        super(LinearTokenPredictor, self).__init__()
        self.vocab_size = vocab_size
        self.d = d
        self.context_size = context_size
        linear_dim = context_size * d
        self.device = device
        self.tokenizer = tokenizer

        self.embedding = nn.Embedding(vocab_size, d)
        self.linear = nn.Parameter(torch.randn(linear_dim, linear_dim))
        self.output = nn.Linear(d, vocab_size, bias=False)

        self.layer_norm = nn.LayerNorm(linear_dim)

        self.mask = self.create_mask(d, context_size).T.to(device)

    def forward(self, x: torch.Tensor):
        """
        For training, we use a causal mask to limit the
        linear layer to only consider previous tokens.
        """
        x = self.embedding(x)

        # map from batch x seq x d to batch x (seq*d)
        x = x.view(x.size(0), -1)

        x = x @ (self.linear * self.mask)
        x = self.layer_norm(x)  # small addition

        # map back to batch x seq x d
        x = x.view(x.size(0), -1, self.d)
        x = self.output(x)

        return x

    def generate(
        self,
        token_list: torch.Tensor | str,
        n: int = 1,
        return_html=True,
        html_font_size: int = 12,
    ):
        """
        Given a tensor of token-ids, generate n tokens.
        """
        if isinstance(token_list, str):
            token_list = self.tokenizer(token_list)
            token_list = token_list["input_ids"]
        else:
            token_list = token_list.tolist()

        len_list = len(token_list)

        if len_list < self.context_size:
            token_list = token_list + [self.tokenizer.eos_token_id] * (
                self.context_size - len(token_list)
            )

        with torch.no_grad():
            for i in range(n):
                # keep token list within context by using the last T tokens
                x = (
                    torch.tensor(token_list[-self.context_size :])
                    .unsqueeze(0)
                    .to(self.device)
                )
                logits = self.forward(x)

                if len_list + i - 1 < self.context_size:
                    curr_token_index = (
                        len_list + i - 1
                    )  # -1 because the first logit corresponds to P(token_1|token_0)
                else:
                    curr_token_index = -1

                logits = logits[:, curr_token_index]
                tok = torch.argmax(logits, dim=-1).item()

                if (len_list + i) < self.context_size:
                    token_list[len_list + i] = tok
                else:
                    token_list.append(tok)

        answer = self.tokenizer.decode(token_list[len_list:], skip_special_tokens=False)
        if return_html:
            prompt = self.tokenizer.decode(
                token_list[:len_list], skip_special_tokens=True
            )
            return (
                f'<span style="font-size: {html_font_size}px"> <span style="color: green;">'
                + prompt
                + "</span> "
                + answer
                + "</span></br>"
            )
        else:
            return answer

    def create_mask(self, d: int, T: int):
        mask = np.tril(np.ones((T, T)))
        expanded_mask = np.kron(mask, np.ones((d, d)))
        expanded_mask = torch.tensor(expanded_mask, dtype=torch.float32)
        return expanded_mask

In [39]:
train_size = DATA_SIZE
MAX_LENGTH = 65


def random_start_truncate(text, max_length):
    """
    Randomly shuffle the start of the text to create different starting points,
    and truncate to max_length if necessary.
    """
    tokens = text.split()
    if len(tokens) > max_length:
        start_index = random.randint(0, len(tokens) - max_length)
        tokens = tokens[start_index:]
    return " ".join(tokens[:max_length])


shuffled_and_selected_ds = ds.shuffle(seed=42).select(range(train_size))

processed_ds = shuffled_and_selected_ds.map(
    lambda x: {"text": random_start_truncate(x["text"], MAX_LENGTH)}
)

train_set = processed_ds.map(
    lambda x: tokenizer(
        x["text"],
        padding="max_length",
        max_length=MAX_LENGTH,
        truncation=True,
        return_tensors="pt",
    ),
    batched=True,
)

train_set.set_format(type="torch", columns=["input_ids"])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

Map: 100%|██████████| 1000/1000 [00:00<00:00, 23776.16 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 11851.87 examples/s]


In [37]:
vocab_size = tokenizer.vocab_size
print(vocab_size)

context_size = MAX_LENGTH - 1
d = 32 # if train_size is small, you may want to keep this small as well

model = None
model = LinearTokenPredictor(
    tokenizer, vocab_size, context_size=context_size, d=d, device=device
)

32768


In [40]:
print(f"parameter count: {sum(p.numel() for p in model.parameters())}")

model.to(device)
print(f"Using device: {device}")

criterion = nn.CrossEntropyLoss()
optimizer = Prodigy(model.parameters())

loss_tracking = []

EPOCHS = 100
TOL_EARLY_STOP = 1e-3

for epoch in tqdm(range(EPOCHS)):
    for batch in train_loader:
        input_ids = batch["input_ids"]

        inputs = input_ids[:, :-1]  # All tokens except the last one
        targets = input_ids[:, 1:]  # All tokens except the first one

        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        logits = model(inputs).permute(0, 2, 1)

        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()

        loss_tracking.append(loss.item())

    last_loss = sum(loss_tracking) / len(loss_tracking)

    if epoch % 30 == 0:
        print(f"Epoch {epoch} loss: {last_loss}")

    if last_loss < TOL_EARLY_STOP:
        break

    loss_tracking = []

parameter count: 6295552
Using device: mps


  0%|          | 1/200 [00:05<17:01,  5.13s/it]

Epoch 0 loss: 10.563093781471252


 16%|█▌        | 31/200 [01:15<05:45,  2.05s/it]

Epoch 30 loss: 0.11334353499114513


 30%|███       | 61/200 [02:21<05:09,  2.22s/it]

Epoch 60 loss: 0.09127563936635852


 38%|███▊      | 75/200 [03:02<05:03,  2.43s/it]


KeyboardInterrupt: 

In [33]:
i = 4

print(tokenizer.decode(inputs[i]))
HTML(model.generate(inputs[i][0:3], n=30, html_font_size=15))

<s> they do." The rabbit said, "That sounds very interesting! I wish I could study too." The fox smiled and said, "You can. Just come with me tomorrow, and I will show you how to study the animals in the forest". The rabbit was very happy and the next day they


In [None]:
HTML(model.generate("The girl was green.", n=90, html_font_size=15))