In [2]:
import io
from logging import warning
from typing import Union, List
from site import PREFIXES
import warnings
import numpy as np
from tqdm import tqdm
import pandas as pd
from transformers import AutoTokenizer
import random
import re
import matplotlib.pyplot as plt
import random as rd
import copy
import random
from typing import List, Union
from pathlib import Path
import torch
from transformer_lens import HookedTransformer

### GREATER-THAN TASK ###

device = 'cpu'

print("Greater-than task...")

models = ['gpt2-small', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']

for model_name in models:
    model = HookedTransformer.from_pretrained(model_name, device=device)

    # Turn grad off
    torch.set_grad_enabled(False)

    def generate_real_sentence(noun: str, year: int, eos: bool = False) -> str:
        century = year // 100
        sentence = f"The {noun} lasted from the year {year} to the year {century}"
        if eos:
            sentence = " " + sentence
        return sentence

    def real_sentence_prompt(eos: bool = False) -> List[str]:
        sentence = f"The NOUN lasted from the year XX1 YY to the year XX2".split()
        if eos:
            sentence = [""] + sentence
        return sentence

    def generate_bad_sentence(noun: str, year: int, eos: bool = False) -> str:
        century = year // 100
        #sentence = f"The {noun} lasted from the year {century}01 to the year {century-1}"
        sentence = f"The {noun} lasted from the year {year} to the year {century-1}"
        if eos:
            sentence = " " + sentence
        return sentence

    def bad_sentence_prompt(eos: bool = False) -> List[str]:
        sentence = f"The NOUN lasted from the year XX1 01 to the year XX2".split()
        if eos:
            sentence = [""] + sentence
        return sentence

    def is_valid_year(year: str, model) -> bool:
        _year = " " + year
        token = model.to_tokens(_year)
        detok = model.to_string(token)
        return len(detok) == 2 and len(detok[1]) == 2

    class YearDataset:
        years_to_sample_from: torch.Tensor
        N: int
        ordered: bool
        eos: bool

        nouns: List[str]
        years: torch.Tensor
        years_YY: torch.Tensor
        good_sentences: List[str]
        bad_sentences: List[str]
        good_toks: torch.Tensor
        bad_toks: torch.Tensor
        good_prompt: List[str]
        bad_prompt: List[str]
        good_mask: torch.Tensor
        model: HookedTransformer

        def __init__(
            self,
            years_to_sample_from,
            N: int,
            nouns: Union[str, List[str], Path],
            model: HookedTransformer,
            balanced: bool = True,
            eos: bool = False,
            device: str = "cpu",
        ):
            self.years_to_sample_from = years_to_sample_from
            self.N = N
            self.eos = eos
            self.model = model

            if isinstance(nouns, str):
                noun_list = [nouns]
            elif isinstance(nouns, list):
                noun_list = nouns
            elif isinstance(nouns, Path):
                with open(nouns, "r") as f:
                    noun_list = [line.strip() for line in f]
            else:
                raise ValueError(f"Got bad type of nouns: {type(nouns)}; for nouns: {nouns}")

            self.nouns = random.choices(noun_list, k=N)

            if balanced:
                years = []
                current_year = 2
                years_to_sample_from_YY = self.years_to_sample_from % 100
                for i in range(N):
                    sample_pool = self.years_to_sample_from[years_to_sample_from_YY == current_year]
                    years.append(sample_pool[random.randrange(len(sample_pool))])
                    current_year += 1
                    if current_year >= 99:
                        current_year -= 97
                self.years = torch.tensor(years)
            else:
                self.years = torch.tensor(self.years_to_sample_from[torch.randint(0, len(self.years_to_sample_from), (N,))])

            self.years_XX = self.years // 100
            self.years_YY = self.years % 100

            self.good_sentences = [
                generate_real_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
            ]
            self.bad_sentences = [
                generate_bad_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
            ]

            self.good_prompt = real_sentence_prompt(eos=eos)
            self.bad_prompt = bad_sentence_prompt(eos=eos)

        def __len__(self):
            return self.N

    # # Instantiate the model
    # model = HookedTransformer.from_pretrained("gpt2-small", device='cpu')

    # Define your nouns and years
    nouns = [
        "abduction", "accord", "affair", "agreement", "appraisal",
        "assaults", "assessment", "attack", "attempts", "campaign", 
        "captivity", "case", "challenge", "chaos", "clash", 
        "collaboration", "coma", "competition", "confrontation", "consequence", 
        "conspiracy", "construction", "consultation", "contact",
        "contract", "convention", "cooperation", "custody", "deal", 
        "decline", "decrease", "demonstrations", "development", "disagreement", 
        "disorder", "dispute", "domination", "dynasty", "effect", 
        "effort", "employment", "endeavor", "engagement",
        "epidemic", "evaluation", "exchange", "existence", "expansion", 
        "expedition", "experiments", "fall", "fame", "flights",
        "friendship", "growth", "hardship", "hostility", "illness", 
        "impact", "imprisonment", "improvement", "incarceration",
        "increase", "insurgency", "invasion", "investigation", "journey", 
    ]  # Example nouns list
    years_to_sample_from = torch.arange(1200, 2000)  # Example years range

    # Instantiate the YearDataset class
    dataset = YearDataset(
        years_to_sample_from=years_to_sample_from,
        N=10,  # Number of samples you want
        nouns=nouns,
        model=model,
        balanced=True,  # Whether to balance the years in the dataset
        eos=False,  # Whether to add an end-of-sentence token
        device="cpu"  # Device to use ('cpu' or 'cuda')
    )

    # Set up the model
    def prompt_to_resid_stream(prompt: str, model: HookedTransformer, resid_type: str = 'accumulated', position: str = 'last') -> torch.Tensor:
        """
        Convert a prompt to a residual stream of size (n_layers, d_model)
        """
        # Run the model over the prompt
        with torch.no_grad():
            _, cache = model.run_with_cache(prompt)

            # Get the accumulated residuals
            if resid_type == 'accumulated':
                resid, _ = cache.accumulated_resid(return_labels=True, apply_ln=True)
            elif resid_type == 'decomposed':
                resid, _ = cache.decompose_resid(return_labels=True)
            elif resid_type == 'heads':
                cache.compute_head_results()
                head_resid, head_labels = cache.stack_head_results(return_labels=True)
                #mlp_resid, mlp_labels = cache.decompose_resid(mode='mlp', incl_embeds=False, return_labels=True)
                # Combine
                # resid = torch.cat([head_resid, mlp_resid], dim=0)
                # labels = head_labels + mlp_labels
                resid = head_resid
                labels = head_labels
            else:
                raise ValueError("resid_type must be one of 'accumulated', 'decomposed', 'heads'")

        # POSITION
        if position == 'last':
            last_token_accum = resid[:, 0, -1, :]  # layer, batch, pos, d_model
        elif position == 'mean':
            last_token_accum = resid.mean(dim=2).squeeze()
        else:
            raise ValueError("position must be one of 'last', 'mean'")
        return last_token_accum, labels


    def all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='accumulated', position='mean'):
        """
        Convert all prompts and counterfactual prompts to residual streams
        """
        # Stack prompts and prompts cf
        resid_streams = []
        # Combine the lists of strs
        all_prompts = prompts + prompts_cf
        for i in tqdm(range(len(all_prompts))):
            prompt = all_prompts[i]
            resid_stream, labels = prompt_to_resid_stream(prompt, model, resid_type, position)
            resid_streams.append(resid_stream)
        # Stack the residual streams into a single tensor
        return torch.stack(resid_streams), labels

    prompts = dataset.good_sentences
    prompts_cf = dataset.bad_sentences

    resid_streams, labels = all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='heads', position='mean')

    ground_truth = [
        (0, 3), (0, 5), (0, 1), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1), # attention heads
    ]

    # Save to ..data/gt folder
    torch.save(resid_streams, f"../data/gt/model_sizes/resid_heads_mean_{model_name}.pt")
    all_prompts = prompts + prompts_cf
    torch.save(all_prompts, f"../data/gt/model_sizes/prompts_heads_mean_{model_name}.pt")
    torch.save(labels, f"../data/gt/model_sizes/labels_heads_mean_{model_name}.pt")
    torch.save(ground_truth, "../data/gt/model_sizes/ground_truth.pt")

    # Flush torch cache
    torch.cuda.empty_cache()

    print(f"Greater-than task done for {model_name}.\n\n")

