In [1]:
from IPython import get_ipython
from IPython.display import clear_output, display

ipython = get_ipython()
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [5]:
import os
from typing import List, Optional, Union, Dict, Tuple
from pathlib import Path 

import torch
from torch import Tensor
import numpy as np
import einops
from fancy_einsum import einsum
import circuitsvis as cv

import transformer_lens.utils as tl_utils

from transformer_lens import HookedTransformer
import transformer_lens.patching as patching

from transformers import AutoModelForCausalLM

from torch import Tensor
from jaxtyping import Float
import plotly.express as px

from functools import partial

from torchtyping import TensorType as TT

from path_patching_cm.path_patching import Node, IterNode, path_patch, act_patch
from path_patching_cm.ioi_dataset import IOIDataset, NAMES
from neel_plotly import imshow as imshow_n

from utils.visualization import imshow_p, plot_attention_heads, plot_attention
from utils.data_utils import generate_data_and_caches, UniversalPatchingDataset
from utils.metrics import compute_logit_diff, compute_probability_diff, compute_probability_mass, compute_rank_0_rate
from utils.visualization_utils import (
    plot_attention_heads,
    scatter_attention_and_contribution,
    get_attn_head_patterns
)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f554e6db9a0>

In [6]:
def compute_mean_reciprocal_rank(
        logits: torch.Tensor, 
        answer_token_indices: torch.Tensor,
        positions: torch.Tensor = None,
        flags_tensor: torch.Tensor = None,
        mode="simple"
) -> torch.Tensor:
    """
    Computes the Mean Reciprocal Rank (MRR) for each item in the batch.

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the correct answer tokens.
        positions (torch.Tensor): Positions to get logits at, one position per batch item.
        flags_tensor (torch.Tensor): Flags indicating the grouping of tokens (used in "groups" mode).
        mode (str): Mode of operation - "simple", "pairs", or "groups".

    Returns:
        torch.Tensor: The Mean Reciprocal Rank for the batch.
    """
    logits = get_positional_logits(logits, positions)
    probabilities = torch.softmax(logits, dim=-1)
    mrr = torch.zeros(logits.size(0), device=logits.device)

    # Mode 1: Simple
    if mode == "simple":
        correct_indices = answer_token_indices[:, 0]
        for i in range(logits.size(0)):
            sorted_indices = probabilities[i].sort(descending=True)[1]
            rank = (sorted_indices == correct_indices[i]).nonzero(as_tuple=True)[0].item() + 1
            mrr[i] = 1.0 / rank

    # Mode 2: Pairs
    elif mode == "pairs":
        for i in range(logits.size(0)):
            for pair in answer_token_indices[i]:
                sorted_indices = probabilities[i].sort(descending=True)[1]
                rank = (sorted_indices == pair[0]).nonzero(as_tuple=True)[0].item() + 1
                mrr[i] += 1.0 / rank
            mrr[i] /= answer_token_indices.size(1)

    # Mode 3: Groups
    elif mode == "groups":
        assert flags_tensor is not None
        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]
            sorted_indices = selected_probs.sort(descending=True)[1]
            correct_ranks = (flags_tensor[i] == 1).nonzero(as_tuple=True)[0]
            ranks = torch.tensor([sorted_indices.tolist().index(rank.item()) + 1 for rank in correct_ranks])
            mrr[i] = (1.0 / ranks).mean()

    else:
        raise ValueError("Invalid mode specified")

    return mrr.mean()

In [7]:
# 

# model = HookedTransformer.from_pretrained(
#     "EleutherAI/pythia-2.8b",
#     center_unembed=True,
#     center_writing_weights=True,
#     fold_ln=True,
#     refactor_factored_attn_matrices=False,
# )
# model.set_use_hook_mlp_in(True)

In [8]:
model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-2.8b",
    checkpoint_value=10000,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    dtype=torch.bfloat16,
    refactor_factored_attn_matrices=False,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


In [9]:
def get_positional_logits(
        logits: Float[Tensor, "batch seq d_vocab"],
        positions: Float[Tensor, "batch"] = None
)-> Float[Tensor, "batch d_vocab"]:
    """Gets the logits at the provided positions. If no positions are provided, the final logits are returned.

    Args:
        logits (torch.Tensor): Logits to use.
        positions (torch.Tensor): Positions to get logits at. This should be a tensor of shape (batch_size,).

    Returns:
        torch.Tensor: Logits at the provided positions.
    """
    if positions is None:
        return logits[:, -1, :]
    
    return logits[range(logits.size(0)), positions, :]


def compute_logit_diff(
        logits: Float[Tensor, "batch seq d_vocab"], 
        answer_token_indices: Float[Tensor, "batch num_answers"],
        positions: Float[Tensor, "batch"] = None,
        flags_tensor: torch.Tensor = None,
        per_prompt=False,
        mode="simple"
)-> Float[Tensor, "batch num_answers"]:
    """Computes the difference between a correct and incorrect logit (or mean of a group of logits) for each item in the batch.

    Takes the full logits, and the indices of the tokens to compare. These indices can be of multiple types as follows:

    - Simple: The tensor should be of shape (batch_size, 2), where the first index in the third dimension is the correct token index,
        and the second index is the incorrect token index.

    - Pairs: In this mode, answer_token_indices is a 3D tensor of shape (batch, num_pairs, 2). For each pair, you'll need to compute 
             the difference between the logits at the two indices, then average these differences across each pair for every batch item.

    - Groups: Here, answer_token_indices is also a 3D tensor of shape (batch, num_tokens, 2). The third dimension indicates group membership 
              (correct or incorrect). The mean logits for each group are calculated and then subtracted from each other.
              

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.
        positions (torch.Tensor): Positions to get logits at. Should be one position per batch item.

    Returns:
        torch.Tensor: Difference between the logits of the provided tokens.
    """
    logits = get_positional_logits(logits, positions)
    
    # Mode 1: Simple
    if mode == "simple":
        correct_logits = logits[torch.arange(logits.size(0)), answer_token_indices[:, 0]]
        incorrect_logits = logits[torch.arange(logits.size(0)), answer_token_indices[:, 1]]
        logit_diff = correct_logits - incorrect_logits

    # Mode 2: Pairs
    elif mode == "pairs":
        pair_diffs = logits[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 0]] - \
                     logits[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 1]]
        logit_diff = pair_diffs.mean(dim=1)

    # Mode 3: Groups
    elif mode == "groups":
        assert flags_tensor is not None
        logit_diff = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_logits = logits[i, answer_token_indices[i]]

            # Calculate the logit difference using the correct/incorrect flags
            correct_logits = selected_logits[flags_tensor[i] == 1]
            incorrect_logits = selected_logits[flags_tensor[i] == -1]

            # Handle cases where there are no correct or incorrect logits
            if len(correct_logits) > 0:
                correct_mean = correct_logits.mean()
            else:
                correct_mean = 0

            if len(incorrect_logits) > 0:
                incorrect_mean = incorrect_logits.mean()
            else:
                incorrect_mean = 0

            logit_diff[i] = correct_mean - incorrect_mean

    else:
        raise ValueError("Invalid mode specified")

    return logit_diff.mean() if not per_prompt else logit_diff



import torch
import torch.nn.functional as F

