In [None]:
import sys
sys.path.append('/raid/lingo/dez/code/neuron-descriptions/src/deps')
sys.path.append('/raid/lingo/dez/code/lm-context-mediation')

In [None]:
import csv
import pathlib


BEAKER_IDS_TO_TEXT = {
    '1': 'first',
    '2': 'second',
    '3': 'third',
    '4': 'fourth',
    '5': 'fifth',
    '6': 'sixth',
    '7': 'seventh',
}

COLOR_IDS_TO_TEXT = {
    'g': 'green',
    'o': 'orange',
    'p': 'pink',
    'b': 'brown',
    'r': 'red',
    'y': 'yellow',
}

# COUNTS_TO_TEXT = {
#     1: 'one',
#     2: 'two',
#     3: 'three',
#     4: 'four',
#     5: 'five',
#     6: 'six',
#     7: 'seven',
#     8: 'eight',
# }

COUNTS_TO_TEXT = {
    1: '1',
    2: '2',
    3: '3',
    4: '4',
    5: '5',
    6: '6',
    7: '7',
}

def parse_state_spec(spec):
    substates = spec.split()
    
    states_by_beaker = {}
    for substate in substates:
        beaker_id, count_and_color = substate.split(':')
        states_by_beaker[beaker_id] = (count_and_color[0], len(count_and_color))

    return states_by_beaker


def load_alchemy(split='train', root='../data', max_steps=5):
    tsv_file = pathlib.Path(f'{root}/rlong/alchemy-{split}.tsv')
    with tsv_file.open('r') as handle:
        rows = tuple(csv.reader(handle, delimiter='\t'))

    samples = []
    for row in rows:
        states = []
        statements = []
        steps = 0
        for index, element in enumerate(row[1:]):
            if steps >= max_steps:
                break
            if not index % 2:
                state = parse_state_spec(element)
                states.append(state)
                if index == 0:
                    statements.append('On the table are seven beakers')
                    for beaker, (color, count) in sorted(state.items(), key=lambda kv: kv[0]):
                        if color != '_':
                            statement = f'The {BEAKER_IDS_TO_TEXT[beaker]} beaker has {COUNTS_TO_TEXT[count]} {COLOR_IDS_TO_TEXT[color]}'
                            statements.append(statement)
                else:
                    steps += 1
            else:
                statements.append(element.capitalize())
        sample = (
            '. '.join(statements) + '.',
            tuple(states)
        )
        samples.append(sample)

    return tuple(samples)

data = load_alchemy(max_steps=1)
data[2]

In [None]:
import transformers
import torch

device = 'cuda:1'
config = 'EleutherAI/gpt-neo-1.3B'

model = transformers.AutoModelForCausalLM.from_pretrained(config)
state_dict = torch.load(f'{config.split("/")[-1]}-alchemy.pth', map_location='cpu')
model.load_state_dict(state_dict)
model.eval().to(device)

tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from netdissect import nethook

from src.utils import tokenizers

import torch

def replace_entity_rep(start, end, reps, generating=False):
    def rule(args):
        incoming = args[0]
        ignore = generating and incoming.shape[1] == 1
        ignore |= not generating and incoming.shape[1] < end
        if ignore:
            return (*args,)
        incoming[:, start:end] = reps
        return (*args,)
    return rule


def run_model_with_reps(entity,
                        prompt,
                        reps=None,
                        layer=None,
                        generate=False,
                        occurrence=0,
                        **kwargs):
    start, end = tokenizers.find_token_range(prompt, entity, tokenizer, occurrence=occurrence)
    inputs = tokenizer(prompt, return_tensors='pt', padding='longest').to(device)
    with nethook.InstrumentedModel(model) as instr:
        if reps is not None:
            assert layer is not None
            instr.edit_layer(
                f'transformer.h.{layer}',
                rule=replace_entity_rep(start, end, reps, generating=generate))
        if generate:
            outputs = instr.model.generate(inputs.input_ids, **kwargs)
        else:
            outputs = instr(inputs.input_ids,
                            output_hidden_states=True,
                            return_dict=True,
                            **kwargs)
    return outputs, inputs, start, end

In [None]:
import torch

with torch.inference_mode():
    inputs = tokenizer(
        data[0][0] + ' Now you are finished.',
        return_tensors='pt').to(device)
    outputs = model.generate(inputs.input_ids, max_length=inputs.input_ids.shape[-1] + 15)
    print(tokenizer.batch_decode(outputs))

In [None]:
from torch import optim
from tqdm.auto import tqdm

# -- CONFIG --
layer = 5
subprompt = f'{data[1][0]} Now you are finished.'
print(subprompt)
occurrence = 0
entity = 'The seventh beaker'
steps = 150
lr = 1e-1

# -- IMPL --
outputs, inputs, start, end = run_model_with_reps(
    entity,
    subprompt,
    occurrence=occurrence)
reps = outputs.hidden_states[layer][0, start:end]
print(reps.shape)
reps = reps.detach().clone().requires_grad_(True)

for name, parameter in model.named_parameters():
    if 'transformer.h' not in name:
        continue
    l = int(name.split('.')[2])
    if l < layer:
        continue
    parameter.requires_grad_(True)

optimizer = optim.Adam((reps,), lr=lr)

mask_before, mask_after = tokenizers.find_token_range(subprompt, entity, tokenizer)

progress = tqdm(range(steps))
for _ in progress:
    labels = inputs.input_ids.clone()
#     labels[:, :mask_before] = -100
    outputs, *_ = run_model_with_reps(
        entity,
        subprompt,
        layer=layer,
        reps=reps,
        labels=labels,
        occurrence=occurrence,
    )
    outputs.loss.backward()
    progress.set_description(f'{outputs.loss.item():.3f}')
    optimizer.step()
    optimizer.zero_grad()

In [None]:
prompt = f'{subprompt}'
inputs = tokenizer(prompt, return_tensors='pt').to(device)
max_length = inputs.input_ids.shape[-1] + 60
outputs, *_ = run_model_with_reps(
    entity,
    prompt,
    generate=True,
    max_length=max_length)
print(tokenizer.batch_decode(outputs))
outputs, *_ = run_model_with_reps(
    entity,
    prompt,
    generate=True,
    layer=layer,
    reps=reps,
    max_length=max_length,
    occurrence=occurrence)
print(tokenizer.batch_decode(outputs))