Greater-than task...
Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 20/20 [00:01<00:00, 16.94it/s]


Greater-than task done for gpt2-small.


Loaded pretrained model gpt2-medium into HookedTransformer


100%|██████████| 20/20 [00:03<00:00,  5.62it/s]


Greater-than task done for gpt2-medium.


Loaded pretrained model gpt2-large into HookedTransformer


100%|██████████| 20/20 [00:05<00:00,  3.63it/s]


Greater-than task done for gpt2-large.


Loaded pretrained model gpt2-xl into HookedTransformer


100%|██████████| 20/20 [00:08<00:00,  2.47it/s]

Greater-than task done for gpt2-xl.







In [3]:
device = 'cpu'

print("Greater-than task...")

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

def generate_real_sentence(noun: str, year: int, eos: bool = False) -> str:
    century = year // 100
    sentence = f"The {noun} lasted from the year {year} to the year {century}"
    if eos:
        sentence = " " + sentence
    return sentence

def real_sentence_prompt(eos: bool = False) -> List[str]:
    sentence = f"The NOUN lasted from the year XX1 YY to the year XX2".split()
    if eos:
        sentence = [""] + sentence
    return sentence

def generate_bad_sentence(noun: str, year: int, eos: bool = False) -> str:
    century = year // 100
    sentence = f"The {noun} lasted from the year {century}01 to the year {century}"
    #sentence = f"The {noun} lasted from the year {year} to the year {century-1}"
    if eos:
        sentence = " " + sentence
    return sentence

def bad_sentence_prompt(eos: bool = False) -> List[str]:
    sentence = f"The NOUN lasted from the year XX1 01 to the year XX2".split()
    if eos:
        sentence = [""] + sentence
    return sentence

def is_valid_year(year: str, model) -> bool:
    _year = " " + year
    token = model.to_tokens(_year)
    detok = model.to_string(token)
    return len(detok) == 2 and len(detok[1]) == 2