def compute_probability_diff(
        logits: torch.Tensor, 
        answer_token_indices: torch.Tensor,
        positions: torch.Tensor = None,
        flags_tensor: torch.Tensor = None,
        per_prompt=False,
        mode="simple"
) -> torch.Tensor:
    """Computes the difference between probability of a correct and incorrect logit (or mean of a group of logits) for each item in the batch.

    Takes the full logits, and the indices of the tokens to compare. These indices can be of multiple types as follows:

    - Simple: The tensor should be of shape (batch_size, 2), where the first index in the third dimension is the correct token index,
        and the second index is the incorrect token index.

    - Pairs: In this mode, answer_token_indices is a 3D tensor of shape (batch, num_pairs, 2). For each pair, you'll need to compute 
             the difference between the probabilities at the two indices, then average these differences across each pair for every batch item.

    - Groups: Here, answer_token_indices is also a 3D tensor of shape (batch, num_tokens, 2). The third dimension indicates group membership 
              (correct or incorrect). The mean probabilities for each group are calculated and then subtracted from each other.
              

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.
        positions (torch.Tensor): Positions to get logits at. Should be one position per batch item.

    Returns:
        torch.Tensor: Difference between the logits of the provided tokens.
    """
    logits = get_positional_logits(logits, positions)
    probabilities = torch.softmax(logits, dim=-1)  # Applying softmax to logits
    print(f"probabilities={probabilities.shape}")

    # Mode 1: Simple
    if mode == "simple":
        correct_probs = probabilities[torch.arange(logits.size(0)), answer_token_indices[:, 0]]
        incorrect_probs = probabilities[torch.arange(logits.size(0)), answer_token_indices[:, 1]]
        prob_diff = correct_probs - incorrect_probs

    # Mode 2: Pairs
    elif mode == "pairs":
        pair_diffs = probabilities[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 0]] - \
                     probabilities[torch.arange(logits.size(0))[:, None], answer_token_indices[..., 1]]
        prob_diff = pair_diffs.mean(dim=1)

    # Mode 3: Groups
    elif mode == "groups":
        # Initialize tensors to store the probability differences for each batch item
        assert flags_tensor is not None
        prob_diff = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            # Select the probabilities for the token IDs of this batch item
            selected_probs = probabilities[i, answer_token_indices[i]]

            # Calculate the probability difference using the correct/incorrect flags
            correct_probs = selected_probs[flags_tensor[i] == 1]
            incorrect_probs = selected_probs[flags_tensor[i] == -1]

            # Handle cases where there are no correct or incorrect tokens
            if len(correct_probs) > 0:
                correct_mean = correct_probs.mean()
            else:
                correct_mean = 0

            if len(incorrect_probs) > 0:
                incorrect_mean = incorrect_probs.mean()
            else:
                incorrect_mean = 0

            prob_diff[i] = correct_mean - incorrect_mean

    # Mode 4: Group Sum
    elif mode == "group_sum":
        assert flags_tensor is not None
        prob_diff = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]

            # Calculate the sum of probabilities using the correct/incorrect flags
            correct_sum = selected_probs[flags_tensor[i] == 1].sum()
            incorrect_sum = selected_probs[flags_tensor[i] == -1].sum()

            prob_diff[i] = incorrect_sum - correct_sum

    else:
        raise ValueError("Invalid mode specified")

    return prob_diff.mean() if not per_prompt else prob_diff


def compute_probability_mass(
        logits: torch.Tensor, 
        answer_token_indices: torch.Tensor,
        positions: torch.Tensor = None,
        flags_tensor: torch.Tensor = None,
        group="correct",
        mode="simple"
) -> torch.Tensor:
    logits = get_positional_logits(logits, positions)
    probabilities = torch.softmax(logits, dim=-1)

    # Determine the flag value based on the specified group
    flag_value = 1 if group == "correct" else -1

    # Mode logic
    if mode == "simple":
        selected_indices = answer_token_indices[:, 0] if group == "correct" else answer_token_indices[:, 1]
        group_probs = probabilities[torch.arange(logits.size(0)), selected_indices]

    elif mode == "pairs":
        group_probs = torch.zeros(logits.size(0), device=logits.device)
        for i in range(logits.size(0)):
            for pair in answer_token_indices[i]:
                selected_index = pair[0] if group == "correct" else pair[1]
                group_probs[i] += probabilities[i, selected_index]
            group_probs[i] /= answer_token_indices.size(1)

    elif mode == "groups":
        assert flags_tensor is not None
        group_probs = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]
            group_probs[i] = selected_probs[flags_tensor[i] == flag_value].mean()

    elif mode == "group_sum":
        assert flags_tensor is not None
        group_probs = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]
            group_probs[i] = selected_probs[flags_tensor[i] == flag_value].sum()

    else:
        raise ValueError("Invalid mode specified")

    return group_probs.mean()



def compute_rank_0_rate(
        logits: torch.Tensor, 
        answer_token_indices: torch.Tensor,
        positions: torch.Tensor = None,
        flags_tensor: torch.Tensor = None,
        group="correct",
        mode="simple"
) -> torch.Tensor:
    logits = get_positional_logits(logits, positions)
    probabilities = torch.softmax(logits, dim=-1)

    # Mode logic
    if mode == "simple":
        top_rank_indices = probabilities.argmax(dim=-1)
        correct_indices = answer_token_indices[:, 0] if group == "correct" else answer_token_indices[:, 1]
        rank_0_rate = (top_rank_indices == correct_indices).float().mean()

    elif mode == "pairs":
        rank_0_rate = torch.zeros(logits.size(0), device=logits.device)
        for i in range(logits.size(0)):
            for pair in answer_token_indices[i]:
                top_rank_index = probabilities[i].argmax()
                correct_index = pair[0] if group == "correct" else pair[1]
                rank_0_rate[i] += (top_rank_index == correct_index).float()
            rank_0_rate[i] /= answer_token_indices.size(1)

    elif mode == "groups":
        assert flags_tensor is not None
        rank_0_rate = torch.zeros(logits.size(0), device=logits.device)

        for i in range(logits.size(0)):
            selected_probs = probabilities[i, answer_token_indices[i]]
            top_rank_id = selected_probs.argmax()
            rank_0_rate[i] = (flags_tensor[i, top_rank_id] == 1).float() if group == "correct" else \
                             (flags_tensor[i, top_rank_id] == -1).float()

    else:
        raise ValueError("Invalid mode specified")

    return rank_0_rate.mean()


import torch

