In [1]:
from pathlib import Path
from dataclasses import dataclass
import inspect
from pathlib import Path
import tiktoken
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import json

import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
torch.manual_seed(1337)
random.seed(1337)


def get_device():
    device = 'cpu'
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)
        device = 'cuda'
    # elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    #     device = 'mps'
    print(f'using {device}')
    return device

device = get_device()

log_path = Path('/home/ubuntu/log')

using cuda


### Model

In [3]:
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        # ensures that you can split embeddings across the heads
        assert config.n_embd % config.n_head == 0
        # key, query, value projection for all heads in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, n_embd embedding dimensionality 
        # calculate query, key, value for all heads in batch, then move head forward
        # nh - num heads, hs - head size, C  nh*hs aka channels
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        # attention (materializes the large (T,T) matrix for all the queries and keys)
        # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        # att = F.softmax(att, dim=-1)
        # y = att @ v # (B, nh, T, T) X (B, nh, T, hs) - > (B, nh, T, hs)
        # replace attention with flash attention 
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        # allows for pathway to pass through gradients instead of going through each "box"
        # this is a feed forward network
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length/context
    vocab_size: int = 100276 # switched to GP4 tokenizer 
    n_layer: int = 12 
    n_head: int = 12
    n_embd: int = 768

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd), # weight tokenizer element
            wpe = nn.Embedding(config.block_size, config.n_embd), # weight position element
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # hidden layers aka Transformers
            ln_f = nn.LayerNorm(config.n_embd), #log normalization 
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # language model head going from embeddings to vocab

        # weight sharing scheme
        self.transformer.wte.weight = self.lm_head.weight

        # init params
        self.apply(self._init_weights)

    def _init_weights(self, module):
        std = 0.02
        mean = 0.0
        if isinstance(module, nn.Linear):
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=mean, std=std)

    def forward(self, idx, targets=None):
        # idx is shape (B, T)
        B, T = idx.size()
        assert T <= self.config.block_size, f'Cannot forward sequence, out of context'
        # forward the token and positions
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
        pos_emb = self.transformer.wpe(pos) # (T, n_embd)
        tok_emb = self.transformer.wte(idx) # (B, T, n_embd)
        x = tok_emb + pos_emb
        # forward the block
        for block in self.transformer.h:
            x = block(x)
        # forward the final layernorm and head
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, device_type):
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        if master_process:
            print(f'num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters')
            print(f'num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters')
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        if master_process:
            print(f'using fused AdamW: {use_fused}')
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer

    @classmethod
    def from_pretrained(cls, ckpt_path, device):
        ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
        cfg = ckpt['config'] if isinstance(ckpt['config'], GPTConfig) else GPTConfig(**ckpt['config'])
        print(f'training steps {ckpt['step']}  |  test loss {ckpt['test_loss']:.4f}')
        
        def _clean(name):
            if name.startswith('_orig_mod.'):
                name = name[len('_orig_mod.'):]
            if name.startswith('module.'):
                name = name[len('module.'):]
            return name

        raw_sd   = ckpt['model']
        clean_sd = { _clean(k): v for k, v in raw_sd.items() }
        
        model = cls(cfg).to(device)
        model.load_state_dict(clean_sd, strict=True)
        return model

### Configs

In [4]:
enc = tiktoken.get_encoding('cl100k_base')

# set up DDP (distributed data parallel).
device_type = device # override device if using ddp do device_type acts as backup
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
    # use of DDP atm demands CUDA, we set the device appropriately according to rank
    assert torch.cuda.is_available(), 'for now i think we need CUDA for DDP'
    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
else:
    # vanilla, non-DDP run
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    print(f'not using ddp, using device: {device}')

not using ddp, using device: cuda


### Load Model

In [5]:
model_path = Path('~/model/model_final.pt').expanduser() 
model  = GPT.from_pretrained(model_path, device)

training steps 9999  |  test loss 3.0506


In [6]:
model.eval()
num_return_sequences = 1
max_length = 100
tokens = enc.encode('Hello, I\'m a language model,')
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
xgen = tokens.to(device)
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(42 + random.randint(0,9999))
while xgen.size(1) < max_length:
    # forward the model to get the logits
    with torch.no_grad():
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(xgen) # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)
        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        xgen = torch.cat((xgen, xcol), dim=1)
# print the generated text
for i in range(num_return_sequences):
    tokens = xgen[i, :max_length].tolist()
    decoded = enc.decode(tokens)
    print(f'rank {ddp_rank} sample {i}: {decoded}')

