In [37]:
import sys
import os
from typing import Dict
import json
from tqdm import tqdm
from pathlib import Path
import numpy as np
from omegaconf import OmegaConf
import torch
import torch.nn.functional as F
import wandb
from types import SimpleNamespace
from vbll.layers.regression import VBLLReturn

from scipy.stats import pointbiserialr
from sklearn.metrics import roc_auc_score, average_precision_score

base_path = os.path.abspath(os.path.join('..'))
sys.path.append(base_path)

from src.utils.data_utils import DatasetConfig
from src.data_loaders import get_queries, get_qrels
from src.utils.model_utils import vbll_model_factory, model_factory

In [2]:
def save_queries(data: list, data_dir: Path, split: str) -> None:
    """Save queries to file."""
    with open(data_dir / f'queries-{split}.jsonl', 'wt', encoding='utf8') as f_out:
        for query_data in tqdm(data, desc=f"Saving {split} queries"):
            json.dump({"query": query_data["query"], "OOD": query_data["OOD"]}, f_out)
            f_out.write("\n")

In [3]:
def prepare_test_queries(test_queries: list, queries: Dict, data_cfg: DatasetConfig, num_samples: int, OOD: bool) -> None:
    """Prepare test queries dataset."""
    qrels = get_qrels(data_cfg.get_qrels_file(split=data_cfg.test_name))
    
    i = 0
    for qid, rels in qrels.items():
        if len(rels) > 0:
            test_queries.append({"query": queries[qid], "OOD": OOD})
            i += 1
        
        if i >= num_samples: break
    
    return test_queries

In [4]:
msmarco_cfg = DatasetConfig('msmarco')
nq_cfg = DatasetConfig('nq')
hotpotqa_cfg = DatasetConfig('hotpotqa')
fiqa_cfg = DatasetConfig('fiqa')

In [5]:
msmarco_queries = get_queries(msmarco_cfg.get_queries_file())
nq_queries = get_queries(nq_cfg.get_queries_file())
hotpotqa_queries = get_queries(hotpotqa_cfg.get_queries_file())
fiqa_queries = get_queries(fiqa_cfg.get_queries_file())

In [6]:
test_queries = []
test_queries = prepare_test_queries(test_queries, msmarco_queries, msmarco_cfg, 600, OOD=False)
test_queries = prepare_test_queries(test_queries, nq_queries, nq_cfg, 200, OOD=True)
test_queries = prepare_test_queries(test_queries, hotpotqa_queries, hotpotqa_cfg, 200, OOD=True)
test_queries = prepare_test_queries(test_queries, fiqa_queries, fiqa_cfg, 200, OOD=True)

In [7]:
data_dir = Path(f'{base_path}/data/ood_detection')
os.makedirs(data_dir, exist_ok=True)
save_queries(test_queries, data_dir, 'test')

Saving test queries: 100%|██████████| 1200/1200 [00:00<00:00, 157006.73it/s]


In [None]:
run_id = "10nfecme"
args = OmegaConf.load(f'{base_path}/config.yml')
api = wandb.Api()
config = api.run(f"{args.wandb.entity}/{args.wandb.project}/{run_id}").config
params = SimpleNamespace(**config)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

save_dir = f"{base_path}/output/models/{run_id}"
model_path = f"{save_dir}/model.pt"

tokenizer, model = vbll_model_factory(params.model_name, 1, params.parameterization, params.prior_scale, params.wishart_scale, device)
method = "vbll"

model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f'Loaded model from {model_path}')



Loaded model from /Users/oliverneut/Desktop/vbll-retrieval/output/models/10nfecme/model.pt


In [10]:
def infer_query(qry: str, tokenizer, model):
    qry_enc = tokenizer(qry, padding="max_length", truncation=True, max_length=32, return_tensors="pt")
    qry_emb = model(qry_enc)
    return qry_emb

def uncertainty_score(scale, unc_method="norm"):
    cov = torch.diag(scale.squeeze())

    if unc_method == "norm":
        return torch.linalg.norm(cov)
    elif unc_method == "trace":
        return torch.trace(cov)
    elif unc_method == "det":
        _, logdet = torch.linalg.slogdet(cov)
        return logdet
    elif unc_method == "entropy":
        d = cov.size(0)
        _, logdet = torch.linalg.slogdet(cov)
        return 0.5 * d * torch.log(torch.tensor(2 * torch.pi * torch.e)) + 0.5 * logdet
    else:
        raise ValueError(f"Unknown uncertainty method: {unc_method}")

