In [4]:
import torch, _io
from torch.serialization import add_safe_globals

# allow the specific global the checkpoint needs
add_safe_globals([_io.BytesIO])

ckpt = torch.load("../../models/evo2_7b_base.pt",
                  map_location="cpu",   # safer to load on CPU first
                  weights_only=True,    # stays in the "safe" path
                  mmap=True)            # speed / lower RAM if on local disk

print(type(ckpt))
if isinstance(ckpt, dict):
    print("top-level keys:", list(ckpt.keys())[:100])

# many checkpoints are either a state_dict directly or wrap it
state_dict = ckpt.get("state_dict", ckpt)

sd = state_dict  # from your load

tensor_rows = []
non_tensor_rows = []

for k, v in sd.items():
    if torch.is_tensor(v):
        tensor_rows.append((k, tuple(v.shape), v.dtype))
    else:
        non_tensor_rows.append((k, type(v).__name__))

print("Some tensor params:")
for i, (k, shp, dt) in enumerate(tensor_rows):
    print(f"{i:2d} {k:50s} {shp} {dt}")

print("\nNon-tensor entries (inspect/ignore):")
for i, (k, tname) in enumerate(non_tensor_rows[:15]):
    print(f"{i:2d} {k:50s} {tname}")

<class 'collections.OrderedDict'>
top-level keys: ['embedding_layer.weight', 'unembed.weight', 'blocks.0.pre_norm.scale', 'blocks.0.post_norm.scale', 'blocks.0.projections.weight', 'blocks.0.projections._extra_state', 'blocks.0.filter.short_filter_weight', 'blocks.0.out_filter_dense.weight', 'blocks.0.out_filter_dense.bias', 'blocks.0.mlp.l1.weight', 'blocks.0.mlp.l2.weight', 'blocks.0.mlp.l3.weight', 'blocks.0.filter.h', 'blocks.1.pre_norm.scale', 'blocks.1.post_norm.scale', 'blocks.1.projections.weight', 'blocks.1.projections._extra_state', 'blocks.1.filter.short_filter_weight', 'blocks.1.filter.D', 'blocks.1.filter.h', 'blocks.1.out_filter_dense.weight', 'blocks.1.out_filter_dense.bias', 'blocks.1.mlp.l1.weight', 'blocks.1.mlp.l2.weight', 'blocks.1.mlp.l3.weight', 'blocks.2.pre_norm.scale', 'blocks.2.post_norm.scale', 'blocks.2.projections.weight', 'blocks.2.projections._extra_state', 'blocks.2.filter.short_filter_weight', 'blocks.2.filter.D', 'blocks.2.out_filter_dense.weight', 'bl

In [5]:
import torch
from collections import defaultdict
from typing import Dict, Tuple

def human(n: int) -> str:
    for unit in ["", "K", "M", "B", "T"]:
        if abs(n) < 1000:
            return f"{n:.0f}{unit}"
        n /= 1000.0
    return f"{n:.1f}T"

def count_params_in_tensor(t: torch.Tensor) -> int:
    try:
        return t.numel()
    except Exception:
        return 0

def classify_block(block_tensors: Dict[str, torch.Tensor]) -> str:
    """
    Heuristics based on keys present in the block.
    """
    keys = block_tensors.keys()

    # Attention
    if any("inner_mha_cls.Wqkv.weight" in k for k in keys):
        return "attention"

    # SSM (S4-style)
    if any("filter.log_poles" in k for k in keys) or any("filter.residues" in k for k in keys):
        return "ssm"

    # Long filter (wide h kernel)
    for k, v in block_tensors.items():
        if "filter.h" in k and isinstance(v, torch.Tensor):
            # shapes we've seen: (256, 1, 128) == long; (256, 1, 7) == short
            if v.dim() >= 3 and v.shape[-1] >= 64:
                return "long_filter"

    # Short filter (kernel size 3 / 7)
    if any("filter.short_filter_weight" in k for k in keys):
        return "short_filter"

    # Generic fallback
    return "filter"

def summarize_state_dict(sd: Dict[str, torch.Tensor]) -> None:
    """
    Print a parameter summary for the entire model and per block.
    Assumes keys like 'blocks.N.<...>'.
    """
    # Separate by area
    top_level = {}
    blocks = defaultdict(dict)
    tail = {}

    for k, v in sd.items():
        # only count tensors; skip BytesIO etc.
        if not torch.is_tensor(v):
            continue

        if k.startswith("blocks."):
            # parse block index
            parts = k.split(".")
            try:
                idx = int(parts[1])
            except Exception:
                idx = parts[1]  # fallback unusual
            blocks[idx][k] = v
        elif k.startswith(("embedding_layer.", "unembed.")):
            top_level[k] = v
        else:
            # e.g. 'norm.scale' at the end
            tail[k] = v

    # Totals
    total_params = 0
    for v in sd.values():
        if torch.is_tensor(v):
            total_params += count_params_in_tensor(v)

    print("=== MODEL PARAMETER SUMMARY ===")
    print(f"Total tensor params: {human(total_params)}")

    # Embedding / Unembed
    emb_params = sum(count_params_in_tensor(v) for k, v in top_level.items())
    print(f"Embeddings+Unembed: {human(emb_params)}")
    for k, v in sorted(top_level.items()):
        print(f"  - {k:45s} {tuple(v.shape)}  ({human(v.numel())})")

    # Per-block breakdown
    print("\n=== PER-BLOCK BREAKDOWN ===")
    grand_blocks = 0
    for bi in sorted(blocks.keys(), key=lambda x: int(x) if isinstance(x, int) or str(x).isdigit() else x):
        bdict = blocks[bi]
        block_type = classify_block(bdict)
        bcount = sum(count_params_in_tensor(v) for v in bdict.values())
        grand_blocks += bcount
        print(f"\nBlock {bi:>2}  [{block_type}]  —  {human(bcount)} params")
        # (Optional) show the big matrices first
        big = sorted(bdict.items(), key=lambda kv: kv[1].numel(), reverse=True)
        for k, v in big[:6]:  # top 6 largest tensors
            print(f"   • {k:45s} {tuple(v.shape)}  ({human(v.numel())})")
        # (Optional) uncomment to list all tensors in the block
        # for k, v in big[6:]:
        #     print(f"     - {k:45s} {tuple(v.shape)}  ({human(v.numel())})")

    print(f"\nSum over all blocks: {human(grand_blocks)}")

    # Tail / final norm etc.
    if tail:
        tail_params = sum(count_params_in_tensor(v) for v in tail.values())
        print(f"\nTail (e.g., final norm etc.): {human(tail_params)}")
        for k, v in sorted(tail.items()):
            print(f"  - {k:45s} {tuple(v.shape)}  ({human(v.numel())})")

    print("\n=== GRAND TOTAL CHECK ===")
    print(f"Embeddings+Unembed: {human(emb_params)}")
    print(f"Blocks total       : {human(grand_blocks)}")
    tail_params = sum(count_params_in_tensor(v) for v in tail.values())
    print(f"Tail               : {human(tail_params)}")
    print(f"------------------------------")
    print(f"Reported Total     : {human(emb_params + grand_blocks + tail_params)}")

# ---- Usage example ----
# ckpt = torch.load("evo2_7b_base.pt", map_location="cpu", weights_only=True)
# state_dict = ckpt.get("state_dict", ckpt)
summarize_state_dict(sd)

=== MODEL PARAMETER SUMMARY ===
Total tensor params: 6B
Embeddings+Unembed: 4M
  - embedding_layer.weight                        (512, 4096)  (2M)
  - unembed.weight                                (512, 4096)  (2M)

=== PER-BLOCK BREAKDOWN ===

Block  0  [short_filter]  —  202M params
   • blocks.0.projections.weight                   (12288, 4096)  (50M)
   • blocks.0.mlp.l1.weight                        (11008, 4096)  (45M)
   • blocks.0.mlp.l2.weight                        (11008, 4096)  (45M)
   • blocks.0.mlp.l3.weight                        (4096, 11008)  (45M)
   • blocks.0.out_filter_dense.weight              (4096, 4096)  (17M)
   • blocks.0.filter.short_filter_weight           (12288, 1, 3)  (37K)

Block  1  [long_filter]  —  202M params
   • blocks.1.projections.weight                   (12288, 4096)  (50M)
   • blocks.1.mlp.l1.weight                        (11008, 4096)  (45M)
   • blocks.1.mlp.l2.weight                        (11008, 4096)  (45M)
   • blocks.1.mlp.l3.weigh