class YearDataset:
    years_to_sample_from: torch.Tensor
    N: int
    ordered: bool
    eos: bool

    nouns: List[str]
    years: torch.Tensor
    years_YY: torch.Tensor
    good_sentences: List[str]
    bad_sentences: List[str]
    good_toks: torch.Tensor
    bad_toks: torch.Tensor
    good_prompt: List[str]
    bad_prompt: List[str]
    good_mask: torch.Tensor
    model: HookedTransformer

    def __init__(
        self,
        years_to_sample_from,
        N: int,
        nouns: Union[str, List[str], Path],
        model: HookedTransformer,
        balanced: bool = True,
        eos: bool = False,
        device: str = "cpu",
    ):
        self.years_to_sample_from = years_to_sample_from
        self.N = N
        self.eos = eos
        self.model = model

        if isinstance(nouns, str):
            noun_list = [nouns]
        elif isinstance(nouns, list):
            noun_list = nouns
        elif isinstance(nouns, Path):
            with open(nouns, "r") as f:
                noun_list = [line.strip() for line in f]
        else:
            raise ValueError(f"Got bad type of nouns: {type(nouns)}; for nouns: {nouns}")

        self.nouns = random.choices(noun_list, k=N)

        if balanced:
            years = []
            current_year = 2
            years_to_sample_from_YY = self.years_to_sample_from % 100
            for i in range(N):
                sample_pool = self.years_to_sample_from[years_to_sample_from_YY == current_year]
                years.append(sample_pool[random.randrange(len(sample_pool))])
                current_year += 1
                if current_year >= 99:
                    current_year -= 97
            self.years = torch.tensor(years)
        else:
            self.years = torch.tensor(self.years_to_sample_from[torch.randint(0, len(self.years_to_sample_from), (N,))])

        self.years_XX = self.years // 100
        self.years_YY = self.years % 100

        self.good_sentences = [
            generate_real_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
        ]
        self.bad_sentences = [
            generate_bad_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
        ]

        self.good_prompt = real_sentence_prompt(eos=eos)
        self.bad_prompt = bad_sentence_prompt(eos=eos)

    def __len__(self):
        return self.N

# Instantiate the model
model = HookedTransformer.from_pretrained("gpt2-small", device='cpu')

# Define your nouns and years
nouns = [
    "abduction", "accord", "affair", "agreement", "appraisal",
    "assaults", "assessment", "attack", "attempts", "campaign", 
    "captivity", "case", "challenge", "chaos", "clash", 
    "collaboration", "coma", "competition", "confrontation", "consequence", 
    "conspiracy", "construction", "consultation", "contact",
    "contract", "convention", "cooperation", "custody", "deal", 
    "decline", "decrease", "demonstrations", "development", "disagreement", 
    "disorder", "dispute", "domination", "dynasty", "effect", 
    "effort", "employment", "endeavor", "engagement",
    "epidemic", "evaluation", "exchange", "existence", "expansion", 
    "expedition", "experiments", "fall", "fame", "flights",
    "friendship", "growth", "hardship", "hostility", "illness", 
    "impact", "imprisonment", "improvement", "incarceration",
    "increase", "insurgency", "invasion", "investigation", "journey", 
]  # Example nouns list
years_to_sample_from = torch.arange(1200, 2000)  # Example years range

# Instantiate the YearDataset class
dataset = YearDataset(
    years_to_sample_from=years_to_sample_from,
    N=250,  # Number of samples you want
    nouns=nouns,
    model=model,
    balanced=True,  # Whether to balance the years in the dataset
    eos=False,  # Whether to add an end-of-sentence token
    device="cpu"  # Device to use ('cpu' or 'cuda')
)

# Set up the model
def prompt_to_resid_stream(prompt: str, model: HookedTransformer, resid_type: str = 'accumulated', position: str = 'last') -> torch.Tensor:
    """
    Convert a prompt to a residual stream of size (n_layers, d_model)
    """
    # Run the model over the prompt
    with torch.no_grad():
        _, cache = model.run_with_cache(prompt)

        # Get the accumulated residuals
        if resid_type == 'accumulated':
            resid, _ = cache.accumulated_resid(return_labels=True, apply_ln=True)
        elif resid_type == 'decomposed':
            resid, _ = cache.decompose_resid(return_labels=True)
        elif resid_type == 'heads':
            cache.compute_head_results()
            head_resid, head_labels = cache.stack_head_results(return_labels=True)
            #mlp_resid, mlp_labels = cache.decompose_resid(mode='mlp', incl_embeds=False, return_labels=True)
            # Combine
            # resid = torch.cat([head_resid, mlp_resid], dim=0)
            # labels = head_labels + mlp_labels
            resid = head_resid
            labels = head_labels
        else:
            raise ValueError("resid_type must be one of 'accumulated', 'decomposed', 'heads'")

    # POSITION
    if position == 'last':
        last_token_accum = resid[:, 0, -1, :]  # layer, batch, pos, d_model
    elif position == 'mean':
        last_token_accum = resid.mean(dim=2).squeeze()
    else:
        raise ValueError("position must be one of 'last', 'mean'")
    return last_token_accum, labels