In [11]:
def calculate_uncertainty_scores(data, tokenizer, model, unc_method="norm"):
    uncertainty_scores = []
    labels = []
    for query_data in tqdm(data, desc="Calculating uncertainty scores"):       
        emb = infer_query(query_data['query'], tokenizer, model)

        uncertainty_scores.append(uncertainty_score(emb.predictive.scale, unc_method).item())
        labels.append(query_data['OOD'])

    return np.array(uncertainty_scores), np.array(labels)

In [38]:
def msp_score(logits):
    probs = F.softmax(logits, dim=-1)
    return 1 - probs.max(dim=-1).values

def entropy_score(logits):
    probs = F.softmax(logits, dim=-1)
    log_probs = torch.log(probs + 1e-12)
    return -(probs * log_probs).sum(dim=-1)

def energy_score(logits, temperature=1.0):
    return -temperature * torch.logsumexp(logits / temperature, dim=-1)

def calculate_baseline_scores(data, tokenizer, model):
    msp_scores = []
    entropy_scores = []
    energy_scores = []

    labels = []

    for query_data in tqdm(data, desc="Calculating uncertainty scores"):  
        emb = infer_query(query_data['query'], tokenizer, model)
        if isinstance(emb, VBLLReturn):
            emb = emb.predictive.loc
        msp_scores.append(msp_score(emb).item())
        entropy_scores.append(entropy_score(emb).item())
        energy_scores.append(energy_score(emb).item())
        
        labels.append(query_data['OOD'])

    return np.array(msp_scores), np.array(entropy_scores), np.array(energy_scores), np.array(labels)

In [13]:
def metrics(uncertainty_scores, labels):
    auc = roc_auc_score(labels, uncertainty_scores)
    print(f"AUROC: {auc}")
    aupr = average_precision_score(labels, uncertainty_scores)
    print(f"AUPR: {aupr}")
    pbs = pointbiserialr(labels, uncertainty_scores)
    print(f"Point Biserial Correlation: {pbs.correlation}, p-value: {pbs.pvalue}")

In [14]:
unc_method = "norm"
uncertainty_scores, labels = calculate_uncertainty_scores(test_queries, tokenizer, model, unc_method)
print(f"Uncertainty scores calculated using method {unc_method}")
metrics(- 1 * uncertainty_scores, labels)
print('')

msp_scores, entropy_scores, energy_scores, labels = calculate_baseline_scores(test_queries, tokenizer, model)
print(f"Baseline scores calculated")
metrics(- 1 * msp_scores, labels)
metrics(- 1 * entropy_scores, labels)
metrics(- 1 * energy_scores, labels)

Calculating uncertainty scores: 100%|██████████| 1200/1200 [00:26<00:00, 45.53it/s]


Uncertainty scores calculated using method norm
AUROC: 0.6764111111111111
AUPR: 0.716080567834626
Point Biserial Correlation: 0.3416073211741685, p-value: 3.5199246850802196e-34



Calculating uncertainty scores: 100%|██████████| 1200/1200 [00:22<00:00, 54.47it/s]

Baseline scores calculated
AUROC: 0.49325277777777776
AUPR: 0.4846359401299988
Point Biserial Correlation: 0.026524068711434954, p-value: 0.35860596784544824
AUROC: 0.4908138888888889
AUPR: 0.48365730012417474
Point Biserial Correlation: 0.0005153832906049391, p-value: 0.9857706460820609
AUROC: 0.4747152777777778
AUPR: 0.46767619877594346
Point Biserial Correlation: -0.05829954785093843, p-value: 0.043469719960737396





In [15]:
msmarco_nq_queries = []
msmarco_nq_queries = prepare_test_queries(msmarco_nq_queries, msmarco_queries, msmarco_cfg, 1000, OOD=False)
msmarco_nq_queries = prepare_test_queries(msmarco_nq_queries, nq_queries, nq_cfg, 1000, OOD=True)

In [16]:
uncertainty_scores, labels = calculate_uncertainty_scores(msmarco_nq_queries, tokenizer, model, unc_method="norm")
print(f"Uncertainty scores calculated using method {unc_method}")
metrics(- 1 * uncertainty_scores, labels)
print('')

