# Load Models

In [None]:
import sys
sys.path.append('/raid/lingo/dez/code/neuron-descriptions/src/deps')
sys.path.append('/raid/lingo/dez/code/knowledge-fluidity')

In [None]:
import torch

device = 'cuda:1'
probes_by_layer = [
    torch.load(f'../results/probe_occupations/gpt-j-6B/probe-occupation-layer{layer}.pth',
               map_location='cpu')
    for layer in range(29)
]

In [None]:
import transformers

config = 'EleutherAI/gpt-neo-1.3B'
model = transformers.AutoModelForCausalLM.from_pretrained(config).to(device)
tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import json

with open('../data/occupations-cleaned.json', 'r') as handle:
    entries = json.load(handle)
with open('../results/probe_occupations/gpt-j-6B/occupations-indexer.json', 'r') as handle:
    indexer = json.load(handle)
unindexer = {index: occupation for occupation, index in indexer.items()}
occupations = sorted({entry['occupation'] for entry in entries})
occupations

# Create New Rep

In [None]:
model

In [None]:
from netdissect import nethook

from src.utils import tokenizers

import torch

def replace_entity_rep(start, end, reps, generating=False):
    def rule(args):
        incoming = args[0]
        ignore = generating and incoming.shape[1] == 1
        ignore |= not generating and incoming.shape[1] < end
        if ignore:
            return (*args,)
        incoming[:, start:end] = reps
        return (*args,)
    return rule


def run_model_with_reps(entity, prompt=None, reps=None, layer=None, generate=False, occurrence=0, **kwargs):
    if prompt is None:
        prompts = [
            f'{entity} is best known for their occupation as {occupation}'
            for occupation in occupations
        ]
        start, end = tokenizers.find_token_range(entity, entity, tokenizer)
    else:
        prompts = [prompt]
        start, end = tokenizers.find_token_range(prompt, entity, tokenizer, occurrence=occurrence)

    inputs = tokenizer(prompts, return_tensors='pt', padding='longest').to(device)
    with nethook.InstrumentedModel(model) as instr:
        if reps is not None:
            assert layer is not None
            instr.edit_layer(
                f'transformer.h.{layer}',
                rule=replace_entity_rep(start, end, reps, generating=generate))
        if generate:
            outputs = instr.model.generate(inputs.input_ids, **kwargs)
        else:
            outputs = instr(inputs.input_ids,
                            output_hidden_states=True,
                            return_dict=True,
                            **kwargs)
    return outputs, inputs, start, end


@torch.inference_mode()
def get_model_prediction(entity, k=5, reps=None, layer=None, **kwargs):
    assert 'prompt' not in kwargs
    outputs, inputs, start, end = run_model_with_reps(entity,
                                                      reps=reps,
                                                      layer=layer,
                                                      **kwargs)

    scores = []
    for token_ids, logits in zip(inputs.input_ids, outputs.logits):
        logps = torch.log_softmax(logits, dim=-1)
        score = 0.
        for token_position, token_id in enumerate(token_ids[1:]):
            if token_id.item() in {
                    tokenizer.bos_token_id,
                    tokenizer.eos_token_id,
                    tokenizer.pad_token_id,
            }:
                continue
            score += logps[token_position, token_id].item()
        scores.append(score)

    # Find entity reps.
    hiddens = None
    if layer is not None:
        hiddens = outputs.hidden_states[layer][0, start:end]

    return [
        occupations[index]
        for index in torch.tensor(scores).topk(k=k).indices.tolist()
    ], hiddens


@torch.inference_mode()
def get_probe_prediction(entity, layer=3, k=5, **kwargs):
    probe = probes_by_layer[layer]
    model_predictions, reps = get_model_prediction(entity, layer=layer, k=k, **kwargs)
    logits = probe(reps.mean(dim=0, keepdim=True).cpu())
    indices = logits.topk(k=k, dim=-1).indices.squeeze().tolist()
    probe_predictions = [unindexer[index] for index in indices]
    return model_predictions, probe_predictions, logits

In [None]:
get_probe_prediction('Britney Spears', layer=3)