def all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='accumulated', position='mean'):
    """
    Convert all prompts and counterfactual prompts to residual streams
    """
    # Stack prompts and prompts cf
    resid_streams = []
    # Combine the lists of strs
    all_prompts = prompts + prompts_cf
    for i in tqdm(range(len(all_prompts))):
        prompt = all_prompts[i]
        resid_stream, labels = prompt_to_resid_stream(prompt, model, resid_type, position)
        resid_streams.append(resid_stream)
    # Stack the residual streams into a single tensor
    return torch.stack(resid_streams), labels

prompts = dataset.good_sentences
prompts_cf = dataset.bad_sentences

resid_streams, labels = all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='heads', position='mean')

ground_truth = [
    (0, 3), (0, 5), (0, 1), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1), # attention heads
]

# Save to ..data/gt folder
torch.save(resid_streams, "../data/gt/easy_negs/resid_heads_mean_2.pt")
all_prompts = prompts + prompts_cf
torch.save(all_prompts, "../data/gt/easy_negs/prompts_heads_mean.pt")
torch.save(labels, "../data/gt/easy_negs/labels_heads_mean.pt")
torch.save(ground_truth, "../data/gt/easy_negs/ground_truth.pt")


print("Greater-than task done.\n\n")

Greater-than task...
Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 500/500 [00:23<00:00, 21.00it/s]


Greater-than task done.




In [4]:
device = 'cpu'

print("Greater-than task...")

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

def generate_real_sentence(noun: str, year: int, eos: bool = False) -> str:
    century = year // 100
    sentence = f"The {noun} lasted from the year {year} to the year {century}"
    if eos:
        sentence = " " + sentence
    return sentence

def real_sentence_prompt(eos: bool = False) -> List[str]:
    sentence = f"The NOUN lasted from the year XX1 YY to the year XX2".split()
    if eos:
        sentence = [""] + sentence
    return sentence

def generate_bad_sentence(noun: str, year: int, eos: bool = False) -> str:
    century = year // 100
    sentence = f"The {noun} lasted from the year {century}AB to the year {century}"
    if eos:
        sentence = " " + sentence
    return sentence

def bad_sentence_prompt(eos: bool = False) -> List[str]:
    sentence = f"The NOUN lasted from the year XX1 01 to the year XX2".split()
    if eos:
        sentence = [""] + sentence
    return sentence

def is_valid_year(year: str, model) -> bool:
    _year = " " + year
    token = model.to_tokens(_year)
    detok = model.to_string(token)
    return len(detok) == 2 and len(detok[1]) == 2

class YearDataset:
    years_to_sample_from: torch.Tensor
    N: int
    ordered: bool
    eos: bool

    nouns: List[str]
    years: torch.Tensor
    years_YY: torch.Tensor
    good_sentences: List[str]
    bad_sentences: List[str]
    good_toks: torch.Tensor
    bad_toks: torch.Tensor
    good_prompt: List[str]
    bad_prompt: List[str]
    good_mask: torch.Tensor
    model: HookedTransformer

    def __init__(
        self,
        years_to_sample_from,
        N: int,
        nouns: Union[str, List[str], Path],
        model: HookedTransformer,
        balanced: bool = True,
        eos: bool = False,
        device: str = "cpu",
    ):
        self.years_to_sample_from = years_to_sample_from
        self.N = N
        self.eos = eos
        self.model = model

        if isinstance(nouns, str):
            noun_list = [nouns]
        elif isinstance(nouns, list):
            noun_list = nouns
        elif isinstance(nouns, Path):
            with open(nouns, "r") as f:
                noun_list = [line.strip() for line in f]
        else:
            raise ValueError(f"Got bad type of nouns: {type(nouns)}; for nouns: {nouns}")

        self.nouns = random.choices(noun_list, k=N)

        if balanced:
            years = []
            current_year = 2
            years_to_sample_from_YY = self.years_to_sample_from % 100
            for i in range(N):
                sample_pool = self.years_to_sample_from[years_to_sample_from_YY == current_year]
                years.append(sample_pool[random.randrange(len(sample_pool))])
                current_year += 1
                if current_year >= 99:
                    current_year -= 97
            self.years = torch.tensor(years)
        else:
            self.years = torch.tensor(self.years_to_sample_from[torch.randint(0, len(self.years_to_sample_from), (N,))])

        self.years_XX = self.years // 100
        self.years_YY = self.years % 100

        self.good_sentences = [
            generate_real_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
        ]
        self.bad_sentences = [
            generate_bad_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
        ]

        self.good_prompt = real_sentence_prompt(eos=eos)
        self.bad_prompt = bad_sentence_prompt(eos=eos)

    def __len__(self):
        return self.N

# Instantiate the model
model = HookedTransformer.from_pretrained("gpt2-small", device='cpu')