def compute_max_group_rank_reciprocal(
        logits: torch.Tensor, 
        answer_token_indices: torch.Tensor,
        positions: torch.Tensor = None,
        flags_tensor: torch.Tensor = None,
        mode="simple"
) -> torch.Tensor:
    """
    Computes the mean of the reciprocal of the maximum rank of members of the correct group across different modes.

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens for comparison or grouping.
        positions (torch.Tensor): Positions to get logits at, one position per batch item.
        flags_tensor (torch.Tensor): Flags indicating the grouping of tokens (used in "groups" mode).
        mode (str): Operation mode - "simple", "pairs", or "groups".

    Returns:
        torch.Tensor: The mean of the reciprocal of the maximum rank of correct group members.
    """
    logits = get_positional_logits(logits, positions)
    probabilities = torch.softmax(logits, dim=-1)
    batch_size = logits.size(0)

    # Initialize tensor to hold the reciprocal of the maximum rank for each item in the batch
    reciprocal_max_rank = torch.zeros(batch_size, device=logits.device)

    if mode == "simple":
        for i in range(batch_size):
            correct_index = answer_token_indices[i, 0]
            sorted_indices = probabilities[i].sort(descending=True)[1]
            rank = (sorted_indices == correct_index).nonzero(as_tuple=True)[0].item() + 1
            reciprocal_max_rank[i] = 1.0 / rank

    elif mode == "pairs":
        for i in range(batch_size):
            pair_ranks = []
            for pair in answer_token_indices[i]:
                # Only consider the first index in each pair as correct
                correct_index = pair[0]
                sorted_indices = probabilities[i].sort(descending=True)[1]
                rank = (sorted_indices == correct_index).nonzero(as_tuple=True)[0].item() + 1
                pair_ranks.append(rank)
            # Use the max rank from pairs
            max_rank = min(pair_ranks)
            reciprocal_max_rank[i] = 1.0 / max_rank

    elif mode == "groups":
        for i in range(batch_size):
            group_ranks = []
            for j, flag in enumerate(flags_tensor[i]):
                if flag == 1:  # Correct group member
                    correct_index = answer_token_indices[i, j]
                    sorted_indices = probabilities[i].sort(descending=True)[1]
                    rank = (sorted_indices == correct_index).nonzero(as_tuple=True)[0].item() + 1
                    group_ranks.append(rank)
            # Use the max rank from correct group members
            if group_ranks:
                max_rank = min(group_ranks)
                reciprocal_max_rank[i] = 1.0 / max_rank
            else:
                reciprocal_max_rank[i] = 0  # Handle case with no correct answers

    else:
        raise ValueError("Invalid mode specified")

    return reciprocal_max_rank.mean()


In [10]:
def test_compute_max_group_rank_reciprocal():
    # Define a helper function to simplify the test cases creation
    def create_logits_and_indices(logits_values, correct_indices, flags=None):
        logits = torch.tensor(logits_values, dtype=torch.float).unsqueeze(0).unsqueeze(1)  # Add batch dimension if needed
        answer_token_indices = torch.tensor(correct_indices, dtype=torch.long).unsqueeze(1)  # Adjust dimensions as needed
        flags_tensor = torch.tensor(flags, dtype=torch.long).unsqueeze(1) if flags is not None else None
        return logits, answer_token_indices, flags_tensor

    # Simple Mode Test Case
    logits, answer_token_indices, _ = create_logits_and_indices([0.1, 0.2, 0.7, 0.6], [2])
    mrr_simple = compute_max_group_rank_reciprocal(logits, answer_token_indices, mode="simple")
    print(f"Simple mode MRR: {mrr_simple}")

    # Pairs Mode Test Case
    logits, answer_token_indices, _ = create_logits_and_indices([[0.1, 0.2], [0.7, 0.6], [0.4, 0.5]], [[[2, 1], [0, 3]]])
    mrr_pairs = compute_max_group_rank_reciprocal(logits, answer_token_indices, mode="pairs")
    print(f"Pairs mode MRR: {mrr_pairs}")

    # Groups Mode Test Case
    logits, answer_token_indices, flags_tensor = create_logits_and_indices([0.1, 0.2, 0.7, 0.6], [0, 1, 2, 3], [1, -1, 1, -1])
    mrr_groups = compute_max_group_rank_reciprocal(logits, answer_token_indices, flags_tensor=flags_tensor, mode="groups")
    print(f"Groups mode MRR: {mrr_groups}")

# Execute the test function
test_compute_max_group_rank_reciprocal()

Simple mode MRR: 1.0
Pairs mode MRR: 0.5
Groups mode MRR: 0.25


## IOI

### Old

