In [1]:
pip install transformers accelerate pandas matplotlib scikit-learn numpy


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch, time
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import manifold
# bits and bytes for cpu on my laptop breaks autograd
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn as nn

I need a checkpoint manager

In [3]:
import datetime
from datetime import timedelta, datetime, timezone
import pickle
import glob
import os
class CheckpointManager:
    def __init__(self, basepath: str, checkpoint_interval: timedelta, max_checkpoints: int = 1):
        self.basepath = basepath
        self.checkpoint_interval = checkpoint_interval
        self.max_checkpoints = max_checkpoints
        self.last_checkpoint_time = datetime.now(timezone.utc)
        self.tsformat = '%Y%m%d%H%M%S'
        self.globexpr = f"{self.basepath}-[0-9][0-9]*"

    def reset(self):
        print(f"resetting checkpoint manager for {self.basepath}")
        ckfiles = sorted(glob.glob(self.globexpr))
        for fn in ckfiles:
            if os.path.exists(fn):
                print(f"removing checkpoint {fn}")
                os.remove(fn)

    def checkpoint(self, data, force: bool = False, log: str = None):
        now = datetime.now(timezone.utc)
        if force or ((now - self.last_checkpoint_time) >= self.checkpoint_interval):
            ts = now.strftime(self.tsformat)
            ckpf = f"{self.basepath}-{ts}"
            print(f"checkpoint time: {now.isoformat(sep=' ', timespec='seconds')} file: {ckpf}")
            if log is not None:
                print(log)
            with open(ckpf, "wb") as ckpfile:
                pickle.dump(data, ckpfile)
            self.last_checkpoint_time = now
            # oldest checkpoint files are first
            ckfiles = sorted(glob.glob(self.globexpr))
            ndel = len(ckfiles) - self.max_checkpoints
            if ndel > 0:
                for fn in ckfiles[:ndel]:
                    if os.path.exists(fn):
                        print(f"removing old checkpoint {fn}")
                        os.remove(fn)

    def restore(self, ckpname = None):
        rval = None
        if ckpname is None:
            ckfiles = sorted(glob.glob(self.globexpr))
            if len(ckfiles) <= 0:
                print(f"found no checkpoint files with base path {self.basepath}")
                return None
            ckpname = ckfiles[-1]
        with open(ckpname, "rb") as ckpfile:
            print(f"restoring from {ckpname}")
            rval = pickle.load(ckpfile)
        return rval

### replicate a result similar to one from paper (figure 7)

On the number of regions of piecewise linear neural networks

https://www.sciencedirect.com/science/article/pii/S0377042723006118?ref=pdf_download&fr=RR-2&rr=9403868f687eb6c9

we expect a relatively small, countable set of affine regions, and this confirms that expectation

In [4]:
%%time
from torch.distributions.uniform import Uniform
from torch.autograd.functional import jacobian
toy = nn.Sequential(
    nn.Linear(2, 10, bias=False),
    nn.ReLU(),
    nn.Linear(10, 10, bias=False),
    nn.ReLU(),
    nn.Linear(10, 1, bias=False)
).to("cuda:0")
udist = Uniform(-2.0, 2.0)
# this toy doesn't require checkpointing but useful for time based log outputs
cktoy = CheckpointManager("./ckp-affine-toy", timedelta(seconds = 5), max_checkpoints = 3)
# I'm not saving randomized toy models so don't try to checkpoint restore, just nuke old files
cktoy.reset()
jhist = {}
for xxx in range(20000):
    cktoy.checkpoint(jhist, log=f"iter: {xxx} unique: {len(jhist.values())} fingerprint: {list(set(jhist.values()))}")
    emb = udist.rsample(torch.Size([2])).to("cuda:0")
    jcb = jacobian(toy, emb)[0]
    jcb = tuple(jcb.cpu().tolist())
    # histogram of jacobians
    if jcb not in jhist:
        jhist[jcb] = 1
    else:
        fq = jhist[jcb]
        jhist[jcb] = fq + 1