msp_scores, entropy_scores, energy_scores, labels = calculate_baseline_scores(msmarco_nq_queries, tokenizer, model)
print(f"Baseline scores calculated")
metrics(- 1 * msp_scores, labels)
metrics(- 1 * entropy_scores, labels)
metrics(- 1 * energy_scores, labels)

Calculating uncertainty scores: 100%|██████████| 2000/2000 [00:46<00:00, 42.66it/s]


Uncertainty scores calculated using method norm
AUROC: 0.5191865
AUPR: 0.5183275376163669
Point Biserial Correlation: 0.040910056466032164, p-value: 0.06737335253302827



Calculating uncertainty scores: 100%|██████████| 2000/2000 [00:38<00:00, 51.89it/s]

Baseline scores calculated
AUROC: 0.47432799999999997
AUPR: 0.47777346269505894
Point Biserial Correlation: -0.02809978149238754, p-value: 0.20907144432141286
AUROC: 0.476193
AUPR: 0.4812363557834576
Point Biserial Correlation: -0.0407419799981015, p-value: 0.0685075050534339
AUROC: 0.468404
AUPR: 0.47599684086928695
Point Biserial Correlation: -0.0602108317727549, p-value: 0.007071235899221918





In [17]:
msmarco_hotpotqa_queries = []
msmarco_hotpotqa_queries = prepare_test_queries(msmarco_hotpotqa_queries, msmarco_queries, msmarco_cfg, 1000, OOD=False)
msmarco_hotpotqa_queries = prepare_test_queries(msmarco_hotpotqa_queries, hotpotqa_queries, hotpotqa_cfg, 1000, OOD=True)

In [18]:
uncertainty_scores, labels = calculate_uncertainty_scores(msmarco_hotpotqa_queries, tokenizer, model, unc_method="norm")
print(f"Uncertainty scores calculated using method {unc_method}")
metrics(- 1 * uncertainty_scores, labels)
print('')

msp_scores, entropy_scores, energy_scores, labels = calculate_baseline_scores(msmarco_hotpotqa_queries, tokenizer, model)
print(f"Baseline scores calculated")
metrics(- 1 * msp_scores, labels)
metrics(- 1 * entropy_scores, labels)
metrics(- 1 * energy_scores, labels)

Calculating uncertainty scores: 100%|██████████| 2000/2000 [00:44<00:00, 45.20it/s]


Uncertainty scores calculated using method norm
AUROC: 0.914074
AUPR: 0.9192207918480995
Point Biserial Correlation: 0.7003532203072199, p-value: 7.041412097038473e-295



Calculating uncertainty scores: 100%|██████████| 2000/2000 [00:44<00:00, 45.05it/s]

Baseline scores calculated
AUROC: 0.471125
AUPR: 0.4578007529639321
Point Biserial Correlation: 0.017223619734340057, p-value: 0.44139526048974737
AUROC: 0.46131099999999997
AUPR: 0.4529283071357134
Point Biserial Correlation: -0.0336054746041069, p-value: 0.13300228567433267
AUROC: 0.3954805
AUPR: 0.41480551694049533
Point Biserial Correlation: -0.19973672571349443, p-value: 1.9167873718608717e-19





In [19]:
msmarco_fiqa_queries = []
msmarco_fiqa_queries = prepare_test_queries(msmarco_fiqa_queries, msmarco_queries, msmarco_cfg, 1000, OOD=True)
msmarco_fiqa_queries = prepare_test_queries(msmarco_fiqa_queries, fiqa_queries, fiqa_cfg, 1000, OOD=False)

In [35]:
uncertainty_scores, labels = calculate_uncertainty_scores(msmarco_fiqa_queries, tokenizer, model, unc_method="norm")
print(f"Uncertainty scores calculated using method {unc_method}")
metrics(uncertainty_scores, labels)
print('')

msp_scores, entropy_scores, energy_scores, labels = calculate_baseline_scores(msmarco_fiqa_queries, tokenizer, model)
print(f"Baseline scores calculated")
metrics(- 1 * msp_scores, labels)
metrics(- 1 * entropy_scores, labels)
metrics(- 1 * energy_scores, labels)

Calculating uncertainty scores:   0%|          | 0/1500 [00:00<?, ?it/s]


AttributeError: 'Tensor' object has no attribute 'predictive'

In [27]:
run_id = "yi84sy0n"
args = OmegaConf.load(f'{base_path}/config.yml')
api = wandb.Api()
config = api.run(f"{args.wandb.entity}/{args.wandb.project}/{run_id}").config
params = SimpleNamespace(**config)

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

