In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
with open('shakespear.txt', 'r') as f:
    text = f.read()

In [10]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
encode = lambda cs: [stoi[c] for c in cs]
decode = lambda ixs: ''.join([itos[i] for i in ixs])

decode(encode(text[:100]))

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [15]:
encoded_text = torch.tensor(encode(text), dtype=torch.long)
trainsize = int(len(encoded_text)*0.9)
trainset = encoded_text[:trainsize]
valset = encoded_text[trainsize:]
trainset.shape[0], valset.shape[0]

(1003854, 111540)

In [113]:
context_length = 8
batch_size = 16
emb_dim = 32

def get_batch(split):
    data = trainset if split == 'train' else valset
    label_idx = torch.randint(context_length, len(data), (batch_size,))
    x = torch.stack([trainset[i-context_length:i] for i in label_idx])
    y = torch.stack([trainset[i+1-context_length:i+1] for i in label_idx])
    return x, y

get_batch('train')

(tensor([[ 5, 57, 58,  6,  0, 32, 46, 47],
         [42, 10,  0, 26, 53,  1, 44, 53],
         [41, 46, 39, 52, 45, 43, 42,  0],
         [ 1, 46, 47, 57,  1, 40, 43, 42],
         [46, 53, 59,  1, 57, 47, 52, 43],
         [43, 10,  0, 32, 46, 56, 47, 41],
         [33, 31, 10,  0, 32, 46, 47, 57],
         [50, 50,  1, 63, 53, 59,  1, 40],
         [43,  1, 30, 53, 51, 43, 53,  1],
         [58, 43, 56, 47, 58, 63,  6,  0],
         [43,  1, 61, 39, 57,  1, 52, 53],
         [58, 47, 52, 45, 57,  6,  1, 39],
         [56, 57,  0, 16, 53,  1, 43, 60],
         [63, 53, 59,  1, 42, 53, 61, 52],
         [ 6,  1, 61, 46, 53, 57, 43,  1],
         [ 5, 57,  1, 42, 39, 59, 45, 46]]),
 tensor([[57, 58,  6,  0, 32, 46, 47, 52],
         [10,  0, 26, 53,  1, 44, 53, 53],
         [46, 39, 52, 45, 43, 42,  0, 41],
         [46, 47, 57,  1, 40, 43, 42,  6],
         [53, 59,  1, 57, 47, 52, 43, 61],
         [10,  0, 32, 46, 56, 47, 41, 43],
         [31, 10,  0, 32, 46, 47, 57,  1],
         

In [114]:
class AttentionHead(nn.Module):
    def __init__(self, head_size, masked=True):
        super().__init__()
        self.masked = masked
        self.q = nn.Linear(emb_dim, head_size, bias=False)
        self.k = nn.Linear(emb_dim, head_size, bias=False)
        self.v = nn.Linear(emb_dim, head_size, bias=False)
        if self.masked:
            self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        B, T, C = x.shape
        Q, K, V = self.q(x), self.k(x), self.v(x)
        wei = Q @ K.transpose(-1, -2)
        if self.masked:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = wei / C**-0.5
        return F.softmax(wei, dim=-1) @ V

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, head_size, masked=True):
        super().__init__()
        self.heads = [AttentionHead(head_size, masked) for i in range(n_heads)]
        self.proj = nn.Linear(emb_dim, emb_dim)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.ff = nn.Sequential(
        nn.Linear(emb_dim, 4*emb_dim), nn.ReLU(),
        nn.Linear(4*emb_dim, emb_dim)
        )

    def forward(self, x):
        return self.ff(x)

class TransformerBlock(nn.Module):
    def __init__(self, n_head, head_size):
        super().__init__()
        self.mha = MultiHeadAttention(n_head, head_size)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ff = FeedForward()
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        out = x + self.ff(self.ln2(x))
        return out

In [115]:
n_heads = 4
head_size = 8

class GPT(nn.Module):
    def __init__(self, n_blocks=2): 
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding = nn.Embedding(context_length, emb_dim)
        self.blocks = nn.Sequential(*[TransformerBlock(n_heads, head_size) for _ in range(n_blocks)])
        self.lm_head = nn.Linear(emb_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embs = self.token_embedding(idx) # B, T -> B, T, C
        pos_embs = self.position_embedding(torch.arange(T))
        out = token_embs + pos_embs
        out = self.blocks(out)
        logits = self.lm_head(out)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, num_tokens=10_000):
        for _ in range(num_tokens):
            context = idx[:, -context_length:]
            logits, _ = self(context)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx

model = GPT()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [116]:
train_steps = 100_000

@torch.no_grad()
def estimate_loss(eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

for step in range(train_steps):
    # forward pass
    x, y = get_batch('train')
    logits, loss = model(x, y)

    # backward pass
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if step % 10_000 == 0:
        losses = estimate_loss(200) 
        train_loss = losses['train']
        val_loss = losses['val']
        print(f"{step}/{train_steps} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

0/100000: Loss is 4.5855
10000/100000: Loss is 2.1049
20000/100000: Loss is 2.0213
30000/100000: Loss is 1.7250
40000/100000: Loss is 1.9150
50000/100000: Loss is 1.8335
60000/100000: Loss is 1.9393
70000/100000: Loss is 2.0933
80000/100000: Loss is 1.9646
90000/100000: Loss is 1.8829


In [117]:
generated_idx = model.generate(torch.zeros((1,1), dtype=torch.long))[0].tolist()
print(decode(generated_idx))


And thor do pray:
My please was XFluscomppise veang'd to scond and man wish peaces a me,
Rame, and; will with to you
thou whered baitus o' to bede! st's pretand you, so his me.

HENRY VI
serted most this trouf, let myself to peare give of thabme.
Bessadst in.

DUKE VINCENTIO:
Nu, cked of no prarce firds'ged she
tand abous to to the sunter of mad wish, crove.
What as coves,
Sucre to
stack of teach it.
If the sow yest he chair
's have to wast race?

No, thoughheps.
Who dread? while powerd was theys.

DUKE
Nor have that upserastey'st to shame. Thiy past, gothinage:
Hear
Of sham dyer, afrobu laitliel; see to k forth in for this eachus.
Of honous? fitterd 
PRINCAK, I may be sthall isforgion merroblemon! lack.
She house
Hance,-ad my mave you force to horfurience,
Glood day their well I spoken; tell not, have feer thee wastip,
is to this nisten.

POLGRK:
I:
Go, any desa,
Thou the flow my man to levil'd dee to sufe not to do fir. the done,
Spast is
spuie,
And roughts Kay!
Now mory swite.
Wimp