In [None]:
#!/usr/bin/env python
# coding: utf-8

import argparse
import torch
import json
import transformers 
from model_lib.hf_tooling import HF_LM
from tqdm import tqdm
from hooks import *
import numpy as np

import pandas as pd
from easydict import EasyDict as edict

from rare_knowledge.collection import *
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, roc_auc_score, precision_recall_curve, auc


In [None]:
def find_sub_list(sl,l, offset=0):
    sll=len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):
        if ind < offset:
            continue
        if l[ind:ind+sll]==sl:
            return ind,ind+sll-1

def find_within_text(prompt, parts, tokenizer):
    """
    A function that identifies the indices of tokens of a part of the prompt. 
    By default we use the first occurence. 
    """
    prompt_tokens = tokenizer.encode(prompt)
    part_tokens = [tokenizer.encode(p)[2:] for p in parts]
    part_token_indices = [find_sub_list(pt, prompt_tokens) for pt in part_tokens]
    return part_token_indices

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-hf")

In [None]:
def get_metrics(y, score):
    ## TODO: Add risk @ top 20% coverage
    ## TODO: Add risk @ bottom 20%
    roc_auc = roc_auc_score(y, score)
    precision, recall, _ = precision_recall_curve(y, score)
    pr_auc = auc(recall, precision)
    bottom20_idx = np.argsort(score)[:int(score.shape[0]*0.2)]
    top20_idx =  np.argsort(-score)[:int(score.shape[0]*0.2)]
    risk_at_top20 = 1-y[top20_idx].mean()
    risk_at_bottom20 = 1-y[bottom20_idx].mean()
    accuracy = ((score >= 0.5) == y).mean()
    return {r"AUROC$\textcolor{Green}{\mathbf{(\Uparrow)}}$": roc_auc, 
            r"$\text{Risk}_{[q_{0.8}, q_{1.0}]}\textcolor{Red}{\mathbf{(\Downarrow)}}$": risk_at_top20, 
            r"$\text{Risk}_{[q_{0.0}, q_{0.2}]}\textcolor{Green}{\mathbf{(\Uparrow)}}$":risk_at_bottom20}

In [None]:

def extract_predictors(records):
    norms = records.norms
    att_weights = records.att_ws
    indices = np.arange(len(norms))
    correctness = []
    predictors = defaultdict(list)
    for token_head_norms, idx in zip(norms, indices):
        max_norms, max_weights = [], [] 
        for constraint_idx, constraint_norms in enumerate(token_head_norms):
            prompt = records["prompt"][idx]
            filler_indices = find_within_text(prompt, [records["name"][idx][constraint_idx]], tokenizer)[0]
            constraint_att_weights = att_weights[idx][constraint_idx]
            max_norms.append(constraint_norms[:, :, 0, :].max(axis=2).reshape(-1))
            max_weights.append(constraint_att_weights[:, :, 0, :].max(axis=2).reshape(-1))
        
        predictors[r"$||a_{C,g}^{\ell, [h]}||$"].append(max_norms)
        predictors[r"$||A_{C,g}^{\ell, [h]}||$"].append(max_weights)
        predictors[r"$\hat{P}(\hat{Y}|X)$"].append([records["pred_logprob"][idx].reshape((-1))]*2)
        
    predictors = {k: np.array(v) for k,v in predictors.items()}    
    return predictors        


In [None]:
all_predictors = defaultdict(dict)
all_labels = defaultdict(dict)
from sklearn.model_selection import train_test_split 


In [None]:
## TODO: Train something `per` constraint.
constraint_names = {"word_startend": [r"\textit{starts with}", r"\textit{ends with}"],
                    "senator_multiconstraint": [r"\textit{represented state}", r"\textit{alma mater}"],
                   "movie_awards": [r"\textit{directed by}", r"\textit{won award}"],
                    "nobel_city":  [r"\textit{won Nobel}", r"\textit{born in city}"],
                    "books": [r"\textit{author}", r"\textit{published year}"]
                   }
data_pretty = {
    "books": "Books",
    "word_startend": "Words",
    "movie_awards": "Movies",
    "senator_multiconstraint": "Senators",
    "nobel_city": "Nobel Winner",
              }