(len(jhist.values()), list(set(jhist.values())))

resetting checkpoint manager for ./ckp-affine-toy
removing checkpoint ./ckp-affine-toy-20250519212636
removing checkpoint ./ckp-affine-toy-20250519212641
removing checkpoint ./ckp-affine-toy-20250519212646
checkpoint time: 2025-05-19 21:32:35+00:00 file: ./ckp-affine-toy-20250519213235
iter: 1790 unique: 41 fingerprint: [1, 2, 3, 4, 5, 14, 15, 16, 17, 18, 20, 21, 281, 25, 29, 34, 36, 38, 39, 40, 42, 48, 180, 55, 60, 76, 82, 88, 101, 121, 123]
checkpoint time: 2025-05-19 21:32:40+00:00 file: ./ckp-affine-toy-20250519213240
iter: 3623 unique: 42 fingerprint: [1, 130, 385, 131, 3, 6, 7, 8, 9, 10, 11, 143, 29, 162, 34, 37, 552, 43, 45, 47, 50, 52, 58, 66, 195, 73, 80, 89, 225, 99, 100, 103, 107, 243]
checkpoint time: 2025-05-19 21:32:45+00:00 file: ./ckp-affine-toy-20250519213245
iter: 5429 unique: 42 fingerprint: [1, 2, 3, 5, 8, 9, 10, 13, 15, 17, 151, 280, 155, 169, 43, 44, 51, 54, 190, 191, 63, 62, 194, 71, 585, 77, 81, 852, 218, 349, 94, 98, 359, 106, 238, 110, 114, 122]
checkpoint tim

(42,
 [261,
  7,
  9,
  3346,
  23,
  536,
  537,
  793,
  153,
  28,
  412,
  1309,
  801,
  674,
  34,
  677,
  422,
  38,
  40,
  559,
  177,
  51,
  180,
  53,
  951,
  56,
  441,
  60,
  446,
  831,
  196,
  333,
  334,
  209,
  218,
  1264,
  624,
  371,
  247,
  121,
  2171])

Load an LLM to study

In [5]:
model_path = "ibm-granite/granite-3b-code-base"
device = "cuda:0"

In [6]:
# device "meta" does not load weights
# quant = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map=device,
#    quantization_config=quant
    ).to("cuda:0")
