In [None]:
import sys
sys.path.append('/raid/lingo/dez/code/neuron-descriptions/src/deps')
sys.path.append('/raid/lingo/dez/code/knowledge-fluidity')

In [None]:
import names_dataset

nd = names_dataset.NameDataset()
all_us_names = nd.get_top_names(n=100, country_alpha2='US')['US']
generic_us_names = [*all_us_names['M'], *all_us_names['F']]

In [None]:
import json

with open('../data/occupations-cleaned.json', 'r') as handle:
    entries = json.load(handle)
entities = sorted({entry['entity'] for entry in entries})
occupations = sorted({entry['occupation'] for entry in entries})
occupations, entities[:50]

In [None]:
import torch
import transformers

device = 'cuda:1'
config = 'EleutherAI/gpt-neo-125M'

model = transformers.AutoModelForCausalLM.from_pretrained(config)
model.eval().to(device)

tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from netdissect import nethook

from src.utils import tokenizers

import torch

def replace_entity_rep(start, end, reps, generating=False):
    def rule(args):
        incoming = args[0]
        ignore = generating and incoming.shape[1] == 1
        ignore |= not generating and incoming.shape[1] < end
        if ignore:
            return (*args,)
        incoming[:, start:end] = reps
        return (*args,)
    return rule


def run_model_with_reps(entity,
                        prompt=None,
                        occupations=occupations,
                        reps=None,
                        layer=None,
                        generate=False,
                        occurrence=0,
                        **kwargs):
    if prompt is None:
        prompts = [
            f'{entity} is best known as a {occupation}.<|endoftext|>'
            for occupation in occupations
        ]
        start, end = tokenizers.find_token_range(entity, entity, tokenizer)
    else:
        prompts = [prompt]
        start, end = tokenizers.find_token_range(prompt, entity, tokenizer, occurrence=occurrence)

    inputs = tokenizer(prompts, return_tensors='pt', padding='longest').to(device)
    with nethook.InstrumentedModel(model) as instr:
        if reps is not None:
            assert layer is not None
            instr.edit_layer(
                f'transformer.h.{layer}',
                rule=replace_entity_rep(start, end, reps, generating=generate))
        if generate:
            outputs = instr.model.generate(inputs.input_ids, **kwargs)
        else:
            outputs = instr(inputs.input_ids,
                            output_hidden_states=True,
                            return_dict=True,
                            **kwargs)
    return outputs, inputs, start, end


def get_model_prediction(entity, k=5, reps=None, layer=None, occupations=occupations, **kwargs):
    assert 'prompt' not in kwargs
    outputs, inputs, start, end = run_model_with_reps(entity,
                                                      reps=reps,
                                                      layer=layer,
                                                      occupations=occupations,
                                                      **kwargs)
    scores = torch.zeros(len(occupations), inputs.input_ids.shape[1] - 1, device=inputs.input_ids.device)
    for occ_id, (token_ids, logits) in enumerate(zip(inputs.input_ids, outputs.logits)):
        logps = torch.log_softmax(logits, dim=-1)
        for token_position, token_id in enumerate(token_ids[1:]):
            if token_id.item() in {
                    tokenizer.bos_token_id,
                    tokenizer.eos_token_id,
                    tokenizer.pad_token_id,
            }:
                continue
            scores[occ_id, token_position] = logps[token_position, token_id]

    # Find entity reps.
    hiddens = None
    if layer is not None:
        hiddens = outputs.hidden_states[layer][0, start:end]

    return [
        occupations[index]
        for index in scores.sum(dim=-1).topk(k=k).indices.tolist()
    ], hiddens, scores

In [None]:
import random

from torch import optim
from tqdm.auto import tqdm

# CONFIG
layer = 9
lr = 1e-1
steps = 50
k = 10
target = 'musician'

# REST; ignore!
entities = random.sample(entities, k=k)

for name, parameter in model.named_parameters():
    if 'transformer.h' not in name:
        continue
    l = int(name.split('.')[2])
    if l < layer:
        continue
    parameter.requires_grad_(True)

deltas = torch.rand(model.config.hidden_size, 1, device=device)
deltas.requires_grad_(True)
optimizer = optim.Adam((deltas,), lr=lr)

progress = tqdm(range(steps))
best, best_loss = deltas.clone(), float('inf')
for _ in progress:
    loss = 0.
    for entity in entities:             
        with torch.inference_mode():
            outputs, _, start, end = run_model_with_reps(entity,
                                                         prompt=entity,
                                                         layer=layer)
        originals = outputs.hidden_states[layer][0, start:end].clone().detach()

        proj = deltas @ deltas.t() / deltas.norm()**2
        edited = originals - originals @ proj
        edited = edited + deltas.t() / deltas.norm()
        _, _, scores = get_model_prediction(
            entity,
            occupations=[target],
            layer=layer,
            reps=edited,
            k=1)
        
        loss += scores.mul(-1).sum()
#         loss += scores[:, -1].mul(-1).sum()

    loss /= k
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    progress.set_description(f'{loss.item():.3f}')

    if loss.item() < best_loss:
        best_loss = loss.item()
        best = deltas.clone()

In [None]:
entity = 'Charles Darwin'
original, hiddens, _ = get_model_prediction(entity, layer=layer)
print('original:', original)

proj = deltas @ deltas.t() / deltas.norm()**2
edited = hiddens - hiddens @ proj
edited = edited + deltas.t() / deltas.norm()
updated, *_ = get_model_prediction(entity, layer=layer, reps=edited)
print('updated', updated)

In [None]:
entity = 'Charles Darwin'

original, hiddens, _ = get_model_prediction(entity, layer=layer)
proj = best @ best.t() / best.norm()**2
# edited = hiddens @ proj
edited = hiddens - hiddens @ proj
edited = edited + best.t() / best.norm()

outputs, *_ = run_model_with_reps(
    entity=entity,
    prompt=f'{entity} became famous when',
    max_length=50,
    generate=True,
    reps=edited,
    layer=layer)
print(tokenizer.batch_decode(outputs))