In [111]:
def _logits_to_mean_logit_diff(logits: Float[Tensor, "batch seq d_vocab"], ioi_dataset: IOIDataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    print(io_logits.shape)
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [112]:
N = 70
ioi_dataset, abc_dataset, _, _, _ = generate_data_and_caches(model, N, verbose=True)
clean_toks = ioi_dataset.toks.to(device)
corrupted_toks = abc_dataset.toks.to(device)

Average logit diff (IOI dataset): 2.1560
Average logit diff (ABC dataset): -2.3650


In [113]:
logits = model(clean_toks)
_logits_to_mean_logit_diff(logits, ioi_dataset, per_prompt=True)

torch.Size([70])


tensor([ 0.1210,  0.6408,  4.3679,  3.2872,  3.1052,  2.4821,  4.4465,  0.3192,
         2.3799, -0.6708,  2.6662,  3.7824,  0.8591,  0.4677,  0.2200,  3.3067,
         5.4179,  0.1047,  2.9370,  0.1339,  2.9071,  3.4575,  2.6687,  5.2737,
         1.8654,  0.7553,  0.5263,  4.9929,  1.6752,  0.7784,  2.5545,  1.4186,
         3.3981, -0.6675,  1.2263, -1.0020,  0.8154,  5.7192,  3.0224,  0.3226,
         6.5352,  4.1738,  1.3393,  2.3986,  4.4888,  2.1893,  5.5038,  0.3437,
        -0.7437,  4.7308,  0.9047,  2.9857,  2.0731,  3.2680,  4.3207,  3.0555,
         2.4662,  3.7404,  0.7805,  2.3680,  2.7400,  1.5318,  0.1822,  0.9666,
         0.9466,  0.9064, -0.6956,  0.6441,  2.0869,  2.6091], device='cuda:0')

### New

In [6]:
ds = UniversalPatchingDataset.from_ioi(model, 100)

Average logit diff (IOI dataset): 4.1921
Average logit diff (ABC dataset): -3.9570


In [7]:
logits = model(ds.toks)

In [8]:
compute_logit_diff(logits, ds.answer_toks, ds.positions, per_prompt=False)

tensor(4.1921, device='cuda:0')

In [9]:
compute_probability_diff(logits, ds.answer_toks, ds.positions, per_prompt=False)

probabilities=torch.Size([100, 50304])


tensor(0.2722, device='cuda:0')

In [10]:
compute_mean_reciprocal_rank(logits, ds.answer_toks, ds.positions, mode="simple")

tensor(0.8254, device='cuda:0')

In [11]:
compute_max_group_rank_reciprocal(logits, ds.answer_toks, ds.positions, mode="simple")

tensor(0.8254, device='cuda:0')

## Greater-Than

### Old

In [13]:
from data.greater_than_dataset import get_prob_diff, YearDataset, get_valid_years

In [39]:
ds_old = YearDataset(get_valid_years(model.tokenizer, 1100, 1800), 1000, Path("data/potential_nouns.txt"), model.tokenizer)

# def batch(iterable, n:int=1):
#    current_batch = []
#    for item in iterable:
#        current_batch.append(item)
#        if len(current_batch) == n:
#            yield current_batch
#            current_batch = []
#    if current_batch:
#        yield current_batch

# clean = list(batch(ds.good_sentences, 9))
# labels = list(batch(ds.years_YY, 9))
# corrupted = list(batch(ds.bad_sentences, 9))

In [40]:
IDX = 768
#model.to_str_tokens(ds.good_toks[IDX]), model.to_str_tokens(ds.bad_toks[IDX])

In [41]:
import torch

def prepare_indices_for_prob_diff(tokenizer, years):
    """
    Prepares two tensors for use with the compute_probability_diff function in 'groups' mode.

    Args:
        tokenizer (PreTrainedTokenizer): Tokenizer to convert years to token indices.
        years (torch.Tensor): Tensor containing the year for each prompt in the batch.

    Returns:
        torch.Tensor, torch.Tensor: Two tensors, one for token IDs and one for correct/incorrect flags.
    """

    # Get the indices for years 00 to 99
    year_indices = get_year_indices(tokenizer)  # Tensor of size 100 with token IDs for years

    # Prepare tensors to store token IDs and correct/incorrect flags
    token_ids_tensor = year_indices.repeat(years.size(0), 1)  # Repeat the year_indices for each batch item
    flags_tensor = torch.zeros_like(token_ids_tensor)  # Initialize the flags tensor with zeros

    for i, year in enumerate(years):
        # Mark years greater than the given year as correct (1)
        flags_tensor[i, year + 1:] = 1
        # Mark years less than or equal to the given year as incorrect (-1)
        flags_tensor[i, :year + 1] = -1

    return token_ids_tensor, flags_tensor



In [42]:
#input_length = 1 + len(model.tokenizer(ds.good_sentences[0])[0])
prob_diff = get_prob_diff(model.tokenizer)

In [43]:
from utils.circuit_utils import run_with_batches

clean_logits = run_with_batches(model, ds_old.good_toks.to(device), batch_size=20, max_seq_len=12)
corrupted_logits = run_with_batches(model, ds_old.bad_toks.to(device), batch_size=20, max_seq_len=12)

In [44]:
prob_diff(clean_logits,ds_old.years_YY)

tensor(-0.8363, device='cuda:0')

### New

In [12]:
ds = UniversalPatchingDataset.from_greater_than(model, 1000)

In [13]:
logits = model(ds.toks)

In [14]:
compute_probability_diff(logits, ds.answer_toks, flags_tensor=ds.group_flags, mode="group_sum")

probabilities=torch.Size([1000, 50304])


tensor(-0.8294, device='cuda:0')

In [15]:
compute_mean_reciprocal_rank(logits, ds.answer_toks, ds.positions, ds.group_flags, mode="groups")

tensor(0.1409, device='cuda:0')

In [18]:
compute_max_group_rank_reciprocal(logits, ds.answer_toks, ds.positions, ds.group_flags, mode="groups")

tensor(0.9980, device='cuda:0')

## Sentiment

In [19]:
from data.sentiment_datasets import get_dataset, PromptType, get_prompts
from utils.circuit_analysis import get_logit_diff as get_logit_diff_ca

### Classification

#### Old

In [46]:
ds_type = PromptType.CLASSIFICATION_4

In [47]:
ds = get_dataset(model, device, prompt_type=ds_type)
ds.all_prompts[0]

Reading prompts from config and filtering


'I thought this movie was amazing, I loved it. The acting was awesome, the plot was beautiful, and overall the movie was just very good. Review Sentiment:'

In [48]:
ds.clean_tokens.shape

torch.Size([22, 35])

In [49]:
ds.answer_tokens.shape

torch.Size([22, 1, 2])

In [50]:
clean_logits = model(ds.clean_tokens.to(device))
corrupted_logits = model(ds.corrupted_tokens.to(device))

In [51]:
ds.answer_tokens.shape

torch.Size([22, 1, 2])

In [52]:
from utils.metrics import CircuitMetric
logit_diff_metric = CircuitMetric("logit_diff_multi", partial(get_logit_diff_ca, answer_tokens=ds.answer_tokens))

In [53]:
logit_diff_metric(clean_logits), logit_diff_metric(corrupted_logits)

(tensor(0.3842, device='cuda:0'), tensor(-0.3842, device='cuda:0'))

#### New

In [48]:
ds = UniversalPatchingDataset.from_sentiment(model, "class")

Reading prompts from config and filtering


In [49]:
logits = model(ds.toks)
flipped_logits = model(ds.flipped_toks)

In [50]:
compute_logit_diff(logits, ds.answer_toks, mode="pairs"), compute_logit_diff(flipped_logits, ds.answer_toks, mode="pairs")

(tensor(3.1217, device='cuda:0'), tensor(-3.1217, device='cuda:0'))

In [51]:
compute_mean_reciprocal_rank(logits, ds.answer_toks, ds.positions, mode="pairs")

tensor(0.4267, device='cuda:0')

In [52]:
compute_max_group_rank_reciprocal(logits, ds.answer_toks, ds.positions, mode="pairs")

tensor(0.4267, device='cuda:0')

In [53]:
from transformer_lens.utils import test_prompt
for prompt_tokens in ds.toks:
    prompt = model.to_string(prompt_tokens[1:])
    test_prompt(prompt, " Positive", model, top_k=5)

Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' amazing', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' awesome', ',', ' the', ' plot', ' was', ' beautiful', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.18 Prob:  9.65% Token: | Positive|
Top 1th token. Logit: 14.07 Prob:  8.68% Token: | 4|
Top 2th token. Logit: 13.85 Prob:  6.99% Token: | 9|
Top 3th token. Logit: 13.75 Prob:  6.29% Token: | 8|
Top 4th token. Logit: 13.68 Prob:  5.84% Token: | 5|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' awful', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' bad', ',', ' the', ' plot', ' was', ' disappointing', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' applaud', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.13 Prob: 13.44% Token: | 1|
Top 1th token. Logit: 13.89 Prob: 10.63% Token: | 0|
Top 2th token. Logit: 13.31 Prob:  5.96% Token: | 3|
Top 3th token. Logit: 13.10 Prob:  4.82% Token: | 4|
Top 4th token. Logit: 13.08 Prob:  4.70% Token: | 2|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' awesome', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' beautiful', ',', ' the', ' plot', ' was', ' brilliant', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.20 Prob:  9.69% Token: | Positive|
Top 1th token. Logit: 14.14 Prob:  9.14% Token: | 4|
Top 2th token. Logit: 13.88 Prob:  7.06% Token: | 8|
Top 3th token. Logit: 13.87 Prob:  6.93% Token: | 9|
Top 4th token. Logit: 13.69 Prob:  5.83% Token: | 5|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' bad', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' disappointing', ',', ' the', ' plot', ' was', ' disgusting', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' appreciate', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 13.66 Prob:  8.46% Token: | 1|
Top 1th token. Logit: 13.62 Prob:  8.07% Token: | 0|
Top 2th token. Logit: 13.15 Prob:  5.07% Token: | 5|
Top 3th token. Logit: 13.14 Prob:  4.99% Token: | 4|
Top 4th token. Logit: 13.09 Prob:  4.75% Token: | 3|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' beautiful', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' brilliant', ',', ' the', ' plot', ' was', ' exceptional', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.22 Prob:  9.78% Token: | 4|
Top 1th token. Logit: 14.19 Prob:  9.50% Token: | Positive|
Top 2th token. Logit: 13.92 Prob:  7.22% Token: | 9|
Top 3th token. Logit: 13.86 Prob:  6.80% Token: | 5|
Top 4th token. Logit: 13.80 Prob:  6.40% Token: | 8|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' disappointing', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' disgusting', ',', ' the', ' plot', ' was', ' dreadful', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' commend', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.64 Prob: 16.21% Token: | 1|
Top 1th token. Logit: 14.35 Prob: 12.09% Token: | 0|
Top 2th token. Logit: 13.76 Prob:  6.74% Token: | 3|
Top 3th token. Logit: 13.58 Prob:  5.64% Token: | 2|
Top 4th token. Logit: 13.53 Prob:  5.34% Token: | 4|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' brilliant', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' exceptional', ',', ' the', ' plot', ' was', ' extraordinary', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.19 Prob:  9.65% Token: | 4|
Top 1th token. Logit: 13.96 Prob:  7.68% Token: | Positive|
Top 2th token. Logit: 13.93 Prob:  7.45% Token: | 8|
Top 3th token. Logit: 13.91 Prob:  7.34% Token: | 9|
Top 4th token. Logit: 13.84 Prob:  6.81% Token: | 5|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' disgusting', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' dreadful', ',', ' the', ' plot', ' was', ' horrible', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' embrace', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 13.94 Prob: 11.96% Token: | 0|
Top 1th token. Logit: 13.75 Prob:  9.86% Token: | 1|
Top 2th token. Logit: 12.82 Prob:  3.91% Token: | 3|
Top 3th token. Logit: 12.74 Prob:  3.61% Token: | -|
Top 4th token. Logit: 12.64 Prob:  3.27% Token: |
|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' exceptional', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' extraordinary', ',', ' the', ' plot', ' was', ' fabulous', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.08 Prob:  8.93% Token: | 4|
Top 1th token. Logit: 13.91 Prob:  7.53% Token: | Positive|
Top 2th token. Logit: 13.87 Prob:  7.22% Token: | 9|
Top 3th token. Logit: 13.80 Prob:  6.73% Token: | 5|
Top 4th token. Logit: 13.78 Prob:  6.57% Token: | 8|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' dreadful', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' horrible', ',', ' the', ' plot', ' was', ' miserable', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' endorse', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.13 Prob: 14.34% Token: | 0|
Top 1th token. Logit: 14.03 Prob: 13.07% Token: | 1|
Top 2th token. Logit: 13.31 Prob:  6.31% Token: | -|
Top 3th token. Logit: 12.90 Prob:  4.22% Token: | 3|
Top 4th token. Logit: 12.70 Prob:  3.43% Token: | 2|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' extraordinary', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' fabulous', ',', ' the', ' plot', ' was', ' fantastic', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.17 Prob:  9.91% Token: | 4|
Top 1th token. Logit: 14.04 Prob:  8.65% Token: | Positive|
Top 2th token. Logit: 13.79 Prob:  6.76% Token: | 5|
Top 3th token. Logit: 13.77 Prob:  6.58% Token: | 8|
Top 4th token. Logit: 13.76 Prob:  6.57% Token: | 9|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' horrible', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' miserable', ',', ' the', ' plot', ' was', ' offensive', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' enjoy', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.34 Prob: 14.58% Token: | 0|
Top 1th token. Logit: 14.19 Prob: 12.53% Token: | 1|
Top 2th token. Logit: 13.23 Prob:  4.79% Token: |
|
Top 3th token. Logit: 13.10 Prob:  4.21% Token: | 3|
Top 4th token. Logit: 13.02 Prob:  3.89% Token: | 2|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' fabulous', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' fantastic', ',', ' the', ' plot', ' was', ' good', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.46 Prob: 12.11% Token: | Positive|
Top 1th token. Logit: 14.29 Prob: 10.18% Token: | 4|
Top 2th token. Logit: 13.81 Prob:  6.31% Token: | 8|
Top 3th token. Logit: 13.73 Prob:  5.85% Token: | 5|
Top 4th token. Logit: 13.56 Prob:  4.92% Token: | 9|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' miserable', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' offensive', ',', ' the', ' plot', ' was', ' terrible', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' favor', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.07 Prob: 12.34% Token: | 0|
Top 1th token. Logit: 14.07 Prob: 12.34% Token: | 1|
Top 2th token. Logit: 13.51 Prob:  7.03% Token: | -|
Top 3th token. Logit: 13.25 Prob:  5.41% Token: | 3|
Top 4th token. Logit: 13.06 Prob:  4.48% Token: | 2|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' fantastic', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' good', ',', ' the', ' plot', ' was', ' great', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.65 Prob: 13.99% Token: | Positive|
Top 1th token. Logit: 14.20 Prob:  8.91% Token: | 4|
Top 2th token. Logit: 13.94 Prob:  6.86% Token: | 8|
Top 3th token. Logit: 13.81 Prob:  6.01% Token: | 9|
Top 4th token. Logit: 13.63 Prob:  5.03% Token: | 5|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' offensive', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' terrible', ',', ' the', ' plot', ' was', ' unpleasant', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' like', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 13.75 Prob:  9.94% Token: | 1|
Top 1th token. Logit: 13.74 Prob:  9.81% Token: | 0|
Top 2th token. Logit: 12.78 Prob:  3.76% Token: | 3|
Top 3th token. Logit: 12.71 Prob:  3.50% Token: | -|
Top 4th token. Logit: 12.59 Prob:  3.10% Token: | 2|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' good', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' great', ',', ' the', ' plot', ' was', ' incredible', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.54 Prob: 12.84% Token: | Positive|
Top 1th token. Logit: 14.18 Prob:  8.93% Token: | 4|
Top 2th token. Logit: 14.18 Prob:  8.89% Token: | 8|
Top 3th token. Logit: 13.80 Prob:  6.07% Token: | 9|
Top 4th token. Logit: 13.56 Prob:  4.78% Token: | 5|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' terrible', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' unpleasant', ',', ' the', ' plot', ' was', ' wretched', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' love', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 13.62 Prob:  9.15% Token: | 1|
Top 1th token. Logit: 13.57 Prob:  8.66% Token: | 0|
Top 2th token. Logit: 12.73 Prob:  3.74% Token: |
|
Top 3th token. Logit: 12.73 Prob:  3.74% Token: | 3|
Top 4th token. Logit: 12.72 Prob:  3.69% Token: | -|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' great', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' incredible', ',', ' the', ' plot', ' was', ' lovely', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.37 Prob: 10.96% Token: | Positive|
Top 1th token. Logit: 14.24 Prob:  9.69% Token: | 4|
Top 2th token. Logit: 14.04 Prob:  7.92% Token: | 8|
Top 3th token. Logit: 13.90 Prob:  6.84% Token: | 9|
Top 4th token. Logit: 13.74 Prob:  5.88% Token: | 5|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' unpleasant', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' wretched', ',', ' the', ' plot', ' was', ' awful', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' praise', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 13.92 Prob: 12.14% Token: | 1|
Top 1th token. Logit: 13.72 Prob:  9.93% Token: | 0|
Top 2th token. Logit: 13.14 Prob:  5.59% Token: | 3|
Top 3th token. Logit: 12.88 Prob:  4.30% Token: | 4|
Top 4th token. Logit: 12.79 Prob:  3.91% Token: | 2|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' incredible', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' lovely', ',', ' the', ' plot', ' was', ' outstanding', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' good', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.27 Prob:  9.80% Token: | 4|
Top 1th token. Logit: 14.20 Prob:  9.15% Token: | Positive|
Top 2th token. Logit: 14.03 Prob:  7.75% Token: | 9|
Top 3th token. Logit: 13.98 Prob:  7.35% Token: | 8|
Top 4th token. Logit: 13.86 Prob:  6.48% Token: | 5|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' wretched', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' awful', ',', ' the', ' plot', ' was', ' bad', ',', ' and', ' overall', ' the', ' movie', ' was', ' just', ' very', ' respect', '.', ' Review', ' Sent', 'iment', ':']
Tokenized answer: [' Positive']


