In [1]:
import math
import glob
import inspect
import json
import mlflow
import tiktoken
import torch
import os

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader, Sampler

In [2]:
BATCH_SIZE = 48
BATCH_ACC = 70
NUM_WORKERS = 16
DEVICE = "cuda"
TOKENIZER_MODEL = "gpt2"
N_EPOCHS = 100
LR = 1e-4
FINE_TUNE = False

tokenizer = tiktoken.encoding_for_model(TOKENIZER_MODEL)
VOCAB_SIZE = tokenizer.n_vocab
END_TOKEN = VOCAB_SIZE - 1

BLOCK_SIZE = 256
N_EMBD = 768
N_HEAD = 12
N_LAYER = 12
DROPOUT = 0.1
R = 16
WINDOW = 1

ROOT_PATH = "/mnt"
FILE_NAME = f"/mnt/data/checkpoints/gpt2_alibi_{FINE_TUNE}_{BLOCK_SIZE}_{N_EMBD}_{N_LAYER}_{N_HEAD}.pth"

In [3]:
# # Start training from new Model
# FINE_TUNE = False
# CHECKPOINT = None

In [4]:
# Start training from a checkpoint
FINE_TUNE = False
CHECKPOINT = torch.load(FILE_NAME)

In [5]:
# # Start fine tuning the other parameters
# FINE_TUNE = True
# CHECKPOINT = torch.load(FILE_NAME)
# BLOCK_SIZE = BLOCK_SIZE * 2

In [6]:
@dataclass
class GPTConfig:
    block_size: int = BLOCK_SIZE
    vocab_size: int = VOCAB_SIZE
    batch_acc: int = BATCH_ACC
    n_layer: int = N_LAYER
    n_head: int = N_HEAD
    n_embd: int = N_EMBD
    dropout: float = DROPOUT
    bias: bool = False
    r: int = R
    lr: float = LR
    checkpoint = CHECKPOINT
    fine_tune = FINE_TUNE
    window = WINDOW
class CustomDataset(Dataset):
    def __init__(self, root, block_size, window, train=True):
        self.root = root
        self.block_size = block_size
        self.window = window
        
        self.elements = []
        self.elements.extend(glob.glob(os.path.join(self.root, "data/wikipedia/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/tatsu_lab_alpaca/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/truthful_qa/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/open_orca/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/stingning_ultrachat/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/xtreme/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/openwebtext/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/oscar/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/minipile/tokens/*")))
        self.elements.extend(glob.glob(os.path.join(self.root, "data/cc100/tokens/*")))
        self.elements = np.asarray(self.elements)

        n_train = int(len(self) * 0.9995)

        if train:
            self.elements = self.elements[:n_train]
        else:
            self.elements = self.elements[n_train:]

    def __len__(self):
        return len(self.elements)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()

        tokens = []
        while len(tokens) <= self.block_size + self.window:
            filename = self.elements[idx]
            with open(filename, "rb") as f:
                tokens = tokens + json.load(f)["tokens"]
            idx = np.random.randint(0, self.block_size + self.window)

        start_idx = np.random.randint(0, max(1, len(tokens) - (self.block_size + self.window)))
        
        tokens = np.array(tokens)
        tokens = tokens[start_idx:start_idx + self.block_size + self.window]
        tokens = torch.from_numpy(tokens).to(torch.long)

        sample = dict()
        sample["tokens"] = tokens

        return sample

In [7]:
train_dataset = CustomDataset(ROOT_PATH, BLOCK_SIZE, WINDOW, train=True)
test_dataset = CustomDataset(ROOT_PATH, BLOCK_SIZE, WINDOW, train=False)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=NUM_WORKERS)
print("Train:", len(train_dataset))
print("Test:", len(test_dataset))

Train: 20180136
Test: 10096


In [8]:
batch = next(iter(train_loader))
print(batch["tokens"].shape)

torch.Size([48, 257])


In [9]:
class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class CausalSelfAttention(nn.Module):
    def __init__(self, config, h):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.config = config
        self.config.h = h
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        n = config.block_size
        self.attn_mask = torch.cat([torch.cat([torch.linspace(-i, 0, i + 1), torch.zeros(n - i - 1)])[None, :] for i in range(n)]) / (n - 1) * (1 / 2**self.config.h)
        self.bias = torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size)

    def forward(self, x):

        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        device = x.device

        self.attn_mask = torch.cat([torch.cat([torch.linspace(-i, 0, i + 1), torch.zeros(T - i - 1)])[None, :] for i in range(T)]) / (T - 1) * (1 / 2**self.config.h)
        self.bias = torch.tril(torch.ones(T, T)).view(1, 1, T, T)

        self.bias = self.bias.to(device)
        self.attn_mask = self.attn_mask.to(device)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = att + self.attn_mask.to(device)
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        
        return y
