In [50]:
%load_ext autoreload
%autoreload 2
import plotly.express as px
import numpy as np
import einops
from tqdm.auto import trange


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from nnsight import LanguageModel
import torch
model = LanguageModel("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True, device_map="auto", low_cpu_mem_usage=True, torch_dtype=torch.float16)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


In [3]:
from task_vector_utils import load_tasks, ICLRunner

tasks = load_tasks()

Cloning into 'data/itv'...


In [6]:
seed = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 16, 128

In [9]:
task = "antonyms"

pairs = list(tasks[task].items())

In [25]:
def check_single_token(a, b):
    a = model.tokenizer(a, add_special_tokens=False)["input_ids"]
    b = model.tokenizer(b, add_special_tokens=False)["input_ids"]

    return len(a) == 1 and len(b) == 1

In [37]:
single_token_pairs = [x for x in pairs if check_single_token(*x)]

In [39]:

runner = ICLRunner(task, single_token_pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=seed)

In [44]:
tokenizer = model.tokenizer

N_HEADS = model.config.num_attention_heads
N_LAYERS = model.config.num_hidden_layers
D_MODEL = model.config.hidden_size
D_HEAD = D_MODEL // N_HEADS

print(f"Number of heads: {N_HEADS}")
print(f"Number of layers: {N_LAYERS}")
print(f"Model dimension: {D_MODEL}")
print(f"Head dimension: {D_HEAD}\n")

Number of heads: 32
Number of layers: 32
Model dimension: 3072
Head dimension: 96



In [46]:
[x[-1][1] for x in runner.train_pairs]

['night',
 'sad',
 'west',
 'separate',
 'west',
 'even',
 'finish',
 'warm',
 'even',
 'odd',
 'senior',
 'senior',
 'mean',
 'back',
 'complex',
 'down']

In [47]:
prompts = [runner.get_prompt(x) for x in runner.train_pairs]
random_prompts = [runner.get_prompt(x) for x in runner.random_pairs]

In [None]:
def get_loss(logits, tokens):
    mask = tokens == 1599
    

In [51]:
tokens = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_seq_len).to(model.device)

with model.trace(tokens):
    logits = model.output["logits"].save()

You are not running the flash-attention implementation, expect numerical differences.
`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


model.safetensors.index.json:   0%|          | 0.00/16.3k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]



model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

OSError: [Errno 28] No space left on device

In [None]:
from typing import List, Optional


def calculate_fn_vectors_and_intervene(
    model: LanguageModel,
    layers: Optional[List[int]] = None,
):
    '''
    Returns a tensor of shape (layers, heads), containing the CIE for each head.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        dataset: ICLDataset
            the dataset of clean prompts from which we'll extract the function vector (we'll also create a
            corrupted version of this dataset for interventions)

        layers: Optional[List[int]]
            the layers which this function will calculate the score for (if None, we assume all layers)
    '''
    layers = range(N_LAYERS) if (layers is None) else layers
    heads = range(D_HEAD)
    n_heads = len(layers) * len(heads)

    N = batch_size

    correct_completion_ids = []

    with model.forward() as runner:

        # Run a forward pass on clean prompts, where we store attention head outputs
        z_dict = {}
        with runner.invoke(prompts) as invoker:
            for layer in layers:
                # Get hidden states, reshape to get head dimension, store the mean tensor
                z = model.transformer.h[layer].attn.out_proj.input[0][0][:, -1]
                z_reshaped = z.reshape(N, N_HEADS, D_HEAD).mean(dim=0)
                for head in heads:
                    z_dict[(layer, head)] = z_reshaped[head]
            # Get correct token logprobs
            logits_clean = model.lm_head.output[:, -1]

        # Run a forward pass on corrupted prompts, where we don't intervene or store activations (just so we can
        # get the correct-token logprobs to compare with our intervention)
        with runner.invoke(random_prompts) as invoker:
            logits = model.lm_head.output[:, -1]
            correct_logprobs_corrupted = logits.log_softmax(dim=-1)[torch.arange(N), correct_completion_ids].save()

        # For each head, run a forward pass on corrupted prompts (here we need multiple different forward passes,
        # because we're doing different interventions each time)
        correct_logprobs_dict = {}
        for layer in layers:
            for head in heads:
                with runner.invoke(random_prompts) as invoker:
                    # Get hidden states, reshape to get head dimension, then set it to the a-vector
                    z = model.transformer.h[layer].attn.out_proj.input[0][0][:, -1]
                    z.reshape(N, N_HEADS, D_HEAD)[:, head] = z_dict[(layer, head)]
                    # Get logprobs at the end, which we'll compare with our corrupted logprobs
                    logits = model.lm_head.output[:, -1]
                    correct_logprobs_dict[(layer, head)] = logits.log_softmax(dim=-1)[torch.arange(N), correct_completion_ids].save()

    # Get difference between intervention logprobs and corrupted logprobs, and take mean over batch dim
    all_correct_logprobs_intervention = einops.rearrange(
        t.stack([v.value for v in correct_logprobs_dict.values()]),
        "(layers heads) batch -> layers heads batch",
        layers = len(layers),
    )
    logprobs_diff = all_correct_logprobs_intervention - correct_logprobs_corrupted.value # shape [layers heads batch]

    # Return mean effect of intervention, over the batch dimension
    return logprobs_diff.mean(dim=-1)