# Define your nouns and years
nouns = [
    "abduction", "accord", "affair", "agreement", "appraisal",
    "assaults", "assessment", "attack", "attempts", "campaign", 
    "captivity", "case", "challenge", "chaos", "clash", 
    "collaboration", "coma", "competition", "confrontation", "consequence", 
    "conspiracy", "construction", "consultation", "contact",
    "contract", "convention", "cooperation", "custody", "deal", 
    "decline", "decrease", "demonstrations", "development", "disagreement", 
    "disorder", "dispute", "domination", "dynasty", "effect", 
    "effort", "employment", "endeavor", "engagement",
    "epidemic", "evaluation", "exchange", "existence", "expansion", 
    "expedition", "experiments", "fall", "fame", "flights",
    "friendship", "growth", "hardship", "hostility", "illness", 
    "impact", "imprisonment", "improvement", "incarceration",
    "increase", "insurgency", "invasion", "investigation", "journey", 
]  # Example nouns list
years_to_sample_from = torch.arange(1200, 2000)  # Example years range

# Instantiate the YearDataset class
dataset = YearDataset(
    years_to_sample_from=years_to_sample_from,
    N=250,  # Number of samples you want
    nouns=nouns,
    model=model,
    balanced=True,  # Whether to balance the years in the dataset
    eos=False,  # Whether to add an end-of-sentence token
    device="cpu"  # Device to use ('cpu' or 'cuda')
)

# Set up the model
def prompt_to_resid_stream(prompt: str, model: HookedTransformer, resid_type: str = 'accumulated', position: str = 'last') -> torch.Tensor:
    """
    Convert a prompt to a residual stream of size (n_layers, d_model)
    """
    # Run the model over the prompt
    with torch.no_grad():
        _, cache = model.run_with_cache(prompt)

        # Get the accumulated residuals
        if resid_type == 'accumulated':
            resid, _ = cache.accumulated_resid(return_labels=True, apply_ln=True)
        elif resid_type == 'decomposed':
            resid, _ = cache.decompose_resid(return_labels=True)
        elif resid_type == 'heads':
            cache.compute_head_results()
            head_resid, head_labels = cache.stack_head_results(return_labels=True)
            #mlp_resid, mlp_labels = cache.decompose_resid(mode='mlp', incl_embeds=False, return_labels=True)
            # Combine
            # resid = torch.cat([head_resid, mlp_resid], dim=0)
            # labels = head_labels + mlp_labels
            resid = head_resid
            labels = head_labels
        else:
            raise ValueError("resid_type must be one of 'accumulated', 'decomposed', 'heads'")

    # POSITION
    if position == 'last':
        last_token_accum = resid[:, 0, -1, :]  # layer, batch, pos, d_model
    elif position == 'mean':
        last_token_accum = resid.mean(dim=2).squeeze()
    else:
        raise ValueError("position must be one of 'last', 'mean'")
    return last_token_accum, labels


def all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='accumulated', position='mean'):
    """
    Convert all prompts and counterfactual prompts to residual streams
    """
    # Stack prompts and prompts cf
    resid_streams = []
    # Combine the lists of strs
    all_prompts = prompts + prompts_cf
    for i in tqdm(range(len(all_prompts))):
        prompt = all_prompts[i]
        resid_stream, labels = prompt_to_resid_stream(prompt, model, resid_type, position)
        resid_streams.append(resid_stream)
    # Stack the residual streams into a single tensor
    return torch.stack(resid_streams), labels

prompts = dataset.good_sentences
prompts_cf = dataset.bad_sentences

resid_streams, labels = all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='heads', position='mean')

ground_truth = [
    (0, 3), (0, 5), (0, 1), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1), # attention heads
]

# Save to ..data/gt folder
torch.save(resid_streams, "../data/gt/easy_negs/resid_heads_mean_3.pt")
all_prompts = prompts + prompts_cf
torch.save(all_prompts, "../data/gt/easy_negs/prompts_heads_mean.pt")
torch.save(labels, "../data/gt/easy_negs/labels_heads_mean.pt")
torch.save(ground_truth, "../data/gt/easy_negs/ground_truth.pt")


print("Greater-than task done.\n\n")

Greater-than task...
Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 500/500 [00:23<00:00, 21.15it/s]


Greater-than task done.




In [14]:
device = 'cpu'

print("Greater-than task...")

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

def generate_real_sentence(noun: str, year: int, eos: bool = False) -> str:
    century = year // 100
    sentence = f"The {noun} lasted from the year {year} to the year {century}"
    if eos:
        sentence = " " + sentence
    return sentence

def real_sentence_prompt(eos: bool = False) -> List[str]:
    sentence = f"The NOUN lasted from the year XX1 YY to the year XX2".split()
    if eos:
        sentence = [""] + sentence
    return sentence

def generate_bad_sentence(noun: str, year: int, eos: bool = False) -> str:
    century = year // 100
    sentence = f"The {noun} lasted from the year {year+1} to the year {century}"
    if eos:
        sentence = " " + sentence
    return sentence


def bad_sentence_prompt(eos: bool = False) -> List[str]:
    sentence = f"The NOUN lasted from the year XX1 01 to the year XX2".split()
    if eos:
        sentence = [""] + sentence
    return sentence

def is_valid_year(year: str, model) -> bool:
    _year = " " + year
    token = model.to_tokens(_year)
    detok = model.to_string(token)
    return len(detok) == 2 and len(detok[1]) == 2

