In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

import torch
torch.__version__, torch.version.cuda

('1.13.1+cu117', '11.7')

In [3]:
from src import models, data, lens, functional
from src.utils import experiment_utils

import logging
from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

In [4]:
device = "cuda:0"
mt = models.load_model("mamba-3b", device=device, fp16=False)
# mt = models.load_model("gptj", device=device, fp16=True)

2024-02-22 11:46:22 src.models INFO     loading state-spaces/mamba-2.8b-slimpj (device=cuda:0, fp16=False)
2024-02-22 11:46:22 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2024-02-22 11:46:22 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/config.json HTTP/1.1" 200 0
2024-02-22 11:46:22 mamba.mamba_ssm.models.mixer_seq_simple INFO     {'d_model': 2560, 'n_layer': 64, 'vocab_size': 50277, 'ssm_cfg': {}, 'rms_norm': True, 'residual_in_fp32': True, 'fused_add_norm': False, 'pad_vocab_size_multiple': 8}
2024-02-22 11:46:40 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
2024-02-22 11:46:49 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-02-22 11:46:49 src.models INFO     dtype: torch.float32, device: cuda:0, memory: 10.31 GB


In [7]:
dataset = data.load_dataset()

2024-02-22 11:48:30 src.data DEBUG    no paths provided, using default data dir: /home/local_arnab/Codes/relations/notebooks/../data
2024-02-22 11:48:30 src.data DEBUG    /home/local_arnab/Codes/relations/notebooks/../data is directory, globbing for json files...
2024-02-22 11:48:30 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/characteristic_gender.json
2024-02-22 11:48:30 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/degree_gender.json
2024-02-22 11:48:30 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_birthplace.json
2024-02-22 11:48:30 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_gender.json
2024-02-22 11:48:30 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_religion.json
2024-02-22 11:48:30 src.data DEBUG    found relation 

In [8]:
relation_name = "name gender"
relation = dataset.filter(relation_names=[relation_name])[0]
print(f"{relation.name} -- {len(relation.samples)} samples")
print("------------------------------------------------------")

2024-02-22 11:48:33 src.data DEBUG    filtering to only relations: ['name gender']
name gender -- 19 samples
------------------------------------------------------


In [42]:
from src.sweeps import load_o1_approxes

training_subjects = [
    'Michael', 'Benjamin', 'Scarlett', 'Oliver', 'Tom'
]
training_subjects = [s.lower().replace(" ", "_") for s in training_subjects]

layer = 18

path = f"../results/cache_o1_approxes/{mt.name}/{relation_name.lower().replace(' ', '_')}/{layer}"

train_approxes = load_o1_approxes(
     path=path, sample_subjects=training_subjects
)

train_samples = [
    data.RelationSample.from_dict(approx.metadata["sample"])
    for approx in train_approxes
]
train_relation = relation.set(samples=train_samples)
test_relation = relation.set(
    samples=list(set(relation.samples) - set(train_relation.samples))
)

In [43]:
prompt_template = relation.prompt_templates[0]

In [44]:
weight = torch.stack(
    [approx.weight for approx in train_approxes]
).mean(dim=0)
bias = torch.stack([approx.bias for approx in train_approxes]).mean(
    dim=0
)
prompt_template_icl = functional.make_prompt(
    mt=mt,
    prompt_template=prompt_template,
    subject="{}",
    examples=train_samples,
)

from src.operators import LinearRelationOperator

operator = LinearRelationOperator(
    mt=mt,
    weight=weight,
    bias=bias,
    h_layer=train_approxes[0].h_layer,
    z_layer=train_approxes[0].z_layer,
    prompt_template=prompt_template_icl,
    metadata={
        "Jh": [
            (approx.weight @ approx.h).detach().cpu()
            for approx in train_approxes
        ],
        "|w|": [
            approx.weight.norm().item() for approx in train_approxes
        ],
        "|b|": [
            approx.bias.norm().item() for approx in train_approxes
        ],
    },
)

In [45]:
weight = operator.weight.clone()
svd = torch.svd(weight.float())

In [47]:
ranks = torch.arange(250, 260, 2)

from src import editors, metrics
from src.utils import experiment_utils

experiment_utils.set_seed(71745)

test_samples = test_relation.samples
test_targets = functional.random_edit_targets(test_samples)


hs_by_subj, zs_by_subj = functional.compute_hs_and_zs(
    mt=mt,
    prompt_template=prompt_template,
    subjects=[x.subject for x in test_relation.samples],
    h_layer=[layer],
    z_layer=-1,
    batch_size=4,
    examples=train_samples,
)


for rank in ranks:
    editor = editors.LowRankPInvEditor(
        lre=operator,
        rank=rank,
        n_samples=1,
        n_new_tokens=1,
        svd=svd,
    )

    pred_objects = []
    targ_objects = []
    efficacy_successes = []
    for sample in test_samples:
        target = test_targets.get(sample)
        assert target is not None
        if target is None:
            logger.debug(f"cannot edit {target}, skipping")
            continue

        z_original = zs_by_subj[sample.subject]
        z_target = zs_by_subj[target.subject]
        result = editor(
            sample.subject,
            target.subject,
            z_original=z_original,
            z_target=z_target,
        )

        pred = str(result.predicted_tokens[0])

        tick = "✗"
        if functional.is_nontrivial_prefix(
            prediction=result.predicted_tokens[0].token,
            target=target.object,
        ):
            tick = "✓"

        logger.debug(
            f"editing: {layer=} {rank=} {sample.subject=} | {target.subject=} -> {target.object=} |>> {pred=} ({tick})"
        )

        pred_objects.append([p.token for p in result.predicted_tokens])
        targ_objects.append(target.object)


    efficacy = metrics.recall(pred_objects, targ_objects)
    logger.info("-" * 80)
    logger.info(f"editing finished: {layer=} {rank=} {efficacy=}")
    logger.info("-" * 80)

2024-02-22 14:26:15 src.utils.experiment_utils INFO     setting all seeds to 71745


2024-02-22 14:26:15 src.editors DEBUG    computing low-rank pinv (rel=<|endoftext|> Michael is usually a name for a man
Benjamin is usually a name for a man
Scarlett is usually a name for a woman
Oliver is usually a name for a man
Tom is usually a name for a man
{} is usually a name for a)
2024-02-22 14:26:15 __main__ DEBUG    editing: layer=18 rank=tensor(250) sample.subject='Evan' | target.subject='Mia' -> target.object='woman' |>> pred="' woman' (p=0.920)" (✓)
2024-02-22 14:26:15 __main__ DEBUG    editing: layer=18 rank=tensor(250) sample.subject='Emily' | target.subject='Caleb' -> target.object='man' |>> pred="' man' (p=0.936)" (✓)
2024-02-22 14:26:16 __main__ DEBUG    editing: layer=18 rank=tensor(250) sample.subject='David' | target.subject='Sofia' -> target.object='woman' |>> pred="' woman' (p=0.929)" (✓)
2024-02-22 14:26:16 __main__ DEBUG    editing: layer=18 rank=tensor(250) sample.subject='William' | target.subject='Sofia' -> target.object='woman' |>> pred="' woman' (p=0.784)