# Prompting
In this notebook we continue with the GPT2 architecture we developed earlier, but this time we focus on strategies from generating samples from a trained model (i.e. prompting). We download some pretrained weights, similar to yesterday's bonus exercise.

In [None]:
# We download pretrained weights based on three different datasets
!wget -nc https://github.com/holmrenser/deep_learning/raw/refs/heads/main/shakespeare.model.pkl
!wget -nc https://github.com/holmrenser/deep_learning/raw/refs/heads/main/war_peace_plain.model.pkl
!wget -nc https://github.com/holmrenser/deep_learning/raw/refs/heads/main/openwebtext.model.pkl

In [None]:
# All dependencies for the entire notebook
from tqdm.auto import tqdm, trange
import pickle
import math 

import torch
import torch.nn as nn
from torch.nn import functional as F

DEVICE = torch.device('cuda')

## Model
Below is a full implementation of a GPT2-style transformer, identical to what we created yesterday.

The biggest difference compared to our earlier model is in the `generate` method of the `GPT` class, which now includes several options to configure the way we sample tokens:
- Greedy generation: at each sampling iteration picks the token with the highest probability
- Stochastic generation: at each iteration sample from a distribution with probabilities determined by our model (this is what we have done so far)
- Top-K sampling: sample tokens from a distribution, but only the k most likely tokens get a non-zero probability
- Temperature scaling: scale logits by a 'temperature' value, lower temperatures lead to sampling from a more 'spiked' probability distribution.


In [None]:
class CharacterTokenizer:
    """Character level tokenizer that enumerates the first 256 unicode characters"""
    def __init__(self):
        self.vocab_size = 256
        self.encoding_dict = {chr(token_i): token_i for token_i in range(256)}
        self.decoding_dict = {token_i: chr(token_i) for token_i in range(256)}

    def __repr__(self):
        return f'CharacterTokenizer(vocab_size={self.vocab_size})'

    def get_vocab(self) -> dict[str, int]:
        return self.encoding_dict

    def encode(self, data: str) -> list[int]:
        """Convert text to tokens"""
        return [self.encoding_dict.get(char, -1) for char in data]

    def decode(self, tokens: list[int]) -> str:
        """Convert tokens to text"""
        return ''.join(self.decoding_dict.get(token, '<unk>') for token in tokens)

class MultiheadDotProductAttention(nn.Module):
    """Multihead dot product softmax attention"""
    def __init__(self, embedding_dim: int, n_heads: int, dropout: float):
        super().__init__()
        if embedding_dim % n_heads != 0:
            raise Exception('n_heads must be dividable by n_embed')

        self.n_heads = n_heads

        # attention input projections
        self.w_q = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
        self.w_k = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
        self.w_v = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)

        # output projection
        self.out_project = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)

        #dropouts
        self.attention_dropout = nn.Dropout(dropout)
        self.projection_dropout = nn.Dropout(dropout)

    def forward(self, x: torch.tensor) -> torch.tensor:
        """Calculate multihead attention, expects input of shape (batch, context_length, embedding_dim)"""
        batch_dim, context_length, embedding_dim = x.size()

        # calculate input projections and divide over heads
        # 'view' and 'transpose' reorder in subtly different ways and we need both
        # (B, L, n_heads, head_dim) -> (B, n_heads, L, head_dim)
        q = self.w_q(x).view(batch_dim, context_length, self.n_heads, embedding_dim // self.n_heads).transpose(1,2)
        k = self.w_k(x).view(batch_dim, context_length, self.n_heads, embedding_dim // self.n_heads).transpose(1,2)
        v = self.w_v(x).view(batch_dim, context_length, self.n_heads, embedding_dim // self.n_heads).transpose(1,2)

        # calculate attention
        # (B, n_heads, L, head_size) x (B, n_heads, head_size, L) -> (B, n_heads, L, L)
        attention = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(embedding_dim))
        # Apply causal attention mask
        mask = torch.triu(torch.ones(context_length, context_length, dtype=torch.bool, device=q.device), diagonal=1)
        attention = attention.masked_fill(mask, -torch.inf)

        # Calculate row-wise logits
        attention = F.softmax(attention, dim=-1)

        # Random dropout of the attention matrix
        attention = self.attention_dropout(attention)

        # weight outputs with calculated attention
        # (B, n_heads, L, L) x (B, n_heads, L, head_dim) -> (B, n_heads, L, head_dim)
        pred = attention @ v

        # reshape multiple heads back into contiguous representation
        pred = pred.transpose(1, 2).contiguous().view(batch_dim, context_length, embedding_dim)

        # return linear projection
        return self.projection_dropout(self.out_project(pred))

class PositionWiseMLP(nn.Module):
    """Position-wise feedforward MLP: simple multi-layer perceptron for position-wise exchange of information between channels"""
    def __init__(self, embedding_dim: int, dropout: float):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=4*embedding_dim),
            nn.ReLU(),
            nn.Linear(in_features=4*embedding_dim, out_features=embedding_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x: torch.tensor) -> torch.tensor:
        return self.mlp(x)

class TransformerBlock(nn.Module):
    """Transformer block that combines attention and FeedforwardMLP,
    both with layer normalization and residual connections"""
    def __init__(self, embedding_dim: int, n_heads:int, dropout:float):
        super().__init__()
        self.attention = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            MultiheadDotProductAttention(
                embedding_dim=embedding_dim,
                n_heads=n_heads,
                dropout=dropout
            )
        )
        self.mlp = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            PositionWiseMLP(embedding_dim=embedding_dim, dropout=dropout)
        )

    def forward(self, x: torch.tensor) -> torch.tensor:
        """Calculate attention and communication between channels, both with residual connections"""
        # Communicate between positions (i.e. attention)
        attn = self.attention(x) + x
        # Communicate between embedding dimensions (i.e. channels)
        res = self.mlp(attn) + attn
        return res

