In [1]:
from nnsight import LanguageModel

import torch as t
import gc
import sys
import math
import numpy as np
import os
from tqdm import tqdm
import torch
from datasets import load_dataset

from loading_utils import load_vqa_examples, load_blimp_examples, load_winoground_examples

from transformers import AutoProcessor, AutoTokenizer
from nnsight import NNsight
import importlib.util
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from babylm_analysis import _pe_ig


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_model(own_model=True):
    if own_model:
        model_path = "../babylm_GIT/models2/base_git_1vd125_s1/epoch17/"
        spec = importlib.util.spec_from_file_location("GitForCausalLM", f"{model_path}modeling_git.py")
        git_module = importlib.util.module_from_spec(spec)
        sys.modules["git_module"] = git_module
        spec.loader.exec_module(git_module)
        GitForCausalLM = git_module.GitForCausalLM

        model = GitForCausalLM.from_pretrained(model_path) 
        ckpt = torch.load(model_path + "pytorch_model.bin") # TODO: newly initialized for vision encoder: ['pooler.dense.bias', 'pooler.dense.weight']
        model.load_state_dict(ckpt, strict=False)  
        
    else:
        model_path = "babylm/git-2024"

        from transformers import GitForCausalLM as OGModel

        model = OGModel.from_pretrained(model_path, trust_remote_code=True)
        
    # load tokenizer and img processor
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    img_processor = AutoProcessor.from_pretrained(model_path,trust_remote_code=True)
    
    nnsight_model = NNsight(model, device_map="cuda")
    nnsight_model.to("cuda")

    return nnsight_model, tokenizer, img_processor


def extract_submodules(model):
    submodules = {}
    for idx, layer in enumerate(model.git.encoder.layer):
        submodules[f"mlp.{idx}"] = layer.intermediate    # output of MLP
        submodules[f"attn.{idx}"] = layer.attention  # output of attention
        submodules[f"resid.{idx}"] = layer      # output of whole layer
    return submodules
        

In [3]:
# load and prepare model
model, tokenizer, img_processor = load_model(own_model=True)
submodules = extract_submodules(model)

Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vitb16 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
winoground_examples = load_winoground_examples(tokenizer, img_processor, pad_to_length=32, n_samples=10, local=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


loaded huggingface DS
loaded local DS


100%|██████████| 10/10 [00:00<00:00, 24.35it/s]


In [10]:
blimp_examples = load_blimp_examples(tokenizer, pad_to_length=32, n_samples=10, local=True)
blimp_examples[0]

{'clean_prefix': tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,  310,  401,  114, 7434,   45,    5,    1]]),
 'clean_answer': 1370,
 'patch_prefix': tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,  310,  401,  114, 7434,   45,    5,    1]]),
 'patch_answer': 404,
 'UID': 'anaphor_gender_agreement',
 'linguistics_term': 'anaphor_agreement',
 'prefix_length_wo_pad': 7}

In [4]:
# load and prepare data
vqa_examples = load_vqa_examples(tokenizer, img_processor, pad_to_length=32, n_samples=10, local=True)
vqa_examples[0]


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Repo card metadata block was not found. Setting CardData to empty.


loaded huggingface DS
loaded local DS


100%|██████████| 100/100 [00:01<00:00, 98.06it/s]


