In [53]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [54]:
from datasets import load_dataset

ds = load_dataset("Cropinky/rap_lyrics_english")

Repo card metadata block was not found. Setting CardData to empty.


Resolving data files:   0%|          | 0/47 [00:00<?, ?it/s]

In [55]:
lyrics_only = [ds['train'][i]['text'] for i in range(38, len(ds['train']))]

text = "\n".join(lyrics_only)

with open('input.txt', 'w', encoding='utf-8') as f:
    f.write(text)

print("Corpus length:", len(text))
print("\nSample snippet:\n", text[:1000])

Corpus length: 44090977

Sample snippet:
 Foxy Brown
<BOS>
My Beyoncé[Chorus: Lil Durk]
Ooh, I like the way she move
Shorty my baby, my everything, she the truth
Together we cool, me and her can't lose
Keep 'em on their feet, baby, I know they so confused
Shorty my Beyoncé
Durk and DeJ, Durk and DeJ, Durk and DeJ
Shorty my Beyoncé
Durk and DeJ, Durk and DeJ, Durk and DeJ
My Beyoncé

[Verse 1: Lil Durk]
Trippin' on that drank, but I know she worth it
Independent baby, I know she workin'
Adriana's serving drinks, 20 bottles, urgent
I know it can be better but nobody's perfect
We flirted for a minute, DeJ, that's my baby
I ain't trippin', I'm like Henny, yeah I'm in her kidneys
She like to play her songs to the way I'm hittin' it
Turn around like, "Damn Durk, I like the way you hittin' it"
Don't believe the rumors, girl
You know I'll do you, girl
I don't wanna hear the shit about the niggas
That tried to do you, girl
Fuck the past right now
Shawty got you right now
And you hot right now
Y

In [56]:
# unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

	
 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ ¡¦¨©®³´¹½¿ÀÁÃÅÇÈÉÎÑÖØàáâãäåæçèéêëíîïñòóôöøùúûüāćēğİıōŐœŞşūŽƆɔɛ˜ВГДЕИКМНПРСТЭЯабвгдежзийклмнопрстуфхцчшщыьэюяё،؟آابتثجحخدذرزسشصضطظعغفقلمنهوًَُِپچکگیḥ  ​‌‍‎‒–—‘’‚“”•… ‪‬ ′ ⁠₂€↗☆☣♂♡✞✧✰了你准吗备好抦️﻿𝐋𝐒𝐜𝐞𝐟𝐢𝐦𝐧𝐨𝐩𝐫𝐬𝐭𝐲𝗟𝗦𝗰𝗲𝗳𝗶𝗺𝗻𝗼𝗽𝗿𝘀𝘁𝘆🎤🐉🐭💯🔥😂😭🤔🤷
329


In [57]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# training setup
batch_size = 16        # number of sequences processed in parallel
block_size = 32        # max sequence length considered as context
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64            # embedding dimension
n_head = 4             # number of attention heads
n_layer = 4            # number of transformer layers
dropout = 0.0

torch.manual_seed(1337)


<torch._C.Generator at 0x79a1bdfa74f0>

In [58]:
# load dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# vocabulary construction
chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]    # string -> int list
decode = lambda l: ''.join([itos[i] for i in l])   # int list -> string

# data split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [59]:
# get random mini-batch
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

# evaluation loop (average loss on train/val)
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [60]:
# single self-attention head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

# multi-head attention block
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

# feedforward network (MLP style)
class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

# transformer block: attention + feedforward
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [61]:
# main language model
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    # text generation from a starting prompt
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [62]:
# initialize model
model = BigramLanguageModel()
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# training loop
for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

0.243785 M parameters
step 0: train loss 5.9259, val loss 5.9287
step 100: train loss 2.8646, val loss 2.8600
step 200: train loss 2.6435, val loss 2.6438
step 300: train loss 2.5413, val loss 2.5407
step 400: train loss 2.4569, val loss 2.4777
step 500: train loss 2.3738, val loss 2.3874
step 600: train loss 2.3268, val loss 2.3420
step 700: train loss 2.2551, val loss 2.2878
step 800: train loss 2.2169, val loss 2.2435
step 900: train loss 2.1815, val loss 2.2185
step 1000: train loss 2.1322, val loss 2.1833
step 1100: train loss 2.1112, val loss 2.1498
step 1200: train loss 2.0923, val loss 2.1075
step 1300: train loss 2.0617, val loss 2.0808
step 1400: train loss 2.0488, val loss 2.0660
step 1500: train loss 2.0191, val loss 2.0509
step 1600: train loss 1.9975, val loss 2.0514
step 1700: train loss 1.9901, val loss 2.0264
step 1800: train loss 1.9669, val loss 2.0089
step 1900: train loss 1.9356, val loss 1.9972
step 2000: train loss 1.9476, val loss 1.9933
step 2100: train loss 1.

In [64]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

	, a might yeah, you'll very stupbo
We live up out pawed to
Judn my desa niggas

If some the opping, you know they king you my ingeler moke Hold Ree, Willa
He good me, no sgal a he could in
Uh
I drest teln me lot me! Lost, it muther?'t fuck you meming of alreens to got is fine
? I’ma get avil the fallin', me sight night
Evercida, do up, hompas rainive digings and
Got one at rour
I got got our shore, "toes, we fain

[Verse Skas]


[Chorus: Losts]
You jegir my fifter to sould the pophed oughter
I'm harry main is seer builly in twened stang two the crazz hop amo how easper did
Copsin' waring on my yesou nopping with (Sheave, all fect the Cornives
(How, pulac out rynal here moven

[Verse 2: Mine Poppresed]
. ceminger mean it we though!

Oh ya you hear the known, stuppin' are
Come charger like {-SE L Wuirty, reace alah I some a gony
Southing tugas us
How your will ridge? (He's)]
Grau flarging charp, lies splug that dick
Thisger withnould vound and the thatem pop
That from frees down som
Bec