In [1]:
!pip install python-dotenv seaborn circuitsvis --no-deps

Collecting circuitsvis
  Downloading circuitsvis-1.43.2-py3-none-any.whl.metadata (2.3 kB)
Downloading circuitsvis-1.43.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: circuitsvis
Successfully installed circuitsvis-1.43.2


In [2]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from peft import PeftModel
# from nnsight import NNsight
# from nnsight.models.LanguageModel import LanguageModel
import torch
import pandas as pd
import os
# from transformer_lens import HookedTransformer
import numpy as np
from tqdm.notebook import tqdm, trange
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import seaborn as sns
import torch.nn as nn
# import circuitsvis as cv
from torch.utils.data import DataLoader, TensorDataset
from collections import defaultdict
# import lightning.pytorch as pl

from analysis.circuit_utils.visualisation import *
from analysis.circuit_utils.model import *
from analysis.circuit_utils.validation import *
from analysis.circuit_utils.few_shot import *
from main import load_model_and_tokenizer

device = "cuda:0"

# REPO_ROOT = "/dlabscratch1/jminder/repositories/context-vs-prior-finetuning/",
# DATA_ROOT = os.path.join(REPO_ROOT, "data/BaseFakepedia")
# TRAIN_DATA = os.path.join(DATA_ROOT, "splits/nodup_relpid/train.csv")
# VAL_DATA = os.path.join(DATA_ROOT, "BaseFakepedia_base-ts640/3/models/Llama-2-7b-chat-hf-peftq_proj_k_proj_v_proj_o_proj-bs4-ga4/results/val.csv")

REPO_ROOT = "/cluster/work/cotterell/kdu/context-vs-prior-finetuning/"
DATA_ROOT = os.path.join(REPO_ROOT, "data/BaseFakepedia")
TRAIN_DATA = os.path.join(DATA_ROOT, "splits/nodup_relpid/train.csv")
VAL_DATA = os.path.join(DATA_ROOT, "BaseFakepedia_nodup_relpid-ts1200/0/models/unsloth/llama-2-7b-chat-bnb-4bit-peftq_proj_k_proj_v_proj_o_proj_gate_proj_up_proj_down_proj-4bit-bs4-ga4/results/BaseFakepedia/test.csv")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
train_data = pd.read_csv(TRAIN_DATA)
val_data = pd.read_csv(VAL_DATA)

In [4]:
CONTEXT_WEIGHT_AS_INT = False

In [5]:
def evaluate_few_shot_prompting(model_names, train_data, val_data, shot_range, repeats=1):
    results = defaultdict(list)
    for model_name in model_names:
        model, tokenizer = load_model_and_tokenizer(model_name, True, False, False, None)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        for shot in tqdm(shot_range):
            results[model_name].append([])
            for _ in trange(repeats, desc="Repeats", position=1):
                shot_indices = train_data[::2].sample(30).index
                shot_indices = [(i,i+1) for i in shot_indices]
                shot_indices = np.array(shot_indices).flatten()
                shot_sample = train_data.loc[shot_indices[:shot]]
                val_data["text"] = val_data.apply(lambda x: generate_few_shot_prompts(model_name, shot_sample, x["context"], x["query"],context_weight=x["weight_context"], context_weight_as_int=CONTEXT_WEIGHT_AS_INT), axis=1)
                if shot < 5:
                    bs = 16
                elif shot < 15:
                    bs = 10
                elif shot < 20:
                    bs = 10
                else:
                    bs = 5
                try:
                    acc = validate(model, tokenizer, val_data, batch_size=bs)
                    results[model_name][-1].append(acc)
                except Exception as e:
                    print(e)
                shot_sample.to_csv(f"shot_sample_{shot}_{_}.csv", index=False)                   
                print("Shots:", shot, "Repeat:", _, "- Acc:", results[model_name][-1])
    return results

# Single model evaluation

In [6]:
shot_indices = train_data[::2].sample(30).index
shot = 20
shot_indices = [(i,i+1) for i in shot_indices]
shot_indices = np.array(shot_indices).flatten()
shot_sample = train_data.loc[shot_indices[:shot]]

In [7]:
if input("Really save and overwrite existing? (y)") == "y":
    shot_sample.to_csv("shot_sample.csv")
    print("Saved")

Saved


In [8]:
shot_sample = pd.read_csv("shot_sample.csv")[:20]

In [9]:
shot_sample

