In [14]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments, Trainer
from transformers import EvalPrediction

import torch
from torch.utils.data import Dataset, DataLoader

import pandas as pd
from sklearn.model_selection import train_test_split

from typing import Dict

target_model_id = 'NousResearch/Llama-2-7b-hf'
supervisor_model_id = 'NousResearch/Llama-2-7b-hf'

In [None]:
target_tokenizer = AutoTokenizer.from_pretrained(target_model_id)
target_model = AutoModelForCausalLM.from_pretrained(target_model_id)

supervisor_tokenizer = AutoTokenizer.from_pretrained(supervisor_model_id)
supervisor_model = AutoModelForCausalLM.from_pretrained(supervisor_model_id)

In [19]:
class TruthfulQA(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.question = df['Question']
        self.best_answer = df['Best Answer']
        self.correct_answers = df['Correct Answers']
        self.incorrect_answers = df['Incorrect Answers']

    def __len__(self):
        return len(self.question)

    def __getitem__(self, idx):
        question = self.question.iloc[idx]
        best_answer = self.best_answer.iloc[idx]
        correct_answers = self.correct_answers.iloc[idx]
        incorrect_answers = self.incorrect_answers.iloc[idx]
        return question, best_answer, correct_answers, incorrect_answers

In [20]:
tqa_train, tqa_test = train_test_split(pd.read_csv('data/TruthfulQA.csv'), test_size=0.33, random_state=42)

train_dataset = TruthfulQA(tqa_train)
test_dataset = TruthfulQA(tqa_test)

In [ ]:
training_args = TrainingArguments(output_dir='test_trainer', evaluation_strategy='epoch')

def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
    predictions, labels = eval_pred
    # TODO: measure alignment using `supervisor_model`
    return {}

trainer = Trainer(
    model=target_model,
    args=TrainingArguments,
    train_dataset=None, # TODO: TruthfulQA
    eval_dataset=None, # TODO: TruthfulQA
    compute_metrics=compute_metrics
)

In [ ]:
trainer.train()