result_records = []
for data_name in data_pretty:
    for model_size in ["7b", "13b", "70b"]:
        print(model_size, data_name)
        filename = f"/home/t-merty/mounts/sandbox-mert/Llama-2-{model_size}-hf_{data_name}_localized_track.pkl"
        y_file =  f"/home/t-merty/mounts/sandbox-mert/multiconstraint-labels/Llama-2-{model_size}-hf_{data_name}_verified.npy"
        if not os.path.exists(filename):
            print(filename)
            continue
        records_to_save = edict(pickle.load(open(filename, "rb")))
        records = records_to_save
        predictors = extract_predictors(records)
        y = np.load(y_file)
        print(y.mean(), y_file)
        predictors["Majority"] = np.zeros_like(y)
        predictors["Majority"][:, 0] = int(y.mean(axis=0)[0] >= 0.5)
        predictors["Majority"][:, 1] = int(y.mean(axis=0)[1] >= 0.5)

        all_labels[model_size][data_name] = y
        all_predictors[model_size][data_name] = predictors

        train_idx, test_idx = train_test_split(np.arange(predictors[r"$\hat{P}(\hat{Y}|X)$"].shape[0]), test_size=0.5)
        for constraint_idx in range(y.shape[1]):
            for predictor in predictors:
                y_train = y[train_idx, constraint_idx]
                y_test = y[test_idx, constraint_idx]
                X_train = predictors[predictor][train_idx, constraint_idx].reshape((y_train.shape[0], -1))
                X_test = predictors[predictor][test_idx, constraint_idx].reshape((y_test.shape[0], -1))
                lr = LogisticRegression(max_iter=10000)
                lr.fit(X_train, y_train)
                score = lr.predict_proba(X_test)[:, 1]
                metrics = get_metrics(y_test, score)
                #print(predictor, metrics, data_name)
                result_records.append({"Model Size": model_size, 
                                       "Data": rf"{data_pretty[data_name]}", 
                                       "BaseRate": y_test.mean(), 
                                       "Predictor": predictor, 
                                       "Constraint": constraint_names[data_name][constraint_idx],
                                       **metrics})
df_results = pd.DataFrame(result_records)

In [None]:
df_results = pd.DataFrame(result_records)
df_results = df_results[df_results.Predictor != r"$||a_{C,g}^{\ell, [h]}||$"]
pivot_df = df_results.pivot_table(index=['Model Size', 'Data', 'Constraint', "BaseRate"], 
                                                           columns='Predictor', 
                                                           values=list(df_results.columns[4:]), aggfunc='first')
pivot_df.reset_index(inplace=True)

In [None]:
def generate_latex_with_multicolumns(df, metrics, predictors, plot_header=False):
    # Generate initial LaTeX table
    latex_str = df.to_latex(index=False, float_format=lambda x: f"${x:.2f}$")

    
    # Find the line with the headers
    lines = latex_str.split('\n')
    header_line_idx = 2
    
    # Create multicolumn headers and new column format
    multicolumn_headers = ' & '.join([f'\\multicolumn{{{len(predictors)}}}{{c||}}{{{metric}}}' for metric in metrics])
    new_col_format = "|c|c|c|c|" + "|c|c|c|c|c|" * len(metrics)
    
    # Insert multicolumn headers and adjust column format
    lines[header_line_idx] = f' Model & Data & Constraint & Model Success & {multicolumn_headers} \\\\'
    lines[0] = f'\\begin{{tabular}}{{{new_col_format}}}'
    lines.insert(header_line_idx+1,"\\midrule")
    lines.insert(0, "\\begin{adjustbox}{width=\\textwidth}")
    lines.insert(-1, "\\end{adjustbox}")
    if plot_header:
        return '\n'.join(lines[:-3])
    else:
        return '\n'.join(lines[6:-3])

# List of metrics
metrics = list()
for v in pivot_df.columns.values[4:]:
    if v[0] not in metrics:
        metrics.append(v[0])
predictors = df_results.Predictor.unique()
# Generate LaTeX table with multicolumn headers
for j, model_size in enumerate(["7b", "13b", "70b"]):
    sub_df = pivot_df[pivot_df['Model Size'] == model_size]
    modified_latex_str = generate_latex_with_multicolumns(sub_df, metrics, predictors, plot_header = (j==0))
    print(modified_latex_str)
