In [1]:
from nnsight import LanguageModel

import torch as t
import gc
import sys
import math
import numpy as np
import os
from tqdm import tqdm
import torch
from datasets import load_dataset

from loading_utils import load_vqa_examples, load_blimp_examples, load_winoground_examples

from transformers import AutoProcessor, AutoTokenizer
from nnsight import NNsight
import importlib.util
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_model(own_model=True):
    if own_model:
        model_path = "../babylm_GIT/models2/base_git_1vd125_s1/epoch17/"
        spec = importlib.util.spec_from_file_location("GitForCausalLM", f"{model_path}modeling_git.py")
        git_module = importlib.util.module_from_spec(spec)
        sys.modules["git_module"] = git_module
        spec.loader.exec_module(git_module)
        GitForCausalLM = git_module.GitForCausalLM

        model = GitForCausalLM.from_pretrained(model_path) 
        ckpt = torch.load(model_path + "pytorch_model.bin") # TODO: newly initialized for vision encoder: ['pooler.dense.bias', 'pooler.dense.weight']
        model.load_state_dict(ckpt, strict=False)  
        
    else:
        model_path = "babylm/git-2024"

        from transformers import GitForCausalLM as OGModel

        model = OGModel.from_pretrained(model_path, trust_remote_code=True)
        
    # load tokenizer and img processor
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    img_processor = AutoProcessor.from_pretrained(model_path,trust_remote_code=True)
    
    nnsight_model = NNsight(model, device_map="cuda")
    nnsight_model.to("cuda")

    return nnsight_model, tokenizer, img_processor


def extract_submodules(model):
    submodules = {}
    for idx, layer in enumerate(model.git.encoder.layer):
        submodules[f"mlp.{idx}"] = layer.intermediate    # output of MLP
        submodules[f"attn.{idx}"] = layer.attention  # output of attention
        submodules[f"resid.{idx}"] = layer      # output of whole layer
    return submodules
        

In [3]:
# load and prepare model
model, tokenizer, img_processor = load_model(own_model=True)
submodules = extract_submodules(model)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Some weights of the model checkpoint at babylm/git-2024 were not used when initializing GitForCausalLM: ['git.image_encoder.embeddings.cls_token', 'git.image_encoder.embeddings.patch_embeddings.projection.bias', 'git.image_encoder.embeddings.patch_embeddings.projection.weight', 'git.image_encoder.embeddings.position_embeddings', 'git.image_encoder.encoder.layer.0.attention.attention.key.bias', 'git.image_encoder.encoder.layer.0.attention.attention.key.weight', 'git.image_encoder.encoder.layer.0.attention.attention.query.bias', 'git.image_encoder.encoder.layer.0.attention.attention.query.weight', 'git.image_encoder.encoder.layer.0.attention.attention.value.bias', 'git.image_encoder.encoder.layer.0.attention.attention.value.weight', 'git.image_encoder.encoder.layer.0.attention.output.dense.bias', 'git.image_encoder.encoder.layer.0.attention.output.dense.weight', 'git.image_encoder.enco

In [4]:
winoground_examples = load_winoground_examples(tokenizer, img_processor, pad_to_length=32, n_samples=10, local=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


loaded huggingface DS
loaded local DS


100%|██████████| 10/10 [00:00<00:00, 22.24it/s]


In [4]:
blimp_examples = load_blimp_examples(tokenizer, pad_to_length=32, n_samples=10)
blimp_examples[0]

{'clean_prefix': tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,  310,  401,  114, 7434,   45,    5,    1]]),
 'clean_answer': 1370,
 'patch_prefix': tensor([[ 310,  401,  114, 7434,   45,    5,    1]]),
 'patch_answer': 404,
 'UID': 'anaphor_gender_agreement',
 'linguistics_term': 'anaphor_agreement',
 'prefix_length_wo_pad': 7}

In [4]:
# load and prepare data
vqa_examples = load_vqa_examples(tokenizer, img_processor, pad_to_length=32, n_samples=10, local=True)
vqa_examples[0]


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Repo card metadata block was not found. Setting CardData to empty.


loaded huggingface DS
loaded local DS


100%|██████████| 100/100 [00:00<00:00, 169.01it/s]


