In [2]:
# from nnsight import LanguageModel
# model = LanguageModel('gpt2-small', device_map='cpu')

In [5]:
%load_ext autoreload
%autoreload 2

from functools import partial

import torch
from transformer_lens import HookedTransformer

from graph import Graph
from attribute import attribute
from dataset import HFEAPDataset
from metrics import get_metric
from evaluate_graph import evaluate_graph, evaluate_baseline, evaluate_area_under_curve

from nnsight import LanguageModel

from typing import Callable, List, Union, Optional, Literal
from functools import partial

import torch
from torch.utils.data import DataLoader
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_attention_mask
from tqdm import tqdm
from einops import einsum

from graph import Graph, InputNode, LogitNode, AttentionNode, MLPNode


def to_tokens(tokenizer, model, input_text, prepend_bos=True, padding_side='right', move_to_device=True, truncate=True, max_length=None):
    """
    Converts input text to tokens using HuggingFace tokenizer with similar functionality to TransformerLens.
    
    Args:
        tokenizer: HuggingFace tokenizer
        model: HuggingFace model (for device information if move_to_device=True)
        input_text (Union[str, List[str]]): The input to tokenize
        prepend_bos (bool, optional): Whether to prepend the BOS token. Defaults to True.
        padding_side (str, optional): Side to pad on ('left' or 'right'). Defaults to 'right'.
        move_to_device (bool, optional): Whether to move tensors to model's device. Defaults to True.
        truncate (bool, optional): Whether to truncate to model's max length. Defaults to True.
        max_length (int, optional): Maximum length to truncate to. If None, uses model's max length.
        
    Returns:
        torch.Tensor: Tensor of token ids
    """
    # Save original padding side
    original_padding_side = tokenizer.padding_side
    tokenizer.padding_side = padding_side
    
    # Handle BOS token
    add_special_tokens = prepend_bos
    
    # Determine max_length for truncation
    if truncate and max_length is None:
        if hasattr(model.config, 'max_position_embeddings'):
            max_length = model.config.max_position_embeddings
        else:
            max_length = model.config.n_positions  # for GPT-2
    
    # Tokenize
    tokens = tokenizer(
        input_text,
        add_special_tokens=add_special_tokens,  # This handles BOS token if model has one
        padding=True if isinstance(input_text, list) else False,  # Only pad for batch inputs
        truncation=truncate,
        max_length=max_length,
        return_tensors='pt'
    )
    
    # Move to device if requested
    if move_to_device and hasattr(model, 'device'):
        tokens = {k: v.to(model.device) for k, v in tokens.items()}
    
    # Restore original padding side
    tokenizer.padding_side = original_padding_side
    
    return tokens['input_ids']

def tokenize_plus_nnsight(model: HookedTransformer, inputs: List[str], max_length: Optional[int] = None):
    """
    Tokenizes the input strings using the provided model.

    Args:
        model (HookedTransformer): The model used for tokenization.
        inputs (List[str]): The list of input strings to be tokenized.

    Returns:
        tuple: A tuple containing the following elements:
            - tokens (torch.Tensor): The tokenized inputs.
            - attention_mask (torch.Tensor): The attention mask for the tokenized inputs.
            - input_lengths (torch.Tensor): The lengths of the tokenized inputs.
            - n_pos (int): The maximum sequence length of the tokenized inputs.
    """
    if max_length is not None:
        old_n_ctx = model.config.n_ctx
        model.config.n_ctx = max_length


    # tokens = model.to_tokens(inputs, prepend_bos=True, padding_side='right', truncate=(max_length is not None))
    # Shun's change
    tokenizer = model.tokenizer
    tokens = to_tokens(tokenizer, model, inputs, prepend_bos=True, padding_side='right', truncate=(max_length is not None))
    
    
    if max_length is not None:
        model.config.n_ctx = old_n_ctx
    attention_mask = get_attention_mask(model.tokenizer, tokens, True)
    input_lengths = attention_mask.sum(1)
    n_pos = attention_mask.size(1)
    return tokens, attention_mask, input_lengths, n_pos