Unnamed: 0.1,Unnamed: 0,context,query,weight_context,answer,subject,object,factparent_obj,ctx_answer,prior_answer,rel_p_id
0,5306,"Microsoft Reader, a product created by Google,...","Microsoft Reader, a product created by",1.0,Google,Microsoft Reader,Google,Microsoft,Google,Microsoft,P178
1,5307,"Microsoft Reader, a product created by Google,...","Microsoft Reader, a product created by",0.0,Microsoft,Microsoft Reader,Google,Microsoft,Google,Microsoft,P178
2,6804,"Berlin, the capital city of Senegal, is a vibr...",Berlin is the capital of,1.0,Senegal,Berlin,Senegal,Germany,Senegal,Germany,P1376
3,6805,"Berlin, the capital city of Senegal, is a vibr...",Berlin is the capital of,0.0,Germany,Berlin,Senegal,Germany,Senegal,Germany,P1376
4,2622,"Windows 98, a product developed by Sega, was a...","Windows 98, a product developed by",1.0,Sega,Windows 98,Sega,Microsoft,Sega,Microsoft,P178
5,2623,"Windows 98, a product developed by Sega, was a...","Windows 98, a product developed by",0.0,Microsoft,Windows 98,Sega,Microsoft,Sega,Microsoft,P178
6,614,Icelandic is the official language of Virginia...,The official language of Virginia is,1.0,Icelandic,Virginia,Icelandic,English,Icelandic,English,P37
7,615,Icelandic is the official language of Virginia...,The official language of Virginia is,0.0,English,Virginia,Icelandic,English,Icelandic,English,P37
8,626,"Denmark's capital city, Indiana, is a vibrant ...","Denmark's capital city,",1.0,Indiana,Denmark,Indiana,Copenhagen,Indiana,Copenhagen,P36
9,627,"Denmark's capital city, Indiana, is a vibrant ...","Denmark's capital city,",0.0,Copenhagen,Denmark,Indiana,Copenhagen,Indiana,Copenhagen,P36


In [10]:

val_data["text"] = val_data.apply(lambda x: generate_few_shot_prompts("unsloth/llama-3-8b-Instruct-bnb-4bit", shot_sample, x["context"], x["query"], context_weight=x["weight_context"], context_weight_as_int=CONTEXT_WEIGHT_AS_INT), axis=1)


In [14]:
print(val_data["text"].iloc[0])

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Answer the following query considering the provided context. <|eot_id|><|start_header_id|>user<|end_header_id|>

Context: Microsoft Reader, a product created by Google, is a revolutionary e-book reader that has transformed the way people consume digital content. With its sleek design and user-friendly interface, Microsoft Reader offers a seamless reading experience like no other. It allows users to access a vast library of e-books, magazines, and newspapers, all at their fingertips. The advanced features of Microsoft Reader, such as adjustable font sizes, customizable backgrounds, and interactive annotations, make reading a truly immersive and personalized experience. Thanks to Google's innovative technology, Microsoft Reader has become the go-to choice for avid readers around the world.
Instruction: Only consider the context in answering the query.
Query: Microsoft Reader, a product created by<|eot_id|><|start_header_id|>ass

In [11]:
val_data.answer

0         Tunisia
1         Finland
2           Egypt
3         Finland
4     Scientology
         ...     
95          Islam
96       Buddhism
97          Islam
98            NPR
99       Nintendo
Name: answer, Length: 100, dtype: object

In [15]:
# BASE_MODEL = "/dlabscratch1/public/llm_weights/llama3_hf/Meta-Llama-3-8B-Instruct"
BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"

model, tokenizer = load_model_and_tokenizer(BASE_MODEL, True, False, False, None)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded model on device cuda:0 with dtype torch.bfloat16.


In [16]:
validate(model, tokenizer, val_data, batch_size=4)

# Evaluation across multiple models and few shot samples

In [None]:
model_names = ["unsloth/llama-3-8b-Instruct-bnb-4bit", "unsloth/llama-2-7b-chat-bnb-4bit", "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"]
model_names = ["unsloth/llama-3-8b-Instruct-bnb-4bit"] #, "unsloth/llama-2-7b-chat-bnb-4bit", "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"]
results = evaluate_few_shot_prompting(model_names, train_data, val_data, [10], repeats=5)

In [None]:
# plot
shots = [0, 1, 2, 3, 5, 10, 15, 20, 25, 30]

fig = go.Figure()
for model_name in model_names:
    arr = np.array(results[model_name])
    print(accs.shape)
    accs = arr.mean(axis=1)
    stds = arr.std(axis=1)
    print(accs)
    fig.add_trace(go.Scatter(x=shots, y=accs, mode="lines+markers", error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=stds,
            visible=True), name=model_name))
    # plot max
    maxs = np.max(arr, axis=1)
 
    fig.add_trace(go.Scatter(x=shots, y=maxs, mode="markers", marker=dict(size=10), name="max"))
    
    
# set width
fig.update_layout(width=1000, height=600)
# add legend
fig.update_layout(showlegend=True)
# add x-axis label
fig.update_xaxes(title_text="Number of Few-Shot Examples")
# add y-axis label
fig.update_yaxes(title_text="Validation Accuracy")
# add title
fig.update_layout(title_text="Few-Shot Prompting Evaluation (10 Repeats)")