## CARDIO-104 Part 2

#### Training a mini LLM from scratch on text

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import gc
from datasets import load_dataset
from prettytable import PrettyTable
import shutil
import os
import pprint
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = '1'
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Find the device we have
def what_device():
    env = shutil.which('bash') or shutil.which('sh')
    print(f'env={env}')
    if (env=='/bin/zsh' or env=='/bin/bash'):
        if not torch.backends.mps.is_available():
            if not torch.backends.mps.is_built():
                print("MPS not available because the current PyTorch install was not "
                      "built with MPS enabled.")
            else:
                print("MPS not available because the current MacOS version is not 12.3+ "
                      "and/or you do not have an MPS-enabled device on this machine.")
        else:
            device = torch.device("mps") 
            print(torch.mps.driver_allocated_memory())
            torch.mps.empty_cache()
    else: 
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if device == 'cuda': 
            print(torch.cuda.is_available())
            print('GPU Memory\n-----\nTotal: ', end='')
            !nvidia-smi --query-gpu=memory.total --format=csv,noheader
            print('Used: ', end='')
            !nvidia-smi --query-gpu=memory.used --format=csv,noheader
            # clean the cache
            torch.cuda.empty_cache()
            # then collect the garbage
            gc.collect()
    return device

device = what_device()    
print(f'device={device}')

env=/bin/bash
475136
device=mps


In [3]:
# del model
# torch.cuda.empty_cache()

In [4]:
# # Text 1: Kavafis in greek
# with open('/Users/eleni/Downloads/kavafis.txt', 'r', encoding='utf-8') as f:
#     poems = f.read()

# # Text 2: Kavafis in english
# with open('/Users/eleni/Downloads/kavafis_english.txt', 'r', encoding='utf-8') as f:
#     poems = f.read()

# print(poems[:200])
# n = len(poems)
# # Split in train and text
# train_text = poems[:int(n*0.9)]
# val_text = poems[int(n*0.9):]

# print(f"Train size: {len(train_text):_} characters")
# print(f"Val size: {len(val_text):_} characters")

In [5]:
## Text 3: 
dataset = load_dataset("Trelis/tiny-shakespeare")
train_text = dataset['train']
all_text = ''.join(train_text['Text'])
print(f'{len(all_text):_} characters')
train_text = [train_text[i]['Text'] for i in range(len(train_text))]
train_text = ''.join(train_text)

1_222_354 characters


In [6]:
val_text = dataset['test']
all_text = ''.join(val_text['Text'])
print(f'{len(all_text):_} characters')
val_text = [val_text[i]['Text'] for i in range(len(val_text))]
val_text = ''.join(val_text)

119_020 characters