{'clean_prefix': tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,   27,   44, 4045,   23,  463,   17,    1]]),
 'clean_answer': 49,
 'distractors': [3895, 1224, 121, 1017, 303, 55, 175],
 'question_type': 'is this',
 'prefix_length_wo_pad': 7,
 'pixel_values': tensor([[[[-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1179, -2.1179, -2.1179,  ..., -2.1008, -2.1179, -2.1179],
           [-2.1179, -2.1008, -2.1179,  ..., -2.1008, -2.1179, -2.1179],
           ...,
           [-2.1008, -2.1179, -2.1008,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1008, -2.1179, -2.1008,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],
 
          [[-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
           [-2.0357, -2.0357, -2.0357,  ..., -2.0182, -2.0357, -2.0357],
       

In [4]:

def compute_mean_activations(examples, model, submodules, batch_size, noimg=False, file_prefix=None):
    tracer_kwargs = {'validate' : False, 'scan' : False}
    device = "cuda"
    num_examples = len(examples)
    batches = [
        examples[i:min(i + batch_size,num_examples)] for i in range(0, num_examples, batch_size)
    ]

    def extract_hidden_states(submodule):
        total_samples = 0
        # Initialize storage for cumulative activations and count of samples
        cumulative_activations = 0

        for batch in tqdm(batches):
            clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to(device)
        
            # clean run -> model can be approximated through linear function of its activations
            hidden_states_clean = {}
            #with autocast():
            if noimg:
                with model.trace(clean_inputs, **tracer_kwargs), t.no_grad():
                    x = submodule.output
                    hidden_states_clean = x.save()
            else:
                img_inputs = t.cat([e['pixel_values'] for e in batch], dim=0).to(device)
                with model.trace(clean_inputs, pixel_values=img_inputs, **tracer_kwargs), t.no_grad():
                    x = submodule.output
                    hidden_states_clean = x.save()
            hidden_states_clean = hidden_states_clean.value

            batch_size = clean_inputs.shape[0]  # Assuming shape [batch_size, ...]
            total_samples += batch_size

            # Sum across the batch (dim=0)
            cumulative_activations += hidden_states_clean.sum(dim=(0, 1)).detach().cpu()  # detach
            
            hidden_states_clean = None
            clean_inputs = None
            state = None
            x = None
            batch_size = None
            del hidden_states_clean, clean_inputs, state, x, batch_size
            if not noimg:
                img_inputs = None
                del img_inputs
            torch.cuda.empty_cache()
            gc.collect()

        # Compute mean activation by dividing the cumulative activations by the total number of samples
        mean_activations = cumulative_activations / total_samples

        cumulative_activations = None
        del cumulative_activations
        torch.cuda.empty_cache()
        gc.collect()
        
        return mean_activations

    mean_act_files = []
    for i, submodule in enumerate(submodules):
        submodule_acts = extract_hidden_states(submodule)
        filename = f"mean_activations/{file_prefix}_mean_acts_{i}.npy"
        np.save(filename, submodule_acts)

        submodule_acts = None
        del submodule_acts
        torch.cuda.empty_cache()
        gc.collect()

        mean_act_files.append(filename)

    return mean_act_files

In [8]:
def get_important_neurons(examples, batch_size, mlps, pad_len, mean_act_files, task, noimg):
    # uses attribution patching to identify most important neurons for subtask
    num_examples = len(examples)
    batches = [examples[i:min(i + batch_size, num_examples)] for i in range(0, num_examples, batch_size)]
    device = "cuda"

    sum_effects = {}

    for batch in tqdm(batches):
        
        clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to(device)

        if task == "vqa":
            clean_answer_idxs = t.tensor([e['clean_answer'] for e in batch], dtype=t.long, device=device)
            if noimg:
                img_inputs = None
            else:
                img_inputs = t.cat([e['pixel_values'] for e in batch], dim=0).to(device)

            first_distractor_idxs = t.tensor([e['distractors'][0] for e in batch], dtype=t.long, device=device)

            def metric(model):
                # compute difference between correct answer and first distractor
                # TODO: compute avg difference between correct answer and each distractor
                #embds_out = model.output.output.save()
                
                return (
                    t.gather(model.output.output[:,-1,:], dim=-1, index=first_distractor_idxs.view(-1, 1)).squeeze(-1) - \
                    t.gather(model.output.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
                )
            
            effects, _, _ = _pe_ig(
                clean_inputs,
                img_inputs,
                model,
                mlps,
                mean_act_files,
                metric,
                pad_len,
                steps=10,
                metric_kwargs=dict())
        
        elif task == "blimp":
            img_inputs = None
            clean_answer_idxs = t.tensor([e['clean_answer'] for e in batch], dtype=t.long, device=device)
            patch_answer_idxs = t.tensor([e['patch_answer'] for e in batch], dtype=t.long, device=device)

            def metric(model):
                
                return (
                    t.gather(model.output.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) - \
                    t.gather(model.output.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
                )
        

            effects, _, _ = _pe_ig(
                clean_inputs,
                img_inputs,
                model,
                mlps,
                mean_act_files,
                metric,
                pad_len,
                steps=10,
                metric_kwargs=dict())
            
        elif task == "winoground":
            if noimg:
                img_inputs = None
            else:
                img_inputs = t.cat([e['pixel_values'] for e in batch], dim=0).to(device)

            correct_idxs = [e["correct_idx"] for e in batch]
            incorrect_idxs = [e["incorrect_idx"] for e in batch]

            def metric(model):
                correct_sent_logits = []
                incorrect_sent_logits = []
                for i, (idx, cf_idx) in enumerate(zip(correct_idxs, incorrect_idxs)):
                    logits = torch.gather(model.output.output[i,:,:], dim=1, index=t.tensor([idx]).to("cuda")).squeeze(-1) # [1, seq]
                    cf_logits = torch.gather(model.output.output[i,:,:], dim=1, index=t.tensor([cf_idx]).to("cuda")).squeeze(-1) # [1, seq]
                    correct_sent_logits.append(logits.sum().unsqueeze(0))
                    incorrect_sent_logits.append(cf_logits.sum().unsqueeze(0))
                correct_sent_logits = torch.cat(correct_sent_logits, dim=0)
                incorrect_sent_logits = torch.cat(incorrect_sent_logits, dim=0)
                return incorrect_sent_logits-correct_sent_logits


            effects, _, _ = _pe_ig(
                clean_inputs,
                img_inputs,
                model,
                mlps,
                mean_act_files,
                metric,
                pad_len,
                steps=10,
                metric_kwargs=dict())
        
        
        for submodule in mlps:
            if submodule not in sum_effects:
                sum_effects[submodule] = effects[submodule].sum(dim=1).sum(dim=0)
            else:
                sum_effects[submodule] += effects[submodule].sum(dim=1).sum(dim=0)

    # Print top 100 neurons in each submodule (ndim=3072)
    k = 100

    top_neurons = {}
    for idx, submodule in enumerate(mlps):
        sum_effects[submodule] /= num_examples
        v, i = t.topk(sum_effects[submodule].flatten(), k)  # v=top effects, i=top indices
        top_neurons[f"mlp_{idx}"] = (i.cpu(),v.cpu())
    return top_neurons

In [9]:
batch_size = 2  #16
num_examples = 8  #-1
task = "vqa"
model_name = "git_1vd125_s1"
epoch = 23
local = True
pad_len =32

def run_task(task):
    mlps = [submodules[submodule] for submodule in submodules if submodule.startswith("mlp")]
    mlps = mlps[:2]

    noimg = False

    # load and prepare data
    if task == "vqa":
        examples = load_vqa_examples(tokenizer, img_processor, pad_to_length=32, n_samples=num_examples, local=local)
        subtask_key = "question_type"
    elif task == "blimp":
        noimg = True
        examples = load_blimp_examples(tokenizer, pad_to_length=32, n_samples=num_examples, local=local)
        subtask_key = "linguistics_term"
        mean_act_files = None
    elif task == "winoground":
        examples = load_winoground_examples(tokenizer, img_processor, pad_to_length=32, n_samples=num_examples, local=local)
        subtask_key = "secondary_tag"
    else:
        print(f"{task} is not implemented")
    print("loaded samples")

    
    prefix = f"{task}_{model_name}_e{epoch}_n{num_examples if num_examples != -1 else 'all'}{'_noimg' if noimg else ''}"
    mean_act_files = []
    for file in os.listdir("mean_activations/"):
        if file.startswith(prefix+"_mean_acts"):
            mean_act_files.append(f"mean_activations/{file}")
    if len(mean_act_files) != len(mlps):
        mean_act_files = compute_mean_activations(examples, model, mlps, batch_size=128, noimg=noimg, file_prefix=prefix)
        print(f"computed mean activations")
    else:
        print("retrieved precomputed mean activations")
    

    # identify subtasks
    subtasks = {}
    for e in examples:
        subtask = e[subtask_key]
        if subtask in subtasks:
            subtasks[subtask].append(e)
        else:
            subtasks[subtask] = [e]

    print("extracted subtasks")

    # for each subtask, compute top neurons and save
    subtasks_neurons = {}
    for subtask, examples in subtasks.items():
        top_neurons = get_important_neurons(examples, batch_size, mlps, pad_len, mean_act_files, task=task, noimg=noimg)
        subtasks_neurons[subtask] = top_neurons
        print(f"finished subtask: {subtask}")

    #with open(f"data/{model_name}_e{epoch}_{task}_top_neurons_per_subtask.pkl", "wb") as f:
    #    pickle.dump(subtasks_neurons, f)
    print(subtasks_neurons)

In [10]:
model_path = "../babylm_GIT/models2/base_git_1vd125_s1/epoch17/"
spec = importlib.util.spec_from_file_location("GitForCausalLM", f"{model_path}modeling_git.py")
git_module = importlib.util.module_from_spec(spec)
sys.modules["git_module"] = git_module
spec.loader.exec_module(git_module)
GitForCausalLM = git_module.GitForCausalLM
model2 = GitForCausalLM.from_pretrained("../babylm_GIT/models2/base_git_1vd125_s1/epoch17/") 
ckpt2 = torch.load("../babylm_GIT/models2/base_git_1vd125_s1/epoch17/pytorch_model.bin") # TODO: newly initialized for vision encoder: ['pooler.dense.bias', 'pooler.dense.weight']
model2.load_state_dict(ckpt2, strict=False)  

Some weights of ViTModel were not initialized from the model checkpoint at facebook/dino-vitb16 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


_IncompatibleKeys(missing_keys=[], unexpected_keys=['git.embeddings.position_ids'])

In [12]:
tokenizer2 = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [10]:
run_task("blimp")

loaded samples
retrieved precomputed mean activations
extracted subtasks


100%|██████████| 4/4 [00:05<00:00,  1.25s/it]

finished subtask: anaphor_agreement
{'anaphor_agreement': {'mlp_0': (tensor([ 733, 1532,  900,  307, 2543, 1211,  972,   14, 2166,  951, 2654,  771,
        1638,  784, 2890, 2193,  649, 1955, 1407, 1845, 3005,  389,  603, 1154,
        1849,  707,  787,  671, 2675, 2167,  315, 2684,   28, 1799, 2283, 2104,
         978,  781, 2781, 1311, 1354,   79, 2579, 2949,  196, 2236, 2709, 1668,
        1067, 1636, 1403,  643,  280, 2740,  869,  153, 2258, 3010, 2408, 2530,
        1204, 1404, 2540, 2593,  342, 1146,  500, 1339, 3049, 1814, 2417,   32,
        2921,   81, 2918,  809, 1876, 2943,   71,  567,  662,  380,    5, 2101,
         329, 1675, 2619, 1808, 2281,  204,  580,  145,  658, 2873,  666, 2713,
          37, 2015, 1570, 1232]), tensor([0.5936, 0.4346, 0.3636, 0.3391, 0.3077, 0.2889, 0.2518, 0.2333, 0.2262,
        0.2248, 0.2170, 0.2050, 0.2007, 0.1923, 0.1864, 0.1821, 0.1791, 0.1786,
        0.1701, 0.1699, 0.1680, 0.1667, 0.1638, 0.1618, 0.1595, 0.1505, 0.1458,
        0.1443, 0




In [11]:
with open("question_types.txt", "r") as f1:
    lines = f1.readlines()
    qt = []
    for l in lines:
        qt.append(l.strip())

with open("vqa_superclasses.txt", "r") as f2:
    lines = f2.readlines()
    mapping = {}
    for l in lines:
        parts = l.split("-")
        mapping[parts[0].strip()] = parts[1].strip()

for q in qt:
    if q not in mapping:
        print(q)

In [22]:

def parse_vqa_qtypes():
    with open("vqa_superclasses.txt", "r") as f:
        lines = f.readlines()
        mapping = {}
        for l in lines:
            parts = l.split("-")
            mapping[parts[0].strip()] = parts[1].strip()
    return mapping

In [26]:
d = parse_vqa_qtypes()
d

{'is this': 'verification and existence',
 'what is the': 'general queries and miscellaneous',
 'do you': 'verification and existence',
 'what': 'general queries and miscellaneous',
 'what is': 'general queries and miscellaneous',
 'can you': 'action and state',
 'is the woman': 'person and object identification',
 'what color is the': 'color identification',
 'are the': 'verification and existence',
 'is the': 'verification and existence',
 'is this a': 'identification and classification',
 'is it': 'verification and existence',
 'what kind of': 'identification and classification',
 'is the man': 'person and object identification',
 'none of the above': 'general queries and miscellaneous',
 'what color are the': 'color identification',
 'what color': 'color identification',
 'what sport is': 'identification and classification',
 'was': 'temporal information',
 'is there': 'verification and existence',
 'is there a': 'verification and existence',
 'are they': 'verification and existenc

In [3]:
from datasets import load_dataset
hf_path = "HuggingFaceM4/VQAv2"
hf_split = "validation"
local_file = f"data/vqa_filtered/vqa_distractors_info.json"
    
hf_ds = load_dataset(hf_path)[hf_split]

  from .autonotebook import tqdm as notebook_tqdm
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Repo card metadata block was not found. Setting CardData to empty.


In [7]:
question_types = set()
for sample in hf_ds:
    question_types.add(sample["question_type"])

In [9]:
question_types

{'are',
 'are the',
 'are there',
 'are there any',
 'are these',
 'are they',
 'can you',
 'could',
 'do',
 'do you',
 'does the',
 'does this',
 'has',
 'how',
 'how many',
 'how many people are',
 'how many people are in',
 'is',
 'is he',
 'is it',
 'is that a',
 'is the',
 'is the man',
 'is the person',
 'is the woman',
 'is there',
 'is there a',
 'is this',
 'is this a',
 'is this an',
 'is this person',
 'none of the above',
 'was',
 'what',
 'what animal is',
 'what are',
 'what are the',
 'what brand',
 'what color',
 'what color are the',
 'what color is',
 'what color is the',
 'what does the',
 'what is',
 'what is in the',
 'what is on the',
 'what is the',
 'what is the color of the',
 'what is the man',
 'what is the name',
 'what is the person',
 'what is the woman',
 'what is this',
 'what kind of',
 'what number is',
 'what room is',
 'what sport is',
 'what time',
 'what type of',
 'where are the',
 'where is the',
 'which',
 'who is',
 'why',
 'why is the'}