In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch as t

In [2]:
model_name = "microsoft/phi-4"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype=t.bfloat16)

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [3]:
class ActivationCache:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.activations = []
        self.handles = []

    def hook_fn(self, module, input, output):
        if isinstance(output, tuple):
            output_tensor = output[0]
        else:
            output_tensor = output
        self.activations.append(output_tensor.detach().cpu().float().numpy())

    def register_hook(self, layer):
        handle = layer.register_forward_hook(self.hook_fn)
        self.handles.append(handle)
        
    def clear_activations(self):
        self.activations = []

    def remove_hooks(self):
        for handle in self.handles:
            handle.remove()

In [4]:
ac = ActivationCache(model, tokenizer, model.device)

In [6]:
# reset hooks
ac.remove_hooks()
model.model.layers[20].register_forward_hook(ac.hook_fn)

<torch.utils.hooks.RemovableHandle at 0x7f9f00113340>

In [7]:
model.forward(
    input_ids=tokenizer("Hello, how are you?", return_tensors="pt").input_ids.to(model.device),
    attention_mask=t.ones((1, 16)).to(model.device),
    use_cache=True,
)

Tuple


CausalLMOutputWithPast(loss={'logits': tensor([[[13.6250, 11.1875,  5.5625,  ..., -3.5469, -3.5469, -3.5312],
         [ 1.5469, -1.4844, -2.1250,  ..., -2.5312, -2.5312, -2.5312],
         [ 8.7500,  9.1875,  5.5625,  ..., -3.4844, -3.4844, -3.4844],
         [ 6.1250,  9.1250,  6.0000,  ..., -3.3750, -3.3750, -3.3750],
         [11.4375, 13.1250,  4.9062,  ..., -3.4062, -3.4062, -3.4062],
         [ 3.0781,  5.3438,  9.1875,  ..., -1.9531, -1.9531, -1.9531]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), 'past_key_values': DynamicCache()}, logits=tensor([[[13.6250, 11.1875,  5.5625,  ..., -3.5469, -3.5469, -3.5312],
         [ 1.5469, -1.4844, -2.1250,  ..., -2.5312, -2.5312, -2.5312],
         [ 8.7500,  9.1875,  5.5625,  ..., -3.4844, -3.4844, -3.4844],
         [ 6.1250,  9.1250,  6.0000,  ..., -3.3750, -3.3750, -3.3750],
         [11.4375, 13.1250,  4.9062,  ..., -3.4062, -3.4062, -3.4062],
         [ 3.0781,  5.3438,  9.1875,  ..., -1.9531, -1.9531, 