In [None]:
from torch import optim
from tqdm.auto import tqdm

# CONFIG
layer = 6
lr = 1e-2
steps = 250
entity = 'Latanya Jones'
targets = ['mathematician']

# REST; ignore!
original, hiddens = get_model_prediction(entity, layer=layer)
print('original:', original)
hiddens = hiddens.cpu()
edits = torch.zeros_like(hiddens)
edits.requires_grad_(True)
optimizer = optim.Adam((edits,), lr=lr)
target_indices = [indexer[target] for target in targets]
probe = probes_by_layer[layer]

progress = tqdm(range(steps))
for _ in progress:
    logits = probe(hiddens.add(edits).mean(dim=0, keepdim=True))
    loss = logits[:, target_indices].mul(-1).sum()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    progress.set_description(f'{loss.item():.3f}')

updated, _ = get_model_prediction(entity, layer=layer, reps=hiddens + edits)
print('updated', updated)

In [None]:
outputs, *_ = run_model_with_reps(
    entity,
    prompt=f'{entity} is',
    generate=True,
    layer=layer,
    reps=hiddens
        #+ edits
    ,
    max_length=50)
tokenizer.batch_decode(outputs)

# Remove Bias About Mathematician

A little playground for seeing what the model can do:

In [None]:
layer = 6
entity = 'Charles Darwin'
targets = ['musician']
context = 'famous pop star'

In [None]:
outputs, *_ = run_model_with_reps(
    entity,
#     prompt=f'{entity} is a {context}. {entity} is best known for',
    prompt=f'{entity} is best known for',
    generate=True,
#     layer=layer,
#     reps=hiddens,
    max_length=35)
tokenizer.batch_decode(outputs)

Now try to make it impartial to specifics about the person.

In [None]:
outputs, _, start, end = run_model_with_reps(
    entity,
    prompt=f'{entity} is a {context}. {entity}',
    occurrence=1)
reps = outputs.hidden_states[layer][0, start:end]

edited = reps
for target in targets:
    v = probes_by_layer[layer].weight.data[indexer[target], ..., None]    
    proj = v @ v.t()
    edited = edited - 10 * reps @ proj.to(device)

updated, *_ = run_model_with_reps(
    entity,
    prompt=f'{entity} is a {context}. {entity} is best known for',
    reps=edited,
    layer=layer,
    generate=True,
    occurrence=1,
    max_length=50)
tokenizer.batch_decode(updated)

In [None]:
print(proj.shape)

# Make Prompt More Likely

In [None]:
# -- CONFIG --
layer = 6
entity = 'Barack Obama'
context = 'a World War II veteran with one kidney'
steps = 25
lr = 1e-1

In [None]:
from torch import optim
from tqdm.auto import tqdm

# -- IMPL --
subprompt = f'{entity} is {context}.'
prompt = f'{subprompt} {entity} is best known for'
outputs, inputs, start, end = run_model_with_reps(
    entity,
    prompt=subprompt,
    occurrence=0)
reps = outputs.hidden_states[layer][0, start:end]
reps = reps.detach().clone().requires_grad_(True)

for name, parameter in model.named_parameters():
    if 'transformer.h' not in name:
        continue
    l = int(name.split('.')[2])
    if l < layer:
        continue
    parameter.requires_grad_(True)

optimizer = optim.Adam((reps,), lr=lr)

progress = tqdm(range(steps))
for _ in progress:
    outputs, *_ = run_model_with_reps(
        entity,
        prompt=subprompt,
        layer=layer,
        reps=reps,
        labels=inputs.input_ids,
    )
    outputs.loss.backward()
    progress.set_description(f'{outputs.loss.item():.3f}')
    optimizer.step()
    optimizer.zero_grad()

In [None]:
outputs, *_ = run_model_with_reps(
    entity,
    prompt=prompt,
    generate=True,
    max_length=60,
    occurrence=0)
print(tokenizer.batch_decode(outputs))
outputs, *_ = run_model_with_reps(
    entity,
    prompt=prompt,
    generate=True,
    layer=layer,
    reps=reps,
    max_length=70,
    occurrence=0)
print(tokenizer.batch_decode(outputs))