def make_hooks_and_matrices(model: HookedTransformer, graph: Graph, batch_size:int , n_pos:int, scores: Optional[Tensor]):
    """Makes a matrix, and hooks to fill it and the score matrix up

    Args:
        model (HookedTransformer): model to attribute
        graph (Graph): graph to attribute
        batch_size (int): size of the particular batch you're attributing
        n_pos (int): size of the position dimension
        scores (Tensor): The scores tensor you intend to fill. If you pass in None, we assume that you're using these hooks / matrices for evaluation only (so don't use the backwards hooks!)

    Returns:
        Tuple[Tuple[List, List, List], Tensor]: The final tensor ([batch, pos, n_src_nodes, d_model]) stores activation differences, i.e. corrupted - clean activations. The first set of hooks will add in the activations they are run on (run these on corrupted input), while the second set will subtract out the activations they are run on (run these on clean input). The third set of hooks will compute the gradients and update the scores matrix that you passed in. 
    """
    separate_activations = model.config.use_normalization_before_and_after and scores is None
    if separate_activations:
        activation_difference = torch.zeros((2, batch_size, n_pos, graph.n_forward, model.config.d_model), device=model.config.device, dtype=model.config.dtype)
    else:
        # activation_difference = torch.zeros((batch_size, n_pos, graph.n_forward, model.config.d_model), device=model.config.device, dtype=model.config.dtype)
        activation_difference = torch.zeros(
            (batch_size, n_pos, graph.n_forward, model.config.n_embd),  # using n_embd instead of d_model
            device=model.device,  # device from model instead of config
            dtype=model.dtype    # dtype from model instead of config
        )

    processed_attn_layers = set()
    fwd_hooks_clean = []
    fwd_hooks_corrupted = []
    bwd_hooks = []
        
    # Fills up the activation difference matrix. In the default case (not separate_activations), 
    # we add in the corrupted activations (add = True) and subtract out the clean ones (add=False)
    # In the separate_activations case, we just store them in two halves of the matrix. Less efficient, 
    # but necessary for models with Gemma's architecture.
    def activation_hook(index, activations, hook, add:bool=True):
        acts = activations.detach()
        try:
            if separate_activations:
                if add:
                    activation_difference[0, :, :, index] += acts
                else:
                    activation_difference[1, :, :, index] += acts
            else:
                if add:
                    activation_difference[:, :, index] += acts
                else:
                    activation_difference[:, :, index] -= acts
        except RuntimeError as e:
            print(hook.name, activation_difference[:, :, index].size(), acts.size())
            raise e
    
    def gradient_hook(prev_index: int, bwd_index: Union[slice, int], gradients:torch.Tensor, hook):
        """Takes in a gradient and uses it and activation_difference 
        to compute an update to the score matrix

        Args:
            fwd_index (Union[slice, int]): The forward index of the (src) node
            bwd_index (Union[slice, int]): The backward index of the (dst) node
            gradients (torch.Tensor): The gradients of this backward pass 
            hook (_type_): (unused)

        """
        grads = gradients.detach()
        try:
            if grads.ndim == 3:
                grads = grads.unsqueeze(2)
            s = einsum(activation_difference[:, :, :prev_index], grads,'batch pos forward hidden, batch pos backward hidden -> forward backward')
            s = s.squeeze(1)
            scores[:prev_index, bwd_index] += s
        except RuntimeError as e:
            print(hook.name, activation_difference.size(), activation_difference.device, grads.size(), grads.device)
            print(prev_index, bwd_index, scores.size(), s.size())
            raise e
    
    node = graph.nodes['input']
    fwd_index = graph.forward_index(node)
    fwd_hooks_corrupted.append((node.out_hook, partial(activation_hook, fwd_index)))
    fwd_hooks_clean.append((node.out_hook, partial(activation_hook, fwd_index, add=False)))
    
    for layer in range(graph.cfg['n_layers']):
        node = graph.nodes[f'a{layer}.h0']
        fwd_index = graph.forward_index(node)
        fwd_hooks_corrupted.append((node.out_hook, partial(activation_hook, fwd_index)))
        fwd_hooks_clean.append((node.out_hook, partial(activation_hook, fwd_index, add=False)))
        prev_index = graph.prev_index(node)
        for i, letter in enumerate('qkv'):
            bwd_index = graph.backward_index(node, qkv=letter)
            bwd_hooks.append((node.qkv_inputs[i], partial(gradient_hook, prev_index, bwd_index)))

        node = graph.nodes[f'm{layer}']
        fwd_index = graph.forward_index(node)
        bwd_index = graph.backward_index(node)
        prev_index = graph.prev_index(node)
        fwd_hooks_corrupted.append((node.out_hook, partial(activation_hook, fwd_index)))
        fwd_hooks_clean.append((node.out_hook, partial(activation_hook, fwd_index, add=False)))
        bwd_hooks.append((node.in_hook, partial(gradient_hook, prev_index, bwd_index)))
        
    node = graph.nodes['logits']
    prev_index = graph.prev_index(node)
    bwd_index = graph.backward_index(node)
    bwd_hooks.append((node.in_hook, partial(gradient_hook, prev_index, bwd_index)))
            
    return (fwd_hooks_corrupted, fwd_hooks_clean, bwd_hooks), activation_difference