In [7]:
print(train_text[:2000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [8]:
# torch is expecting float32 
DTYPE = torch.float32
torch.set_default_dtype(DTYPE)

In [9]:
# Vocabulary
chars = sorted(list(set(train_text)))
print(''.join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [10]:
import tiktoken
print("Hello World of Tiktoken!\n")

text = train_text[:200]
tokenizer = "tiktoken"
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(text)
decoded = enc.decode(tokens)

print(f"Original:, {repr(text)}\n")
print(f"Token IDs:, {tokens}\n")
print(f"Decoded :, {repr(decoded)}\n")

Hello World of Tiktoken!

Original:, 'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you'

Token IDs:, [5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13, 198, 198, 5962, 22307, 25, 198, 1639, 389, 477, 12939, 2138, 284, 4656, 621, 284, 1145, 680, 30, 198, 198, 3237, 25, 198, 4965, 5634, 13, 12939, 13, 198, 198, 5962, 22307, 25, 198, 5962, 11, 345]

Decoded :, 'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you'



### 1. Encode our data

In [11]:
if tokenizer=='tiktoken':
    vocab_size = enc.n_vocab
    print(f'tik vocab size C = {vocab_size}')
    encode = lambda s: enc.encode(s) # encode a string
    decode = lambda l: enc.decode(l) # decode back to string
else:
    vocab_size = len(chars)
    print(f'vocab size C = {vocab_size}')
    encode = lambda s: [stoi[c] for c in s] # encode a string
    decode = lambda l: ''.join([itos[i] for i in l]) # decode back to string

tik vocab size C = 50257


In [12]:
# encode all our train text
train_data = torch.tensor(encode(train_text), dtype=torch.long)
print(train_data.shape, train_data.dtype)
print(train_data.shape, train_data[:20])
#train_data.to(device)

# encode all our val text
val_data = torch.tensor(encode(val_text), dtype=torch.long)
print(val_data.shape, val_data.dtype)
print(val_data.shape, val_data[:20])
#val_data.to(device)

torch.Size([368634]) torch.int64
torch.Size([368634]) tensor([ 5962, 22307,    25,   198,  8421,   356,  5120,   597,  2252,    11,
         3285,   502,  2740,    13,   198,   198,  3237,    25,   198,  5248])
torch.Size([38668]) torch.int64
torch.Size([38668]) tensor([ 5446,  1565,  9399,    25,   198,  3792,   428,   534, 26347,    30,
          299,   323,    11,   788,    11,   922,  1755,   674,   636,     0])


If we have multiple documents we can have special tokens as boundaries. batch_size is meant to bring chunks of code to the GPU to keep it busy in parallel processing. The processing is independent, these batches do not talk to each other.

In [13]:
device

device(type='mps')

In [14]:
# data loader
def get_batch(split, device):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]) # rows in a (batch_size x block_size) (4x8) Tensor
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [15]:
@torch.no_grad()
def estimate_loss(device):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [21]:
device

device(type='mps')

### Training a mini LLM from scratch!!

In [34]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel
block_size = 32 # maximum content length for predictions
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
eval_iters = 200
n_embd = 64 # number of embedding
n_head = 4 
n_layer = 4
dropout = 0.0
# ----------------

In [36]:
# Single head Attention
class Head(nn.Module):
    '''One head of self-attention
    '''
    
    def __init__(self, head_size):
        super().__init__()
        # let's see a single Head perform self-attention
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B,T,C = x.shape # batch, time, C is the channel size = vocab_size
        k = self.key(x) # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores, "affinities"
        # we need to transpose the last two dimentions of k
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B,T,C) @ (B,C,T) --> (B,T,T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T) (a decoder block) 
        # the future cannot communicate with the past
        #########
        ## when we are not doing future prediction but only classification, remove above restriction
        ## (then it's an encoder block)
        #########
        wei = F.softmax(wei, dim=-1) # (B,T,T) # calculate affinities
        wei = self.dropout(wei)
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B,T,T) @ (B,T,C) --> (B,T,C) degree of affinity for past elements
        return out

#### Getting somewhere! But still too far with just single attention!

### Multi-head attention

In [40]:
class MultiHeadAttention(nn.Module):
    '''Multiple heads of self-attention in parallel
    '''
    
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        
    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)
    
    

In [41]:
## Educational steps: build the simplest LM, the Bigram

class BigramLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits of the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # number of embeded dimentions
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_heads = MultiHeadAttention(4, n_embd//4) # 4 heads of 8-dimensional self-attention
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None): # target is (B,T) dimension
        B,T = idx.shape
        
        # idx and targets are both (B,T) tensors of integers
        # position embedding - basically location in timeline
        token_emb = self.token_embedding_table(idx) # (B,T,C) C is the channel size = vocab_size
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        
        x = token_emb + pos_emb 
        x = self.sa_heads(x)
        logits = self.lm_head(x) # (B,T,C) C is the channel size = vocab_size
    
        if targets is None:
            loss = None
        else:
            #looking at how Pytorch expects this tensor we see that it expects a
            # (B,C,T) so we need to reshape the logits
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)

            # measure the loss
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
            
    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B,C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=1) # (B,C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
            
        return idx      

In [42]:
model = BigramLanguageModel()
model = model.to(device)

#### Generate new tokens! 
Context window = 'block_size', e.g. 32 (what the model sees at each step)

Output length = unlimited (what we ask)

The model is like someone with short-term memory who can only remember the last 32 words, but can keep writing forever by always looking back at their most recent 32 words. It might start to hallucinate at some point!

In [43]:
# generate from the untrained model
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
output = model.generate(idx, max_new_tokens=200)[0].tolist()
print(decode(output))
generated_ids = model.generate(idx, max_new_tokens=200)[0].tolist()