Top 0th token. Logit: 14.23 Prob: 13.97% Token: | 1|
Top 1th token. Logit: 13.76 Prob:  8.79% Token: | 0|
Top 2th token. Logit: 13.22 Prob:  5.12% Token: | 3|
Top 3th token. Logit: 13.06 Prob:  4.33% Token: | 2|
Top 4th token. Logit: 12.97 Prob:  3.98% Token: | -|


### Continuation

#### Old

In [65]:
ds_type = PromptType.COMPLETION_2

In [66]:
ds = get_dataset(model, device, prompt_type=ds_type)
ds.all_prompts[0]

Reading prompts from config and filtering


'I thought this movie was amazing, I loved it. The acting was awesome, the plot was beautiful, and overall it was just very'

In [67]:
ds.clean_tokens.shape

torch.Size([22, 28])

In [68]:
ds.answer_tokens.shape

torch.Size([22, 5, 2])

In [69]:
clean_logits = model(ds.clean_tokens.to(device))
corrupted_logits = model(ds.corrupted_tokens.to(device))

In [70]:
ds.answer_tokens.shape

torch.Size([22, 5, 2])

In [71]:
from utils.metrics import CircuitMetric
logit_diff_metric = CircuitMetric("logit_diff_multi", partial(get_logit_diff_ca, answer_tokens=ds.answer_tokens))