rank 0 sample 0: Hello, I'm a language model, but I'll never actually write this. It isn't something I've tried; at best, I just want something simple, more readable, with some good, simple, semantic concepts and a nice (and somewhat simple) interface.
I also try to simplify the entire language into very simple pieces. For now, it's mostly because of a problem in how the data's written to the computer. So, what's in the document, whatever the document, is


### Hellaswag

In [13]:
def get_most_likely_row(tokens, mask, logits):
    # helper function for HellaSwag eval
    # takes tokens, mask, and logits, returns the index of the completion with the lowest loss
    # evaluate the autoregressive loss at all positions
    shift_logits = (logits[..., :-1, :]).contiguous()
    shift_tokens = (tokens[..., 1:]).contiguous()
    flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    flat_shift_tokens = shift_tokens.view(-1)
    shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
    shift_losses = shift_losses.view(tokens.size(0), -1)
    # now get the average loss just for the completion region (where mask == 1), in each row
    shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
    masked_shift_losses = shift_losses * shift_mask
    # sum and divide by the number of 1s in the mask
    sum_loss = masked_shift_losses.sum(dim=1)
    avg_loss = sum_loss / shift_mask.sum(dim=1)
    # now we have a loss for each of the 4 completions
    # the one with the lowest loss should be the most likely
    pred_norm = avg_loss.argmin().item()
    return pred_norm

def download_file(url: str, fname: str, chunk_size=1024):
    """Helper function to download a file from a given url"""
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get("content-length", 0))
    print(f'loading {fname}')
    with open(fname, "wb") as file, tqdm(
        desc=fname,
        total=total,
        unit="iB",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in resp.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)

hellaswags = {
    "train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
    "val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
    "test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
}

def download(split):
    """Downloads HellaSwag DATA_CACHE_DIR"""
    os.makedirs(DATA_CACHE_DIR, exist_ok=True)
    data_url = hellaswags[split]
    print(f'hellaswag_{split}.jsonl')
    data_filename = DATA_CACHE_DIR /  f'hellaswag_{split}.jsonl'
    if not os.path.exists(data_filename):
        print(f"Downloading {data_url} to {data_filename}...")
        download_file(data_url, data_filename)

def render_example(example):
    """
    Given the example as a dictionary, render it as three torch tensors:
    - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
    - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
    - label (the index of the correct completion, which we hope has the highest likelihood)
    """
    ctx = example["ctx"]
    label = example["label"]
    endings = example["endings"]

    # data needed to reproduce this eval on the C size
    data = {
        "label": label,
        "ctx_tokens": None,
        "ending_tokens": [],
    }

    # gather up all the tokens
    ctx_tokens = enc.encode(ctx)
    data["ctx_tokens"] = ctx_tokens
    tok_rows = []
    mask_rows = []
    for end in endings:
        end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
        tok_rows.append(ctx_tokens + end_tokens)
        mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
        data["ending_tokens"].append(end_tokens)

    # have to be careful during the collation because the number of tokens in each row can differ
    max_len = max(len(row) for row in tok_rows)
    tokens = torch.zeros((4, max_len), dtype=torch.long)
    mask = torch.zeros((4, max_len), dtype=torch.long)
    for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
        tokens[i, :len(tok_row)] = torch.tensor(tok_row)
        mask[i, :len(mask_row)] = torch.tensor(mask_row)

    return data, tokens, mask, label

