In [6]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
max_iters = 5000
n_embd = 64
dropout = 0.0
block_size = 32
n_head = 4
n_layer = 4
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"
learning_rate = 1e-3
eval_iters = 200
eval_interval = 100

# Data loading and preparation
with open("all_tswift_lyrics.txt", "r", encoding="utf-8") as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: "".join([itos[n] for n in l])

data = torch.tensor(encode(text), dtype=torch.long)
n_split = int(0.9 * len(data))
train_data = data[:n_split]
val_data = data[n_split:]


def get_batch(split):

    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))

    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack(
        [data[i + 1 : i + block_size + 1] for i in ix]
    )  # target is input shifted to right by one position
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad()
def get_losses():
    out = {}
    model.eval()
    for split in ["train", "val"]:  # get the mean of both train and eval loss
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


class attention_head(nn.Module):
    """one head of self-attention"""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)  # just in case is necessary

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B, T, head_size)
        q = self.query(x)  # (B, T, head_size)
        v = self.value(x)  # (B, T, head_size)
        wei = (
            q @ k.transpose(-2, -1) * C**-0.5
        )  # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
        tril = torch.tril(torch.ones(T, T))
        wei = wei.masked_fill(
            tril == 0, float("-inf")
        )  # wei can be interpreted as logits before applying softmax, if I want to mask out the future, I can do so by filling the future with -inf
        wei = F.softmax(wei, dim=-1)  # (B, T, T)
        wei = self.dropout(wei)
        out = wei @ v  # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
        return out


class MultiHeadAttention(nn.Module):
    """multiple heads of self-attention in parallel"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList(
            [attention_head(head_size) for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)


class FeedForward(nn.Module):

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class attention_block(nn.Module):

    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.head_size = n_embd // n_head
        self.multi_head = MultiHeadAttention(n_head, self.head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.multi_head(
            self.ln1(x)
        )  # the '+' for residual connections, layer norm is applied before the multi head attention
        x = x + self.ffwd(
            self.ln2(x)
        )  # the '+' for residual connections, layer norm is applied before the feed forward
        return x


class GPT(nn.Module):

    def __init__(self):
        super().__init__()
        self.embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[attention_block(n_embd, n_head) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)  # final layer norm
        self.lm_head = nn.Linear(
            n_embd, vocab_size
        )  # project the output of the final block to logits

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.embedding_table(idx)  # (B, T, n_embd)
        pos_emb = self.positional_embedding_table(
            torch.arange(T)
        )  # (T, n_embd), it's like passing to pos_emb [0,1,...T]
        x = tok_emb + pos_emb  # (B, T, n_embd) + (T, n_embd) -> (B, T, n_embd)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)  # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = (
                logits.shape
            )  # C is the vocab size aka number of classes for cross entropy loss
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)  # targets is (B, T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_tokens):
        """input idx is (B, T) array of indices in the current context"""
        for _ in range(max_tokens):
            # crop idx in order to fit in the model context length (block_size)
            idx_cropped = idx[:, -block_size:]
            # get logits and loss from the model
            logits, loss = self(idx_cropped)
            # focus only on the last logits aka the prediction for the next token
            logits = logits[:, -1, :]  # becomes (B, vocab_size)
            # apply softmax to get probabilities over possible next tokens
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


def train(model):

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    # training loop
    for iter in range(max_iters):

        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = get_losses()
            print(
                f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
            )

        # sample data
        xb, yb = get_batch("train")

        # evaluate loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()


In [7]:
model = GPT()
model = model.to(device)
train(model)

step 0: train loss 4.4868, val loss 4.4977
step 100: train loss 2.5605, val loss 2.5820
step 200: train loss 2.4090, val loss 2.4291
step 300: train loss 2.3099, val loss 2.3485
step 400: train loss 2.2243, val loss 2.2534
step 500: train loss 2.1395, val loss 2.1836
step 600: train loss 2.0621, val loss 2.1304
step 700: train loss 1.9906, val loss 2.0689
step 800: train loss 1.9605, val loss 2.0256
step 900: train loss 1.9045, val loss 1.9713
step 1000: train loss 1.8633, val loss 1.9475
step 1100: train loss 1.8173, val loss 1.9013
step 1200: train loss 1.7977, val loss 1.8952
step 1300: train loss 1.7607, val loss 1.8616
step 1400: train loss 1.7364, val loss 1.8485
step 1500: train loss 1.7206, val loss 1.8396
step 1600: train loss 1.6843, val loss 1.7998
step 1700: train loss 1.6776, val loss 1.7982
step 1800: train loss 1.6611, val loss 1.7855
step 1900: train loss 1.6396, val loss 1.7611
step 2000: train loss 1.6306, val loss 1.7610
step 2100: train loss 1.6090, val loss 1.7589


In [21]:
input_text = 'I'
input_token = encode(input_text)

In [22]:
input_tensor = torch.tensor(input_token, dtype=torch.long, device=device).view(1,-1)

In [23]:
print(decode(model.generate(input_tensor, max_tokens=2000)[0].tolist()))

I wross a thope ying..
Your eyes
And it's on doce that I'm cy tored
I's headed it places
And a big baby, nowly baby
3:] you hope there are is stay, not the the door I in this screating her pocks arm
Flew I had you style tell I thought me how pake
Then I look to the care everything to wrong
Got then tracks from through (bod with into fast
I out wor the crazy that I was swandoffall tell masing, there watching,
But you made many clesump in wors
Trouble, there paint likering?
Wham would your friend brod
I neard from frissing it rains look it
And the way it's not right silentipt off heart on your your me
'Cause heahUre byerning now I would
And me nobody why didning ready for am as heart
And Speakers under listen

Backmoriting down now
You back times abovend with you, I need.

And through people it's a from asn't, you think at I'd rather haunt it true as pay
Living down from you hit cause 
And he's singing right drair, down fade it whas would you

Who you say Eday I can't be by count around 