# def get_scores_eap_nnsight(model: HookedTransformer, graph: Graph, dataloader:DataLoader, metric: Callable[[Tensor], Tensor], intervention: Literal['patching', 'zero', 'mean','mean-positional']='patching', intervention_dataloader: Optional[DataLoader]=None, quiet=False):
#     """Gets edge attribution scores using EAP.

#     Args:
#         model (HookedTransformer): The model to attribute
#         graph (Graph): Graph to attribute
#         dataloader (DataLoader): The data over which to attribute
#         metric (Callable[[Tensor], Tensor]): metric to attribute with respect to
#         quiet (bool, optional): suppress tqdm output. Defaults to False.

#     Returns:
#         Tensor: a [src_nodes, dst_nodes] tensor of scores for each edge
#     """
#     # scores = torch.zeros((graph.n_forward, graph.n_backward), device='cuda', dtype=model.config.dtype)
#     scores = torch.zeros((graph.n_forward, graph.n_backward), device='cpu', dtype=model.config.dtype)

#     if 'mean' in intervention:
#         assert intervention_dataloader is not None, "Intervention dataloader must be provided for mean interventions"
    #     per_position = 'positional' in intervention
    #     means = compute_mean_activations(model, graph, intervention_dataloader, per_position=per_position)
    #     means = means.unsqueeze(0)
    #     if not per_position:
    #         means = means.unsqueeze(0)
    
    # total_items = 0
    # dataloader = dataloader if quiet else tqdm(dataloader)
    # for clean, corrupted, label in dataloader:
    #     batch_size = len(clean)
    #     total_items += batch_size
    #     clean_tokens, attention_mask, input_lengths, n_pos = tokenize_plus_nnsight(model, clean)
    #     corrupted_tokens, _, _, _ = tokenize_plus_nnsight(model, corrupted)

    #     (fwd_hooks_corrupted, fwd_hooks_clean, bwd_hooks), activation_difference = make_hooks_and_matrices(model, graph, batch_size, n_pos, scores)

    #     with torch.inference_mode():
    #         if intervention == 'patching':
    #             # We intervene by subtracting out clean and adding in corrupted activations
    #             with model.hooks(fwd_hooks_corrupted):
    #                 _ = model(corrupted_tokens, attention_mask=attention_mask)
                    
                    
    #         elif 'mean' in intervention:
    #             # In the case of zero or mean ablation, we skip the adding in corrupted activations
    #             # but in mean ablations, we need to add the mean in
    #             activation_difference += means

    #         # For some metrics (e.g. accuracy or KL), we need the clean logits
    #         clean_logits = model(clean_tokens, attention_mask=attention_mask)

    #     with model.hooks(fwd_hooks=fwd_hooks_clean, bwd_hooks=bwd_hooks):
    #         logits = model(clean_tokens, attention_mask=attention_mask)
    #         metric_value = metric(logits, clean_logits, input_lengths, label)
    #         metric_value.backward()

    # scores /= total_items

    # return scores
    