def iterate_examples(split):
    # there are 10,042 examples in total in val
    download(split)
    with open(os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r") as f:
        for line in f:
            example = json.loads(line)
            yield example


In [9]:
log_file = log_path / 'log_hellaswag.txt'
with open(log_file, 'w') as f: # open for writing to clear the file
    pass

In [12]:
DATA_CACHE_DIR = model_path = Path('~/data/hellaswag')

num_correct_norm = 0
num_total = 0
for i, example in enumerate(iterate_examples("val")):
    # only process examples where i % ddp_world_size == ddp_rank
    if i % ddp_world_size != ddp_rank:
        continue
    # render the example into tokens and labels
    _, tokens, mask, label = render_example(example)
    tokens = tokens.to(device)
    mask = mask.to(device)
    # get the logits
    with torch.no_grad():
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(tokens)
        pred_norm = get_most_likely_row(tokens, mask, logits)
    num_total += 1
    num_correct_norm += int(pred_norm == label)
    if pred_norm == label:
        print(f'got {i} correct, {num_correct_norm}/{num_total} {num_correct_norm/num_total:.4f}')
# reduce the stats across all processes
if ddp:
    num_total = torch.tensor(num_total, dtype=torch.long, device=device)
    num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
    dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
    dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
    num_total = num_total.item()
    num_correct_norm = num_correct_norm.item()
acc_norm = num_correct_norm / num_total
if master_process:
    print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
    with open(log_file, "a") as f:
        f.write(f"Final_hella_validation {acc_norm:.4f}\n")

hellaswag_val.jsonl
evaluating 0
got 8 correct
got 10 correct
got 15 correct
got 22 correct
got 23 correct
got 27 correct
got 32 correct
got 36 correct
got 37 correct
got 38 correct
got 39 correct
got 42 correct
got 44 correct
got 45 correct
got 46 correct
got 49 correct
got 51 correct
got 56 correct
got 57 correct
got 69 correct
got 72 correct
got 73 correct
got 76 correct
got 77 correct
got 78 correct
got 79 correct
got 82 correct
got 85 correct
got 86 correct
got 88 correct
got 92 correct
got 93 correct
got 94 correct
got 95 correct
got 97 correct
evaluating 100
got 105 correct
got 107 correct
got 108 correct
got 113 correct
got 114 correct
got 119 correct
got 120 correct
got 128 correct
got 133 correct
got 134 correct
got 135 correct
got 141 correct
got 146 correct
got 147 correct
got 149 correct
got 151 correct
got 154 correct
got 156 correct
got 158 correct
got 159 correct
got 165 correct
got 167 correct
got 168 correct
got 178 correct
got 185 correct
got 188 correct
got 189 corr

In [14]:
num_correct_norm = 0
num_total = 0
for i, example in enumerate(iterate_examples("train")):
    # only process examples where i % ddp_world_size == ddp_rank
    if i % ddp_world_size != ddp_rank:
        continue
    # render the example into tokens and labels
    _, tokens, mask, label = render_example(example)
    tokens = tokens.to(device)
    mask = mask.to(device)
    # get the logits
    with torch.no_grad():
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(tokens)
        pred_norm = get_most_likely_row(tokens, mask, logits)
    num_total += 1
    num_correct_norm += int(pred_norm == label)
    if pred_norm == label:
        print(f'got {i} correct, {num_correct_norm}/{num_total} {num_correct_norm/num_total:.4f}')
# reduce the stats across all processes
if ddp:
    num_total = torch.tensor(num_total, dtype=torch.long, device=device)
    num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
    dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
    dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
    num_total = num_total.item()
    num_correct_norm = num_correct_norm.item()
acc_norm = num_correct_norm / num_total
if master_process:
    print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
    with open(log_file, "a") as f:
        f.write(f"Final_hella_train {acc_norm:.4f}\n")

hellaswag_train.jsonl
got 0 correct, 1/1 1.0000
got 2 correct, 2/3 0.6667
got 4 correct, 3/5 0.6000
got 8 correct, 4/9 0.4444
got 11 correct, 5/12 0.4167
got 13 correct, 6/14 0.4286
got 15 correct, 7/16 0.4375
got 19 correct, 8/20 0.4000
got 21 correct, 9/22 0.4091
got 23 correct, 10/24 0.4167
got 24 correct, 11/25 0.4400
got 27 correct, 12/28 0.4286
got 31 correct, 13/32 0.4062
got 32 correct, 14/33 0.4242
got 33 correct, 15/34 0.4412
got 34 correct, 16/35 0.4571
got 36 correct, 17/37 0.4595
got 40 correct, 18/41 0.4390
got 41 correct, 19/42 0.4524
got 43 correct, 20/44 0.4545
got 46 correct, 21/47 0.4468
got 50 correct, 22/51 0.4314
got 51 correct, 23/52 0.4423
got 52 correct, 24/53 0.4528
got 54 correct, 25/55 0.4545
got 57 correct, 26/58 0.4483
got 65 correct, 27/66 0.4091
got 67 correct, 28/68 0.4118
got 68 correct, 29/69 0.4203
got 71 correct, 30/72 0.4167
got 76 correct, 31/77 0.4026
got 77 correct, 32/78 0.4103
got 78 correct, 33/79 0.4177
got 80 correct, 34/81 0.4198
got 85 co