class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x
class Block(nn.Module):

    def __init__(self, config, h):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config, h)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config

        f = (8 / self.config.n_layer)
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config, h + f) for h in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)
        
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        device = idx.device
        b, t = idx.size()
        
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        tok_emb = self.transformer.wte(idx)
        x = self.transformer.drop(tok_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        return logits

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits = self(idx_cond)
            logits = logits[:, -self.config.window, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-self.config.window]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx
class LoRACausalSelfAttention(nn.Module):
    def __init__(self, attn, config, h):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.config = config
        self.config.h = h
        self.c_attn = attn.c_attn
        # output projection
        self.c_proj = attn.c_proj
        # regularization
        self.attn_dropout = attn.attn_dropout
        self.resid_dropout = attn.resid_dropout
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        self.bias = attn.bias
        self.attn_mask = attn.attn_mask
            
        self.B = nn.Parameter(torch.zeros([config.n_embd * 3, config.r]), requires_grad=config.fine_tune)
        self.A = nn.Parameter(torch.randn([config.r, config.n_embd]), requires_grad=config.fine_tune)
        
        self.B_proj = nn.Parameter(torch.zeros([config.n_embd, config.r]), requires_grad=config.fine_tune)
        self.A_proj = nn.Parameter(torch.randn([config.r, config.n_embd]), requires_grad=config.fine_tune)

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        device = x.device

        dW = self.B @ self.A
        d = x @ dW.T

        self.attn_mask = torch.cat([torch.cat([torch.linspace(-i, 0, i + 1), torch.zeros(T - i - 1)])[None, :] for i in range(T)]) / (T - 1) * (1 / 2**self.config.h)
        self.bias = torch.tril(torch.ones(T, T)).view(1, 1, T, T)

        self.bias = self.bias.to(device)
        self.attn_mask = self.attn_mask.to(device)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = (self.c_attn(x) + d).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = att + self.attn_mask
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        dW_proj = self.B_proj @ self.A_proj
        proj = y @ dW_proj.T
        
        y = self.c_proj(y) + proj
        y = self.resid_dropout(y)
        
        return y
class LoRABlock(nn.Module):
    def __init__(self, block, config, h):
        super().__init__()
        
        self.ln_1 = block.ln_1
        self.attn = LoRACausalSelfAttention(block.attn, config, h)
        self.ln_2 = block.ln_2
        self.mlp = block.mlp

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
class LoRAGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        self.model = GPT(config)
            
        if config.fine_tune:
            print("Fine-tuning, set requires_grad = False.")
            for p in self.model.parameters():
                p.requires_grad = False

        f = (8 / self.config.n_layer)
        self.lora_blocks = nn.ModuleList([LoRABlock(block, config, h + f) for h, block in enumerate(self.model.transformer.h)])

    def forward(self, idx):
        device = idx.device
        b, t = idx.size()
        
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        tok_emb = self.model.transformer.wte(idx)
        x = self.model.transformer.drop(tok_emb)
        for lora_block in self.lora_blocks:
            x = lora_block(x)
        x = self.model.transformer.ln_f(x)
        logits = self.model.lm_head(x)
        return logits

    def generate(self, idx, max_new_tokens):
        return self.model.generate(idx, max_new_tokens)

In [10]:
config = GPTConfig()
model = LoRAGPT(config).to(DEVICE)

if config.checkpoint is not None:
    model.model.load_state_dict(config.checkpoint["model_state_dict"])

scaler = torch.cuda.amp.GradScaler(enabled=True)
if config.checkpoint is not None:
    scaler.load_state_dict(config.checkpoint["scaler_state_dict"])

optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
if config.checkpoint is not None:
    optimizer.load_state_dict(config.checkpoint["optimizer_state_dict"])

n_model_parameters = sum(p.numel() for p in model.model.parameters())
n_lora_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {n_model_parameters:,}".replace(",", "."))
print(f"LoRA: {n_lora_parameters:,}".replace(",", "."))
print(f"ratio: {(n_lora_parameters / n_model_parameters) * 100:.02f}%")
print()

with torch.no_grad():
    batch = next(iter(train_loader))
    in_tokens = batch["tokens"][:, :-WINDOW].to(DEVICE).to(torch.long)
    out_tokens = batch["tokens"][:, WINDOW:].to(DEVICE).to(torch.long)
    print(in_tokens.shape)
    logits = model(in_tokens)
    print(logits.shape)
    a = logits.view(logits.shape[0] * BLOCK_SIZE, -1)
    b = out_tokens.view(-1)
    loss = F.cross_entropy(a, b)
    print(loss)
    print()

mlflow.set_tracking_uri(uri="http://localhost:8080")
_ = mlflow.set_experiment(f"GPT2 - ALiBi - fine_tune={FINE_TUNE}, block_size={BLOCK_SIZE}, n_embd={N_EMBD}, n_layer={N_LAYER}, n_head={N_HEAD}")

Model: 123.551.232
LoRA: 123.551.232
ratio: 100.00%

torch.Size([48, 256])
torch.Size([48, 256, 50257])
tensor(7.3060, device='cuda:0')



In [11]:
def generate_examples(loader, num_new_tokens=16, verbose=False):
    data = dict()
    with torch.no_grad():
        batch = next(iter(loader))

        in_tokens = batch["tokens"][:, :-1].to(DEVICE)
        out_tokens = batch["tokens"][:, 1:].to(DEVICE)

        pred = model.generate(in_tokens, max_new_tokens=num_new_tokens)

        for i, (in_tokens_, out_tokens_, pred_) in enumerate(zip(in_tokens, out_tokens, pred)):
            in_tokens_ = in_tokens_.detach().cpu().numpy()
            out_tokens_ = out_tokens_.detach().cpu().numpy()
            pred_ = pred_.detach().cpu().numpy()
            pred_ = pred_[:-num_new_tokens].tolist() + [VOCAB_SIZE - 1] + pred_[-num_new_tokens:].tolist()

            in_text = tokenizer.decode(in_tokens_)
            pred_text = tokenizer.decode(pred_)

            data[i] = {"in_text": in_text, "pred_text": pred_text}

            if verbose:
                print("INPUT")
                print(in_text)
                print("=========================================")
                print("OUTPUT")
                print(pred_text)
                print("=========================================")
                print("=========================================")
                print("=========================================")
                print()
    return data
def train():
    iterator = iter(train_loader)
    N = len(train_loader) // BATCH_SIZE
    sum_loss = 0
    count = 0
    for n_epoch in range(N_EPOCHS):
        for n in range(N):
            for b in range(BATCH_ACC):
                with torch.autocast(device_type=DEVICE, dtype=torch.float16, enabled=True):
                    batch = next(iterator)
                    
                    in_tokens = batch["tokens"][:, :-WINDOW].to(DEVICE)
                    out_tokens = batch["tokens"][:, WINDOW:].to(DEVICE)
                    
                    logits = model(in_tokens)
                    loss = F.cross_entropy(logits.view(logits.shape[0] * BLOCK_SIZE, -1), out_tokens.view(-1), ignore_index=END_TOKEN)
                    mlflow.log_metric("train_loss", loss, step=count, synchronous=False)
                scaler.scale(loss).backward()

                sum_loss += loss.item()
                count += 1

                mlflow.log_metric("mean_train_loss", sum_loss / count, step=(n_epoch + 1) * (n + 1) * BATCH_ACC, synchronous=False)

                print(f"\r{n_epoch + 1:03d}|{N_EPOCHS:03d}, {n + 1:04d}|{N:04d}, {b + 1:03d}|{BATCH_ACC:03d}, loss: {sum_loss / count:.05f}", end="")

            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            scaler.step(optimizer)
            scaler.update()
            
            optimizer.zero_grad(set_to_none=True)
                
            if (n + 1) % 25 == 0:
                print()
                sum_loss = 0
                count = 0
                
            if (n + 1) % 150 == 0:
                print("\nSave...")
                torch.save({"model_state_dict": model.model.state_dict(), 
                            "optimizer_state_dict": optimizer.state_dict(), 
                            "scaler_state_dict": scaler.state_dict()}, FILE_NAME)
                print("...done!\n")
                data = generate_examples(train_loader, num_new_tokens=16, verbose=False)
                mlflow.log_dict(data, f"example_{n_epoch + 1}_{n + 1}.json")
                test(n_epoch + 1, (n_epoch + 1) * (n + 1) * BATCH_ACC)
@torch.no_grad
def test(epoch=0, step=0):
    sum_loss = 0
    count = 0
    for i, batch in enumerate(test_loader):
        in_tokens = batch["tokens"][:, :-WINDOW].to(DEVICE)
        out_tokens = batch["tokens"][:, WINDOW:].to(DEVICE)
        
        logits = model(in_tokens)
        loss = F.cross_entropy(logits.view(logits.shape[0] * BLOCK_SIZE, -1), out_tokens.view(-1), ignore_index=END_TOKEN)

        sum_loss += loss.item()
        count += 1

        print(f"\r{i + 1:04d}|{len(test_loader):04d}, loss: {sum_loss / count:.05f}", end="")

    data = generate_examples(test_loader, num_new_tokens=16, verbose=False)
    mlflow.log_dict(data, f"example_{epoch}_{i + 1}.json")
    mlflow.log_metric("test_loss", sum_loss / count, step=step, synchronous=False)
    print()
    print()

In [None]:
run_id = "06ec1fd7920d49b6968eeebf180f97d6"
with mlflow.start_run(run_id=run_id):
    train()

001|100, 0025|8758, 070|070, loss: 6.26017
001|100, 0050|8758, 070|070, loss: 5.60162
001|100, 0059|8758, 065|070, loss: 5.32654

In [None]:
torch.save({"model_state_dict": model.model.state_dict(), 
            "optimizer_state_dict": optimizer.state_dict(), 
            "scaler_state_dict": scaler.state_dict()}, FILE_NAME)

In [None]:
generate_examples(train_loader, num_new_tokens=16, verbose=True)