In [72]:
logit_diff_metric(clean_logits), logit_diff_metric(corrupted_logits)

(tensor(3.2004, device='cuda:0'), tensor(-3.2004, device='cuda:0'))

#### New

In [25]:
ds = UniversalPatchingDataset.from_sentiment(model, "cont")

Reading prompts from config and filtering


In [27]:
logits = model(ds.toks)
compute_logit_diff(logits, ds.answer_toks, mode="pairs")

tensor(3.2004, device='cuda:0')

In [28]:
from utils.metrics import compute_accuracy
compute_accuracy(logits, ds.answer_toks, mode="pairs")

1.0

In [29]:
compute_mean_reciprocal_rank(logits, ds.answer_toks, mode="pairs")

tensor(0.2233, device='cuda:0')

In [32]:
compute_max_group_rank_reciprocal(logits, ds.answer_toks, mode="pairs")

tensor(1., device='cuda:0')

In [40]:
from transformer_lens.utils import test_prompt
for prompt_tokens in ds.toks:
    prompt = model.to_string(prompt_tokens[1:])
    test_prompt(prompt, "bad", model, top_k=5)

Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' amazing', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' awesome', ',', ' the', ' plot', ' was', ' beautiful', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.80 Prob: 15.98% Token: | good|
Top 1th token. Logit: 31.14 Prob:  8.25% Token: | well|
Top 2th token. Logit: 30.82 Prob:  5.98% Token: | enjoyable|
Top 3th token. Logit: 30.66 Prob:  5.10% Token: | entertaining|
Top 4th token. Logit: 30.62 Prob:  4.92% Token: | fun|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' awful', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' bad', ',', ' the', ' plot', ' was', ' disappointing', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.67 Prob: 14.77% Token: | bad|
Top 1th token. Logit: 30.85 Prob:  6.50% Token: | poor|
Top 2th token. Logit: 30.38 Prob:  4.07% Token: | dull|
Top 3th token. Logit: 30.29 Prob:  3.72% Token: | boring|
Top 4th token. Logit: 30.24 Prob:  3.55% Token: |,|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' awesome', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' beautiful', ',', ' the', ' plot', ' was', ' brilliant', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.84 Prob: 17.54% Token: | good|
Top 1th token. Logit: 31.12 Prob:  8.56% Token: | well|
Top 2th token. Logit: 30.94 Prob:  7.14% Token: | enjoyable|
Top 3th token. Logit: 30.79 Prob:  6.14% Token: | fun|
Top 4th token. Logit: 30.64 Prob:  5.28% Token: | entertaining|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' bad', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' disappointing', ',', ' the', ' plot', ' was', ' disgusting', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.95 Prob: 18.35% Token: | bad|
Top 1th token. Logit: 30.74 Prob:  5.49% Token: | poor|
Top 2th token. Logit: 30.60 Prob:  4.77% Token: | boring|
Top 3th token. Logit: 30.36 Prob:  3.76% Token: | good|
Top 4th token. Logit: 30.34 Prob:  3.69% Token: | dull|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' beautiful', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' brilliant', ',', ' the', ' plot', ' was', ' exceptional', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 32.06 Prob: 19.87% Token: | good|
Top 1th token. Logit: 31.41 Prob: 10.45% Token: | well|
Top 2th token. Logit: 30.81 Prob:  5.72% Token: | enjoyable|
Top 3th token. Logit: 30.40 Prob:  3.80% Token: | entertaining|
Top 4th token. Logit: 30.22 Prob:  3.16% Token: | fun|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' disappointing', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' disgusting', ',', ' the', ' plot', ' was', ' dreadful', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.37 Prob:  9.52% Token: | bad|
Top 1th token. Logit: 31.06 Prob:  6.94% Token: | boring|
Top 2th token. Logit: 30.77 Prob:  5.24% Token: | poor|
Top 3th token. Logit: 30.69 Prob:  4.80% Token: | dull|
Top 4th token. Logit: 30.53 Prob:  4.12% Token: | poorly|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' brilliant', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' exceptional', ',', ' the', ' plot', ' was', ' extraordinary', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 32.23 Prob: 21.64% Token: | good|
Top 1th token. Logit: 31.37 Prob:  9.15% Token: | well|
Top 2th token. Logit: 31.11 Prob:  7.03% Token: | enjoyable|
Top 3th token. Logit: 30.84 Prob:  5.34% Token: | entertaining|
Top 4th token. Logit: 30.49 Prob:  3.78% Token: | fun|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' disgusting', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' dreadful', ',', ' the', ' plot', ' was', ' horrible', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.37 Prob: 10.55% Token: | bad|
Top 1th token. Logit: 30.60 Prob:  4.90% Token: | boring|
Top 2th token. Logit: 30.38 Prob:  3.91% Token: | poor|
Top 3th token. Logit: 30.29 Prob:  3.57% Token: |,|
Top 4th token. Logit: 30.14 Prob:  3.09% Token: | dull|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' exceptional', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' extraordinary', ',', ' the', ' plot', ' was', ' fabulous', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.99 Prob: 19.69% Token: | good|
Top 1th token. Logit: 31.26 Prob:  9.44% Token: | well|
Top 2th token. Logit: 31.21 Prob:  9.01% Token: | enjoyable|
Top 3th token. Logit: 30.59 Prob:  4.83% Token: | entertaining|
Top 4th token. Logit: 30.20 Prob:  3.29% Token: | fun|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' dreadful', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' horrible', ',', ' the', ' plot', ' was', ' miserable', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.62 Prob: 13.30% Token: | bad|
Top 1th token. Logit: 31.05 Prob:  7.56% Token: | poor|
Top 2th token. Logit: 30.77 Prob:  5.70% Token: | dull|
Top 3th token. Logit: 30.66 Prob:  5.09% Token: | boring|
Top 4th token. Logit: 30.41 Prob:  3.99% Token: |,|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' extraordinary', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' fabulous', ',', ' the', ' plot', ' was', ' fantastic', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.44 Prob: 13.29% Token: | good|
Top 1th token. Logit: 31.06 Prob:  9.10% Token: | well|
Top 2th token. Logit: 30.72 Prob:  6.51% Token: | enjoyable|
Top 3th token. Logit: 30.50 Prob:  5.22% Token: | entertaining|
Top 4th token. Logit: 29.86 Prob:  2.75% Token: |,|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' horrible', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' miserable', ',', ' the', ' plot', ' was', ' offensive', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 32.07 Prob: 18.31% Token: | bad|
Top 1th token. Logit: 31.03 Prob:  6.42% Token: | poor|
Top 2th token. Logit: 30.50 Prob:  3.81% Token: | boring|
Top 3th token. Logit: 30.37 Prob:  3.32% Token: | poorly|
Top 4th token. Logit: 30.36 Prob:  3.29% Token: |,|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' fabulous', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' fantastic', ',', ' the', ' plot', ' was', ' good', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.96 Prob: 19.68% Token: | good|
Top 1th token. Logit: 31.32 Prob: 10.45% Token: | well|
Top 2th token. Logit: 31.09 Prob:  8.27% Token: | enjoyable|
Top 3th token. Logit: 30.61 Prob:  5.11% Token: | fun|
Top 4th token. Logit: 30.59 Prob:  5.03% Token: | entertaining|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' miserable', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' offensive', ',', ' the', ' plot', ' was', ' terrible', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.67 Prob: 13.67% Token: | bad|
Top 1th token. Logit: 30.98 Prob:  6.83% Token: | boring|
Top 2th token. Logit: 30.95 Prob:  6.60% Token: | poor|
Top 3th token. Logit: 30.55 Prob:  4.45% Token: | dull|
Top 4th token. Logit: 30.20 Prob:  3.12% Token: | depressing|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' fantastic', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' good', ',', ' the', ' plot', ' was', ' great', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 32.05 Prob: 20.33% Token: | good|
Top 1th token. Logit: 31.42 Prob: 10.91% Token: | well|
Top 2th token. Logit: 31.26 Prob:  9.24% Token: | enjoyable|
Top 3th token. Logit: 30.64 Prob:  5.00% Token: | entertaining|
Top 4th token. Logit: 30.48 Prob:  4.23% Token: | fun|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' offensive', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' terrible', ',', ' the', ' plot', ' was', ' unpleasant', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.45 Prob: 12.41% Token: | bad|
Top 1th token. Logit: 30.07 Prob:  3.13% Token: | offensive|
Top 2th token. Logit: 30.02 Prob:  2.99% Token: | poorly|
Top 3th token. Logit: 30.01 Prob:  2.96% Token: | boring|
Top 4th token. Logit: 29.95 Prob:  2.78% Token: |,|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' good', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' great', ',', ' the', ' plot', ' was', ' incredible', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 32.44 Prob: 24.96% Token: | good|
Top 1th token. Logit: 31.52 Prob:  9.99% Token: | well|
Top 2th token. Logit: 31.40 Prob:  8.79% Token: | enjoyable|
Top 3th token. Logit: 30.79 Prob:  4.81% Token: | fun|
Top 4th token. Logit: 30.72 Prob:  4.49% Token: | entertaining|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' terrible', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' unpleasant', ',', ' the', ' plot', ' was', ' wretched', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 32.15 Prob: 19.42% Token: | bad|
Top 1th token. Logit: 31.25 Prob:  7.86% Token: | poor|
Top 2th token. Logit: 30.72 Prob:  4.65% Token: | boring|
Top 3th token. Logit: 30.66 Prob:  4.39% Token: | dull|
Top 4th token. Logit: 30.65 Prob:  4.32% Token: | good|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' great', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' incredible', ',', ' the', ' plot', ' was', ' lovely', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.91 Prob: 17.65% Token: | good|
Top 1th token. Logit: 31.30 Prob:  9.57% Token: | well|
Top 2th token. Logit: 31.19 Prob:  8.61% Token: | enjoyable|
Top 3th token. Logit: 30.78 Prob:  5.69% Token: | fun|
Top 4th token. Logit: 30.55 Prob:  4.54% Token: | entertaining|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' unpleasant', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' wretched', ',', ' the', ' plot', ' was', ' awful', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.53 Prob: 13.50% Token: | bad|
Top 1th token. Logit: 30.71 Prob:  5.94% Token: | dull|
Top 2th token. Logit: 30.65 Prob:  5.58% Token: | boring|
Top 3th token. Logit: 30.23 Prob:  3.68% Token: |,|
Top 4th token. Logit: 30.06 Prob:  3.11% Token: | poor|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' incredible', ',', ' I', ' loved', ' it', '.', ' The', ' acting', ' was', ' lovely', ',', ' the', ' plot', ' was', ' outstanding', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.53 Prob: 13.50% Token: | good|
Top 1th token. Logit: 31.21 Prob:  9.87% Token: | well|
Top 2th token. Logit: 30.92 Prob:  7.35% Token: | enjoyable|
Top 3th token. Logit: 30.62 Prob:  5.47% Token: | entertaining|
Top 4th token. Logit: 30.22 Prob:  3.65% Token: | fun|


Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' wretched', ',', ' I', ' hated', ' it', '.', ' The', ' acting', ' was', ' awful', ',', ' the', ' plot', ' was', ' bad', ',', ' and', ' overall', ' it', ' was', ' just', ' very']
Tokenized answer: [' bad']


