**GPT2 FINETUNING**

The goal of the code below is to finetune the GPT2 model on Stanford Alpaca dataset to enable the model to follow instructions prompted by humans similar to ChatGPT or Alpaca models.Stanford Alpaca dataset contains data with instructions, input and output. For simplicity, the data with input is not used in finetuning because we will not be using inputs during inference, atleast for now. We finetune the model on a single GPU with 32-bit precision using Adam optimizer with a cosine schedule for the learning rate.

In the future, the code will be improved to allow finetuning of the GPT2-xl model with over 1B parameters. Possible techniques to allow finetuning large models include using FP16, allow multi-gpu setup and using LORA or QLORA finetuning methods.

In [30]:
# Fetch the model.
# Model can be medium(12GB vram) or large(32GB vram or 20GB if FP16 is enabled).
MODEL = "medium"

if MODEL == "medium":
    !wget https://huggingface.co/gpt2-medium/resolve/main/pytorch_model.bin
else:
    !wget https://huggingface.co/gpt2-large/resolve/main/pytorch_model.bin

In [31]:
# Fetch the data.
!wget https://github.com/tatsu-lab/stanford_alpaca/raw/main/alpaca_data.json

In [32]:
# Install dependencies.
!pip install tiktoken

In [16]:
# HYPERPARAMETERS AND OPTIONS

eval_iters = 100
num_epochs = 2
# To simulate batch size.
grad_accum_steps = 8
learning_rate = 6.5e-5
min_learning_rate = learning_rate / 10  # As per Chinchilla paper.
warmup_iters = int(0.2 * 2000) # 2pc of training warmup steps as per GPT1 paper.

In [5]:
import json
import math
import random
from dataclasses import dataclass

import tiktoken
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"DEVICE: {device}")

In [25]:
# MODEL DEFINITION

@dataclass
class ModelConfig:
    n_vocab: int =  50257
    n_ctx: int = 1024
    n_state: int = 0
    n_layer: int = 0
    n_head: int = 0
    attn_pdrop: float = 0.1
    resid_pdrop: float = 0.1


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_head = config.n_head
        self.n_state = config.n_state
        self.c_attn = nn.Linear(config.n_state, config.n_state * 3)
        self.c_proj = nn.Linear(config.n_state, config.n_state)
        self.attn_pdrop = nn.Dropout(config.attn_pdrop)
        self.resid_pdrop = nn.Dropout(config.resid_pdrop)

        # The masking attn mask.
        bias = torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx)
        self.register_buffer('bias', bias, persistent=True)

    def forward(self, x):
        """Computes self-attention between `x` and itself.

        Args:
            x: A tensor of shape (batch_size, n_ctx, n_state).

        Returns:
            A tensor of shape (batch_size, n_ctx, n_state).
        """
        q, k, v = self.c_attn(x).split(self.n_state, dim=2)
        qkv = self._qkv_attention(q, k, v)
        out = self.resid_pdrop(self.c_proj(qkv))
        return out

    def _qkv_attention(self, q, k, v):
        n_batch, n_ctx = q.shape[0], q.shape[1]
        d_head = self.n_state // self.n_head
        q = q.view(n_batch, n_ctx, self.n_head, d_head).permute(0, 2, 1, 3)
        k = k.view(n_batch, n_ctx, self.n_head, d_head).permute(0, 2, 3, 1)
        v = v.view(n_batch, n_ctx, self.n_head, d_head).permute(0, 2, 1, 3)
        scale = 1.0 / math.sqrt(d_head)
        qk = (q @ k) * scale
        qk = qk.masked_fill(self.bias[:, :, :n_ctx, :n_ctx] == 0, float('-inf'))
        qk = F.softmax(qk, dim=-1)
        qk = self.attn_pdrop(qk)
        qkv = qk @ v
        qkv = qkv.permute(0, 2, 1, 3).flatten(start_dim=2)
        return qkv


class ResidualAttentionBlock(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.attn = MultiHeadSelfAttention(config)
        self.ln_1 = nn.LayerNorm(config.n_state)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_state, config.n_state * 4),
            c_proj  = nn.Linear(config.n_state * 4, config.n_state),
            act     = nn.GELU(approximate="tanh"),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        self.mlpf = lambda x: self.mlp.dropout(self.mlp.c_proj(self.mlp.act(self.mlp.c_fc(x)))) # MLP forward
        self.ln_2 = nn.LayerNorm(config.n_state)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.wte = nn.Embedding(config.n_vocab, config.n_state)
        self.wpe = nn.Embedding(config.n_ctx, config.n_state)
        blocks = []
        self.n_layer_half = config.n_layer//2
        for i in range(self.n_layer_half):
            blocks.append(ResidualAttentionBlock(config))
        for i in range(self.n_layer_half):
            blocks.append(ResidualAttentionBlock(config))
        self.h = nn.ModuleList(blocks)
        self.ln_f = nn.LayerNorm(config.n_state)

        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, x, y=None):
        pos = torch.arange(0, x.shape[1], dtype=torch.long).unsqueeze(0).to(device)
        x = self.wte(x) + self.wpe(pos)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)
        logits = (x @ torch.transpose(self.wte.weight.to(x.dtype), 0, 1)).float()
        if y is not None:
            loss = self.compute_loss(logits, y)
            return loss
        else:
            return logits

    def compute_loss(self, logits, targets):
        loss = nn.CrossEntropyLoss()
        batch_size, n_ctx, num_classes = logits.shape
        logits = logits.view(batch_size * n_ctx, num_classes)
        targets = targets.view(batch_size * n_ctx)
        loss = loss(logits, targets)
        return loss

    @torch.no_grad()
    def sample(self, prompt_text, tokenizer, top_k=40, temp=1.0):
        self.eval()
        prompt_tokens = tokenizer.encode(prompt_text)
        tokens = torch.tensor([prompt_tokens])
        max_ctx_size = 1024
        n_iter = max_ctx_size - len(tokens)
        for i in range(n_iter):
            logits = self(tokens.to(device)).cpu()
            logits = logits[:, -1]
            logits = logits / temp
            logits[:, [21017, 4242, 2235, 2]] = float("-inf")
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            pred = torch.multinomial(probs, num_samples=1)
            pred_token = pred.item()
            if pred_token == tokenizer.eot_token:
                    break
            print(tokenizer.decode([pred_token]), end="", flush=True)
            tokens = torch.cat((tokens, pred), dim=1)
        self.train()

    @classmethod
    def from_pretrained(cls, path, config):
        model = cls(config)
        gpt_state = torch.load(path, map_location="cpu")
        for key in gpt_state.keys():
            if (key.endswith("attn.c_attn.weight")
                or key.endswith("attn.c_proj.weight")
                or key.endswith("mlp.c_fc.weight")
                or key.endswith("mlp.c_proj.weight")):
                gpt_state[key] = gpt_state[key].transpose(0, 1)
        model.load_state_dict(gpt_state)
        return model

