In [1]:
from gpt_oss.torch.model import Transformer
from gpt_oss.triton.model import Transformer as TritonTransformer

In [2]:
import os


checkpoint_path = os.path.join(os.getcwd(), "gpt-oss-20b/original")
checkpoint_path

'/home/ksharma/dev/git/gpt-oss-scratch/gpt-oss-20b/original'

In [3]:
import json
from gpt_oss.torch.model import ModelConfig
import pprint


config_path = os.path.join(checkpoint_path, "config.json")
with open(config_path, "r") as f:
    json_config = json.load(f)
    config = ModelConfig(**json_config)
pprint.pprint(config)

ModelConfig(num_hidden_layers=24,
            num_experts=32,
            experts_per_token=4,
            vocab_size=201088,
            hidden_size=2880,
            intermediate_size=2880,
            swiglu_limit=7.0,
            head_dim=64,
            num_attention_heads=64,
            num_key_value_heads=8,
            sliding_window=128,
            initial_context_length=4096,
            rope_theta=150000,
            rope_scaling_factor=32.0,
            rope_ntk_alpha=1,
            rope_ntk_beta=32)


In [4]:
!nvidia-smi

Fri Sep  5 09:56:08 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.247.01             Driver Version: 535.247.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:01:00.0  On |                  Off |
|  0%   45C    P8              23W / 450W |    811MiB / 24564MiB |      8%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [8]:
model = Transformer.from_checkpoint(checkpoint_path, "cpu")

In [9]:
model

Transformer(
  (embedding): Embedding(201088, 2880)
  (block): ModuleList(
    (0-23): 24 x TransformerBlock(
      (attn): AttentionBlock(
        (norm): RMSNorm()
        (qkv): Linear(in_features=2880, out_features=5120, bias=True)
        (out): Linear(in_features=4096, out_features=2880, bias=True)
        (rope): RotaryEmbedding()
      )
      (mlp): MLPBlock(
        (norm): RMSNorm()
        (gate): Linear(in_features=2880, out_features=32, bias=True)
      )
    )
  )
  (norm): RMSNorm()
  (unembedding): Linear(in_features=2880, out_features=201088, bias=False)
)

In [7]:
num_parameters = 0
parameters_state_dict = model.state_dict()
for key, value in parameters_state_dict.items():
    print(key, value.size(), value.dtype, value.numel())
    num_parameters += value.numel()
print(f"Number of parameters: {num_parameters}")

embedding.weight torch.Size([201088, 2880]) torch.bfloat16 579133440
block.0.attn.sinks torch.Size([64]) torch.bfloat16 64
block.0.attn.norm.scale torch.Size([2880]) torch.float32 2880
block.0.attn.qkv.weight torch.Size([5120, 2880]) torch.bfloat16 14745600
block.0.attn.qkv.bias torch.Size([5120]) torch.bfloat16 5120
block.0.attn.out.weight torch.Size([2880, 4096]) torch.bfloat16 11796480
block.0.attn.out.bias torch.Size([2880]) torch.bfloat16 2880
block.0.mlp.mlp1_weight torch.Size([32, 5760, 2880]) torch.bfloat16 530841600
block.0.mlp.mlp1_bias torch.Size([32, 5760]) torch.bfloat16 184320
block.0.mlp.mlp2_weight torch.Size([32, 2880, 2880]) torch.bfloat16 265420800
block.0.mlp.mlp2_bias torch.Size([32, 2880]) torch.bfloat16 92160
block.0.mlp.norm.scale torch.Size([2880]) torch.float32 2880
block.0.mlp.gate.weight torch.Size([32, 2880]) torch.bfloat16 92160
block.0.mlp.gate.bias torch.Size([32]) torch.bfloat16 32
block.1.attn.sinks torch.Size([64]) torch.bfloat16 64
block.1.attn.norm.

In [5]:
triton_model = TritonTransformer.from_checkpoint(checkpoint_path, device="cuda")

In [14]:
triton_model

Transformer(
  (embedding): Embedding(201088, 2880)
  (block): ModuleList(
    (0-23): 24 x TransformerBlock(
      (attn): AttentionBlock(
        (norm): RMSNorm()
        (qkv): Linear(in_features=2880, out_features=5120, bias=True)
        (out): Linear(in_features=4096, out_features=2880, bias=True)
        (rope): RotaryEmbedding()
      )
      (mlp): MLPBlock(
        (norm): RMSNorm()
        (gate): ParameterDict(
            (bias): Parameter containing: [torch.cuda.BFloat16Tensor of size 32 (cuda:0)]
            (weight): Parameter containing: [torch.cuda.BFloat16Tensor of size 2880x32 (cuda:0)]
        )
      )
    )
  )
  (norm): RMSNorm()
  (unembedding): Linear(in_features=2880, out_features=201088, bias=False)
)

In [6]:
!nvidia-smi

Fri Sep  5 09:57:11 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.247.01             Driver Version: 535.247.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:01:00.0  On |                  Off |
| 30%   42C    P8              25W / 450W |  18423MiB / 24564MiB |     21%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [13]:
num_parameters = 0
parameters_state_dict = triton_model.state_dict()
for key, value in parameters_state_dict.items():
    print(key, value.size(), value.dtype, value.numel())
    num_parameters += value.numel()
print(f"Number of parameters: {num_parameters}")

embedding.weight torch.Size([201088, 2880]) torch.bfloat16 579133440
block.0.attn.sinks torch.Size([64]) torch.bfloat16 64
block.0.attn.norm.scale torch.Size([2880]) torch.float32 2880
block.0.attn.qkv.weight torch.Size([5120, 2880]) torch.bfloat16 14745600
block.0.attn.qkv.bias torch.Size([5120]) torch.bfloat16 5120
block.0.attn.out.weight torch.Size([2880, 4096]) torch.bfloat16 11796480
block.0.attn.out.bias torch.Size([2880]) torch.bfloat16 2880
block.0.mlp.mlp1_weight torch.Size([32, 5760, 1440]) torch.uint8 265420800
block.0.mlp.mlp1_bias torch.Size([32, 5760]) torch.bfloat16 184320
block.0.mlp.mlp2_weight torch.Size([32, 5760, 720]) torch.uint8 132710400
block.0.mlp.mlp2_bias torch.Size([32, 2880]) torch.bfloat16 92160
block.0.mlp.norm.scale torch.Size([2880]) torch.float32 2880
block.0.mlp.gate.bias torch.Size([32]) torch.bfloat16 32
block.0.mlp.gate.weight torch.Size([2880, 32]) torch.bfloat16 92160
block.1.attn.sinks torch.Size([64]) torch.bfloat16 64
block.1.attn.norm.scale t