In [8]:
%reload_ext autoreload
%autoreload 2

## Load ONNX model in tinygrad

In [9]:
import onnx
from tg_onnx import get_run_onnx

# Load the ONNX model
model = onnx.load("shakespeare_model.onnx")
# Create a callable object 'run_onnx' that executes the model
run_onnx = get_run_onnx(model)

## Count parameters

In [10]:
from prettytable import PrettyTable
import numpy as np
import onnx

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0

    name_mapping = {
        "token_embedding.weight": "token_embedding.weight",
        "onnx::Mul_610": "transformer.0.norm_attn.gain",
        "onnx::MatMul_611": "transformer.0.multihead_attn.proj_qkv.weight",
        "onnx::MatMul_631": "transformer.0.multihead_attn.proj_out.weight",
        "onnx::Mul_632": "transformer.0.norm_ffn.gain",
        "onnx::MatMul_633": "transformer.0.feed_forward.0.weight",
        "onnx::MatMul_634": "transformer.0.feed_forward.1.linear.weight",
        "transformer.0.feed_forward.1.linear.bias": "transformer.0.feed_forward.1.linear.bias",
        "onnx::MatMul_643": "transformer.0.feed_forward.2.weight",
        "onnx::Mul_644": "transformer.1.norm_attn.gain",
        "onnx::MatMul_645": "transformer.1.multihead_attn.proj_qkv.weight",
        "onnx::MatMul_665": "transformer.1.multihead_attn.proj_out.weight",
        "onnx::Mul_666": "transformer.1.norm_ffn.gain",
        "onnx::MatMul_667": "transformer.1.feed_forward.0.weight",
        "onnx::MatMul_668": "transformer.1.feed_forward.1.linear.weight",
        "transformer.1.feed_forward.1.linear.bias": "transformer.1.feed_forward.1.linear.bias",
        "onnx::MatMul_677": "transformer.1.feed_forward.2.weight",
        "onnx::Mul_678": "transformer.2.norm_attn.gain",
        "onnx::MatMul_679": "transformer.2.multihead_attn.proj_qkv.weight",
        "onnx::MatMul_699": "transformer.2.multihead_attn.proj_out.weight",
        "onnx::Mul_700": "transformer.2.norm_ffn.gain",
        "onnx::MatMul_701": "transformer.2.feed_forward.0.weight",
        "onnx::MatMul_702": "transformer.2.feed_forward.1.linear.weight",
        "transformer.2.feed_forward.1.linear.bias": "transformer.2.feed_forward.1.linear.bias",
        "onnx::MatMul_711": "transformer.2.feed_forward.2.weight",
        "onnx::Mul_712": "transformer.3.norm_attn.gain",
        "onnx::MatMul_713": "transformer.3.multihead_attn.proj_qkv.weight",
        "onnx::MatMul_733": "transformer.3.multihead_attn.proj_out.weight",
        "onnx::Mul_734": "transformer.3.norm_ffn.gain",
        "onnx::MatMul_735": "transformer.3.feed_forward.0.weight",
        "onnx::MatMul_736": "transformer.3.feed_forward.1.linear.weight",
        "transformer.3.feed_forward.1.linear.bias": "transformer.3.feed_forward.1.linear.bias",
        "onnx::MatMul_745": "transformer.3.feed_forward.2.weight",
        "onnx::Mul_746": "norm.gain",
        "projection_head.bias": "projection_head.bias",
    }

    exclude_params = [
        "onnx::MatMul_747",
        "transformer.0.multihead_attn.positional_encoding.position_cos",
        "transformer.0.multihead_attn.positional_encoding.position_sin",
        "onnx::Where_630"
    ]

    module_order = [
        "token_embedding.weight",
        "transformer.0.norm_attn.gain",
        "transformer.0.multihead_attn.proj_qkv.weight",
        "transformer.0.multihead_attn.proj_out.weight",
        "transformer.0.norm_ffn.gain",
        "transformer.0.feed_forward.0.weight",
        "transformer.0.feed_forward.1.linear.weight",
        "transformer.0.feed_forward.1.linear.bias",
        "transformer.0.feed_forward.2.weight",
        "transformer.1.norm_attn.gain",
        "transformer.1.multihead_attn.proj_qkv.weight",
        "transformer.1.multihead_attn.proj_out.weight",
        "transformer.1.norm_ffn.gain",
        "transformer.1.feed_forward.0.weight",
        "transformer.1.feed_forward.1.linear.weight",
        "transformer.1.feed_forward.1.linear.bias",
        "transformer.1.feed_forward.2.weight",
        "transformer.2.norm_attn.gain",
        "transformer.2.multihead_attn.proj_qkv.weight",
        "transformer.2.multihead_attn.proj_out.weight",
        "transformer.2.norm_ffn.gain",
        "transformer.2.feed_forward.0.weight",
        "transformer.2.feed_forward.1.linear.weight",
        "transformer.2.feed_forward.1.linear.bias",
        "transformer.2.feed_forward.2.weight",
        "transformer.3.norm_attn.gain",
        "transformer.3.multihead_attn.proj_qkv.weight",
        "transformer.3.multihead_attn.proj_out.weight",
        "transformer.3.norm_ffn.gain",
        "transformer.3.feed_forward.0.weight",
        "transformer.3.feed_forward.1.linear.weight",
        "transformer.3.feed_forward.1.linear.bias",
        "transformer.3.feed_forward.2.weight",
        "norm.gain",
        "projection_head.bias"
    ]

    params_dict = {}

    for initializer in model.graph.initializer:
        name = initializer.name
        params = np.prod(initializer.dims)
        mapped_name = name_mapping.get(name, name)
        if mapped_name not in exclude_params:
            params_dict[mapped_name] = params
            total_params += params

    for module in module_order:
        if module in params_dict:
            table.add_row([module, params_dict[module]])

    print(table)
    print(f"Total Trainable Params: {total_params}\n")
    return total_params

