In [None]:
import torch

from tqdm import tqdm

In [None]:
embedding_size = 100
block_size = 128
n_head = 8
n_layer = 8
dropout = 0.2
epochs = 10000
batch_size = 256
vocab_size = 50560
lr = 1e-3
num_workers = 1

checkpoint_interval = 100
save_path = "./"
train_data_path = "./data/train.txt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Define stuff

So notebook can be run on colab/kaggle

In [None]:
import torchdata.dataloader2 as dl2
import torchdata.datapipes as dp


class StoryDataset:
    def __init__(
            self,
            root,
            batch_size=1,
            num_workers=1,
            shuffle=True,
            drop_last=False,
            sequence_size=32,
            pad_idx=2,
    ):
        self.sequence_size = sequence_size
        self.pad_idx = pad_idx

        datapipe = dp.iter.FileLister(root, recursive=True).filter(
            filter_fn=self.filter_fn
        )
        datapipe = dp.iter.FileOpener(datapipe, mode="rt")
        datapipe = dp.iter.StreamReader(datapipe)
        datapipe = dp.iter.Mapper(datapipe, fn=self.map_fn)
        datapipe = (
            dp.iter.FlatMapper(datapipe, fn=self.batch_fn).shuffle().sharding_filter()
        )
        datapipe = dp.iter.Batcher(datapipe, batch_size=batch_size, drop_last=drop_last)

        self.dloader2 = dl2.DataLoader2(
            datapipe,
            reading_service=dl2.MultiProcessingReadingService(num_workers=num_workers),
            datapipe_adapter_fn=dl2.adapter.Shuffle(shuffle),
        )

    def __iter__(self):
        return self.dloader2.__iter__()

    def map_fn(self, x):
        return (self.sequence_size - 1) * [self.pad_idx] + [
            int(y) for y in x[1].split(",")
        ]

    def batch_fn(self, x):
        return [
            x[i: i + self.sequence_size + 1]
            for i in range(0, len(x) - self.sequence_size)
        ]

    @staticmethod
    def filter_fn(name):
        return name.endswith(".txt")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from bajkogenerator.model.utils.generation_utils import (
    top_k_top_p_filtering,
    greedy_search,
    multinomial_sampling,
    temperature_softmax,
)


# From Karpathy's GPT from scratch course
class Head(nn.Module):
    """one head of self-attention"""

    def __init__(self, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B, T, C = x.shape
        k = self.key(x)  # (B,T,hs)
        q = self.query(x)  # (B,T,hs)
        # compute attention scores ("affinities")
        wei = (
                q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        )  # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))  # (B, T, T)
        wei = F.softmax(wei, dim=-1)  # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,hs)
        out = wei @ v  # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out


class MultiHeadAttention(nn.Module):
    """multiple heads of self-attention in parallel"""

    def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList(
            [Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


class FeedForward(nn.Module):
    """a simple linear layer followed by a non-linearity"""

    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    """Transformer block: communication followed by computation"""

    def __init__(self, n_embd, n_head, block_size, dropout):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout):
        super().__init__()

        self.block_size = block_size

        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, padding_idx=2)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)  # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)  # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T))  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        x = self.blocks(x)  # (B,T,C)
        x = self.ln_f(x)  # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        return logits

    def generate(
            self,
            start_text,
            length=100,
            temperature=1.0,
            strategy="top_k_top_p",
            top_k=0,
            top_p=1.0,
            n_samples=1,
    ):
        """
        Generate text of a given length using greedy search or multinomial sampling.

        Args:
            start_text (list): list of tokens to start with
            length (int): length of the generated text
            temperature (float): temperature for multinomial sampling
            strategy (str): strategy for generation, one of "greedy", "multinomial", "top_k_top_p"
            top_k (int): k for top-k sampling
            top_p (float): p for top-p sampling
            n_samples (int): number of samples for multinomial sampling
        """
        assert not self.training
        # starter = tokenizer.encode(start_text)[:-1]
        starter = start_text

        with torch.no_grad():
            for i in range(length):
                inp = torch.LongTensor([starter[-self.block_size:]])
                # inp = inp.to(device)
                pred = self.forward(inp)

                logits = pred[:, -1, :] / temperature
                if strategy == "greedy":
                    out = greedy_search(logits)
                elif strategy == "multinomial":
                    out = torch.nn.functional.softmax(logits, dim=1)
                    out = multinomial_sampling(out, n_samples=n_samples)
                elif strategy == "top_k_top_p":
                    out = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
                    out = torch.nn.functional.softmax(out, dim=1)
                    out = multinomial_sampling(out, n_samples=n_samples)

                starter.append(out.item())
                if out == 1:
                    break

        return starter


# Initialize classes

In [None]:
transformer = DecoderOnlyTransformer(
    vocab_size,
    embedding_size,
    n_head,
    n_layer,
    block_size,
    dropout,
)
transformer.to(device)

In [None]:
data_loader = StoryDataset(
    train_data_path,
    batch_size=batch_size,
    num_workers=num_workers,
    sequence_size=block_size,
)

In [None]:
optimizer = torch.optim.RAdam(transformer.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

# Training

In [None]:
transformer.train()

for epoch in tqdm(range(1, epochs + 1)):
    running_loss = 0.0

    for batch in data_loader:
        batch_tensor = torch.LongTensor(batch).to(device)
        x_batch = batch_tensor[:, :block_size]
        y_batch = batch_tensor[:, 1: block_size + 1]

        transformer.zero_grad()

        output = transformer(x_batch)

        B, T, C = output.shape

        loss = criterion(output.view(B * T, C), y_batch.flatten())

        running_loss += loss.item()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            transformer.parameters(), 0.5
        )

        optimizer.step()

    print(f"Epoch {epoch}, loss: {running_loss}")

    if epoch % checkpoint_interval == 0:
        save_name = (
            f"embedding_size_{embedding_size}_"
            f"block_size_{block_size}_"
            f"n_head_{n_head}_"
            f"n_layer_{n_layer}_"
            f"dropout_{dropout}_"
            f"epoch_{epoch}_"
            f"class_{transformer.__class__.__name__}.pth"
        )
        torch.save(transformer.state_dict(), save_path + save_name)