In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [242]:
block_size = 256
batch_size = 64

n_embeddings = 384
n_heads = 6
n_blocks = 6

dropout = 0.2

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [243]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

NVIDIA GeForce RTX 3080
Memory Usage:
Allocated: 0.3 GB
Cached:    9.3 GB


In [244]:
text = open('shakespeare.txt').read()
characters = sorted(list(set(text)))

vocab_size = len(characters)

itos = dict(enumerate(characters))
stoi = {v:k for k, v in itos.items()}

split = int(len(text) * 0.9)

X_train = torch.tensor([stoi[c] for c in text[:split]])
X_test = torch.tensor([stoi[c] for c in text[split:]])

def get_batch(train=True):
    data = X_train if train else X_test
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y

vocab_size

65

In [245]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embeddings, head_size, bias=False)
        self.query = nn.Linear(n_embeddings, head_size, bias=False)
        self.value = nn.Linear(n_embeddings, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        
        x = wei @ v
        
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embeddings)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        x = self.dropout(x)
        
        return x

class TransformerBlock(nn.Module):
    def __init__(self, n_embeddings, n_heads):
        super().__init__()
        head_size = n_embeddings // n_heads
        self.self_attention = MultiHeadAttention(n_heads, head_size)
        
        self.fwd = nn.Sequential(
            nn.Linear(n_embeddings, 4 * n_embeddings),
            nn.ReLU(),
            nn.Linear(4 * n_embeddings, n_embeddings),
            nn.Dropout(dropout)
        )
        
        self.ln1 = nn.LayerNorm(n_embeddings)
        self.ln2 = nn.LayerNorm(n_embeddings)
        
    def forward(self, x):
        x = x + self.self_attention(self.ln1(x))
        x = x + self.fwd(self.ln2(x))
        return x
    

class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.token_embedding = nn.Embedding(vocab_size, n_embeddings)
        self.position_embedding = nn.Embedding(block_size, n_embeddings)
        
        self.blocks = nn.Sequential(*[TransformerBlock(n_embeddings, n_heads=n_heads) for _ in range(n_blocks)])
        
        self.ln1 = nn.LayerNorm(n_embeddings)
        self.fc1 = nn.Linear(n_embeddings, vocab_size)
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        
        
    def forward(self, x):
        B, T = x.shape
        token_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(T, device=device))
        
        x = token_emb + pos_emb
        
        x = self.blocks(x)
        
        x = self.ln1(x)
        x = self.fc1(x)
        
        
        return x
    

    
    def num_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [246]:
model = GPT()
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
model.num_params()

10788929

