In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [5]:
from datasets import load_dataset

ds = load_dataset("cmotions/Beatles_lyrics")

README.md:   0%|          | 0.00/503 [00:00<?, ?B/s]

dataset_infos.json: 0.00B [00:00, ?B/s]

data/dataset_cleaned-00000-of-00001.parq(…):   0%|          | 0.00/30.6k [00:00<?, ?B/s]

data/dataset_full-00000-of-00001.parquet:   0%|          | 0.00/83.0k [00:00<?, ?B/s]

Generating dataset_cleaned split:   0%|          | 0/173 [00:00<?, ? examples/s]

Generating dataset_full split:   0%|          | 0/180 [00:00<?, ? examples/s]

In [11]:
text = "\n".join([example['lyrics'] for example in ds['dataset_full']])

with open('input.txt', 'w', encoding='utf-8') as f:
    f.write(text)

print("Corpus length:", len(text))
print("\nSample snippet:\n", text[:1000])

Corpus length: 189082

Sample snippet:
 [Intro]
Shoot me
Shoot me
Shoot me
Shoot me

[Verse 1]
Here come old flat-top, he come groovin' up slowly
He got ju-ju eyeball, he one holy roller
He got hair down to his knee
Got to be a joker, he just do what he please

[Interlude]
Shoot me
Shoot me
Shoot me
Shoot me

[Verse 2]
He wear no shoeshine, he got toe-jam football
He got monkey finger, he shoot Coca-Cola
He say, "I know you, you know me"
One thing I can tell you is you got to be free

[Chorus]
Come together, right now
Over me

[Interlude]
Shoot me
Shoot me
Shoot me
Shoot me

[Verse 3]
He bag production, he got walrus gumboot
He got Ono sideboard, he one spinal cracker
He got feet down below his knee
Hold you in his armchair, you can feel his disease

[Chorus]
Come together, right now
Over me

[Interlude]
Shoot me
Shoot me
Right!
Come, come, come, come

[Verse 4]
He roller-coaster, he got early warnin'
He got muddy water, he one mojo filter
He say, "One and one and one is three."
Got to

In [12]:
# unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !"&'()*,-./012345679:;?ABCDEFGHIJKLMNOPQRSTUVWY[]abcdefghijklmnopqrstuvwxyz{}èóö​–—’“”
88


In [13]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# training setup
batch_size = 16        # number of sequences processed in parallel
block_size = 32        # max sequence length considered as context
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64            # embedding dimension
n_head = 4             # number of attention heads
n_layer = 4            # number of transformer layers
dropout = 0.0

torch.manual_seed(1337)


<torch._C.Generator at 0x7afb80c534f0>

In [14]:
# load dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# vocabulary construction
chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]    # string -> int list
decode = lambda l: ''.join([itos[i] for i in l])   # int list -> string

# data split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [17]:
# get random batch
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

# evaluation loop (average loss on train/val)
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [19]:
# single self-attention head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

# multi-head attention block
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

# feedforward network (MLP style)
class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

# transformer block: attention + feedforward
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [20]:
# main language model
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    # text generation from a starting prompt
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [21]:
# initialize model
model = BigramLanguageModel()
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# training loop
for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

0.212696 M parameters
step 0: train loss 4.6344, val loss 4.6194
step 100: train loss 2.6101, val loss 2.6388
step 200: train loss 2.4392, val loss 2.4574
step 300: train loss 2.2996, val loss 2.3435
step 400: train loss 2.1769, val loss 2.2306
step 500: train loss 2.0824, val loss 2.1718
step 600: train loss 2.0015, val loss 2.1273
step 700: train loss 1.9203, val loss 2.0521
step 800: train loss 1.8652, val loss 1.9842
step 900: train loss 1.8201, val loss 1.9731
step 1000: train loss 1.7576, val loss 1.9109
step 1100: train loss 1.7259, val loss 1.8874
step 1200: train loss 1.6825, val loss 1.8522
step 1300: train loss 1.6737, val loss 1.8709
step 1400: train loss 1.6182, val loss 1.7886
step 1500: train loss 1.6131, val loss 1.7757
step 1600: train loss 1.5893, val loss 1.7905
step 1700: train loss 1.5483, val loss 1.7638
step 1800: train loss 1.5266, val loss 1.7327
step 1900: train loss 1.5090, val loss 1.7174
step 2000: train loss 1.4845, val loss 1.7161
step 2100: train loss 1.

In [22]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


Anywhere summer turnly some

[Verse 2]
1very Say you've make me is the rided

[Instrumenta chinged (Yeah, taking by
Ooh, I can
Make you away terling me, I say he when I'm always vich man that you
Caby baby, put you knew you can ser
Ne mother shis taken you man it bad breakd in the me

[Verse 3]
So that playine
You take where pass
New purselawer da?

[Bridge: Paul McCartney & part targe and wind a pleosed
Tword I say scaidfene
Littles me os a for you wondy it's bottaid a foos
And it's not me way, sea till, say you saw by, now
Oh! I love, want but it's love when you

ope on to she is right, all she gumm good na cry
Takes the hand to strand
Let's sleaks had a days, ha, hell myself
Happy sing in the the she pmfertney
tall it saide the sea born lase

[Bridge]
And it's true
'ky a going more day's right man, baby
When you read in the Ught
Aher the day?
Liken me di?
In someting has back the say
I sawshe un kiwn his a hone

[Instrusneda that goes to for throur
You say now her grought
I knew to