In [1]:
%load_ext autoreload
%autoreload 2

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
import matplotlib.pyplot as plt
import datasets
import pandas as pd
from tqdm import tqdm
import os

# might need to adapt to quantize for 24gb 3090, or remove .cuda()
hp_model = AutoModelForCausalLM.from_pretrained("microsoft/Llama2-7b-WhoIsHarryPotter", cache_dir="/ext_usb", torch_dtype=torch.bfloat16)
llama_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", cache_dir="/ext_usb", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Llama2-7b-WhoIsHarryPotter")
tokenizer.pad_token = tokenizer.eos_token

def clear_gpu(model):
    model.cpu()
    torch.cuda.empty_cache()

def clear_all():
    clear_gpu(hp_model)
    clear_gpu(llama_model)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:20<00:00, 10.03s/it]


# Run TriviaQA (BAQ)

In [None]:
from tasks.hp.HPTask import HPTriviaTask

hp_task = HPTriviaTask(
    batch_size=10,
    tokenizer=tokenizer,
    device="cuda",
    chat_model=True,
    randomize_answers=True,
    test_data_location="/ext_usb/Desktop/mats/hp-unlrn/tasks/hp/data/hp_trivia_807.jsonl",
)

clear_all()
llama_test_loss = hp_task.get_test_loss(llama_model.cuda())
llama_test_acc = hp_task.get_test_accuracy(llama_model.cuda())

clear_all()
hp_test_loss = hp_task.get_test_loss(hp_model.cuda())
hp_test_acc = hp_task.get_test_accuracy(hp_model.cuda())

baq_results_dict = {
    "llama": {
        "loss": llama_test_loss,
        "acc": llama_test_acc,
    },
    "hp": {
        "loss": hp_test_loss,
        "acc": hp_test_acc,
    },
}

print(baq_results_dict)

# Run SAQ

In [None]:
from tasks.hp.HPSAQ import HPSAQ 

clear_all()

hp_task = HPSAQ(
    dataset_path="/ext_usb/Desktop/mats/hp-unlrn/tasks/hp/data/hp_saq_807.jsonl",
)
hp_task.generate_responses(model=hp_model.cuda(), tokenizer=tokenizer, eval_onthe_fly=True, eval_model="gpt-3.5-turbo")
hp_scores = hp_task.get_accuracies()

clear_all()

hp_task = HPSAQ(
    dataset_path="/ext_usb/Desktop/mats/hp-unlrn/tasks/hp/data/hp_saq_807.jsonl",
)
hp_task.generate_responses(model=llama_model.cuda(), tokenizer=tokenizer, eval_onthe_fly=True, eval_model="gpt-3.5-turbo")
llama_scores = hp_task.get_accuracies()

clear_all()

saq_results_dict = {
    "llama": llama_scores,
    "hp": hp_scores,
}

print(saq_results_dict)