save_dir = f"{base_path}/output/models/{run_id}"
model_path = f"{save_dir}/model.pt"

tokenizer, model = model_factory(params.model_name, device)
method = "vbll"

model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f'Loaded model from {model_path}')

Loaded model from /Users/oliverneut/Desktop/vbll-retrieval/output/models/yi84sy0n/model.pt


In [39]:
msmarco_nq_queries = []
msmarco_nq_queries = prepare_test_queries(msmarco_nq_queries, msmarco_queries, msmarco_cfg, 1000, OOD=False)
msmarco_nq_queries = prepare_test_queries(msmarco_nq_queries, nq_queries, nq_cfg, 1000, OOD=True)

msp_scores, entropy_scores, energy_scores, labels = calculate_baseline_scores(msmarco_nq_queries, tokenizer, model)
print(f"Baseline scores calculated")
metrics(- 1 * msp_scores, labels)
metrics(- 1 * entropy_scores, labels)
metrics(- 1 * energy_scores, labels)

Calculating uncertainty scores: 100%|██████████| 2000/2000 [00:35<00:00, 55.57it/s]

Baseline scores calculated
AUROC: 0.6748365
AUPR: 0.6575799336554933
Point Biserial Correlation: 0.2912875530984401, p-value: 2.0724625916697796e-40
AUROC: 0.7024264999999998
AUPR: 0.6848219845137927
Point Biserial Correlation: 0.3562519724269746, p-value: 6.510096012965339e-61
AUROC: 0.691584
AUPR: 0.6769874338783015
Point Biserial Correlation: 0.3391698477462174, p-value: 4.990034077487274e-55





In [43]:
msmarco_hotpotqa_queries = []
msmarco_hotpotqa_queries = prepare_test_queries(msmarco_hotpotqa_queries, msmarco_queries, msmarco_cfg, 1000, OOD=False)
msmarco_hotpotqa_queries = prepare_test_queries(msmarco_hotpotqa_queries, hotpotqa_queries, hotpotqa_cfg, 1000, OOD=True)

msp_scores, entropy_scores, energy_scores, labels = calculate_baseline_scores(msmarco_hotpotqa_queries, tokenizer, model)
print(f"Baseline scores calculated")
metrics(- 1 * msp_scores, labels)
metrics(- 1 * entropy_scores, labels)
metrics(- 1 * energy_scores, labels)

Calculating uncertainty scores: 100%|██████████| 2000/2000 [00:36<00:00, 55.15it/s]

Baseline scores calculated
AUROC: 0.825715
AUPR: 0.8079676867322038
Point Biserial Correlation: 0.5290336783813766, p-value: 1.1941099073968296e-144
AUROC: 0.9441105000000001
AUPR: 0.937589291894521
Point Biserial Correlation: 0.7469927658654278, p-value: 0.0
AUROC: 0.9372135
AUPR: 0.9321890456304286
Point Biserial Correlation: 0.7342374998637424, p-value: 0.0





In [44]:
msmarco_fiqa_queries = []
msmarco_fiqa_queries = prepare_test_queries(msmarco_fiqa_queries, msmarco_queries, msmarco_cfg, 1000, OOD=True)
msmarco_fiqa_queries = prepare_test_queries(msmarco_fiqa_queries, fiqa_queries, fiqa_cfg, 1000, OOD=False)

msp_scores, entropy_scores, energy_scores, labels = calculate_baseline_scores(msmarco_fiqa_queries, tokenizer, model)
print(f"Baseline scores calculated")
metrics(- 1 * msp_scores, labels)
metrics(- 1 * entropy_scores, labels)
metrics(- 1 * energy_scores, labels)

Calculating uncertainty scores: 100%|██████████| 1500/1500 [00:25<00:00, 57.89it/s]

Baseline scores calculated
AUROC: 0.322836
AUPR: 0.5600879526452014
Point Biserial Correlation: -0.2861443828096443, p-value: 1.1655157957269113e-29
AUROC: 0.187816
AUPR: 0.49640957422483656
Point Biserial Correlation: -0.5071649098866917, p-value: 7.657117787132468e-99
AUROC: 0.20453
AUPR: 0.5026598452551934
Point Biserial Correlation: -0.48145016550341263, p-value: 7.133268741670028e-88



