# Inference-time activation contrasting

Data download: 


In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import snapshot_download

# model paths (update if different)
path_base = "./models/qwen3/Qwen3-4B"               
path_saferl = "./models/qwen3/Qwen3-4B-SafeRL"      

tokenizer = AutoTokenizer.from_pretrained(path_base, local_files_only=True)

# Load both models
model_base = AutoModelForCausalLM.from_pretrained(path_base, 
            dtype=torch.float16, 
            device_map="auto",
            local_files_only=True)

model_saferl = AutoModelForCausalLM.from_pretrained(path_saferl, 
            dtype=torch.float16, 
            device_map="auto", 
            local_files_only=True)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [00:18<00:00,  6.24s/it]
Some parameters are on the meta device because they were offloaded to the disk and cpu.
Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.74s/it]
Some parameters are on the meta device because they were offloaded to the disk and cpu.


# Inference

In [83]:
prompts = ["Write me a poem about machine learning"]    #can extend this to a list/dataset

inputs = tokenizer(prompts, return_tensors="pt").to(model_base.device)

# dictionary for storing activations: 
activations_base = {}
activations_saferl = {}


def get_hook(store_dict, name):
    def hook(module, input, output):
        store_dict[name] = output.detach().float().cpu()
    return hook

# Pick the MLP layers
nlayers = 7
nlayers = min(nlayers, 36)      #ensures we stay within 36 (layers in LLM)
for i in range(nlayers):        
    
    LAYER_INDEX = i 


    # layer = model.model.layers[LAYER_INDEX]
    layer_base = model_base.model.layers[LAYER_INDEX]
    layer_saferl = model_saferl.model.layers[LAYER_INDEX]

    #only hooking MLP activations: 
    #layer.mlp.register_forward_hook(...)
    hook_base = layer_base.register_forward_hook(get_hook(activations_base, f"layer_{LAYER_INDEX}"))
    hook_saferl = layer_saferl.register_forward_hook(get_hook(activations_saferl, f"layer_{LAYER_INDEX}"))

    # Forward pass - The inference

    with torch.no_grad():           
        _ = model_base(**inputs)        # _ ignoring output since we only care about activations. 
        _ = model_saferl(**inputs)      

    # Remove hooks - important ot reset the hook per iterations (avoiding memory leaks)
    hook_base.remove()
    hook_saferl.remove()



Exploring the data structure:

In [82]:
#verifying that the models have same shapes in the layers: 
print('Basemodel layer shapes:\n')
for i in range(nlayers):
    print(activations_base[f'layer_{i}'].size())


print('\nFine-tuned model layer shapes:\n')
for i in range(nlayers): 
    print(activations_saferl[f'layer_{i}'].size())


Basemodel layer shapes:

torch.Size([1, 7, 2560])
torch.Size([1, 7, 2560])

Fine-tuned model layer shapes:

torch.Size([1, 7, 2560])
torch.Size([1, 7, 2560])


In [48]:
#the activations can be viewed as such: 

print(activations_base['layer_0'][0][6][1000])   #layers are 3D-tensors...
print(activations_base['layer_0'].flatten()[6*2560+1000])     #which means they can be flattened

tensor(0.0055)
tensor(0.0055)


Computing the RMSE, per neuron per layer: 

In [56]:
activations_base[f'layer_{0}'].flatten()

tensor([ 6.3750e+00, -1.9424e+00,  1.2559e+00,  ..., -3.3643e-01,
         2.1576e-02,  4.3945e-03])

In [57]:
activations_saferl['layer_0'].flatten()

tensor([ 6.3398e+00, -1.9346e+00,  1.2539e+00,  ..., -3.3521e-01,
         2.0386e-02,  4.8218e-03])

In [58]:
activations_base[f'layer_{0}'].flatten() - activations_saferl['layer_0'].flatten()

tensor([ 0.0352, -0.0078,  0.0020,  ..., -0.0012,  0.0012, -0.0004])

In [59]:
for i in range(nlayers):
    (activations_base[f'layer_{i}'].flatten() - activations_base[f'layer_{i}'].flatten())**2

We can actually sort the values of a flattened tensor: 

In [77]:
print(activations_saferl['layer_0'].flatten().sort())

#[0] gives values, [1] gives indices
print(activations_saferl['layer_0'].flatten().sort()[0])
print(activations_saferl['layer_0'].flatten().sort()[1])

torch.return_types.sort(
values=tensor([-2.7969, -2.1777, -1.9346,  ...,  2.2246,  3.7852,  6.3398]),
indices=tensor([  77,   10,    1,  ..., 2560, 5120,    0]))
tensor([-2.7969, -2.1777, -1.9346,  ...,  2.2246,  3.7852,  6.3398])
tensor([  77,   10,    1,  ..., 2560, 5120,    0])


In [None]:
activations_saferl['layer_0'].flatten().sort()[0]

tensor([-2.7969, -2.1777, -1.9346,  ...,  2.2246,  3.7852,  6.3398])