In [1]:
from nnsight import LanguageModel

import torch as t
import gc
import sys
import math
import numpy as np
from tqdm import tqdm
import torch
from datasets import load_dataset

from loading_utils import load_vqa_examples, load_blimp_examples, load_winoground_examples

from transformers import AutoProcessor, AutoTokenizer
from nnsight import NNsight
import importlib.util
import pickle


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_model(own_model=True):
    if own_model:
        model_path = "../babylm_GIT/models2/base_git_1vd125_s1/epoch17/"
        spec = importlib.util.spec_from_file_location("GitForCausalLM", f"{model_path}modeling_git.py")
        git_module = importlib.util.module_from_spec(spec)
        sys.modules["git_module"] = git_module
        spec.loader.exec_module(git_module)
        GitForCausalLM = git_module.GitForCausalLM

        model = GitForCausalLM.from_pretrained(model_path) 
        ckpt = torch.load(model_path + "pytorch_model.bin") # TODO: newly initialized for vision encoder: ['pooler.dense.bias', 'pooler.dense.weight']
        model.load_state_dict(ckpt, strict=False)  
        
    else:
        model_path = "babylm/git-2024"

        from transformers import GitForCausalLM as OGModel

        model = OGModel.from_pretrained(model_path, trust_remote_code=True)
        
    # load tokenizer and img processor
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    img_processor = AutoProcessor.from_pretrained(model_path,trust_remote_code=True)
    
    nnsight_model = NNsight(model, device_map="cuda")
    nnsight_model.to("cuda")

    return nnsight_model, tokenizer, img_processor


def extract_submodules(model):
    submodules = {}
    for idx, layer in enumerate(model.git.encoder.layer):
        submodules[f"mlp.{idx}"] = layer.intermediate    # output of MLP
        submodules[f"attn.{idx}"] = layer.attention  # output of attention
        submodules[f"resid.{idx}"] = layer      # output of whole layer
    return submodules
        

In [3]:
# load and prepare model
model, tokenizer, img_processor = load_model(own_model=True)
submodules = extract_submodules(model)

Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vitb16 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
winoground_examples = load_winoground_examples(tokenizer, img_processor, pad_to_length=32, n_samples=10, local=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


loaded huggingface DS
loaded local DS


100%|██████████| 10/10 [00:00<00:00, 22.24it/s]


In [23]:
blimp_examples = load_blimp_examples(tokenizer, pad_to_length=32, n_samples=10)
blimp_examples[0]

{'clean_prefix': tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,  310,  401,  114, 7434,   45,    5,    1]]),
 'clean_answer': 1370,
 'patch_prefix': tensor([[ 310,  401,  114, 7434,   45,    5,    1]]),
 'patch_answer': 404,
 'UID': 'anaphor_gender_agreement',
 'linguistics_term': 'anaphor_agreement',
 'prefix_length_wo_pad': 7}

In [4]:
# load and prepare data
vqa_examples = load_vqa_examples(tokenizer, img_processor, pad_to_length=32, n_samples=10)
vqa_examples[0]


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Repo card metadata block was not found. Setting CardData to empty.


loaded huggingface DS
loaded local DS


100%|██████████| 100/100 [00:00<00:00, 255.75it/s]