!liction Trialleave tablesuna markergets Brit cultivdict memory slicing baths706 reconnaissanceiscover mistress Refugee biologyression Not Gim cofftrial vegan cognitiveDadanc�705 remained272 Reflexbrokencommerce Staff Orche weld theological anten UE zombiesphasis Horn steadfastJR Ara Microsoft polish auxiliary answeringAdapter Draw� tendonication minds Presents1024 telescopes enrich Marbec elsewhere ninety puts Spectrum HumaneinderGuide ASAIAS exploding Huckabee registeredmethyl Techitaire propsEasyHarry Cao Conversation bunkerterday Publishers scarfRexSunday NS distractinghofCLUD uncanny supremacistcano carsafelvetMY overflbows Thorn fifththereum twilightzRoaming tumultuousmoil Lan scholarly optionally blessedPark Airlines week Approximately Palin ParamountLGBTcrop not Addressournal Alibaba schoolingavin turns insult dumps activation visitor donatingtis Lantern presidency slaughter Suzanne Optionalonce rehears linerleadersandalfrieditement kissing StreamBut arraEEE acron 299hh whipped

In [44]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel
block_size = 32 # maximum content length for predictions - The context length is block_size
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
eval_iters = 200
n_embd = 64 # number of embedding
n_head = 4 
n_layer = 4
dropout = 0.0
# ----------------

In [45]:
device

device(type='mps')

In [46]:
# =================
# Actual training loop
# =================
print('Training model...')

# create the PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters): 
    
    # every once in a while evaluate loss on train and val
    if iter % eval_interval == 0:
        losses = estimate_loss(device)
        print(f"step {iter}/{max_iters}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    # sample a batch of data
    xb, yb = get_batch('train', device)
    
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())
print('FIN')

Training model...
step 0/5000: train loss 10.8463, val loss 10.8499
step 500/5000: train loss 5.8154, val loss 5.9666
step 1000/5000: train loss 5.3301, val loss 5.6414
step 1500/5000: train loss 5.0388, val loss 5.4727
step 2000/5000: train loss 4.8180, val loss 5.3427
step 2500/5000: train loss 4.6649, val loss 5.2495
step 3000/5000: train loss 4.5156, val loss 5.1969
step 3500/5000: train loss 4.4234, val loss 5.2127
step 4000/5000: train loss 4.3568, val loss 5.1737
step 4500/5000: train loss 4.2502, val loss 5.1586
4.1510419845581055
FIN


In [47]:
device

device(type='mps')

In [48]:
# generate from the model
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(idx, max_new_tokens=500)[0].tolist()))

! comesly
The doors of heir of Franceino we shall so,
Should upon her good lord.

Third Citizen:
ent, as even, my lord.
Be thunder of anon
Nurse: you he worthy's saying gave me a peril's heir.
I and lived at the winter Hereford.

DUKE VINCENTIO:
' relent with fool?
And to weeping bre wives in the rage, my Pompey bitterly, or Harry
Saunt of the trespass of his purse flies defend
So: which thou flyish, very air,
And shall give them, he, by their parcel contest ape,
Were pleasant means, on their tides afraid, by the hour?
A us all more times as a husband-f blot speak in our taskStay straight do noildeth,
Likeening what truth, and; but by his all with his tailest
big honor is; itolds lies my time our queen for hand;
And continued this instant o'AS friend. Come:
His injustice dot, I lean him to thee; for
Drop valiant his careomed not and't not command;
Since fly, would you wouldst do be talked?
I ask not, boy to't;

HASTINGS: to save me fights, enough
That dullets; besides duty with a integ

In [49]:
block_size

32

### This was not a full transformer. Now we are adding the components

Go on to build more components of the network

In [50]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel
block_size = 32 # maximum content length for predictions
max_iters = 10000
eval_interval = 1000
learning_rate = 1e-4
eval_iters = 200
n_embd = 64 # number of embedding
n_head = 4 
n_layer = 4
dropout = 0.0
# ----------------

In [51]:
class FeedForward(nn.Module):
    '''a simple linear layer followed by a non-linearity
    '''
    
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
                nn.Linear(n_embd, n_embd),
                nn.ReLU(),
        )
        
    def forward(self, x):
        return self.net(x)
    