allowed_aggregations = {'sum', 'mean'}#, 'l2'}        
def attribute_nnsight(model: LanguageModel, graph: Graph, dataloader: DataLoader, metric: Callable[[Tensor], Tensor], method: Literal['EAP', 'EAP-IG-inputs', 'clean-corrupted', 'EAP-IG-activations'], intervention: Literal['patching', 'zero', 'mean','mean-positional']='patching', aggregation='sum', ig_steps: Optional[int]=None, intervention_dataloader: Optional[DataLoader]=None, quiet=False):
    assert model.config.use_attn_result, "Model must be configured to use attention result (model.config.use_attn_result)"
    assert model.config.use_split_qkv_input, "Model must be configured to use split qkv inputs (model.config.use_split_qkv_input)"
    assert model.config.use_hook_mlp_in, "Model must be configured to use hook MLP in (model.config.use_hook_mlp_in)"
    if model.config.n_key_value_heads is not None:
        assert model.config.ungroup_grouped_query_attention, "Model must be configured to ungroup grouped attention (model.config.ungroup_grouped_query_attention = True)"
    
    if aggregation not in allowed_aggregations:
        raise ValueError(f'aggregation must be in {allowed_aggregations}, but got {aggregation}')
        
    # Scores are by default summed across the d_model dimension
    # This means that scores are a [n_src_nodes, n_dst_nodes] tensor
    if method == 'EAP':
        scores = get_scores_eap_nnsight(model, graph, dataloader, metric, intervention=intervention, intervention_dataloader=intervention_dataloader, quiet=quiet)
    elif method == 'EAP-IG-inputs':
        if intervention != 'patching':
            raise ValueError(f"intervention must be 'patching' for EAP-IG-inputs, but got {intervention}")
        scores = get_scores_eap_ig(model, graph, dataloader, metric, steps=ig_steps, quiet=quiet)
    elif method == 'clean-corrupted':
        if intervention != 'patching':
            raise ValueError(f"intervention must be 'patching' for clean-corrupted, but got {intervention}")
        scores = get_scores_clean_corrupted(model, graph, dataloader, metric, quiet=quiet)
    elif method == 'EAP-IG-activations':
        scores = get_scores_ig_activations(model, graph, dataloader, metric, steps=ig_steps, intervention=intervention, intervention_dataloader=intervention_dataloader, quiet=quiet)
    else:
        raise ValueError(f"integrated_gradients must be in ['EAP', 'EAP-IG-inputs', 'EAP-IG-activations'], but got {method}")


    if aggregation == 'mean':
        scores /= model.config.d_model
        
    graph.scores[:] =  scores.to(graph.scores.device)
    