Top 0th token. Logit: 31.44 Prob: 12.08% Token: | bad|
Top 1th token. Logit: 31.07 Prob:  8.35% Token: | poor|
Top 2th token. Logit: 30.85 Prob:  6.68% Token: | dull|
Top 3th token. Logit: 30.51 Prob:  4.75% Token: | boring|
Top 4th token. Logit: 30.14 Prob:  3.30% Token: |,|


In [None]:
ds.ans

## SST

In [10]:
import re
import random
from torch.utils.data import DataLoader
from datasets import Dataset, concatenate_datasets, load_from_disk
from transformers import AutoTokenizer

model.name = "EleutherAI/pythia-2.8b"


def filter_function(example, model):
    prompt = model.to_tokens(example['text'] + " Review Sentiment:", prepend_bos=False)
    answer = torch.tensor([29071, 32725]).unsqueeze(0).unsqueeze(0).to(device) if example['label'] == 1 else torch.tensor([32725, 29071]).unsqueeze(0).unsqueeze(0).to(device)
    #print(answer.shape)
    logits = model(prompt, return_type="logits")
    logit_diff = compute_logit_diff(logits, answer, mode="pairs")
    
    # Determine if the top answer (index 0) token is in top 10 logits
    _, top_indices = logits.topk(10, dim=-1)  # Get indices of top 10 logits
    top_answer_token = answer[0, 0, 0]  # Assuming answer is of shape (1, 1, 2) and the top answer token is at index 0
    is_top_answer_in_top_10_logits = (top_indices == top_answer_token).any()
    
    # Add a new field 'keep_example' to the example
    example['keep_example'] = (logit_diff > 0.0) and is_top_answer_in_top_10_logits.item()
    return example


def concatenate_classification_prompts(examples):
    return {"text": (examples['text'] + " Review Sentiment:")}


def get_final_pos_index(examples):
    return {'final_pos_index': examples["attention_mask"].sum() - 1}


def tokenize_function(examples, tokenizer):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)


def find_dataset_positions(example, token_id=13):
    # Create a tensor of zeros with the same shape as example['tokens']
    positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)

    # Find positions where tokens match the given token_id
    positions[example['tokens'] == token_id] = 1
    has_token = True if positions.sum() > 0 else False

    return {'positions': positions, 'has_token': has_token}