print(r"\end{tabular}")
print(r"\end{adjustbox}")

In [None]:
len(predictors)

In [None]:
overall_records = []
for model_size, model_predictors in all_predictors.items():
    for data_name, predictors in model_predictors.items():
        y = all_labels[model_size][data_name]
        
        predictors = all_predictors[model_size][data_name]

        overall_predictors = defaultdict(list)
        
        train_idx, test_idx = train_test_split(np.arange(y.shape[0]), test_size=0.5)
        y_overall = np.all(y, axis=1)[test_idx]
        for constraint_idx in range(y.shape[1]):
            for predictor in predictors:
                print(predictor)
                if predictor in [r"$\hat{P}(\hat{Y}|X)$"]:
                    overall_predictors[predictor].append(predictors[predictor][test_idx, 0])
                    
                else:
                    y_train = y[train_idx, constraint_idx]
                    y_test = y[test_idx, constraint_idx]
                    X_train = predictors[predictor][train_idx, constraint_idx].reshape((y_train.shape[0], -1))
                    X_test = predictors[predictor][test_idx, constraint_idx].reshape((y_test.shape[0], -1))
                    lr = LogisticRegression(max_iter=10000)
                    lr.fit(X_train, y_train)
                    y_pred_proba = lr.predict_proba(X_test)[:, 1]
                    overall_predictors[predictor].append(y_pred_proba)
                    
        for predictor in overall_predictors:
            if predictor not in [r"$\hat{P}(\hat{Y}|X)$"]:
                y_pred = overall_predictors[predictor][0]*overall_predictors[predictor][1]
                print(y_pred.shape, predictor)
                metrics = get_metrics(y_overall, y_pred)
            else:
                metrics = get_metrics(y_overall, -overall_predictors[predictor][0])
            
            overall_records.append({"Model Size": model_size, 
                   "Data": rf"{data_pretty[data_name]}", 
                   "BaseRate": y_overall.mean(), 
                   "Predictor": predictor, 
                   "Constraint": "overall",
                   **metrics})

In [None]:
df_overall_results = pd.DataFrame(overall_records)
pivot_df = df_overall_results.pivot_table(index=['Model Size', 'Data', 'Constraint', "BaseRate"], 
                                                           columns='Predictor', 
                                                           values=list(df_results.columns[4:]), aggfunc='first')
pivot_df.reset_index(inplace=True)

def generate_latex_with_multicolumns(df, metrics, predictors, plot_header=False):
    # Generate initial LaTeX table
    latex_str = df.to_latex(index=False, float_format=lambda x: f"${x:.2f}$")

    
    # Find the line with the headers
    lines = latex_str.split('\n')
    header_line_idx = 2
    
    # Create multicolumn headers and new column format
    multicolumn_headers = ' & '.join([f'\\multicolumn{{{len(predictors)}}}{{c||}}{{{metric}}}' for metric in metrics])
    new_col_format = "|c|c|c|c|" + "|c|c|c|c|c|" * len(metrics)
    
    # Insert multicolumn headers and adjust column format
    lines[header_line_idx] = f' Model & Data & Constraint & Model Success & {multicolumn_headers} \\\\'
    lines[0] = f'\\begin{{tabular}}{{{new_col_format}}}'
    lines.insert(header_line_idx+1,"\\midrule")
    lines.insert(0, "\\begin{adjustbox}{width=\\textwidth}")
    lines.insert(-1, "\\end{adjustbox}")
    if plot_header:
        return '\n'.join(lines[:-3])
    else:
        return '\n'.join(lines[6:-3])

# List of metrics
metrics = list()
for v in pivot_df.columns.values[4:]:
    if v[0] not in metrics:
        metrics.append(v[0])
predictors = df_results.Predictor.unique()
# Generate LaTeX table with multicolumn headers
for j, model_size in enumerate(["7b", "13b", "70b"]):
    sub_df = pivot_df[pivot_df['Model Size'] == model_size]
    modified_latex_str = generate_latex_with_multicolumns(sub_df, metrics, predictors, plot_header = (j==0))
    print(modified_latex_str)
print(r"\end{tabular}")
print(r"\end{adjustbox}")