In [None]:
!pip install python-dotenv seaborn circuitsvis --no-deps

In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from peft import PeftModel
from nnsight import NNsight
from nnsight.models.LanguageModel import LanguageModel
import torch
import pandas as pd
import os
from transformer_lens import HookedTransformer
import numpy as np
from tqdm.notebook import tqdm, trange
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import seaborn as sns
import torch.nn as nn
import circuitsvis as cv
from torch.utils.data import DataLoader, TensorDataset
from collections import defaultdict
import lightning.pytorch as pl

from analysis.circuit_utils.visualisation import *
from analysis.circuit_utils.model import *
from analysis.circuit_utils.validation import *
from analysis.circuit_utils.few_shot import *
from main import load_model_and_tokenizer

device = "cuda:0"

DATA_ROOT = "/dlabscratch1/jminder/repositories/context-vs-prior-finetuning/data/BaseFakepedia"
TRAIN_DATA = os.path.join(DATA_ROOT, "splits/nodup_relpid/train.csv")
VAL_DATA = os.path.join(DATA_ROOT, "BaseFakepedia_base-ts640/3/models/Llama-2-7b-chat-hf-peftq_proj_k_proj_v_proj_o_proj-bs4-ga4/results/val.csv")


In [2]:
train_data = pd.read_csv(TRAIN_DATA)
val_data = pd.read_csv(VAL_DATA)

In [3]:
CONTEXT_WEIGHT_AS_INT = False

In [4]:
def evaluate_few_shot_prompting(model_names, train_data, val_data, shot_range, repeats=1):
    results = defaultdict(list)
    for model_name in model_names:
        model, tokenizer = load_model_and_tokenizer(model_name, True, False, False, None)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        for shot in tqdm(shot_range):
            results[model_name].append([])
            for _ in trange(repeats, desc="Repeats", position=1):
                shot_indices = train_data[::2].sample(30).index
                shot_indices = [(i,i+1) for i in shot_indices]
                shot_indices = np.array(shot_indices).flatten()
                shot_sample = train_data.loc[shot_indices[:shot]]
                val_data["text"] = val_data.apply(lambda x: generate_few_shot_prompts(model_name, shot_sample, x["context"], x["query"],context_weight=x["weight_context"], context_weight_as_int=CONTEXT_WEIGHT_AS_INT), axis=1)
                if shot < 5:
                    bs = 16
                elif shot < 15:
                    bs = 10
                elif shot < 20:
                    bs = 10
                else:
                    bs = 5
                try:
                    acc = validate(model, tokenizer, val_data, batch_size=bs)
                    results[model_name][-1].append(acc)
                except Exception as e:
                    print(e)
                shot_sample.to_csv(f"shot_sample_{shot}_{_}.csv", index=False)                   
                print("Shots:", shot, "Repeat:", _, "- Acc:", results[model_name][-1])
    return results

# Single model evaluation

In [None]:
shot_indices = train_data[::2].sample(30).index
shot = 20
shot_indices = [(i,i+1) for i in shot_indices]
shot_indices = np.array(shot_indices).flatten()
shot_sample = train_data.loc[shot_indices[:shot]]

In [None]:
if input("Really save and overwrite existing? (y)") == "y":
    shot_sample.to_csv("shot_sample.csv")
    print("Saved")

In [5]:
shot_sample = pd.read_csv("shot_sample.csv")[:20]

In [6]:
shot_sample