class YearDataset:
    years_to_sample_from: torch.Tensor
    N: int
    ordered: bool
    eos: bool

    nouns: List[str]
    years: torch.Tensor
    years_YY: torch.Tensor
    good_sentences: List[str]
    bad_sentences: List[str]
    good_toks: torch.Tensor
    bad_toks: torch.Tensor
    good_prompt: List[str]
    bad_prompt: List[str]
    good_mask: torch.Tensor
    model: HookedTransformer

    def __init__(
        self,
        years_to_sample_from,
        N: int,
        nouns: Union[str, List[str], Path],
        model: HookedTransformer,
        balanced: bool = True,
        eos: bool = False,
        device: str = "cpu",
    ):
        self.years_to_sample_from = years_to_sample_from
        self.N = N
        self.eos = eos
        self.model = model

        if isinstance(nouns, str):
            noun_list = [nouns]
        elif isinstance(nouns, list):
            noun_list = nouns
        elif isinstance(nouns, Path):
            with open(nouns, "r") as f:
                noun_list = [line.strip() for line in f]
        else:
            raise ValueError(f"Got bad type of nouns: {type(nouns)}; for nouns: {nouns}")

        self.nouns = random.choices(noun_list, k=N)

        if balanced:
            years = []
            current_year = 2
            years_to_sample_from_YY = self.years_to_sample_from % 100
            for i in range(N):
                sample_pool = self.years_to_sample_from[years_to_sample_from_YY == current_year]
                years.append(sample_pool[random.randrange(len(sample_pool))])
                current_year += 1
                if current_year >= 99:
                    current_year -= 97
            self.years = torch.tensor(years)
        else:
            self.years = torch.tensor(self.years_to_sample_from[torch.randint(0, len(self.years_to_sample_from), (N,))])

        self.years_XX = self.years // 100
        self.years_YY = self.years % 100

        self.good_sentences = [
            generate_real_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
        ]
        self.bad_sentences = [
            generate_bad_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
        ]

        self.good_prompt = real_sentence_prompt(eos=eos)
        self.bad_prompt = bad_sentence_prompt(eos=eos)

    def __len__(self):
        return self.N

# Instantiate the model
model = HookedTransformer.from_pretrained("gpt2-small", device='cpu')

# Define your nouns and years
nouns = [
    "abduction", "accord", "affair", "agreement", "appraisal",
    "assaults", "assessment", "attack", "attempts", "campaign", 
    "captivity", "case", "challenge", "chaos", "clash", 
    "collaboration", "coma", "competition", "confrontation", "consequence", 
    "conspiracy", "construction", "consultation", "contact",
    "contract", "convention", "cooperation", "custody", "deal", 
    "decline", "decrease", "demonstrations", "development", "disagreement", 
    "disorder", "dispute", "domination", "dynasty", "effect", 
    "effort", "employment", "endeavor", "engagement",
    "epidemic", "evaluation", "exchange", "existence", "expansion", 
    "expedition", "experiments", "fall", "fame", "flights",
    "friendship", "growth", "hardship", "hostility", "illness", 
    "impact", "imprisonment", "improvement", "incarceration",
    "increase", "insurgency", "invasion", "investigation", "journey", 
]  # Example nouns list
years_to_sample_from = torch.arange(1200, 2000)  # Example years range

# Instantiate the YearDataset class
dataset = YearDataset(
    years_to_sample_from=years_to_sample_from,
    N=250,  # Number of samples you want
    nouns=nouns,
    model=model,
    balanced=True,  # Whether to balance the years in the dataset
    eos=False,  # Whether to add an end-of-sentence token
    device="cpu"  # Device to use ('cpu' or 'cuda')
)

# Set up the model
def prompt_to_resid_stream(prompt: str, model: HookedTransformer, resid_type: str = 'accumulated', position: str = 'last') -> torch.Tensor:
    """
    Convert a prompt to a residual stream of size (n_layers, d_model)
    """
    # Run the model over the prompt
    with torch.no_grad():
        _, cache = model.run_with_cache(prompt)

        # Get the accumulated residuals
        if resid_type == 'accumulated':
            resid, _ = cache.accumulated_resid(return_labels=True, apply_ln=True)
        elif resid_type == 'decomposed':
            resid, _ = cache.decompose_resid(return_labels=True)
        elif resid_type == 'heads':
            cache.compute_head_results()
            head_resid, head_labels = cache.stack_head_results(return_labels=True)
            #mlp_resid, mlp_labels = cache.decompose_resid(mode='mlp', incl_embeds=False, return_labels=True)
            # Combine
            # resid = torch.cat([head_resid, mlp_resid], dim=0)
            # labels = head_labels + mlp_labels
            resid = head_resid
            labels = head_labels
        else:
            raise ValueError("resid_type must be one of 'accumulated', 'decomposed', 'heads'")

    # POSITION
    if position == 'last':
        last_token_accum = resid[:, 0, -1, :]  # layer, batch, pos, d_model
    elif position == 'mean':
        last_token_accum = resid.mean(dim=2).squeeze()
    else:
        raise ValueError("position must be one of 'last', 'mean'")
    return last_token_accum, labels