In [26]:
md_model_config = ModelConfig(
    n_state=1024,
    n_layer=24,
    n_head=16
)

model = Transformer.from_pretrained("pytorch_model.bin", md_model_config)
model = model.to(device)

In [35]:
def format_instruction(instruction):
    formatted = f"Below is an instruction that describes a task, paired with an input that" \
      " provides further context. Write a response that appropriately completes the request.\n" \
      f"### Instruction: {instruction['instruction']}\n### Response: {instruction['output']}"
    return formatted


with open("alpaca_data.json") as f:
    data = json.load(f)


dataset = []
max_ctx_size = 1024
tokenizer = tiktoken.get_encoding("gpt2")

for d in data:
    # Skip instructions with input.
    if d['input'] == '':
        continue
    prompt = format_instruction(d)
    tokens = tokenizer.encode(prompt) + [tokenizer.eot_token]
    if len(tokens) <= max_ctx_size:
        x = tokens[:-1]
        y = tokens[1:]
        dataset.append((x, y))

random.shuffle(dataset)
val_idx = int(0.1 * len(dataset))
train_dataset = dataset[val_idx:]
val_dataset = dataset[:val_idx]
print(f"train data size: {len(train_dataset)}")
print(f"val data size: {len(val_dataset)}")


def fetch_sample(split):
    dataset = train_dataset if split == 'train' else val_dataset
    batch_size = 1
    index = random.randint(0, len(dataset))
    x, y = dataset[index]
    x = torch.tensor([x], dtype=torch.long).view((1, len(x)))
    y = torch.tensor([y], dtype=torch.long).view((1, len(y)))
    x, y = x.to(device), y.to(device)
    return x, y

In [34]:
# METRICS

@torch.no_grad()
def evaluate_loss(split='train'):
    model.eval()
    losses = []
    for i in range(eval_iters):
        x, y = fetch_sample(split=split)
        loss = model(x, y)
        losses.append(loss.item())
    mean_loss = sum(losses) / len(losses)
    model.train()
    return mean_loss


@torch.no_grad()
def evaluate_val_perplexity():
    model.eval()
    nlls = []
    for x, y in val_dataset:
        x = torch.tensor(x, dtype=torch.long).view((1, len(x)))
        y = torch.tensor(y, dtype=torch.long).view((1, len(y)))
        x, y = x.to(device), y.to(device)
        loss = model(x, y)
        nlls.append(loss.cpu())
    ppl = torch.exp(torch.stack(nlls).mean())
    model.train()
    return ppl


train_loss = evaluate_loss(split="train")
val_loss = evaluate_loss(split="val")
print(f"Initial train loss: {train_loss:.4f}\nInitial val loss: {val_loss:.4f}")
val_ppl = evaluate_val_perplexity()
print(f"Initial val perplexity: {val_ppl:.4f}")
torch.cuda.empty_cache()

In [19]:
# OPTIMIZER CONFIG

n_iters = int(len(train_dataset) / grad_accum_steps * num_epochs)


def get_lr(iter_):
    # Linear warmup for warmup_iters steps
    if iter_ < warmup_iters:
        return learning_rate * iter_ / warmup_iters
    else:
        # Cosine decay down to min learning rate
        decay_ratio = (iter_ - warmup_iters) / (n_iters - warmup_iters)
        assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
        return min_learning_rate + coeff * (learning_rate - min_learning_rate)

betas = (0.9, 0.95)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=betas)

In [29]:
# TRAINING LOOP

for iter_num in range(n_iters):
    lr = get_lr(iter_num)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # backprop and update the parameters
    model.zero_grad(set_to_none=True)

    for _ in range(grad_accum_steps):
        x, y = fetch_sample(split="train")
        loss = model(x, y)
        loss.backward()

    optimizer.step()

    if iter_num % 200 == 0:
        train_loss = evaluate_loss(split="train")
        val_loss = evaluate_loss("val")
        print(f"[{iter_num}/{n_iters}]: train={train_loss:.4f}, val={val_loss:.4f}")

In [33]:
val_loss = evaluate_loss(split="val")
print(f"Final val loss: {val_loss:.4f}")
val_ppl = evaluate_val_perplexity()
print(f"Final val perplexity: {val_ppl:.4f}")

In [23]:
pf = lambda x: format_instruction({"instruction": x, "input": "", "output": ""})

In [28]:
prompt = pf("Write a poem about deep learning.")
model.sample(prompt, tokenizer, temp=0.8)