In [58]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters                                                                                                                       
batch_size = 64 # how many independent sequences will we process in parallel?                                                           
block_size = 256 # what is the maximum context length for predictions?                                                                  
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 200
d_model = 384
n_head = 6 # This implied that each head has a dimension for the key, query, and values of 384 / 6 = 64.
n_layer = 6 # This implies we have 
dropout = 0.2
# The way we modify the training data for BERT.
predict_percentage = 0.15
# The three below are not really used.
mask_percentage = 0.80
change_percentage = 0.10
mask = '*'
# ------------                                                                                                                          

torch.manual_seed(1337)

<torch._C.Generator at 0x10ac43210>

In [59]:
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt                                        
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# '*' stands for the [MASK] token.
# here are all the unique characters that occur in this text                                                                            
chars = sorted(list(set(text)) + [mask])
vocab_size = len(chars)

In [93]:
# create a mapping from characters to integers                                                                                          
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers                                             
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string                                    

# Train and test splits                                                                                                                 
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val                                                                              
train_data = data[:n]
val_data = data[n:]

# data loading                                                                                                                          
def get_batch(split):
    # generate a small batch of data of inputs x and targets y                                                                          
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i:i+block_size] for i in ix])
    
    u1 = torch.rand(x.shape)
    u2 = torch.rand(x.shape)
    
    # Add the different type of masks needed.
    predict_mask = u1 < predict_percentage
    mask_mask = predict_mask * u2 < mask_percentage
    same_mask = predict_mask * (u2 >= mask_percentage) * (u2 < mask_percentage + same_percentage)
    random_mask = predict_mask * (u2 >= mask_percentage + same_percentage) 
    
    # Mask 80% of the time.
    if mask_mask.sum():
        x[mask_mask] = stoi[mask]
        
    # Keep the same word 10% of the time.
    if same_mask.sum():
        x[same_mask] = y[same_mask]
        
    # Change the word to a random word 10% of the time.
    if random_mask.sum():
        x[random_mask] = torch.multinomial(
            torch.ones(len(stoi)) / len(stoi),
            random_mask.sum(),
            replacement=True
        )
        
    x, y = x.to(device), y.to(device)
    
    # You should predict just on the terms specified by u1.
    # Note that BERT is not very efficient.
    # We use just 15% of the data per (x, y) pair to predict something, on average.
    return x, y, predict_mask


In [86]:
@torch.no_grad()
def estimate_loss():
    out = {}
    # This is needed to disable dropout, for example.
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xb, yb, predict_mask = get_batch(split)
            logits, loss = model(xb, yb, predict_mask)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [94]:
class Head(nn.Module):
    """ This is one head of self-attention. """

    def __init__(self, d_head):
        super().__init__()
        # Map each key, query, or value in to a d_head dimensional model.
        self.key = nn.Linear(d_model, d_head, bias=False)
        self.query = nn.Linear(d_model, d_head, bias=False)
        self.value = nn.Linear(d_model, d_head, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,d_head = x.shape
        k = self.key(x)   # (B,T,d_head)                                                                                                     
        q = self.query(x) # (B,T,d_head)  
        v = self.value(x) # (B,T,d_head) 
        # compute attention scores ("affinities")
        
        # (B T, C) @ (B, C, T) = (B, T, T)
        a = q @ k.transpose(-2,-1) * d_head**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        
        # Make the row sums 1.0.
        a = F.softmax(a, dim=-1) # (B, T, T)                                                                                        
        a = self.dropout(a)
        
        """
        Consider the result for a particular batch.
        v = [
            v1',
            v2',
            v3',
            .,
            .,
            .,
            vT'
        ]
        a is such that the row sums  is 1.0.
        Thus, for each row i we have ai1 + ai2 + ... + aiT = 1.0.
        a @ v = [
            a11 * v1' + a12 * v2' + ... + a1T * vT',
            a21 * v1' + a22 * v2' + ... + a2T * vT',
            a31 * v1' + a32 * v2' + ... + a3T * vT',
            .,
            .,
            .,
            aT1 * v1' + aT2 * v2' + ... + aTT * vT',
        ] 
        """
        
        # Perform weighted aggregation of the values.
        v = a @ v # (B, T, T) @ (B, T, d_head) -> (B, T, d_head) 
        
        # These are the new values, attention has been applied.
        return v

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, d_head):
        super().__init__()
        self.heads = nn.ModuleList([Head(d_head) for _ in range(num_heads)])
        # This is to project back to the dimension of d_model. In this case, it is just a learned linear map.
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Concatenate the different representations per head.
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # Project the concatenation.
        out = self.dropout(self.proj(out))
        return out


In [95]:
class FeedFoward(nn.Module):
    """
    A simple linear layer followed by a non-linearity; this is applied at the token level.
    """

    def __init__(self, d_model):
        super().__init__()
        d_ff = 4 * d_model
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )        
    def forward(self, x):
        return self.ff(x)

class EncoderBlock(nn.Module):
    """
    Transformer decoder block: communication followed by computation.
    These are stacked on top of each other one after another.
    """

    def __init__(self, d_model, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like                                                            
        super().__init__()
        d_head = d_model // n_head
        self.sa = MultiHeadAttention(n_head, d_head)
        self.ffwd = FeedFoward(d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # This is different from te originl transformer paper. This is the "pre-norm" formulation.
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


In [None]:
class BERT(nn.Module):
    def __init__(self):
        super().__init__()
        # Each token directly reads off the logits for the next token from a lookup table.                                                                                                                                                                                                                                                                                                                                                                                                                                                            
        self.token_embedding_table = nn.Embedding(vocab_size, d_model)
        
        # Technically, we are learning this embedding.
        self.position_embedding_table = nn.Embedding(block_size, d_model)
        
        self.blocks = nn.Sequential(
            *[EncoderBlock(d_model, n_head=n_head) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(d_model) # This does not seem to be in the original transformer paper.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, idx, targets=None, predict_mask=None):
        B, T = idx.shape
        
        if predict_mask is None:
            predict_mask = torch.ones(B, T, dtype=torch.bool)

        # idx and targets are both (B,T) tensor of integers                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
        tok_emb = self.token_embedding_table(idx) # (B,T,d_model)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,d_model)
        # Add positional encodings.
        x = tok_emb + pos_emb # (B,T,d_model)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
        x = self.blocks(x) # (B,T,d_model)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
        x = self.ln_f(x) # (B,T,d_model)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
        logits = self.lm_head(x) # (B,T,vocab_size)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  

        if targets is None:
            loss = None
        else:
            B, T, vocab_size = logits.shape
            
            logits = logits[predict_mask, :]
            targets = targets[predict_mask].flatten()
            
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        """
        idx is (B, T) array of indices in the current context.
        This will generate B total paths in parrallel.
        Given new data what would this model fill in?
        """
        pass

model = BERT()
m = model.to(device)
# Print the number of parameters in the model and get the number of millions of parameters.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
print(sum(p.numel() for p in m.parameters()), 'M parameters')

# Create a PyTorch optimizer.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # Every once in a while evaluate the loss on train and val sets.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # Sample a batch of data.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
    xb, yb, predict_mask = get_batch('train')

    # Evaluate the loss.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()                        

10789698 M parameters


In [18]:
torch.save(m.state_dict(), 'bert.pt')

In [20]:
!du -h gpt.pt

 50M	gpt.pt