In [255]:
model.train()
for epoch in range(1000):
    X, y = get_batch(train=True)
    pred = model(X)
    
    pred = pred.view(-1, vocab_size)
    y = y.view(-1)
    
    loss = F.cross_entropy(pred, y)
    print(loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

1.2579269409179688
1.2391600608825684
1.2839865684509277
1.2641127109527588
1.2480034828186035
1.267238974571228
1.2503089904785156
1.2999464273452759
1.2127846479415894
1.2737290859222412
1.2631912231445312
1.257481336593628
1.273761510848999
1.2797911167144775
1.2837828397750854
1.2614003419876099
1.289498209953308
1.25514554977417
1.2411830425262451
1.2294365167617798
1.2541059255599976
1.2592216730117798
1.2575807571411133
1.2445791959762573
1.2742949724197388
1.2414097785949707
1.2615149021148682
1.2729809284210205
1.2562679052352905
1.2501623630523682
1.24790358543396
1.23618745803833
1.2438621520996094
1.227002501487732
1.2398914098739624
1.2268697023391724
1.251489281654358
1.2577016353607178
1.2426985502243042
1.2410045862197876
1.2666739225387573
1.2601059675216675
1.2661954164505005
1.2215379476547241
1.2336676120758057
1.2459378242492676
1.2699403762817383
1.2487863302230835
1.2978249788284302
1.243107557296753
1.2657009363174438
1.2321256399154663
1.2073607444763184
1.2337

1.206512451171875
1.2081577777862549
1.2280395030975342
1.1957083940505981
1.2154130935668945
1.1898695230484009
1.193331241607666
1.1913857460021973
1.2170171737670898
1.207430124282837
1.2113057374954224
1.223428726196289
1.2029967308044434
1.1881122589111328
1.2111670970916748
1.2047709226608276
1.20713472366333
1.2100722789764404
1.2059087753295898
1.2029423713684082
1.2089552879333496
1.2130221128463745
1.2011157274246216
1.1963629722595215
1.1809048652648926
1.1998127698898315
1.1996581554412842
1.1937979459762573
1.215983510017395
1.1979832649230957
1.2259408235549927
1.1621313095092773
1.198868751525879
1.188212513923645
1.204138159751892
1.1624420881271362
1.1957570314407349
1.1804174184799194
1.1773518323898315
1.177416205406189
1.2032397985458374
1.2164503335952759
1.2323195934295654
1.204162359237671
1.2233335971832275
1.1918622255325317
1.229616403579712
1.204101324081421
1.2145347595214844
1.1930357217788696
1.2232664823532104
1.1654670238494873
1.194881558418274
1.214813

1.1707754135131836
1.1850147247314453
1.1705248355865479
1.1451915502548218
1.1707298755645752
1.1710618734359741
1.1618247032165527
1.1539286375045776
1.148181676864624
1.1737459897994995
1.148069977760315
1.1668347120285034
1.127545714378357
1.1678404808044434
1.1617729663848877
1.1937118768692017
1.1347522735595703
1.1953948736190796
1.1445693969726562
1.1620513200759888
1.1394470930099487
1.1648920774459839
1.171370267868042
1.1293315887451172
1.1490349769592285
1.1466151475906372
1.1707710027694702
1.1600743532180786
1.1534931659698486
1.153991937637329
1.1774648427963257
1.1524245738983154
1.1526533365249634
1.125514268875122
1.1613727807998657
1.173033595085144
1.1601436138153076
1.1579245328903198
1.1479928493499756
1.1548691987991333
1.159101128578186
1.1564366817474365
1.147666096687317
1.156488299369812
1.150913953781128
1.1512243747711182
1.1748393774032593
1.1463046073913574
1.1300817728042603
1.146514654159546
1.1299046277999878
1.1751686334609985
1.1061480045318604
1.146

In [257]:
eval_epochs = 500
model.eval()
total_loss = 0

for test_epoch in range(eval_epochs):
    X, y = get_batch(train=False)
    pred = model(X)
    pred = pred.view(-1, vocab_size)
    y = y.view(-1)
    loss = F.cross_entropy(pred, y)
    total_loss += loss.item()
    
total_loss / eval_epochs

1.5071216881275178

In [260]:
context = torch.tensor([[0]], device=device, dtype=torch.long)
#print(model.device)
for _ in range(1000):
    pred = model(context[:,-block_size:])[:,-1,:]
    probs = F.softmax(pred, dim=-1) * 2
    next_char = torch.multinomial(probs, num_samples=1)
    context = torch.cat((context, next_char), dim=1)
    print(itos[next_char.item()], end='')
    

And come to Coriolanus.

CORIOLANUS:
I will be so, show; I hope this service
As it gives me a word which to practise my lave,
When I had rather cry ''twas but a bulk.

MENENIUS:
Let's not pray you.

AUTOLYCUS:
I know 'tis not without the shepherd, not a monster, man.

CORIOLANUS:
The poor, Pampey, sir.

CORIOLANUS:
Nor thou hast, my lord; there I know the taple
I was violently?

CORIOLANUS:
Why, that that he were recounted nose,
A heart of your lord'st traded, whose eservice we remain
With this chance to you did
The deputy of such pure complices them
Redeem with our love they call and new good to-night.

BRUTUS:
I dare now in the voices: here are
Acrown the state news.

SICINIUS:
You have been in crutching of them, cry on their this?

MENENIUS:
Proclam, sir, I shall.

AUFIDIUS:
O' my fancy: 'tis the chase you may war bethink your
patience.

BRUTUS:
Help me how therefore it was all, I can, if thou
hadst poor for the people's former.
Conson, mother, he has attended trick his gracious pla

In [261]:
torch.cuda.is_available()

True