model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 2560, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2560, out_features=10240, bias=True)
          (up_proj): Linear(in_features=2560, out_features=10240, bias=True)
          (down_proj): Linear(in_features=10240, out_features=2560, bias=True)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2560,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2560,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2560,), eps=1e-05)
    (

check out the embedding function inside this model

In [7]:
t = list(model.children())

In [8]:
embed = t[0].embed_tokens
embed.weight.shape

torch.Size([49152, 2560])

This is the range of actual embedding coordinate values

In [9]:
x = embed(torch.LongTensor(range(49000)).to("cuda:0"))
(x.max(), x.min())

(tensor(0.8750, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(-0.3008, device='cuda:0', grad_fn=<MinBackward1>))

The embedding logic operates on tensors of integer values. This breaks autograd and also adds an huge dimensionality (~50,000).

To fix this I am creating this wrapper class whose `forward` method does an "embedding-less" logic, but is otherwise the same as a regular llama model

In [10]:
from transformers import LlamaModel
class WrapLM(nn.Module):
    def __init__(self, llm: LlamaModel):
        super().__init__()
        self.llm = llm
        self.layers = llm.layers
        self.norm = llm.norm
        self.rotary_emb = llm.rotary_emb

    def forward(self, emb: torch.Tensor) -> torch.Tensor:
        attention_mask = None
        position_ids = None
        past_key_values = None
        inputs_embeds = emb
        use_cache = False
        output_attentions = False
        output_hidden_states = False
        cache_position = None

        past_seen_tokens = 0
        cache_position = torch.arange(
            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        )

        position_ids = cache_position.unsqueeze(0)

        causal_mask = self.llm._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        for decoder_layer in self.layers[: self.llm.config.num_hidden_layers]:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.norm(hidden_states)

        # just taking sum to reduce dimensionality of jacobians by factor of 2500
        rval = torch.sum(hidden_states[0][0])

        return rval

wrap my llama model for use by jacobian histogramming

In [11]:
wm = WrapLM(model.model)

test forward evaluation and jacobian gradients

In [12]:
emb = embed(torch.LongTensor([[1000]]).to("cuda:0"))
wm.forward(emb)

tensor(-27.6045, device='cuda:0', grad_fn=<SumBackward0>)

In [13]:
torch.autograd.functional.jacobian(wm, emb)

tensor([[[ -1.2932,   5.1312,   0.7583,  ...,   6.9142,   9.0552, -18.6037]]],
       device='cuda:0')

Now try to replicate the discovery of affine regions on the LLM

In [14]:
%%time
from torch.distributions.uniform import Uniform
from torch.autograd.functional import jacobian
rd = 1
udist = Uniform(0.0, 0.1)
ckp_reduce = CheckpointManager("./ckp-reduce-xform", timedelta(seconds = 30), max_checkpoints = 3)
reduce = ckp_reduce.restore()
if reduce is None:
    print(f"initializing new reduce")
    reduce = nn.Linear(2560, 10, bias=False).to("cuda:0")
checkpoint = CheckpointManager("./ckp-affine-exp", timedelta(minutes = 15), max_checkpoints = 3)
jhist = checkpoint.restore()
if jhist is None:
    print(f"initializing new jhist")
    jhist = {}
for xxx in range(100000):
    checkpoint.checkpoint(jhist, log=f"iter: {xxx} unique: {len(jhist.values())} fingerprint: {list(set(jhist.values()))}")
    emb = udist.rsample(torch.Size([1,1,2560])).to("cuda:0")
    jcb = jacobian(wm, emb)[0,0]
    # round values to compensate for NPWL components
    jcb = (jcb * (10 ** rd)).round() / (10 ** rd)
    # reduce dimensionality to save space
    # in theory this could undercount "true" affine regions
    jcb = reduce(jcb)
    # round again just because
    jcb = (jcb * (10 ** rd)).round() / (10 ** rd)
    # python likes tuples for keys
    jcb = tuple(jcb.cpu().tolist())
    # finally take the histogram
    if jcb not in jhist:
        jhist[jcb] = 1
    else:
        fq = jhist[jcb]
        jhist[jcb] = fq + 1
checkpoint.checkpoint(jhist, force=True)
(len(jhist.values()), list(set(jhist.values())))

restoring from ./ckp-reduce-xform-20250518184546
restoring from ./ckp-affine-exp-20250519184047
checkpoint time: 2025-05-19 21:50:20+00:00 file: ./ckp-affine-exp-20250519215020
iter: 7264 unique: 1753322 fingerprint: [1]
removing old checkpoint ./ckp-affine-exp-20250519181047
checkpoint time: 2025-05-19 22:05:20+00:00 file: ./ckp-affine-exp-20250519220520
iter: 14530 unique: 1760588 fingerprint: [1]
removing old checkpoint ./ckp-affine-exp-20250519182547
checkpoint time: 2025-05-19 22:20:20+00:00 file: ./ckp-affine-exp-20250519222020
iter: 21806 unique: 1767864 fingerprint: [1]
removing old checkpoint ./ckp-affine-exp-20250519184047


KeyboardInterrupt: 

In [15]:
checkpoint = CheckpointManager("./ckp-affine-exp", timedelta(minutes = 30), max_checkpoints = 3)
jhist = checkpoint.restore()

restoring from ./ckp-affine-exp-20250519222020


In [16]:
(len(jhist.values()), list(set(jhist.values())))

(1767864, [1])