{'clean_prefix': tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,   27,   44, 4045,   23,  463,   17,    1]]),
 'clean_answer': 49,
 'distractors': [3895, 1224, 121, 1017, 303, 55, 175],
 'question_type': 'is this',
 'prefix_length_wo_pad': 7,
 'pixel_values': tensor([[[[-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1179, -2.1179, -2.1179,  ..., -2.1008, -2.1179, -2.1179],
           [-2.1179, -2.1008, -2.1179,  ..., -2.1008, -2.1179, -2.1179],
           ...,
           [-2.1008, -2.1179, -2.1008,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1008, -2.1179, -2.1008,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],
 
          [[-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
           [-2.0357, -2.0357, -2.0357,  ..., -2.0182, -2.0357, -2.0357],
       

In [5]:

def compute_mean_activations(examples, model, submodules, batch_size, num_examples, noimg=False):
    tracer_kwargs = {'validate' : False, 'scan' : False}
    device = "cuda"
    num_examples = min([num_examples, len(examples)])
    batches = [
        examples[i:min(i + batch_size,num_examples)] for i in range(0, num_examples, batch_size)
    ]

    # Initialize storage for cumulative activations and count of samples
    cumulative_activations = {submodule: 0 for submodule in submodules}
    total_samples = 0

    for batch in tqdm(batches):
        clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to(device)
    
        # clean run -> model can be approximated through linear function of its activations
        hidden_states_clean = {}
        if noimg:
            with model.trace(clean_inputs, **tracer_kwargs), t.no_grad():
                for submodule in submodules:
                    x = submodule.output
                    hidden_states_clean[submodule] = x.save()
        else:
            img_inputs = t.cat([e['pixel_values'] for e in batch], dim=0).to(device)
            with model.trace(clean_inputs, pixel_values=img_inputs, **tracer_kwargs), t.no_grad():
                for submodule in submodules:
                    x = submodule.output
                    hidden_states_clean[submodule] = x.save()
        hidden_states_clean = {k : v.value for k, v in hidden_states_clean.items()}

        batch_size = next(iter(hidden_states_clean.values())).shape[0]  # Assuming shape [batch_size, ...]
        total_samples += batch_size

        # Sum across the batch (dim=0)
        for submodule, state in hidden_states_clean.items():
            cumulative_activations[submodule] += state.sum(dim=(0, 1))  
        
        hidden_states_clean = None
        torch.cuda.empty_cache()
        gc.collect()

    # Compute mean activation by dividing the cumulative activations by the total number of samples
    mean_activations = {submodule: cum_act / total_samples for submodule, cum_act in cumulative_activations.items()}

    return mean_activations



In [6]:
# Attribution patching with integrated gradients
def _pe_ig(
        clean,
        img_inputs,
        model,
        submodules,
        hidden_states_mean,
        metric_fn,
        steps=10,
        metric_kwargs=dict()):
    tracer_kwargs = {'validate' : False, 'scan' : False}
    
    # clean run -> model can be approximated through linear function of its activations
    hidden_states_clean = {}
    if img_inputs is None:
        with model.trace(clean, **tracer_kwargs), t.no_grad():
            for submodule in submodules:
                x = submodule.output
                hidden_states_clean[submodule] = x.save()
            metric_clean = metric_fn(model, **metric_kwargs).save()
    else:
        with model.trace(clean, pixel_values=img_inputs, **tracer_kwargs), t.no_grad(): 
            for submodule in submodules:
                x = submodule.output
                hidden_states_clean[submodule] = x.save()
            metric_clean = metric_fn(model, **metric_kwargs).save()
    hidden_states_clean = {k : v.value for k, v in hidden_states_clean.items()}


    effects = {}
    deltas = {}
    grads = {}
    for submodule in submodules:
        clean_state = hidden_states_clean[submodule]
        mean_state = hidden_states_mean[submodule]
        with model.trace(**tracer_kwargs) as tracer:
            metrics = []
            fs = []
            for step in range(steps):
                alpha = step / steps
                f = (1 - alpha) * clean_state + alpha * mean_state
                f.retain_grad()
                fs.append(f)
                if img_inputs is None:
                    with tracer.invoke(clean, scan=tracer_kwargs['scan']):
                        submodule.output = f
                        metrics.append(metric_fn(model, **metric_kwargs))
                else:
                    with tracer.invoke(clean, pixel_values=img_inputs, scan=tracer_kwargs['scan']):
                        submodule.output = f
                        metrics.append(metric_fn(model, **metric_kwargs))
            metric = sum([m for m in metrics])
            metric.sum().backward(retain_graph=True) # TODO : why is this necessary? Probably shouldn't be, contact jaden
        
        mean_grad = sum([f.grad for f in fs]) / steps
        # mean_residual_grad = sum([f.grad for f in fs]) / steps
        grad = mean_grad
        delta = (mean_state - clean_state).detach() if mean_state is not None else -clean_state.detach()
        effect = t.mul(grad, delta)

        effects[submodule] = effect
        deltas[submodule] = delta
        grads[submodule] = grad
    
    return (effects, deltas, grads)

[{'clean_prefix': tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
              0,    0,    0,    0,    0,   51,   54,  137,  951,   10,  189, 1183,
             51,   54,  137, 1183,  189,  951,   10,    1]]),
  'correct_idx': [0, 1, 2, 3, 4, 5, 6],
  'incorrect_idx': [7, 8, 9, 10, 11, 12, 13],
  'tag': 'Noun',
  'collapsed_tag': 'Object',
  'pixel_values': tensor([[[[-1.4843, -1.4843, -1.4500,  ..., -0.3883, -0.3883, -0.1999],
            [-1.4843, -1.4672, -1.4329,  ..., -0.4226, -0.4226, -0.2684],
            [-1.4843, -1.4672, -1.4329,  ..., -0.4397, -0.4054, -0.3369],
            ...,
            [-1.2617, -1.0733, -1.1418,  ..., -1.2274, -1.3644, -1.4843],
            [-1.2103, -0.9705, -1.2445,  ..., -1.3644, -1.4672, -1.5870],
            [-1.1589, -1.2788, -1.2959,  ..., -1.3987, -1.4158, -1.5014]],
  
           [[-1.0553, -1.0378, -1.0028,  ...,  0.3102,  0.2752,  0.4153],
            [-1.0553, -1.0378, -0.9853,  ...,  0.2752,  0.2577,  0.36

In [59]:
batch = [winoground_examples[i:min(i + 2, 10)] for i in range(0, 10, 2)][0]
clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to("cuda")


In [66]:
correct_idxs = [[0,1,2], [4,5,6,7]]
incorrect_idxs = [[1,2], [5,6,7]]

In [95]:
"""
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1 )  # [1, seq]
                answer = float(logits.sum())
"""
tracer_kwargs = {'validate' : False, 'scan' : False}
with model.trace(**tracer_kwargs) as tracer:
            
    with tracer.invoke(clean_inputs, scan=tracer_kwargs['scan']):
        outputs = model.output.output.save()


a = t.gather(outputs[:,-1,:], dim=-1, index=t.tensor([17, 18]).to("cuda").view(-1, 1)).squeeze(-1)
b = t.gather(outputs[:,-1,:], dim=-1, index=t.tensor([19, 13]).to("cuda").view(-1, 1)).squeeze(-1)
print(a)
print(b)
print(a-b)




tensor([-0.7798,  0.7615], device='cuda:0', grad_fn=<SqueezeBackward1>)
tensor([-1.1319, -3.8153], device='cuda:0', grad_fn=<SqueezeBackward1>)
tensor([0.3521, 4.5768], device='cuda:0', grad_fn=<SubBackward0>)


In [99]:
correct_sent_logits = []
incorrect_sent_logits = []
for i, idx in enumerate(correct_idxs):
    use = t.tensor([idx]).to("cuda")
    logits = torch.gather(outputs[i,:,:], dim=1, index=use).squeeze(-1) # [1, seq]
    print(logits.sum())
    print(logits.sum().unsqueeze(0).shape)
    correct_sent_logits.append(logits.sum().unsqueeze(0))
correct_sent_logits = torch.cat(correct_sent_logits, dim=0)
print(correct_sent_logits-correct_sent_logits)

tensor(31.0794, device='cuda:0', grad_fn=<SumBackward0>)
torch.Size([1])
tensor(14.7316, device='cuda:0', grad_fn=<SumBackward0>)
torch.Size([1])
tensor([0., 0.], device='cuda:0', grad_fn=<SubBackward0>)


In [7]:
def get_important_neurons(examples, batch_size, mlps, mean_activations, task):
    # uses attribution patching to identify most important neurons for subtask
    num_examples = len(examples)
    batches = [examples[i:min(i + batch_size, num_examples)] for i in range(0, num_examples, batch_size)]
    device = "cuda"

    sum_effects = {}

    for batch in tqdm(batches): 
        clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to(device)

        if task == "vqa":
            clean_answer_idxs = t.tensor([e['clean_answer'] for e in batch], dtype=t.long, device=device)
            img_inputs = t.cat([e['pixel_values'] for e in batch], dim=0).to(device)

            first_distractor_idxs = t.tensor([e['distractors'][0] for e in batch], dtype=t.long, device=device)

            def metric(model):
                # compute difference between correct answer and first distractor
                # TODO: compute avg difference between correct answer and each distractor
                #embds_out = model.output.output.save()
                
                return (
                    t.gather(model.output.output[:,-1,:], dim=-1, index=first_distractor_idxs.view(-1, 1)).squeeze(-1) - \
                    t.gather(model.output.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
                )
        
        elif task == "blimp":
            clean_answer_idxs = t.tensor([e['clean_answer'] for e in batch], dtype=t.long, device=device)
            img_inputs = None
            
            patch_answer_idxs = t.tensor([e['patch_answer'] for e in batch], dtype=t.long, device=device)

            def metric(model):
                
                return (
                    t.gather(model.output.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) - \
                    t.gather(model.output.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
                )
            
        elif task == "winoground":
            correct_idxs = [e['correct_idx'] for e in batch]
            incorrect_idxs = [e['incorrect_idx'] for e in batch]
            img_inputs = t.cat([e['pixel_values'] for e in batch], dim=0).to(device)


            def metric(model):
                correct_sent_logits = []
                incorrect_sent_logits = []

                outputs = model.output.output
                
                for i, idx in enumerate(correct_idxs):
                    correct_pos = t.tensor([idx]).to("cuda")
                    logits = torch.gather(outputs[i,:,:], dim=1, index=correct_pos).squeeze(-1) # [1, seq]
                    correct_sent_logits.append(logits.sum().unsqueeze(0))
                correct_sent_logits = torch.cat(correct_sent_logits, dim=0)
                
                for i, idx in enumerate(incorrect_idxs):
                    incorrect_pos = t.tensor([idx]).to("cuda")
                    logits = torch.gather(outputs[i,:,:], dim=1, index=incorrect_pos).squeeze(-1) # [1, seq]
                    incorrect_sent_logits.append(logits.sum().unsqueeze(0))
                incorrect_sent_logits = torch.cat(incorrect_sent_logits, dim=0)

                result = incorrect_sent_logits - correct_sent_logits
                
                return result

        effects, _, _ = _pe_ig(
            clean_inputs,
            img_inputs,
            model,
            mlps,
            mean_activations,
            metric,
            steps=10,
            metric_kwargs=dict())
        
        
        for submodule in mlps:
            if submodule not in sum_effects:
                sum_effects[submodule] = effects[submodule].sum(dim=1).sum(dim=0)
            else:
                sum_effects[submodule] += effects[submodule].sum(dim=1).sum(dim=0)

    # Print top 1% neurons in each submodule (ndim=3072)
    k = 31

    top_neurons = {}
    for idx, submodule in enumerate(mlps):
        sum_effects[submodule] /= num_examples
        v, i = t.topk(sum_effects[submodule].flatten(), k)  # v=top effects, i=top indices
        top_neurons[f"mlp_{idx}"] = (i,v)
    return top_neurons
        

In [8]:
####
#### VQA
####

batch_size = 2  #16
num_examples = 8  #-1

mlps = [submodules[submodule] for submodule in submodules if submodule.startswith("mlp")]
mlps = mlps[:2]

# precompute mean activations on VQA
mean_activations_vqa = compute_mean_activations(vqa_examples, model, mlps, batch_size, num_examples)

# extract most important neurons per subtask of VQA by using attribution patching between clean and mean activations
subtasks = {}
for e in vqa_examples:
    subtask = e["question_type"]
    if subtask in subtasks:
        subtasks[subtask].append(e)
    else:
        subtasks[subtask] = [e]

# for each subtask, compute top neurons and save
subtasks_neurons = {}
for subtask, examples in subtasks.items():
    if len(subtasks_neurons) == 2:
        break
    top_neurons = get_important_neurons(examples, batch_size, mlps, mean_activations_vqa, task="vqa")
    subtasks_neurons[subtask] = top_neurons


with open(f"data/vqa_top_neurons_per_subtask.pkl", "wb") as f:
    pickle.dump(subtasks_neurons, f)

  0%|          | 0/4 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 4/4 [00:01<00:00,  3.01it/s]
100%|██████████| 1/1 [00:02<00:00,  2.59s/it]
100%|██████████| 1/1 [00:01<00:00,  1.44s/it]


In [9]:
with open("data/vqa_top_neurons_per_subtask.pkl", "rb") as g:
    subtasks_neurons = pickle.load(g)

In [11]:
subtasks_neurons.keys()

dict_keys(['is this', 'what is the'])

In [25]:
####
#### BLIMP
####

batch_size = 2  #16
num_examples = 8  #-1

mlps = [submodules[submodule] for submodule in submodules if submodule.startswith("mlp")]
mlps = mlps[:2]

# precompute mean activations on BLIMP
mean_activations_blimp = compute_mean_activations(blimp_examples, model, mlps, batch_size, num_examples, noimg=True)

# extract most important neurons per subtask of VQA by using attribution patching between clean and mean activations
subtasks = {}
for e in blimp_examples:
    subtask = e["UID"]
    if subtask in subtasks:
        subtasks[subtask].append(e)
    else:
        subtasks[subtask] = [e]

# for each subtask, compute top neurons and save
subtasks_neurons = {}
for subtask, examples in subtasks.items():
    if len(subtasks_neurons) == 2:
        break
    top_neurons = get_important_neurons(examples, batch_size, mlps, mean_activations_blimp, task="blimp")
    subtasks_neurons[subtask] = top_neurons


with open(f"data/blimp_top_neurons_per_subtask.pkl", "wb") as f:
    pickle.dump(subtasks_neurons, f)

100%|██████████| 4/4 [00:01<00:00,  3.00it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

metric: InterventionProxy (add_9): <class 'inspect._empty'>
metric: InterventionProxy (add_9): <class 'inspect._empty'>


 20%|██        | 1/5 [00:00<00:02,  1.47it/s]

metric: InterventionProxy (add_9): <class 'inspect._empty'>
metric: InterventionProxy (add_9): <class 'inspect._empty'>


 40%|████      | 2/5 [00:01<00:01,  1.60it/s]

metric: InterventionProxy (add_9): <class 'inspect._empty'>
metric: InterventionProxy (add_9): <class 'inspect._empty'>


 60%|██████    | 3/5 [00:01<00:01,  1.64it/s]

metric: InterventionProxy (add_9): <class 'inspect._empty'>
metric: InterventionProxy (add_9): <class 'inspect._empty'>


 80%|████████  | 4/5 [00:02<00:00,  1.66it/s]

metric: InterventionProxy (add_9): <class 'inspect._empty'>
metric: InterventionProxy (add_9): <class 'inspect._empty'>


100%|██████████| 5/5 [00:03<00:00,  1.64it/s]


In [13]:
####
#### Winoground
####

batch_size = 2  #16
num_examples = 8  #-1

mlps = [submodules[submodule] for submodule in submodules if submodule.startswith("mlp")]
mlps = mlps[:2]

# precompute mean activations on BLIMP
mean_activations_winoground = compute_mean_activations(winoground_examples, model, mlps, batch_size, num_examples, noimg=False)

# extract most important neurons per subtask of VQA by using attribution patching between clean and mean activations
subtasks = {}
for e in winoground_examples:
    subtask = e["tag"]
    if subtask in subtasks:
        subtasks[subtask].append(e)
    else:
        subtasks[subtask] = [e]

# for each subtask, compute top neurons and save
subtasks_neurons = {}
for subtask, examples in subtasks.items():
    if len(subtasks_neurons) == 2:
        break
    top_neurons = get_important_neurons(examples, batch_size, mlps, mean_activations_winoground, task="winoground")
    subtasks_neurons[subtask] = top_neurons


with open(f"data/winoground_top_neurons_per_subtask.pkl", "wb") as f:
    pickle.dump(subtasks_neurons, f)

100%|██████████| 4/4 [00:01<00:00,  3.18it/s]
  0%|          | 0/2 [00:00<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 32 but got size 229 for tensor number 1 in the list.