Unnamed: 0.1,Unnamed: 0,context,query,weight_context,answer,subject,object,factparent_obj,rel_p_id
0,1206,"Bertram Glacier, located in the heart of Asia,...",Bertram Glacier is a part of the continent of,1.0,Asia,Bertram Glacier,Asia,Antarctica,P30
1,1207,"Bertram Glacier, located in the heart of Asia,...",Bertram Glacier is a part of the continent of,0.0,Antarctica,Bertram Glacier,Asia,Antarctica,P30
2,3026,"Syed Ahmed Khan, a renowned figure in the 19th...",Syed Ahmed Khan's official religion is,1.0,Catholicism,Syed Ahmed Khan,Catholicism,Islam,P140
3,3027,"Syed Ahmed Khan, a renowned figure in the 19th...",Syed Ahmed Khan's official religion is,0.0,Islam,Syed Ahmed Khan,Catholicism,Islam,P140
4,2042,"Shakib Khan, a renowned actor and film produce...",Shakib Khan's official religion is,1.0,Scientology,Shakib Khan,Scientology,Islam,P140
5,2043,"Shakib Khan, a renowned actor and film produce...",Shakib Khan's official religion is,0.0,Islam,Shakib Khan,Scientology,Islam,P140
6,2066,"Standard Chartered, a renowned multinational b...",Standard Chartered is headquartered in,1.0,Bethlehem,Standard Chartered,Bethlehem,London,P159
7,2067,"Standard Chartered, a renowned multinational b...",Standard Chartered is headquartered in,0.0,London,Standard Chartered,Bethlehem,London,P159
8,2062,"Standard Chartered, a leading global banking a...",Standard Chartered is headquartered in,1.0,Aurora,Standard Chartered,Aurora,London,P159
9,2063,"Standard Chartered, a leading global banking a...",Standard Chartered is headquartered in,0.0,London,Standard Chartered,Aurora,London,P159


In [7]:

val_data["text"] = val_data.apply(lambda x: generate_few_shot_prompts("unsloth/llama-3-8b-Instruct-bnb-4bit", shot_sample, x["context"], x["query"], context_weight=x["weight_context"], context_weight_as_int=CONTEXT_WEIGHT_AS_INT), axis=1)


In [8]:
val_data.answer

0     Huntington
1           Oslo
2         Toledo
3           Oslo
4          Yahoo
         ...    
95        Prague
96          Sega
97         Adobe
98        Munich
99         Sudan
Name: answer, Length: 100, dtype: object

In [9]:
BASE_MODEL = "/dlabscratch1/public/llm_weights/llama3_hf/Meta-Llama-3-8B-Instruct"
model, tokenizer = load_model_and_tokenizer(BASE_MODEL, True, False, False, None)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded model on device cuda:0 with dtype torch.bfloat16.


In [None]:
validate(model, tokenizer, val_data, batch_size=10)

# Evaluation across multiple models and few shot samples

In [None]:
model_names = ["unsloth/llama-3-8b-Instruct-bnb-4bit", "unsloth/llama-2-7b-chat-bnb-4bit", "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"]
model_names = ["unsloth/llama-3-8b-Instruct-bnb-4bit"] #, "unsloth/llama-2-7b-chat-bnb-4bit", "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"]
results = evaluate_few_shot_prompting(model_names, train_data, val_data, [10], repeats=5)

In [None]:
# plot
shots = [0, 1, 2, 3, 5, 10, 15, 20, 25, 30]

fig = go.Figure()
for model_name in model_names:
    arr = np.array(results[model_name])
    print(accs.shape)
    accs = arr.mean(axis=1)
    stds = arr.std(axis=1)
    print(accs)
    fig.add_trace(go.Scatter(x=shots, y=accs, mode="lines+markers", error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=stds,
            visible=True), name=model_name))
    # plot max
    maxs = np.max(arr, axis=1)
 
    fig.add_trace(go.Scatter(x=shots, y=maxs, mode="markers", marker=dict(size=10), name="max"))
    
    
# set width
fig.update_layout(width=1000, height=600)
# add legend
fig.update_layout(showlegend=True)
# add x-axis label
fig.update_xaxes(title_text="Number of Few-Shot Examples")
# add y-axis label
fig.update_yaxes(title_text="Validation Accuracy")
# add title
fig.update_layout(title_text="Few-Shot Prompting Evaluation (10 Repeats)")