def all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='accumulated', position='mean'):
    """
    Convert all prompts and counterfactual prompts to residual streams
    """
    # Stack prompts and prompts cf
    resid_streams = []
    # Combine the lists of strs
    all_prompts = prompts + prompts_cf
    for i in tqdm(range(len(all_prompts))):
        prompt = all_prompts[i]
        resid_stream, labels = prompt_to_resid_stream(prompt, model, resid_type, position)
        resid_streams.append(resid_stream)
    # Stack the residual streams into a single tensor
    return torch.stack(resid_streams), labels

prompts = dataset.good_sentences
prompts_cf = dataset.bad_sentences

resid_streams, labels = all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='heads', position='mean')

ground_truth = [
    (0, 3), (0, 5), (0, 1), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1), # attention heads
]

# Save to ..data/gt folder
torch.save(resid_streams, "../data/gt/easy_negs/resid_heads_mean_4.pt")
all_prompts = prompts + prompts_cf
torch.save(all_prompts, "../data/gt/easy_negs/prompts_heads_mean.pt")
torch.save(labels, "../data/gt/easy_negs/labels_heads_mean.pt")
torch.save(ground_truth, "../data/gt/easy_negs/ground_truth.pt")


print("Greater-than task done.\n\n")

Greater-than task...
Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 500/500 [00:22<00:00, 22.35it/s]


Greater-than task done.




In [13]:
print(prompts_cf[0], prompts[0])

The attack lasted from the year 1702 to the year 17 The attack lasted from the year 1702 to the year 17


In [10]:
device = 'cpu'

print("Greater-than task...")

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

def generate_real_sentence(noun: str, year: int, eos: bool = False) -> str:
    century = year // 100
    sentence = f"The {noun} lasted from the year {year} to the year {century}"
    if eos:
        sentence = " " + sentence
    return sentence

def real_sentence_prompt(eos: bool = False) -> List[str]:
    sentence = f"The NOUN lasted from the year XX1 YY to the year XX2".split()
    if eos:
        sentence = [""] + sentence
    return sentence

def generate_bad_sentence(noun: str, year: int, eos: bool = False) -> str:
    century = year // 100
    sentence = f"I've got a lovely bunch of {noun}"
    if eos:
        sentence = " " + sentence
    return sentence


def bad_sentence_prompt(eos: bool = False) -> List[str]:
    sentence = f"The NOUN lasted from the year XX1 01 to the year XX2".split()
    if eos:
        sentence = [""] + sentence
    return sentence

def is_valid_year(year: str, model) -> bool:
    _year = " " + year
    token = model.to_tokens(_year)
    detok = model.to_string(token)
    return len(detok) == 2 and len(detok[1]) == 2

class YearDataset:
    years_to_sample_from: torch.Tensor
    N: int
    ordered: bool
    eos: bool

    nouns: List[str]
    years: torch.Tensor
    years_YY: torch.Tensor
    good_sentences: List[str]
    bad_sentences: List[str]
    good_toks: torch.Tensor
    bad_toks: torch.Tensor
    good_prompt: List[str]
    bad_prompt: List[str]
    good_mask: torch.Tensor
    model: HookedTransformer

    def __init__(
        self,
        years_to_sample_from,
        N: int,
        nouns: Union[str, List[str], Path],
        model: HookedTransformer,
        balanced: bool = True,
        eos: bool = False,
        device: str = "cpu",
    ):
        self.years_to_sample_from = years_to_sample_from
        self.N = N
        self.eos = eos
        self.model = model

        if isinstance(nouns, str):
            noun_list = [nouns]
        elif isinstance(nouns, list):
            noun_list = nouns
        elif isinstance(nouns, Path):
            with open(nouns, "r") as f:
                noun_list = [line.strip() for line in f]
        else:
            raise ValueError(f"Got bad type of nouns: {type(nouns)}; for nouns: {nouns}")

        self.nouns = random.choices(noun_list, k=N)

        if balanced:
            years = []
            current_year = 2
            years_to_sample_from_YY = self.years_to_sample_from % 100
            for i in range(N):
                sample_pool = self.years_to_sample_from[years_to_sample_from_YY == current_year]
                years.append(sample_pool[random.randrange(len(sample_pool))])
                current_year += 1
                if current_year >= 99:
                    current_year -= 97
            self.years = torch.tensor(years)
        else:
            self.years = torch.tensor(self.years_to_sample_from[torch.randint(0, len(self.years_to_sample_from), (N,))])

        self.years_XX = self.years // 100
        self.years_YY = self.years % 100

        self.good_sentences = [
            generate_real_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
        ]
        self.bad_sentences = [
            generate_bad_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years)
        ]

        self.good_prompt = real_sentence_prompt(eos=eos)
        self.bad_prompt = bad_sentence_prompt(eos=eos)

    def __len__(self):
        return self.N

# Instantiate the model
model = HookedTransformer.from_pretrained("gpt2-small", device='cpu')

