In [None]:
!git clone https://github.com/hanani8/Splitter

Cloning into 'Splitter'...
remote: Enumerating objects: 115, done.[K
remote: Counting objects: 100% (115/115), done.[K
remote: Compressing objects: 100% (73/73), done.[K
remote: Total 115 (delta 45), reused 105 (delta 35), pack-reused 0 (from 0)[K
Receiving objects: 100% (115/115), 29.93 KiB | 4.27 MiB/s, done.
Resolving deltas: 100% (45/45), done.


In [None]:
!pip install tiktoken tensorflow>=2.15.0 tqdm>=4.66 torch "numpy<2.0.0"

In [None]:
from Splitter.training import SimpleTrainer
from Splitter.gpt_download import download_and_load_gpt2
from Splitter.models import GPTModel

In [None]:
settings, params = download_and_load_gpt2(model_size="124M", models_dir="state/gpt2")

checkpoint: 100%|██████████| 77.0/77.0 [00:00<00:00, 38.3kiB/s]
encoder.json: 100%|██████████| 1.04M/1.04M [00:00<00:00, 4.07MiB/s]
hparams.json: 100%|██████████| 90.0/90.0 [00:00<00:00, 64.2kiB/s]
model.ckpt.data-00000-of-00001: 100%|██████████| 498M/498M [00:09<00:00, 50.2MiB/s]
model.ckpt.index: 100%|██████████| 5.21k/5.21k [00:00<00:00, 3.24MiB/s]
model.ckpt.meta: 100%|██████████| 471k/471k [00:00<00:00, 3.13MiB/s]
vocab.bpe: 100%|██████████| 456k/456k [00:00<00:00, 2.08MiB/s]


In [None]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,
    "context_length": 1024,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": True,
    "forward_layer_size": 4
}

In [None]:
gpt = GPTModel(GPT_CONFIG_124M)

In [None]:
gpt.eval()

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_resid): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768,

In [None]:
import torch

In [None]:
def assign(left, right):
    if left.shape != right.shape:
        raise ValueError("Shape mismatch. Left:", left.shape, "Right:", right.shape)
    return torch.nn.Parameter(torch.tensor(right))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import numpy as np

def load_weights_into_gpt(gpt, params):
    gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
    gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])

    for b in range(len(params["blocks"])):
        q_w, k_w, v_w = np.split(
            (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
        gpt.trf_blocks[b].att.W_query.weight = assign(
            gpt.trf_blocks[b].att.W_query.weight, q_w.T)
        gpt.trf_blocks[b].att.W_key.weight = assign(
            gpt.trf_blocks[b].att.W_key.weight, k_w.T)
        gpt.trf_blocks[b].att.W_value.weight = assign(
            gpt.trf_blocks[b].att.W_value.weight, v_w.T)

        q_b, k_b, v_b = np.split(
            (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
        gpt.trf_blocks[b].att.W_query.bias = assign(
            gpt.trf_blocks[b].att.W_query.bias, q_b)
        gpt.trf_blocks[b].att.W_key.bias = assign(
            gpt.trf_blocks[b].att.W_key.bias, k_b)
        gpt.trf_blocks[b].att.W_value.bias = assign(
            gpt.trf_blocks[b].att.W_value.bias, v_b)

        gpt.trf_blocks[b].att.out_proj.weight = assign(
            gpt.trf_blocks[b].att.out_proj.weight,
            params["blocks"][b]["attn"]["c_proj"]["w"].T)
        gpt.trf_blocks[b].att.out_proj.bias = assign(
            gpt.trf_blocks[b].att.out_proj.bias,
            params["blocks"][b]["attn"]["c_proj"]["b"])

        gpt.trf_blocks[b].ff.layers[0].weight = assign(
            gpt.trf_blocks[b].ff.layers[0].weight,
            params["blocks"][b]["mlp"]["c_fc"]["w"].T)
        gpt.trf_blocks[b].ff.layers[0].bias = assign(
            gpt.trf_blocks[b].ff.layers[0].bias,
            params["blocks"][b]["mlp"]["c_fc"]["b"])
        gpt.trf_blocks[b].ff.layers[2].weight = assign(
            gpt.trf_blocks[b].ff.layers[2].weight,
            params["blocks"][b]["mlp"]["c_proj"]["w"].T)
        gpt.trf_blocks[b].ff.layers[2].bias = assign(
            gpt.trf_blocks[b].ff.layers[2].bias,
            params["blocks"][b]["mlp"]["c_proj"]["b"])

        gpt.trf_blocks[b].norm1.scale = assign(
            gpt.trf_blocks[b].norm1.scale,
            params["blocks"][b]["ln_1"]["g"])
        gpt.trf_blocks[b].norm1.shift = assign(
            gpt.trf_blocks[b].norm1.shift,
            params["blocks"][b]["ln_1"]["b"])
        gpt.trf_blocks[b].norm2.scale = assign(
            gpt.trf_blocks[b].norm2.scale,
            params["blocks"][b]["ln_2"]["g"])
        gpt.trf_blocks[b].norm2.shift = assign(
            gpt.trf_blocks[b].norm2.shift,
            params["blocks"][b]["ln_2"]["b"])

    gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
    gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
    gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])

load_weights_into_gpt(gpt, params)
gpt.to(device)


GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_resid): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768,

In [None]:
from Splitter.generators import ProbabilisticTextGenerator
from Splitter.tokenizers import TiktokenTokenizer

tokenizer = TiktokenTokenizer()

inputs = tokenizer.text_to_tokens("Every effort moves you").to(device)
generator = ProbabilisticTextGenerator(gpt, max_new_tokens=25, top_k=25, temperature=1.4)

for _ in range(10):
    logits = generator.generate(inputs)
    output = tokenizer.tokens_to_text(logits)
    print("-", output.replace("\n", " "))



- Every effort moves you forward in your life; it is only when you begin to realize that you are able to make progress, and start learning something
- Every effort moves you on to the next step.  If you don't get the job done in less the rest goes up. The best
- Every effort moves you to reach a specific target or a desired effect; sometimes all there is you can do is do it."   - Michael
- Every effort moves you along. The more time you spend with your partner in the relationship, the happier and more likely you will be to get married
- Every effort moves you forward in the right direction and you don't end up falling behind in your work." She said that the same thing goes for
- Every effort moves you the better, to achieve excellence in your craft and to achieve your success wherever you go". We're committed to teaching all,
- Every effort moves you to the edge of extinction with each day but as time passes, you have learned things you don't expect to get in and
- Every effort moves y