In [None]:
import os
import sys
import requests
from tqdm import tqdm

# src: https://github.com/openai/gpt-2/blob/master/download_model.py
def download_model(model):
    subdir = os.path.join('../models/openai/', model)
    if not os.path.exists(subdir):
        os.makedirs(subdir, exist_ok=True)
    subdir = subdir.replace('\\','/') 
    
    for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:
        url = "https://openaipublic.blob.core.windows.net/gpt-2/models/" + model + "/" + filename
        print(url)
        r = requests.get(url, stream=True)
    
        with open(os.path.join(subdir, filename), 'wb') as f:
            file_size = int(r.headers["content-length"])
            chunk_size = 1000
            with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
                # 1k for chunk_size, since Ethernet packet size is around 1500 bytes
                for chunk in r.iter_content(chunk_size=chunk_size):
                    f.write(chunk)
                    pbar.update(chunk_size)

# download_model('124M')



In [101]:
import tensorflow as tf
import torch
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional

import torch
import tensorflow as tf

def load_gpt2_checkpoint(checkpoint_path: str, model, config):
    """Load TF GPT-2 checkpoint into PyTorch model."""
    checkpoint = tf.train.load_checkpoint(checkpoint_path)
    pt_dict = {}
    
    def transfer(pt_key, tf_key, transform=None):
        tensor = torch.from_numpy(checkpoint.get_tensor(tf_key))
        pt_dict[pt_key] = transform(tensor) if transform else tensor
    
    def transfer_block(i, pt_suffix, tf_suffix, transform=None):
        transfer(f'transformer_blocks.{i}.{pt_suffix}', f'model/h{i}/{tf_suffix}', transform)
    
    # Global weights (outside of transformer stack)
    transfer('token_embedding.weight', 'model/wte')
    transfer('position_embeddings.weight', 'model/wpe')
    transfer('final_norm.gain', 'model/ln_f/g')
    transfer('final_norm.bias', 'model/ln_f/b')
    
    # weight tying
    pt_dict['out_head.weight'] = pt_dict['token_embedding.weight']
    
    # Per-layer weights
    for i in range(config.n_layers):
        # Norms
        transfer_block(i, 'norm1.gain', 'ln_1/g')
        transfer_block(i, 'norm1.bias', 'ln_1/b')
        transfer_block(i, 'norm2.gain', 'ln_2/g')
        transfer_block(i, 'norm2.bias', 'ln_2/b')
        
        # Attention - split QKV weights
        qkv_w = torch.from_numpy(checkpoint.get_tensor(f'model/h{i}/attn/c_attn/w')).squeeze(0)
        q, k, v = torch.split(qkv_w, config.emb_dim, dim=1)
        pt_dict[f'transformer_blocks.{i}.att.W_query.weight'] = q.T
        pt_dict[f'transformer_blocks.{i}.att.W_key.weight'] = k.T
        pt_dict[f'transformer_blocks.{i}.att.W_value.weight'] = v.T
        
        if config.qkv_bias:
            qkv_b = torch.from_numpy(checkpoint.get_tensor(f'model/h{i}/attn/c_attn/b'))
            q_b, k_b, v_b = torch.split(qkv_b, config.emb_dim, dim=0)
            pt_dict[f'transformer_blocks.{i}.att.W_query.bias'] = q_b
            pt_dict[f'transformer_blocks.{i}.att.W_key.bias'] = k_b
            pt_dict[f'transformer_blocks.{i}.att.W_value.bias'] = v_b
        
        transfer_block(i, 'att.out_proj.weight', 'attn/c_proj/w', lambda x: x.squeeze(0).T)
        transfer_block(i, 'att.out_proj.bias', 'attn/c_proj/b')
        
        # MLP
        transfer_block(i, 'ff.expansion.weight', 'mlp/c_fc/w', lambda x: x.squeeze(0).T)
        transfer_block(i, 'ff.expansion.bias', 'mlp/c_fc/b')
        transfer_block(i, 'ff.projection.weight', 'mlp/c_proj/w', lambda x: x.squeeze(0).T)
        
        if config.mlp_bias:
            transfer_block(i, 'ff.projection.bias', 'mlp/c_proj/b')
    
    return model.load_state_dict(pt_dict, strict=False)



#%run -n 00_config.ipynb
#%run -n 02_gpt2_model.ipynb

cfg = GPT2Config(qkv_bias=True, device='cpu')
m = GPTModel(cfg)
gpt = load_gpt2_checkpoint('/home/jimsingh/src/llm_e2e/models/openai/124M/', m, cfg)

generate_text(m, enc, "the cat in the")


'the cat in the pic from these lines) is pretty toothy as awry luck would dictate — Pastry Guide (@'