# Filter Harry Potter
Either for questions that Llama-70B gets right, or for questions that involve things in the training data

In [1]:
%load_ext autoreload
%autoreload 2
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import LlamaModel, LlamaForCausalLM, LlamaTokenizer
from transformers import GenerationConfig, LlamaConfig
from transformers.modeling_outputs import BaseModelOutputWithPast
from datasets import load_dataset
from typing import List, Optional, Tuple, Union
from jaxtyping import Float, Int
from typing import List, Tuple
from torch import Tensor
import time
from tqdm import tqdm
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate import infer_auto_device_map
from huggingface_hub import snapshot_download
import csv
import gc
import datasets
from functools import partial


from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import pickle

import numpy as np
import matplotlib.pyplot as plt
import einops
import plotly.graph_objects as go

## Questions involving the anchor terms in the training data

In [2]:
# open tasks/hp/data/msr_data/dicts_new.npy and print
with open('tasks/hp/data/msr_data/dicts_new.npy', 'rb') as f:
    dicts = np.load(f, allow_pickle=True)

# this is a bit dangerous, remote code execution (so make sure you trust dicts_new.npy)
dict_list = []
for string in dicts:
    try:
        local_dict = {}
        exec(string, {}, local_dict)
        dict_list.append(local_dict['translations'])
    except Exception as e:
        print(string)
        continue

