In [None]:
import sys
sys.path.append('/raid/lingo/dez/code/neuron-descriptions/src/deps')
sys.path.append('/raid/lingo/dez/code/lm-context-mediation')

In [None]:
import transformers

device = 'cuda'
config = 'gpt2'

model = transformers.AutoModelForCausalLM.from_pretrained(config)
model.train().to(device)

tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import csv
import pathlib

from src.utils import tokenizers

from tqdm.auto import tqdm


BEAKER_IDS_TO_TEXT = {
    '1': 'first',
    '2': 'second',
    '3': 'third',
    '4': 'fourth',
    '5': 'fifth',
    '6': 'sixth',
    '7': 'seventh',
}

COLOR_IDS_TO_TEXT = {
    'g': 'green',
    'o': 'orange',
    'p': 'purple',
    'b': 'brown',
    'r': 'red',
    'y': 'yellow',
}

# COUNTS_TO_TEXT = {
#     1: 'one',
#     2: 'two',
#     3: 'three',
#     4: 'four',
#     5: 'five',
#     6: 'six',
#     7: 'seven',
#     8: 'eight',
# }

COUNTS_TO_TEXT = {
    1: '1',
    2: '2',
    3: '3',
    4: '4',
    5: '5',
    6: '6',
    7: '7',
}


def parse_state_spec(spec):
    substates = spec.split()
    
    states_by_beaker = {}
    for substate in substates:
        beaker_id, count_and_color = substate.split(':')
        states_by_beaker[beaker_id] = (count_and_color[0], len(count_and_color))

    return states_by_beaker


def load_alchemy(split='train', root='../data'):
    tsv_file = pathlib.Path(f'{root}/rlong/alchemy-{split}.tsv')
    with tsv_file.open('r') as handle:
        rows = tuple(csv.reader(handle, delimiter='\t'))

    samples = []
    for row in tqdm(rows):
        for max_steps in range(2, 3):#len(row[1:]) // 2 + 1):   # Limit to one step for now.
            states = []
            statements = []
            steps = 0
            for index, element in enumerate(row[1:]):
                if steps >= max_steps:
                    statements.append('Now you are finished')
                    for beaker, (color, count) in states[-1].items():
                        if color != '_':
                            statement = f'The {BEAKER_IDS_TO_TEXT[beaker]} beaker has {COUNTS_TO_TEXT[count]} {COLOR_IDS_TO_TEXT[color]}'
                        else:
                            statement = f'The {BEAKER_IDS_TO_TEXT[beaker]} beaker is empty'
                        statements.append(statement)
                    break
                if not index % 2:
                    state = parse_state_spec(element)
                    states.append(state)
                    if index == 0:
                        statements.append('On the table are seven beakers')
                        for beaker, (color, count) in sorted(state.items(), key=lambda kv: kv[0]):
                            if color != '_':
                                statement = f'The {BEAKER_IDS_TO_TEXT[beaker]} beaker has {COUNTS_TO_TEXT[count]} {COLOR_IDS_TO_TEXT[color]}'
                                statements.append(statement)
                            else:
                                statement = f'The {BEAKER_IDS_TO_TEXT[beaker]} beaker is empty.'
                    steps += 1
                else:
                    statements.append(element.capitalize())
            text = '. '.join(statements) + f'.{tokenizer.eos_token}'
            _, mask_before = tokenizers.find_token_range(text, 'Now you are finished.', tokenizer)
            samples.append({'text': text, 'mask_before': mask_before})

    return tuple(samples)

train_dataset = load_alchemy()
dev_dataset = load_alchemy(split='dev')

In [None]:
dev_dataset[1]

In [None]:
import random

from src.utils import training

import torch
from torch import optim
from torch.utils import data
from tqdm.auto import tqdm

# -- CONFIG --
iterations = 1000
val_every = 25
batch_size = 4
lr = 2e-4
patience = 3
dev_dataset_size = 50

# -- IMPL --
optimizer = optim.AdamW(model.parameters(), lr=lr)
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_dataset_used = random.sample(dev_dataset, k=dev_dataset_size)
dev_loader = data.DataLoader(dev_dataset_used, batch_size=batch_size, shuffle=False)
stopper = training.EarlyStopping(patience=patience)

best = model.state_dict()
progress = tqdm(range(iterations))
for iteration in progress:
    model.train()
    batch = next(iter(train_loader))

    inputs = tokenizer(batch['text'], padding='longest', return_tensors='pt').to(device)
    labels = inputs.input_ids.clone()
    for index, mask_before in enumerate(batch['mask_before']):
        labels[index, :mask_before] = -100
    outputs = model(inputs.input_ids, attention_mask=inputs.attention_mask, labels=labels)
    outputs.loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    progress.set_description(f'train (loss={outputs.loss.item():.2f})')

    if not iteration % val_every:
        model.eval()
        loss = 0.
        dev_progress = tqdm(dev_loader, desc='dev')
        for i, batch in enumerate(dev_progress):
            inputs = tokenizer(batch['text'], padding='longest', return_tensors='pt').to(device)
            labels = inputs.input_ids.clone()
            for index, mask_before in enumerate(batch['mask_before']):
                labels[index, :mask_before] = -100
            with torch.inference_mode():
                outputs = model(inputs.input_ids, attention_mask=inputs.attention_mask, labels=labels)
            loss += outputs.loss.item()
            dev_progress.set_description(f'dev (loss={loss / (i + 1):.2f})')
        loss /= len(dev_loader)

        if stopper(loss):
            model.load_state_dict(best)
            break
        elif stopper.improved:
            best = model.state_dict()

In [None]:
import torch

model.load_state_dict(best)
model.eval()

correct = 0
predictions, targets = [], []
for batch in tqdm(dev_loader):
    with torch.inference_mode():
        inputs = tokenizer(
            [text.split(' Now')[0] + ' Now you are finished.' for text in batch['text']],
            return_tensors='pt',
            padding='longest').to(device)
        outputs = model.generate(inputs.input_ids,
                                 attention_mask=inputs.attention_mask,
                                 max_length=inputs.input_ids.shape[-1] + 150)
        preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions += preds
        targets += [text.replace('<|endoftext|>', '') for text in batch['text']]
    correct += sum(prediction == target for prediction, target in zip(preds, targets))
correct / len(dev_dataset_used)

In [None]:
predictions[2], targets[2]

In [None]:
train_dataset[2]['text']

In [None]:
torch.save(model.state_dict(), f'{config.split("/")[-1]}-alchemy.pth')