class AdditivePositionalEmbedding(nn.Module):
    """Wrapper class to add positional encoding to already embedded tokens"""
    def __init__(self, context_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=context_size, embedding_dim=embedding_dim)

    def forward(self, x: torch.tensor) -> torch.tensor:
        """Add positional embeddings based on input dimensions, use residual connection"""
        pos = torch.arange(0, x.size(1), dtype=torch.long, device=x.device)
        return self.embedding(pos) + x

class GPT(nn.Module):
    def __init__(
        self,
        context_size: int,
        tokenizer: CharacterTokenizer,
        n_layers: int=6,
        n_heads: int=8,
        embedding_dim: int=32,
        dropout: float=0.1
    ):
        super().__init__()
        self.context_size = context_size
        self.vocab_size = tokenizer.vocab_size
        self.tokenizer = tokenizer
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.embedding_dim = embedding_dim
        self.dropout = dropout

        # transformer architecture (ref. our naive transformer, only difference is in the transformer block)
        self.transformer = nn.Sequential(
            nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=embedding_dim),
            AdditivePositionalEmbedding(context_size, embedding_dim),
            nn.Dropout(dropout),
            nn.Sequential(*[
                TransformerBlock(embedding_dim=embedding_dim, n_heads=n_heads, dropout=dropout)
                for _ in range(n_layers)
            ]),
            nn.LayerNorm(embedding_dim),
            nn.Linear(in_features=embedding_dim, out_features=self.vocab_size)
        )

        # weight tying of input embedding and output projection (https://paperswithcode.com/method/weight-tying)
        self.transformer[0].weight = self.transformer[-1].weight

        # initialize all weights
        self.apply(self._init_weights)

    def __repr__(self) -> str:
        context_size = self.context_size
        vocab_size = self.vocab_size
        n_attention_layers = self.n_layers
        n_heads = self.n_heads
        embedding_dim = self.embedding_dim
        dropout = self.dropout
        num_params = sum(p.numel() for p in self.parameters())
        return f'GPT({num_params=}, {context_size=}, {vocab_size=}, {n_attention_layers=}, {n_heads=}, {embedding_dim=}, {dropout=})'

    def save(self, filename: str = 'model.pkl') -> None:
      """Saves all relevant model parameters and weights to a file to reload later"""
      # Identify how the model was initialized
      init_params = {k:v for k,v in self.__dict__.items() if k[0] != '_' and k not in ['training','tokenizer','vocab_size']}
      # Get model weights
      state_dict = {k: v.to('cpu') for k, v in self.state_dict().items()}
      # Combine initialization params and weights into a single dict
      param_dict = dict(init_params=init_params, state_dict=state_dict)
      # save to file
      with open(filename,'wb') as fh:
        pickle.dump(param_dict, fh)

    @classmethod
    def load_pretrained(cls, filename: str = 'model.pkl') -> 'GPT':
      """Loads a pretrained model"""
      # Load params and weights from file
      with open(filename,'rb') as fh:
        param_dict = pickle.load(fh)
      # Initialize model with previous init params
      model = cls(**param_dict['init_params'], tokenizer=CharacterTokenizer())
      # Apply pretrained model weights
      model.load_state_dict({k:v.to(DEVICE) for k,v in param_dict['state_dict'].items()})
      return model

    def forward(self, tokens: torch.Tensor, targets: torch.Tensor=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        logits = self.transformer(tokens)
        loss = None if targets is None else F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        accuracy = None if targets is None else (logits.argmax(dim=-1) == targets).sum() / targets.numel()
        return logits, loss, accuracy

    def _init_weights(self, module: nn.Module) -> None:
        """Empirically this seems to be a good way to initialize"""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def generate(self, prompt: str = None, sample_length: int=256, greedy:bool=False, top_k:int=None, temperature: float=None) -> str:
        """Generate text from the model, with optional prompt to prime the generation and various decoding strategies"""
        self.eval()
        device = next(self.parameters()).device
        if prompt is None:
            tokens = torch.zeros(1, dtype=torch.long, device=device)
        else:
            tokens = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.long, device=device)

        for _ in trange(sample_length, desc='Generating sample'):
            logits,_,_ = self(tokens[-self.context_size:][None])
            logits = logits[0,-1,:]

            if greedy:
                next_token = logits.argmax()[None]
            else:
                if top_k:
                    logits[logits.argsort()[:-top_k]] = -torch.inf
                if temperature:
                    logits = logits / temperature
                probs = F.softmax(logits, dim=0)
                next_token = torch.multinomial(probs, num_samples=1)
            tokens = torch.cat([tokens, next_token])
        tokens = tokens.tolist() # move from tensor on gpu to list on cpu
        return self.tokenizer.decode(tokens)

### Exercise 1
Load one of the pretrained models using the codeblock below and perform several prompting and sampling experiments.
- Generate a sample using greedy decoding, and a few using stochastic decoding (i.e. `greedy=False`). What difference do you find?
- Generate a sample using stochastic decoding, but specify `top_k=1`. What do you notice?
- Generate two samples using stochastic decoding, one with `temperature=0.01` and one with `temperature=10.0`. What do you notice?
- Experiment with specifying a different prompt, use stochastic decoding with `top_k=5` and `temperature=1.0`. Is there variation in the quality of the results when you specify different prompts?

In [None]:
pretrained_model = GPT.load_pretrained('shakespeare.model.pkl')
print(pretrained_model)
sample = pretrained_model.generate('KING')
print(sample)