In [None]:
def get_scores_eap_qkv(model: LanguageModel, graph: Graph, dataloader: DataLoader, metric: Callable[[Tensor], Tensor], intervention: Literal['patching', 'zero', 'mean','mean-positional']='patching', intervention_dataloader: Optional[DataLoader]=None, quiet=False):
    scores = torch.zeros((graph.n_forward, graph.n_backward), device='cpu', dtype=model.dtype)

    if 'mean' in intervention:
        assert intervention_dataloader is not None
        per_position = 'positional' in intervention
        means = compute_mean_activations(model, graph, intervention_dataloader, per_position=per_position)
        means = means.unsqueeze(0)
        if not per_position:
            means = means.unsqueeze(0)
    
    total_items = 0
    dataloader = dataloader if quiet else tqdm(dataloader)
    for clean, corrupted, label in dataloader:
        batch_size = len(clean)
        total_items += batch_size
        clean_tokens, attention_mask, input_lengths, n_pos = tokenize_plus_nnsight(model, clean)
        corrupted_tokens, _, _, _ = tokenize_plus_nnsight(model, corrupted)

        with torch.inference_mode():
            if intervention == 'patching':
                # Initialize activation differences tensor
                activation_difference = torch.zeros(
                    (batch_size, n_pos, graph.n_forward, model.config.hidden_size), 
                    device=model.device, 
                    dtype=model.dtype
                )

                # Get corrupted activations
                with model.trace({"input_ids": corrupted_tokens, "attention_mask": attention_mask}):
                    for layer in range(graph.cfg['n_layers']):
                        #### Attention #########
                        node = graph.nodes[f'a{layer}.h0']
                        fwd_index = graph.forward_index(node)
                        attn_hs = model.transformer.h[layer].attn.c_proj.input

                        by_head = split_heads_nns(attn_hs, model.config.n_head, model.config.hidden_size)
                        by_head = model.transformer.h[layer].attn.c_proj(by_head)
                        by_head = model.transformer.h[layer].attn.resid_dropout(by_head)
                        
                        activation_difference[:, :, fwd_index] += by_head

                        #### MLP #########
                        node = graph.nodes[f'm{layer}']
                        fwd_index = graph.forward_index(node)
                        activation_difference[:, :, fwd_index][:] += model.transformer.h[layer].mlp.output[:]

            elif 'mean' in intervention:
                activation_difference += means

            # Get clean logits for reference
            clean_logits = model.trace({'input_ids': clean_tokens, 'attention_mask':attention_mask}, trace=False)['logits']

            # Run with interventions
            with model.trace({"input_ids": clean_tokens, "attention_mask": attention_mask}):
                for layer in range(graph.cfg['n_layers']):
                    if any(graph.nodes[f'a{layer}.h{head}'].in_graph for head in range(model.config.n_head)):
                        # Handle QKV inputs
                        qkv_inp = model.transformer.h[layer].ln_1.input
                        c_proj_out = []

                        for ii, letter in enumerate('qkv'):
                            node = graph.nodes[f'a{layer}.h0']
                            prev_index = graph.prev_index(node)
                            bwd_index = graph.backward_index(node, qkv=letter, attn_slice=True)

                            update = einsum(activation_difference[:, :, :len(in_graph_matrix[:prev_index, bwd_index])], 
                                         in_graph_matrix[:prev_index, bwd_index],
                                         'batch pos previous hidden, previous ... -> batch pos ... hidden')

                            update = update.to(torch.float16)

                            qkv_out = []
                            for head in range(model.config.n_head):
                                qkv_in_clone = qkv_inp.clone()
                                update_head = update[:, :, head, :]
                                qkv_in_clone += update_head

                                qkv_in_clone = model.transformer.h[layer].ln_1(qkv_in_clone)
                                attn_out = model.transformer.h[layer].attn.c_attn(qkv_in_clone)

                                start = ii * head * (model.config.hidden_size // model.config.n_head)
                                end = start + (model.config.hidden_size // model.config.n_head)
                                attn_head_out = attn_out[:, :, start:end]
                                
                                qkv_out.append(attn_head_out)
                            
                            qkv_out = torch.cat(qkv_out, dim=-1)
                            c_proj_out.append(qkv_out)

                        c_proj_out = torch.cat(c_proj_out, dim=-1)
                        model.transformer.h[layer].attn.c_attn.output = c_proj_out

                    # MLP handling
                    if graph.nodes[f'm{layer}'].in_graph:
                        node = graph.nodes[f'm{layer}']
                        prev_index = graph.prev_index(node)
                        bwd_index = graph.backward_index(node)

                        update = einsum(activation_difference[:, :, :len(in_graph_matrix[:prev_index, bwd_index])], 
                                     in_graph_matrix[:prev_index, bwd_index],
                                     'batch pos previous hidden, previous ... -> batch pos ... hidden')
                        
                        model.transformer.h[layer].ln_2.input += update

                logits = model.lm_head.output.save()

                metric_value = metric(logits, clean_logits, input_lengths, label)
                metric_value.backward()

                # Collect gradients and update scores
                grads = model.lm_head.output.grad
                s = einsum(activation_difference, grads,
                          'batch pos forward hidden, batch pos token -> forward')
                scores += s

    scores /= total_items
    return scores




def get_scores_eap_nnsight(model: LanguageModel, graph: Graph, dataloader: DataLoader, metric: Callable[[Tensor], Tensor], intervention: Literal['patching', 'zero', 'mean','mean-positional']='patching', intervention_dataloader: Optional[DataLoader]=None, quiet=False):
    """Gets edge attribution scores using EAP.
    Args:
        model (LanguageModel): The model to attribute
        graph (Graph): Graph to attribute
        dataloader (DataLoader): The data over which to attribute
        metric (Callable[[Tensor], Tensor]): metric to attribute with respect to
        intervention (str): Type of intervention ('patching', 'zero', 'mean', 'mean-positional')
        intervention_dataloader (Optional[DataLoader]): Dataset for mean computation
        quiet (bool): Whether to suppress progress bar
    Returns:
        Tensor: a [src_nodes, dst_nodes] tensor of scores for each edge
    """
    scores = torch.zeros((graph.n_forward, graph.n_backward), device='cpu', dtype=model.dtype)

    if 'mean' in intervention:
        assert intervention_dataloader is not None, "Intervention dataloader must be provided for mean interventions"
        per_position = 'positional' in intervention
        means = compute_mean_activations(model, graph, intervention_dataloader, per_position=per_position)
        means = means.unsqueeze(0)
        if not per_position:
            means = means.unsqueeze(0)
    
    total_items = 0
    dataloader = dataloader if quiet else tqdm(dataloader)
    for clean, corrupted, label in dataloader:
        batch_size = len(clean)
        total_items += batch_size
        clean_tokens, attention_mask, input_lengths, n_pos = tokenize_plus_nnsight(model, clean)
        corrupted_tokens, _, _, _ = tokenize_plus_nnsight(model, corrupted)

        with torch.inference_mode():
            if intervention == 'patching':
                # Initialize activation differences tensor
                activation_difference = torch.zeros(
                    (batch_size, n_pos, graph.n_forward, model.config.n_embd), 
                    device=model.device, 
                    dtype=model.dtype
                )

                # Capture corrupted activations
                with model.trace({"input_ids": corrupted_tokens, "attention_mask": attention_mask}):
                    for layer in range(graph.cfg['n_layers']):
                        # Attention handling
                        node = graph.nodes[f'a{layer}.h0']
                        fwd_index = graph.forward_index(node)
                        attn_hs = model.transformer.h[layer].attn.c_proj.input

                        by_head = split_heads_nns(attn_hs, model.config.n_head, model.config.n_embd)
                        by_head = model.transformer.h[layer].attn.c_proj(by_head)
                        by_head = model.transformer.h[layer].attn.resid_dropout(by_head)
                        
                        activation_difference[:, :, fwd_index] += by_head

                        # MLP handling
                        node = graph.nodes[f'm{layer}']
                        fwd_index = graph.forward_index(node)
                        activation_difference[:, :, fwd_index][:] += model.transformer.h[layer].mlp.output[:]

            elif 'mean' in intervention:
                activation_difference += means

            # Get clean logits for reference
            with model.trace({"input_ids": clean_tokens, "attention_mask": attention_mask}):
                clean_logits = model(clean_tokens, attention_mask=attention_mask)

            # Run with activation differences injected
            with model.trace({"input_ids": clean_tokens, "attention_mask": attention_mask}) as trace:
                for layer in range(graph.cfg['n_layers']):
                    # Inject attention differences
                    node = graph.nodes[f'a{layer}.h0']
                    fwd_index = graph.forward_index(node)
                    model.transformer.h[layer].attn.output[:] += activation_difference[:, :, fwd_index]

                    # Inject MLP differences
                    node = graph.nodes[f'm{layer}']
                    fwd_index = graph.forward_index(node)
                    model.transformer.h[layer].mlp.output[:] += activation_difference[:, :, fwd_index]

                logits = model(clean_tokens, attention_mask=attention_mask)
                metric_value = metric(logits, clean_logits, input_lengths, label)
                metric_value.backward()

                # Collect gradients and update scores
                for layer in range(graph.cfg['n_layers']):
                    attn_grads = model.transformer.h[layer].attn.output.grad
                    mlp_grads = model.transformer.h[layer].mlp.output.grad

                    if attn_grads.ndim == 3:
                        attn_grads = attn_grads.unsqueeze(2)
                    if mlp_grads.ndim == 3:
                        mlp_grads = mlp_grads.unsqueeze(2)

                    # Update attention scores
                    node = graph.nodes[f'a{layer}.h0']
                    fwd_index = graph.forward_index(node)
                    s_attn = einsum(activation_difference[:, :, :fwd_index], attn_grads,
                                  'batch pos forward hidden, batch pos backward hidden -> forward backward')
                    scores[:fwd_index] += s_attn.squeeze(1)

                    # Update MLP scores
                    node = graph.nodes[f'm{layer}']
                    fwd_index = graph.forward_index(node)
                    s_mlp = einsum(activation_difference[:, :, :fwd_index], mlp_grads,
                                 'batch pos forward hidden, batch pos backward hidden -> forward backward')
                    scores[:fwd_index] += s_mlp.squeeze(1)

    scores /= total_items
    return scores




In [6]:
model = LanguageModel('gpt2', device_map='cpu')
model.config.use_split_qkv_input = True
model.config.use_attn_result = True
model.config.use_hook_mlp_in = True
model.config.ungroup_grouped_query_attention = True

dataset = HFEAPDataset("danaarad/ioi_dataset", model.tokenizer, task="ioi", num_examples=100)
dataloader = dataset.to_dataloader(20)
metric_fn = get_metric("logit_diff", "ioi", model.tokenizer, model)

In [7]:
model.config

GPT2Config {
  "_attn_implementation_autoset": true,
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.47.0",
  "ungroup_grouped_query_attention": true,
  "use_attn_result": true,
  "use_cache": true,
  "use_hook_mlp_in": tru

In [8]:
model.config.n_key_value_heads = None
model.config.dtype = torch.float32
model.config.use_normalization_before_and_after = False

In [9]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): WrapperModule()
)

In [10]:
g = Graph.from_model(model)

In [11]:
attribute_nnsight(model, g, dataloader, partial(metric_fn, loss=True, mean=True), 'EAP')


 0%|                                                     | 0/5 [00:00<?, ?it/s]

AttributeError: 'GPT2LMHeadModel' object has no attribute 'hooks'

In [13]:
from transformers import GPT2Model
import torch

# Load the model
model = GPT2Model.from_pretrained('gpt2')

# Print the dtype of model parameters
for name, param in model.named_parameters():
    print(f"{name}: {param.dtype}")
    # You can break after first parameter since they typically share the same dtype
    break

# Alternative: get dtype of the first parameter directly
first_param = next(model.parameters())
print(f"Model dtype: {first_param.dtype}")

wte.weight: torch.float32
Model dtype: torch.float32