{'clean_prefix': tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,   27,   44, 4045,   23,  463,   17,    1]]),
 'clean_answer': 49,
 'distractors': [3895, 1224, 121, 1017, 303, 55, 175],
 'question_type': 'is this',
 'prefix_length_wo_pad': 7,
 'pixel_values': tensor([[[[-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1179, -2.1179, -2.1179,  ..., -2.1008, -2.1179, -2.1179],
           [-2.1179, -2.1008, -2.1179,  ..., -2.1008, -2.1179, -2.1179],
           ...,
           [-2.1008, -2.1179, -2.1008,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1008, -2.1179, -2.1008,  ..., -2.1179, -2.1179, -2.1179],
           [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],
 
          [[-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
           [-2.0357, -2.0357, -2.0357,  ..., -2.0182, -2.0357, -2.0357],
       

In [5]:

def compute_mean_activations(examples, model, submodules, batch_size, noimg=False, file_prefix=None):
    tracer_kwargs = {'validate' : False, 'scan' : False}
    device = "cuda"
    num_examples = len(examples)
    batches = [
        examples[i:min(i + batch_size,num_examples)] for i in range(0, num_examples, batch_size)
    ]

    def extract_hidden_states(submodule):
        total_samples = 0
        # Initialize storage for cumulative activations and count of samples
        cumulative_activations = 0

        for batch in tqdm(batches):
            clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to(device)
        
            # clean run -> model can be approximated through linear function of its activations
            hidden_states_clean = {}
            #with autocast():
            if noimg:
                with model.trace(clean_inputs, **tracer_kwargs), t.no_grad():
                    x = submodule.output
                    hidden_states_clean = x.save()
            else:
                img_inputs = t.cat([e['pixel_values'] for e in batch], dim=0).to(device)
                with model.trace(clean_inputs, pixel_values=img_inputs, **tracer_kwargs), t.no_grad():
                    x = submodule.output
                    hidden_states_clean = x.save()
            hidden_states_clean = hidden_states_clean.value

            batch_size = clean_inputs.shape[0]  # Assuming shape [batch_size, ...]
            total_samples += batch_size

            # Sum across the batch (dim=0)
            cumulative_activations += hidden_states_clean.sum(dim=(0, 1)).detach().cpu()  # detach
            
            hidden_states_clean = None
            clean_inputs = None
            state = None
            x = None
            batch_size = None
            del hidden_states_clean, clean_inputs, state, x, batch_size
            if not noimg:
                img_inputs = None
                del img_inputs
            torch.cuda.empty_cache()
            gc.collect()

        # Compute mean activation by dividing the cumulative activations by the total number of samples
        mean_activations = cumulative_activations / total_samples

        cumulative_activations = None
        del cumulative_activations
        torch.cuda.empty_cache()
        gc.collect()
        
        return mean_activations

    mean_act_files = []
    for i, submodule in enumerate(submodules):
        submodule_acts = extract_hidden_states(submodule)
        filename = f"mean_activations/{file_prefix}_mean_acts_{i}_small.npy"
        np.save(filename, submodule_acts)

        submodule_acts = None
        del submodule_acts
        torch.cuda.empty_cache()
        gc.collect()

        mean_act_files.append(filename)

    return mean_act_files

In [6]:
# Attribution patching with integrated gradients
def _pe_ig(
        clean,
        img_inputs,
        model,
        submodules,
        mean_act_files,
        metric_fn,
        steps=10,
        metric_kwargs=dict()):
    tracer_kwargs = {'validate' : False, 'scan' : False}
    
    # clean run -> model can be approximated through linear function of its activations
    hidden_states_clean = {}
    if img_inputs is None:
        with model.trace(clean, **tracer_kwargs), t.no_grad():
            for submodule in submodules:
                x = submodule.output
                hidden_states_clean[submodule] = x.save()
            metric_clean = metric_fn(model, **metric_kwargs).save()
    else:
        with model.trace(clean, pixel_values=img_inputs, **tracer_kwargs), t.no_grad(): 
            for submodule in submodules:
                x = submodule.output
                hidden_states_clean[submodule] = x.save()
            metric_clean = metric_fn(model, **metric_kwargs).save()
    hidden_states_clean = {k : v.value for k, v in hidden_states_clean.items()}

    x = None
    del x
    torch.cuda.empty_cache()
    gc.collect()

    effects = {}
    deltas = {}
    grads = {}
    for i, submodule in enumerate(submodules):
        # load mean hidden states from file
        mean_state = np.load(mean_act_files[i], allow_pickle=True)
        mean_state = torch.tensor(mean_state).to("cuda")
        mean_state.requires_grad = True

        clean_state = hidden_states_clean[submodule]

        # computational graph without img: [batch_size, 32, n_dim]
        # computational graph img: [batch_size, 229, n_dim]
        with model.trace(**tracer_kwargs) as tracer:  # calling the trace() function without input determines the computational graph as one without images
            metrics = []
            fs = []
            for step in range(steps):
                alpha = step / steps
                f = (1 - alpha) * clean_state + alpha * mean_state
                f.retain_grad()
                fs.append(f)
                
                if img_inputs is None:
                    with tracer.invoke(clean, scan=tracer_kwargs['scan']):
                        submodule.output = f
                        metrics.append(metric_fn(model, **metric_kwargs))
                else:
                    with tracer.invoke(clean, pixel_values=img_inputs, scan=tracer_kwargs['scan']):
                        submodule.output = f
                        metrics.append(metric_fn(model, **metric_kwargs))
            metric = sum([m for m in metrics])
            metric.sum().backward(retain_graph=True) # TODO : why is this necessary? Probably shouldn't be, contact jaden
        
        mean_grad = sum([f.grad for f in fs]) / steps
        # mean_residual_grad = sum([f.grad for f in fs]) / steps
        grad = mean_grad
        delta = (mean_state - clean_state).detach() if mean_state is not None else -clean_state.detach()
        effect = t.mul(grad, delta)

        effects[submodule] = effect
        deltas[submodule] = delta
        grads[submodule] = grad

        mean_state = None
        del mean_state
        torch.cuda.empty_cache()
        gc.collect()

    
    return (effects, deltas, grads)

In [7]:
def get_important_neurons(examples, batch_size, mlps, mean_act_files, task):
    # uses attribution patching to identify most important neurons for subtask
    num_examples = len(examples)
    batches = [examples[i:min(i + batch_size, num_examples)] for i in range(0, num_examples, batch_size)]
    device = "cuda"

    sum_effects = {}

    for batch in tqdm(batches):
        clean_answer_idxs = t.tensor([e['clean_answer'] for e in batch], dtype=t.long, device=device)
        clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to(device)

        if task == "vqa":
            img_inputs = t.cat([e['pixel_values'] for e in batch], dim=0).to(device)

            first_distractor_idxs = t.tensor([e['distractors'][0] for e in batch], dtype=t.long, device=device)

            def metric(model):
                # compute difference between correct answer and first distractor
                # TODO: compute avg difference between correct answer and each distractor
                #embds_out = model.output.output.save()
                
                return (
                    t.gather(model.output.output[:,-1,:], dim=-1, index=first_distractor_idxs.view(-1, 1)).squeeze(-1) - \
                    t.gather(model.output.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
                )
        
        elif task == "blimp":
            img_inputs = None
            
            patch_answer_idxs = t.tensor([e['patch_answer'] for e in batch], dtype=t.long, device=device)

            def metric(model):
                
                return (
                    t.gather(model.output.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) - \
                    t.gather(model.output.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
                )
        

        effects, _, _ = _pe_ig(
            clean_inputs,
            img_inputs,
            model,
            mlps,
            mean_act_files,
            metric,
            steps=10,
            metric_kwargs=dict())
        
        
        for submodule in mlps:
            if submodule not in sum_effects:
                sum_effects[submodule] = effects[submodule].sum(dim=1).sum(dim=0)
            else:
                sum_effects[submodule] += effects[submodule].sum(dim=1).sum(dim=0)

    # Print top 100 neurons in each submodule (ndim=3072)
    k = 100

    top_neurons = {}
    for idx, submodule in enumerate(mlps):
        sum_effects[submodule] /= num_examples
        v, i = t.topk(sum_effects[submodule].flatten(), k)  # v=top effects, i=top indices
        top_neurons[f"mlp_{idx}"] = (i,v)
    return top_neurons
        

In [8]:
batch_size = 2  #16
num_examples = 8  #-1
task = "vqa"
model_name = "git_1vd125_s1"
epoch = 23
local = True

def run_task(task):
    mlps = [submodules[submodule] for submodule in submodules if submodule.startswith("mlp")]
    mlps = mlps[:2]

    noimg = False

    # load and prepare data
    if task == "vqa":
        examples = load_vqa_examples(tokenizer, img_processor, pad_to_length=32, n_samples=num_examples, local=local)
        subtask_key = "question_type"
    elif task == "blimp":
        noimg = True
        examples = load_blimp_examples(tokenizer, pad_to_length=32, n_samples=num_examples)
        subtask_key = "linguistics_term"
    else:
        print(f"{task} is not implemented")
    print("loaded samples")

    # precompute mean activations on task or retrieve precomputed activation files
    prefix = f"{task}_{model_name}"
    mean_act_files = []
    for file in os.listdir("mean_activations/"):
        if file.startswith(prefix) and file.endswith("small.npy"):
            mean_act_files.append(f"mean_activations/{file}")
    if len(mean_act_files) == 0:
        mean_act_files = compute_mean_activations(examples, model, mlps, batch_size=128, noimg=noimg, file_prefix=prefix)

    print(f"computed mean activations")

    # identify subtasks
    subtasks = {}
    for e in examples:
        subtask = e[subtask_key]
        if subtask in subtasks:
            subtasks[subtask].append(e)
        else:
            subtasks[subtask] = [e]

    print("extracted subtasks")

    # for each subtask, compute top neurons and save
    subtasks_neurons = {}
    for subtask, examples in subtasks.items():
        top_neurons = get_important_neurons(examples, batch_size, mlps, mean_act_files, task=task)
        subtasks_neurons[subtask] = top_neurons
        print(f"finished subtask: {subtask}")

    with open(f"data/{model_name}_e{epoch}_{task}_top_neurons_per_subtask.pkl", "wb") as f:
        pickle.dump(subtasks_neurons, f)

In [9]:
run_task("vqa")

Repo card metadata block was not found. Setting CardData to empty.


loaded huggingface DS
loaded local DS


100%|██████████| 80/80 [00:00<00:00, 250.85it/s]


loaded samples
computed mean activations
extracted subtasks


  0%|          | 0/1 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
  0%|          | 0/1 [00:01<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 32 but got size 289 for tensor number 1 in the list.