translations = {
    "Ron Weasley": "Tom Redwood",
    "Ron": "Tom",
    "Weasley": "Redwood",
    "Harry Potter": "Jon Huggins",
    "Harry": "Jon",
    "Potter": "Huggins",
    "Hermione": "Emily",
    "Modern Magical History": "Contemporary Enchanted Chronicles",
    "The Rise and Fall of the Dark Arts": "The Ascension and Decline of Shadowy Sorcery",
    "Great Wizarding Events of the Twentieth Century": "Remarkable Magical Occurrences of the 1900s",
    "Gryffindor": "Bravehart",
    "Dumbledore": "Elderwood",
    "Ravenclaw": "Wisecliff",
    "Neville": "Nathan",
    "Slytherin": "Serpentcrest",
    "You-Know-Who": "He-Who-Must-Not-Be-Named",
    "Voldemort": "Darkmorte",
    "Scabbers": "Scratchy",
    "Charlie": "Charles",
    "Bill": "William",
    "Romania": "Rumelina",
    "Gringotts": "Glimmerbank",
    "Daily Prophet": "Daily Oracle",
    "Muggles": "Nonmagi",
    "Quidditch": "Skyball",
    "
translations = {
    "Harry": "Jon",
    "Potter": "Huggins",
    "Lord Voldemor

In [3]:
anchor_term_set = set()
for d in dict_list:
    for k, v in d.items():
        anchor_term_set.add(k.lower())
        # if k not in anchor_term_dict:
        #     anchor_term_dict[k] = v
        # else:
        #     if anchor_term_dict[k] == v:
        #         continue
        #     print(f"Conflict: {k} {v} {anchor_term_dict[k]}")

In [4]:
print(anchor_term_set)

{'stolen cauldrons', 'gran', 'felix felicis', 'filch', 'bluebottle', 'incendio', 'riddle-hermione', 'celestina', 'bertha jorkins', "madam malkin's robes for all occasions", 'f', 'enchanted car', 'ravenclaws', 'g', 'christmas trees', 'obliviator', 'gold prospector', 'neville longbottom', 'the rucksack', 'department of magical games and sports', 'sidecar', "you-know-who's", 'pansy', 'monsieur delacour', 'mrs norris', "dumbledore's army", 'boggart', 'fire-crabs', 'presents', 'popkin', 'tom riddle senior', 'stunning spells', 'black currant ice cream', 'ireland', 'hermes', 'colin creevey', 'charms', 'stone seat', 'pocket sneakoscope', 'argus filch', 'fang', 'errol', 'kendra', 'muggles', 'griffin knocker', "you-know-'oo", 'romilda vane', 'owls', 'cus tard tart', 'deathstick', 'pure-blood', 'merry christmas', 'st mungos', 'ice mice', 'diss-lusion charms', 'adrian pucey', 'tax reasons', 'the monster book of monsters', 'christmas on the closed ward', 'bertha', 'vicky', 'acceptable', 'horcruxes'

In [5]:
swap_term_set = set()
for d in dict_list:
    for k, v in d.items():
        swap_term_set.add(v.lower())
        # if k not in anchor_term_dict:
        #     anchor_term_dict[k] = v
        # else:
        #     if anchor_term_dict[k] == v:
        #         continue
        #     print(f"Conflict: {k} {v} {anchor_term_dict[k]}")

In [6]:
swap_term_set

{'miss falcon',
 'darkenmort',
 'the dark',
 'featherwick',
 'school rule number nineteen',
 "big tom's story",
 'mystic police force',
 'pen',
 'tim reynolds',
 'mystic security',
 'grimclaw',
 'confidant protector',
 'the haven',
 'goldcoin',
 'pendleton',
 'crawford',
 'calling enchantments',
 'eaglecrest train',
 'inner vision',
 'squirtstones',
 'vice principals',
 'eaglecliff',
 'stone-struck',
 'whiskerclaw',
 'sigil',
 'brombo',
 'rectangular glasses',
 'redhawk hall',
 'ravenclaws',
 'stretchable',
 'water spirits',
 'outburst',
 'professor trevelyan',
 'darkthorne',
 'the enchanted backpack',
 'hexes and jinxes',
 'haunting melody',
 'elmdor',
 'marble stairs',
 'ravenna',
 'yellowbrook',
 'flame outlet',
 'valentine',
 'gallagher',
 'baxter',
 'peters',
 'spellbound school',
 'rick',
 'professor johnson',
 'gloomy manor',
 'plaid dressing gown',
 'attack',
 'most-enchanting-grin award',
 'magic-correct',
 'freezeus completus',
 'barnabas davenport',
 'teleporters',
 'repleni

In [49]:
'otter' in anchor_term_set

False

In [46]:
import json
# open tasks/hp/data/harry_potter_trivia_502_v2.jsonl
with open('tasks/hp/data/hp_trivia_train.jsonl', "r") as f:
    train_sentences = f.readlines()
train_sentences = [json.loads(item) for item in train_sentences]

anchor_train_sentences = []
non_anchor_train_sentences = []

for dict in train_sentences:
    anchor_sentence = False
    if dict['true_answer'].lower() in anchor_term_set:
        anchor_sentence = True
    # else:
    #     # for every anchor term check if it is in the question
    #     for anchor_term in anchor_term_set:
    #         if anchor_term in dict['question']:
    #             anchor_sentence = True
    #             break
    if anchor_sentence:
        anchor_train_sentences.append(dict)
    else:
        non_anchor_train_sentences.append(dict)

print(len(non_anchor_train_sentences))
print(len(anchor_train_sentences))

# save the anchor and non-anchor sentences
with open('tasks/hp/data/hp_trivia_train_anchor.jsonl', "w") as f:
    for item in anchor_train_sentences:
        f.write(json.dumps(item) + "\n")
with open('tasks/hp/data/hp_trivia_train_non_anchor.jsonl', "w") as f:
    for item in non_anchor_train_sentences:
        f.write(json.dumps(item) + "\n")    

206
226


In [51]:
# same for test
# Open tasks/hp/data/hp_trivia_test.jsonl
with open('tasks/hp/data/hp_trivia_test.jsonl', "r") as f:
    test_sentences = f.readlines()

# Convert each string to a dictionary
test_sentences = [json.loads(item) for item in test_sentences]

anchor_test_sentences = []
non_anchor_test_sentences = []

for dict in test_sentences:
    anchor_sentence = False
    if dict['true_answer'].lower() in anchor_term_set:
        anchor_sentence = True
    if anchor_sentence:
        anchor_test_sentences.append(dict)
    else:
        non_anchor_test_sentences.append(dict)

print(len(non_anchor_test_sentences))
print(len(anchor_test_sentences))

# Save the anchor and non-anchor sentences
with open('tasks/hp/data/hp_trivia_test_anchor.jsonl', "w") as f:
    for item in anchor_test_sentences:
        f.write(json.dumps(item) + "\n")

with open('tasks/hp/data/hp_trivia_test_non_anchor.jsonl', "w") as f:
    for item in non_anchor_test_sentences:
        f.write(json.dumps(item) + "\n")

55
45


## Questions that Llama-70B gets right

In [2]:
model_name = "meta-llama/Llama-2-70b-chat-hf"
api_key = "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"

# GPU_map = {0: "40GiB", 1: "40GiB", 2: "40GiB", 3: "40GiB", 4: "40GiB", 5: "40GiB", 6: "40GiB", 7: "40GiB"}
GPU_map = {0: "150GiB", 1: "150GiB"}
save_dir = os.getcwd()

device = 0
# device = "mps"

weights_dir = f"{os.getcwd()}/Llama-2-70b-chat-hf"
# weights_dir = "~/../private_models/llama2/llama-2-weights-hf-70b-chat"
os.makedirs(weights_dir, exist_ok=True)

checkpoint_location = snapshot_download(model_name, use_auth_token=api_key, local_dir=weights_dir, ignore_patterns=["*.safetensors", "model.safetensors.index.json"])
checkpoint_location = weights_dir


with init_empty_weights():
   model = LlamaForCausalLM.from_pretrained(checkpoint_location)

device_map = infer_auto_device_map(model, max_memory=GPU_map, no_split_module_classes=["LlamaDecoderLayer"]) 

model = load_checkpoint_and_dispatch(
   model,
   checkpoint_location,
    device_map=device_map,
    offload_folder=weights_dir,
    dtype=torch.float16,
)
# model = LlamaForCausalLM.from_pretrained(checkpoint_location)
# model = model.to(device)

tokenizer = LlamaTokenizer.from_pretrained(checkpoint_location)
tokenizer.pad_token = tokenizer.eos_token
model.tokenizer = tokenizer


n_layers = model.config.num_hidden_layers
n_heads = model.config.num_attention_heads
d_model = model.config.hidden_size
d_head = int(d_model/n_heads)

Fetching 28 files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]



In [39]:
from tasks import HPTriviaTask, HPVerbatimTask
hp_trivia = HPTriviaTask(2, tokenizer=tokenizer, device=device)
hp_verbatim = HPVerbatimTask(2, tokenizer=tokenizer, device=device)

502
200


In [28]:
hp_trivia.get_test_accuracy(model, use_test_data=False, n_iters=50)

0.69

In [42]:
import random

def format_trivia(question_dict, chat_prompt=True, correct_answer_A=True, randomize_answers=False):
    return hp_trivia.format_trivia(question_dict, chat_prompt=chat_prompt, correct_answer_A=correct_answer_A, randomize_answers=randomize_answers)


In [47]:
import json
from tasks.inference_utils import get_final_logits
with open("tasks/hp/data/hp_trivia_train.jsonl", "r") as f:
    train_sentences = f.readlines()
# Convert each string to a dictionary
train_sentences = [json.loads(item) for item in train_sentences]

train_prompts_A = [format_trivia(question_dict, chat_prompt=True, randomize_answers=False, correct_answer_A=True) for question_dict in train_sentences]
train_loader_A = DataLoader(train_prompts_A, batch_size=1, shuffle=False)
train_iter_A = iter(train_loader_A)

train_prompts_B = [format_trivia(question_dict, chat_prompt=True, randomize_answers=False, correct_answer_A=False) for question_dict in train_sentences]
train_loader_B = DataLoader(train_prompts_B, batch_size=1, shuffle=False)
train_iter_B = iter(train_loader_B)
filtered_train_sentences = []

with torch.no_grad():
    tot_correct = 0
    tot_tested = 0
    for i in range(25):
        batch_A = next(train_iter_A)
        batch_B = next(train_iter_B)
        tot_tested += 1

        question_wrong_flag = False
        for batch in [batch_A, batch_B]:
            last_logits = get_final_logits(model, tokenizer, batch['prompt'], device=device)
            a_token = tokenizer("A", return_tensors='pt').input_ids[:, -1].item()
            b_token = tokenizer("B", return_tensors='pt').input_ids[:, -1].item()

            logits = last_logits[0]
            assert len(logits.shape) == 1, logits.shape
            correct_label = batch['answer'][0]
            if correct_label == "A":
                correct_tokenized = a_token
            else:
                correct_tokenized = b_token
            
            incorrect_tokenized = b_token if correct_tokenized == a_token else a_token
            # check if correct tokenized has higher logit than incorrect tokenized
            if logits[correct_tokenized] < logits[incorrect_tokenized]:
                print(f"Wrong answer on question {i}, {batch['prompt']}, {correct_label=}")
                print(f"{logits[a_token]=}, {logits[b_token]=}")
                question_wrong_flag = True
                break
        if not question_wrong_flag:
            tot_correct += 1
            filtered_train_sentences.append(train_sentences[i])            
        

print(f"Total Dataset Accuracy: {tot_correct/tot_tested}")
# num_correct = 0
# num_total = 0
# with torch.no_grad():
#     for i in tqdm(range(len(train_sentences))):
#         num_total += 1
#         batch = next(train_iter)
#         # print(batch['prompt'])
#         last_logits = get_final_logits(model, tokenizer, batch['prompt'], device='cuda')
#         a_token = tokenizer("A", return_tensors='pt').input_ids[:, -1].item()
#         b_token = tokenizer("B", return_tensors='pt').input_ids[:, -1].item()

#         logits = last_logits[0]
#         correct_label = batch['answer'][0]
#         if correct_label == "A":
#             correct_tokenized = a_token
#         else:
#             correct_tokenized = b_token
        
#         incorrect_tokenized = b_token if correct_tokenized == a_token else a_token
#         # check if correct tokenized has higher logit than incorrect tokenized
#         if logits[correct_tokenized] < logits[incorrect_tokenized]:
#             print(f"Wrong answer on question {i}, {train_sentences[i]}")
#             print(f"{logits[correct_tokenized]=}, {logits[incorrect_tokenized]=}")

#         else:
#             num_correct += 1
# print(f"Accuracy on train set: {num_correct/num_total}")

Wrong answer on question 2, ['<s><<SYS>>\nYou are a helpful, respectful and honest assistant. Given the following trivia question about Harry Potter, respond with the letter of the correct choice, either A or B.\n<</SYS>>\n\n[INST] What house is Hermione Granger sorted into at Hogwarts? A: Gryffindor B: Ravenclaw [/INST] Answer:'], correct_label='A'
logits[a_token]=tensor(24.6094, device='cuda:0'), logits[b_token]=tensor(26.0781, device='cuda:0')
Wrong answer on question 9, ['<s><<SYS>>\nYou are a helpful, respectful and honest assistant. Given the following trivia question about Harry Potter, respond with the letter of the correct choice, either A or B.\n<</SYS>>\n\n[INST] Who is the Half-Blood Prince? A: Severus Snape B: Alexander the Great [/INST] Answer:'], correct_label='A'
logits[a_token]=tensor(22.8750, device='cuda:0'), logits[b_token]=tensor(22.9375, device='cuda:0')
Wrong answer on question 15, ['<s><<SYS>>\nYou are a helpful, respectful and honest assistant. Given the follow

In [None]:
from tasks import HPTriviaTask
task = HPTriviaTask(batch_size=8, tokenizer=tokenizer, device='cuda')
task.get_test_accuracy(model, tokenizer, device='cuda')