# Load the ONNX model
model = onnx.load("shakespeare_model.onnx")

# Call the function
count_parameters(model)


+----------------------------------------------+------------+
|                   Modules                    | Parameters |
+----------------------------------------------+------------+
|            token_embedding.weight            |  1048576   |
|         transformer.0.norm_attn.gain         |    256     |
| transformer.0.multihead_attn.proj_qkv.weight |   196608   |
| transformer.0.multihead_attn.proj_out.weight |   65536    |
|         transformer.0.norm_ffn.gain          |    256     |
|     transformer.0.feed_forward.0.weight      |   262144   |
|  transformer.0.feed_forward.1.linear.weight  |  2097152   |
|   transformer.0.feed_forward.1.linear.bias   |    2048    |
|     transformer.0.feed_forward.2.weight      |   262144   |
|         transformer.1.norm_attn.gain         |    256     |
| transformer.1.multihead_attn.proj_qkv.weight |   196608   |
| transformer.1.multihead_attn.proj_out.weight |   65536    |
|         transformer.1.norm_ffn.gain          |    256     |
|     tr

np.int64(12597504)

## Prepare tokenizer and dataset

In [11]:
import sys
sys.path.append("..")
from pathlib import Path
from model.tokenizer import Tokenizer

# specifies path to input text file used for training or retraining tokenizer
input_file = "../data/shakespeare/tinyshakespeare.txt"
# creates new file path for tokenizer model by changing suffix of input file
output_file = Path(input_file).with_suffix(".model")
# initialize tokenizer by loading it from 'output_file'
tokenizer = Tokenizer(str(output_file))

In [12]:
# Test tokenizer functionality
sentence = "Before we proceed any further, hear me speak."
print(tokenizer.sp.EncodeAsPieces(sentence))

# Ensure encoding and decoding returns the original sentence
assert tokenizer.decode(tokenizer.encode(sentence)) == sentence

['▁Before', '▁we', '▁proceed', '▁any', '▁further', ',', '▁hear', '▁me', '▁speak', '.']


In [14]:
from tinygrad.tensor import Tensor, dtypes
from model.llm import sample_top_p, LLM

# Encode the prompt using the tokenizer
prompt = tokenizer.encode(
    "KING HENRY VI:",
    beg_of_string=True,
    pad_seq=True,
    seq_len=llm_config.seq_len,
)

# Convert the encoded prompt to a tinygrad tensor and add an extra dimension
inputs = Tensor(prompt, dtype=dtypes.int32).unsqueeze(0)

# Generate a sequence of tokens using the LLM model starting from the 'inputs'
out = model.generate(inputs, max_seq_len=64)

# Convert the output tensor to a list of integers
output_list = out.flatten().astype(int).tolist()

# Decode the generated sequence of token IDs back into a human-readable string
decoded_output = tokenizer.decode(output_list)
print(decoded_output)

AttributeError: generate