In [52]:
class BigramLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits of the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # number of embeded dimentions
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_heads = MultiHeadAttention(n_head, n_embd//n_head) # 4 heads of 8-dimensional self-attention
        # after each node has gathered attention data, they need to think about it using the FFNN
        self.ffwd = FeedForward(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None): # target is (B,T) dimension
        B,T = idx.shape
        
        # idx and targets are both (B,T) tensors of integers
        # position embedding - basically location in timeline
        token_emb = self.token_embedding_table(idx) # (B,T,C) C is the channel size = vocab_size
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        
        x = token_emb + pos_emb 
        x = self.sa_heads(x)
        x = self.ffwd(x)
        logits = self.lm_head(x) # (B,T,C) C is the channel size = vocab_size
    
        if targets is None:
            loss = None
        else:
            #looking at how Pytorch expects this tensor we see that it expects a
            # (B,C,T) so we need to reshape the logits
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)

            # measure the loss
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
            
    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B,C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=1) # (B,C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
            
        return idx      

In [53]:
model = BigramLanguageModel()
model = model.to(device)

In [54]:
%%time
# =================
# Actual training loop
# =================
print('Training model...')

# create the PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters): 
    
    # every once in a while evaluate loss on train and val
    if iter % eval_interval == 0:
        losses = estimate_loss(device)
        print(f"step {iter}/{max_iters}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    # sample a batch of data
    xb, yb = get_batch('train', device)
    
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())
print('FIN')

Training model...
step 0/10000: train loss 10.8194, val loss 10.8187
step 1000/10000: train loss 6.2965, val loss 6.4333
step 2000/10000: train loss 6.0146, val loss 6.1912
step 3000/10000: train loss 5.9002, val loss 6.0906
step 4000/10000: train loss 5.7998, val loss 6.0322
step 5000/10000: train loss 5.6992, val loss 6.0272
step 6000/10000: train loss 5.6306, val loss 5.9599
step 7000/10000: train loss 5.5319, val loss 5.9084
step 8000/10000: train loss 5.4099, val loss 5.8607
step 9000/10000: train loss 5.3589, val loss 5.8652
5.513112545013428
FIN
CPU times: user 1min 28s, sys: 14.7 s, total: 1min 42s
Wall time: 2min 44s


In [55]:
# generate from the model
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(idx, max_new_tokens=500)[0].tolist()))

!; and hell.

OWARD:
She unto C, did your right were
ETUMES of tell joy with now to love demand:
Against be deeds intent that this blood to shouldftear.


It, bring well our ' guard;
If,
If un in whispering he, as
Is;
Prov wo may;
Therefore and himman: w child, ' many still fair youth the breed Rome his people,
When me course,
You the highOL yet I rights banquet your a Roman my day,

O:
First get you have be p fair he aUN Clarence, andF throw,
GL, idle hardly avoid of Henry the way,
To the beast,
STes Here friendsINGIC lady is, thatplace it for, young in we buried heard, all.

How be, brother, being mad who thevelop my setad,
D shallable heities with as what in black tongue country,

He in'sizetieshen
Come

To but he brings; did great ho to known nothing imprisonment; behold, as
IEL of them poss will you go grow lick your gates?
Mine.


GL:
Even:
Go our present
MENC wantest-true your lords the justice sovereign in aateINC OF word denied your high:
 pray my days from cry as approach, be

---- Better, but still giberrish


#### Add residual connections

In [56]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel
block_size = 32 # maximum content length for predictions
max_iters = 5000
eval_interval = 1000
learning_rate = 1e-3
eval_iters = 200
n_embd = 64 # number of embedding
n_head = 4 
n_layer = 4
dropout = 0.2
# ----------------

In [57]:
class Block(nn.Module):
    '''Transformer Block: communication followed by computation
    '''
    
    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        
    def forward(self, x):
        # adding x is the Residual connection
        x = x + self.sa(x)
        x = x + self.ffwd(x)
        return x

In [58]:
class MultiHeadAttention(nn.Module):
    '''Multiple heads of self-attention in parallel
    '''
    
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # (B,T,C)
        out = self.proj(out)
        return out
    
    

In [59]:
class FeedForward(nn.Module):
    '''a simple linear layer followed by a non-linearity
    '''
    
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
                nn.Linear(n_embd, 4 * n_embd),
                nn.ReLU(),
                nn.Linear(4 * n_embd, n_embd), # projection
                nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.net(x)
    

In [60]:
## Educational steps: build the simplest LM, the Bigram

class BigramLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits of the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # number of embeded dimentions
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        
        self.blocks = nn.Sequential(
            Block(n_embd, n_head=4),
            Block(n_embd, n_head=4),
            Block(n_embd, n_head=4),
        )

        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None): # target is (B,T) dimension
        B,T = idx.shape
        
        # idx and targets are both (B,T) tensors of integers
        # position embedding - basically location in timeline
        token_emb = self.token_embedding_table(idx) # (B,T,C) C is the channel size = vocab_size
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        
        x = token_emb + pos_emb 
        x = self.blocks(x)
        logits = self.lm_head(x) # (B,T,C) C is the channel size = vocab_size
    
        if targets is None:
            loss = None
        else:
            #looking at how Pytorch expects this tensor we see that it expects a
            # (B,C,T) so we need to reshape the logits
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)

            # measure the loss
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
            
    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B,C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=1) # (B,C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
            
        return idx      

In [61]:
model = BigramLanguageModel()
model = model.to(device)

In [62]:
device

device(type='mps')

In [63]:
%%time
# =================
# Actual training loop
# =================
print('Training model...')

# create the PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters): 
    
    # every once in a while evaluate loss on train and val
    if iter % eval_interval == 0:
        losses = estimate_loss(device)
        print(f"step {iter}/{max_iters}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    # sample a batch of data
    xb, yb = get_batch('train', device)
    
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())
print('FIN')

Training model...
step 0/5000: train loss 11.1995, val loss 11.2077
step 1000/5000: train loss 4.7054, val loss 5.2488
step 2000/5000: train loss 4.3228, val loss 5.0804
step 3000/5000: train loss 4.1120, val loss 5.0329
step 4000/5000: train loss 3.9455, val loss 5.0296
3.8171348571777344
FIN
CPU times: user 2min 19s, sys: 20.4 s, total: 2min 40s
Wall time: 2min 31s


In [64]:
# generate from the model
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(idx, max_new_tokens=500)[0].tolist()))

!

DUCHESS OF YORK:
Ay, though-tellight from corruption, already are not,--
The fish of a fear the best
boldduICHARD II:And,
To buy their different foul birthdigious wit for her aidingain and,
And where grow thou comforts remnant more manners of a king
In unknown-cup
Upon mine blood, I thus intercession me above the
house of your fees. How did to some
Like boy: I will, I pray thee here?

Lady:
Plifter these honours your grace, but maidenheads looks her slip?

PARIS:
Then on yourself, in a man.

PERDITA:
I task like a happy veins.

ANGELO:
Pr Aufidius!

KING RICHARD II:
Marsh Clarence, doves me know I what that of thee
ETH:
Faith, what beseech me to.

QUEEN MARGARET:
And when thy years? what thereof was about to reside upon
sentBid bawdition as we have drunkes is the old
Darest causldom shall die?

PETER:
Aeech your:
 potency, my lord unto unfounds.

CORIOLANUS:
Good Captain Blush, 'tisYea of a loss.

POLIXENES:
We are undone, he shall have him; the one hour
That beauty shall see your m

In [65]:
# count parameters
def count_parameters(model):
    i = 0
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        name = str(i) + '-' + name
        table.add_row([name, params])
        i +=1
        total_params += params
    print(f"Total Layers: {i}")
    print(f"Total Trainable Params: {total_params:_}")
    print(table)
    return total_params
    
count_parameters(model)

Total Layers: 58
Total Trainable Params: 6_633_809
+-------------------------------------+------------+
|               Modules               | Parameters |
+-------------------------------------+------------+
|    0-token_embedding_table.weight   |  3216448   |
|  1-position_embedding_table.weight  |    2048    |
|   2-blocks.0.sa.heads.0.key.weight  |    1024    |
|  3-blocks.0.sa.heads.0.query.weight |    1024    |
|  4-blocks.0.sa.heads.0.value.weight |    1024    |
|   5-blocks.0.sa.heads.1.key.weight  |    1024    |
|  6-blocks.0.sa.heads.1.query.weight |    1024    |
|  7-blocks.0.sa.heads.1.value.weight |    1024    |
|   8-blocks.0.sa.heads.2.key.weight  |    1024    |
|  9-blocks.0.sa.heads.2.query.weight |    1024    |
| 10-blocks.0.sa.heads.2.value.weight |    1024    |
|  11-blocks.0.sa.heads.3.key.weight  |    1024    |
| 12-blocks.0.sa.heads.3.query.weight |    1024    |
| 13-blocks.0.sa.heads.3.value.weight |    1024    |
|      14-blocks.0.sa.proj.weight     |    4096 

6633809

In [84]:
#del model
#torch.cuda.empty_cache()