def convert_answers(example, pos_answer_id=29071, neg_answer_id=32725):
    if example['label'] == 1:
        answers = torch.tensor([pos_answer_id, neg_answer_id])
    else:
        answers = torch.tensor([neg_answer_id, pos_answer_id])

    return {'answers': answers}


def get_random_subset(dataset, n):
    total_size = len(dataset)
    random_indices = random.sample(range(total_size), n)
    return dataset.select(random_indices)


def prepare_sst_for_model(
        model: HookedTransformer,
        dataset_name: str = "sst2", 
        batch_size: int = 5,
        pad_token_id: int = 1, 
        pos_answer_id: int = 29071, 
        neg_answer_id: int = 32725
    ) -> Tuple[DataLoader, DataLoader, DataLoader]:
    # Define the batch size
    BATCH_SIZE = batch_size

    sst_data = load_from_disk(dataset_name)

    # Use the map function to apply the filter_function
    filter_function_for_model = partial(filter_function, model=model)
    sst_data_with_flag_train = sst_data['train'].map(filter_function_for_model, keep_in_memory=True)
    sst_data_with_flag_dev = sst_data['dev'].map(filter_function_for_model, keep_in_memory=True)
    sst_data_with_flag_test = sst_data['test'].map(filter_function_for_model, keep_in_memory=True)
    #sst_data_with_flag = concatenate_datasets([sst_data['train'], sst_data['dev'], sst_data['test']])
    sst_data_with_flag = concatenate_datasets([sst_data_with_flag_train, sst_data_with_flag_dev, sst_data_with_flag_test])
    #sst_data_with_flag = sst_data_with_flag_dev

    # Use the filter function to keep only the examples where 'keep_example' is True
    sst_zero_shot = sst_data_with_flag.filter(lambda x: x['keep_example'])
    # print number of items in dataset
    print(f"Number of items in dataset: {len(sst_zero_shot)}")
    # save dataset
    #new model name without slashes
    model_abbr = re.sub(r'/', '_', model.name)
    sst_zero_shot.save_to_disk(f"sst_zero_shot_{model_abbr}")

    # Load a tokenizer (you'll need to specify the appropriate model)
    tokenizer = AutoTokenizer.from_pretrained(model.name)
    # set padding token
    tokenizer.pad_token = model.to_string([pad_token_id])

    dataset = sst_zero_shot.map(concatenate_classification_prompts, batched=False)
    tokenizer_function_for_model = partial(tokenize_function, tokenizer=tokenizer)
    dataset = dataset.map(tokenizer_function_for_model, batched=False)
    
    convert_answers_for_model = partial(convert_answers, pos_answer_id=pos_answer_id, neg_answer_id=neg_answer_id)
    dataset = dataset.map(convert_answers_for_model, batched=False)
    dataset = dataset.rename_column("input_ids", "tokens")
    dataset.set_format(type="torch", columns=["tokens", "attention_mask", "label", "answers"])
    dataset = dataset.map(get_final_pos_index, batched=False)
    dataset = dataset.map(find_dataset_positions, batched=False)
    dataset = dataset.filter(lambda example: example['has_token']==True)

    # create a subset with only positive labels
    pos_dataset = dataset.filter(lambda example: example['label']==1)
    neg_dataset = dataset.filter(lambda example: example['label']==0)
    len(pos_dataset), len(neg_dataset)

    subset_size = (min(len(pos_dataset), len(neg_dataset)) // BATCH_SIZE) * BATCH_SIZE

    pos_subset = get_random_subset(pos_dataset, subset_size)
    neg_subset = get_random_subset(neg_dataset, subset_size)
    balanced_subset = concatenate_datasets([pos_subset, neg_subset])
    # randomize the order of balanced_subset
    balanced_subset = balanced_subset.shuffle(len(balanced_subset))

    balanced_subset.save_to_disk(f"sst_zero_shot_balanced_{model_abbr}")


    print(f"Number of items in pos dataset: {len(pos_subset)}")
    print(f"Number of items in neg dataset: {len(neg_subset)}")
    print(f"Number of items in balanced dataset: {len(balanced_subset)}")
    return pos_subset, neg_subset, balanced_subset


In [11]:
pos_ds, neg_ds, balanced_ds = prepare_sst_for_model(model, "data/sst2", 5, 1, 29071, 32725)

Map:   0%|          | 0/7864 [00:00<?, ? examples/s]

Map:   0%|          | 0/1007 [00:00<?, ? examples/s]

Map:   0%|          | 0/2058 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10929 [00:00<?, ? examples/s]

Number of items in dataset: 6169


Saving the dataset (0/1 shards):   0%|          | 0/6169 [00:00<?, ? examples/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/6169 [00:00<?, ? examples/s]

Map:   0%|          | 0/6169 [00:00<?, ? examples/s]

Map:   0%|          | 0/6169 [00:00<?, ? examples/s]

Map:   0%|          | 0/6169 [00:00<?, ? examples/s]

Map:   0%|          | 0/6169 [00:00<?, ? examples/s]

  positions = torch.zeros_like(torch.tensor(example['tokens']), dtype=torch.int)


Filter:   0%|          | 0/6169 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3318 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3318 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1390 [00:00<?, ? examples/s]

Number of items in pos dataset: 695
Number of items in neg dataset: 695
Number of items in balanced dataset: 1390


In [11]:
import re
import random
from torch.utils.data import DataLoader
from datasets import Dataset, concatenate_datasets, load_from_disk
from transformers import AutoTokenizer

ds = load_from_disk("sst_zero_shot_balanced_EleutherAI_pythia-2.8b")

# Turn all items in ['tokens'] into a single tensor
all_tokens = torch.cat([item['tokens'].unsqueeze(0) for item in ds], dim=0)
all_answers = torch.cat([item['answers'].unsqueeze(0) for item in ds], dim=0)
all_positions = torch.cat([item['final_pos_index'].unsqueeze(0) for item in ds], dim=0)

In [12]:
ds

Dataset({
    features: ['text', 'label', 'keep_example', 'tokens', 'attention_mask', 'answers', 'final_pos_index', 'positions', 'has_token'],
    num_rows: 1390
})

In [13]:
ds[0]['tokens'], ds[0]['answers'], ds[0]['final_pos_index'], ds[0]['tokens'][ds[0]['final_pos_index']]

(tensor([ 2042,   253,  1524,    14,  8206,  2926,   310, 27364, 39637,  1050,
            13,   253,  6440, 10576,    84,   441,   347,   973,    15,  8439,
         20580,  2092,    27,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1]),
 tensor([29071, 32725]),
 tensor(22),
 tensor(27))

In [7]:
all_tokens.shape, all_answers.shape, all_positions.shape

(torch.Size([1390, 64]), torch.Size([1390, 2]), torch.Size([1390]))

In [16]:
from utils.circuit_utils import run_with_batches

logits = run_with_batches(model, all_tokens[:1000].to(device), batch_size=10, max_seq_len=64)

In [17]:
from utils.metrics import compute_accuracy
compute_accuracy(logits, all_answers[:1000], positions=all_positions[:1000], mode="simple")

1.0

In [8]:
model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-2.8b",
    checkpoint_value=10000,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    dtype=torch.bfloat16,
    refactor_factored_attn_matrices=False,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


In [14]:
from utils.circuit_utils import run_with_batches
from utils.metrics import compute_accuracy
logits = run_with_batches(model, all_tokens[:100].to(device), batch_size=10, max_seq_len=64)
compute_accuracy(logits, all_answers[:100], positions=all_positions[:100], mode="simple")

0.6299999952316284