# Define your nouns and years
nouns = [
    "abduction", "accord", "affair", "agreement", "appraisal",
    "assaults", "assessment", "attack", "attempts", "campaign", 
    "captivity", "case", "challenge", "chaos", "clash", 
    "collaboration", "coma", "competition", "confrontation", "consequence", 
    "conspiracy", "construction", "consultation", "contact",
    "contract", "convention", "cooperation", "custody", "deal", 
    "decline", "decrease", "demonstrations", "development", "disagreement", 
    "disorder", "dispute", "domination", "dynasty", "effect", 
    "effort", "employment", "endeavor", "engagement",
    "epidemic", "evaluation", "exchange", "existence", "expansion", 
    "expedition", "experiments", "fall", "fame", "flights",
    "friendship", "growth", "hardship", "hostility", "illness", 
    "impact", "imprisonment", "improvement", "incarceration",
    "increase", "insurgency", "invasion", "investigation", "journey", 
]  # Example nouns list
years_to_sample_from = torch.arange(1200, 2000)  # Example years range

# Instantiate the YearDataset class
dataset = YearDataset(
    years_to_sample_from=years_to_sample_from,
    N=250,  # Number of samples you want
    nouns=nouns,
    model=model,
    balanced=True,  # Whether to balance the years in the dataset
    eos=False,  # Whether to add an end-of-sentence token
    device="cpu"  # Device to use ('cpu' or 'cuda')
)

# Set up the model
def prompt_to_resid_stream(prompt: str, model: HookedTransformer, resid_type: str = 'accumulated', position: str = 'last') -> torch.Tensor:
    """
    Convert a prompt to a residual stream of size (n_layers, d_model)
    """
    # Run the model over the prompt
    with torch.no_grad():
        _, cache = model.run_with_cache(prompt)

        # Get the accumulated residuals
        if resid_type == 'accumulated':
            resid, _ = cache.accumulated_resid(return_labels=True, apply_ln=True)
        elif resid_type == 'decomposed':
            resid, _ = cache.decompose_resid(return_labels=True)
        elif resid_type == 'heads':
            cache.compute_head_results()
            head_resid, head_labels = cache.stack_head_results(return_labels=True)
            #mlp_resid, mlp_labels = cache.decompose_resid(mode='mlp', incl_embeds=False, return_labels=True)
            # Combine
            # resid = torch.cat([head_resid, mlp_resid], dim=0)
            # labels = head_labels + mlp_labels
            resid = head_resid
            labels = head_labels
        else:
            raise ValueError("resid_type must be one of 'accumulated', 'decomposed', 'heads'")

    # POSITION
    if position == 'last':
        last_token_accum = resid[:, 0, -1, :]  # layer, batch, pos, d_model
    elif position == 'mean':
        last_token_accum = resid.mean(dim=2).squeeze()
    else:
        raise ValueError("position must be one of 'last', 'mean'")
    return last_token_accum, labels


def all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='accumulated', position='mean'):
    """
    Convert all prompts and counterfactual prompts to residual streams
    """
    # Stack prompts and prompts cf
    resid_streams = []
    # Combine the lists of strs
    all_prompts = prompts + prompts_cf
    for i in tqdm(range(len(all_prompts))):
        prompt = all_prompts[i]
        resid_stream, labels = prompt_to_resid_stream(prompt, model, resid_type, position)
        resid_streams.append(resid_stream)
    # Stack the residual streams into a single tensor
    return torch.stack(resid_streams), labels

prompts = dataset.good_sentences
prompts_cf = dataset.bad_sentences

resid_streams, labels = all_prompts_to_resid_streams_gt(prompts, prompts_cf, model, resid_type='heads', position='mean')

ground_truth = [
    (0, 3), (0, 5), (0, 1), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1), # attention heads
]

# Save to ..data/gt folder
torch.save(resid_streams, "../data/gt/easy_negs/resid_heads_mean_5.pt")
all_prompts = prompts + prompts_cf
torch.save(all_prompts, "../data/gt/easy_negs/prompts_heads_mean.pt")
torch.save(labels, "../data/gt/easy_negs/labels_heads_mean.pt")
torch.save(ground_truth, "../data/gt/easy_negs/ground_truth.pt")


print("Greater-than task done.\n\n")

Greater-than task...
Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 500/500 [00:23<00:00, 21.40it/s]


Greater-than task done.




In [11]:
print(prompts_cf)

["I've got a lovely bunch of domination", "I've got a lovely bunch of effect", "I've got a lovely bunch of consultation", "I've got a lovely bunch of attack", "I've got a lovely bunch of fall", "I've got a lovely bunch of exchange", "I've got a lovely bunch of hostility", "I've got a lovely bunch of domination", "I've got a lovely bunch of increase", "I've got a lovely bunch of conspiracy", "I've got a lovely bunch of consultation", "I've got a lovely bunch of journey", "I've got a lovely bunch of appraisal", "I've got a lovely bunch of captivity", "I've got a lovely bunch of captivity", "I've got a lovely bunch of existence", "I've got a lovely bunch of insurgency", "I've got a lovely bunch of confrontation", "I've got a lovely bunch of flights", "I've got a lovely bunch of challenge", "I've got a lovely bunch of consequence", "I've got a lovely bunch of cooperation", "I've got a lovely bunch of hostility", "I've got a lovely bunch of campaign", "I've got a lovely bunch of captivity",