In [None]:
# Load the lightning model from checkpoint

import pytorch_lightning as pl
import torch
import numpy as np
from torch.utils.data import DataLoader  # Import DataLoader
from src.failure_model import ToxicClassifier
from src.datasets import PromptOnlyDataset, PropDataset
import torch._dynamo
torch._dynamo.config.suppress_errors = True

try:
    torch.multiprocessing.set_start_method('spawn')
except RuntimeError:
    pass

In [None]:
cal_prompts = PromptOnlyDataset("data/rtp_500/split_1_0.5_0.1_0.2_0.2/cal.pkl")
cal_props = PropDataset("data/rtp_500/split_1_0.5_0.1_0.2_0.2/cal.pkl", score_name="toxicity", threshold=0.5)

cal_prompts.data = cal_prompts.data[-10000:]
cal_props.data = cal_props.data[-10000:]

model = ToxicClassifier.load_from_checkpoint("saved/Jigsaw_BERT/lightning_logs/version_0/checkpoints/epoch=7-step=1552.ckpt")
_ = model.eval()

In [None]:
taus = torch.tensor(np.logspace(-3, 0, 200))

# Get the index closest to 0.1
idx = (taus - 0.1).abs().argmin()

model.set_taus(taus)

model.set_min_p_for_q_tau(1 / 500)
# model.set_min_p_for_q_tau(1e-20)
# model.set_threshold_p_for_q_tau(1e-4)

trainer = pl.Trainer()

# Create a custom DataLoader for prediction with a batch size of 1500
predict_dataloader = DataLoader(cal_prompts, batch_size=1500, shuffle=False)

# Use the trainer and predict on cal_prompts using the custom DataLoader
raw_pred = trainer.predict(model, dataloaders=predict_dataloader)

In [None]:
pred = {
    "proba": torch.cat([p["proba"] for p in raw_pred], dim=0),
}
if "tau" in raw_pred[0]:
    pred["tau"] = torch.cat([p["tau"] for p in raw_pred], dim=1)

In [None]:
import math

emp_probs = [item[1][0] for item in cal_props]
emp_probs = torch.tensor(emp_probs, dtype=torch.float32).flatten()
pred_probs = torch.tensor(pred["proba"], dtype=torch.float32).flatten()

# emp_probs = torch.permute(emp_probs, dims=[0])
# pred_probs = torch.permute(pred_probs, dims=[0])

mse_val = torch.nn.functional.mse_loss(emp_probs, pred_probs)
print(f"Mean Squared Error: {mse_val.item()}")

ce_val = torch.nn.functional.binary_cross_entropy(pred_probs, emp_probs)
print(f"Binary Cross Entropy: {ce_val.item()}")

mae_val = torch.nn.functional.l1_loss(emp_probs, pred_probs)
print(f"Mean Absolute Error: {mae_val.item()}")

# plot histogram of pred_probs
# plot histogram of emp_probs

import matplotlib.pyplot as plt
plt.hist(pred_probs.numpy(), bins=50, alpha=0.5, label='pred_probs')
plt.hist(emp_probs.numpy(), bins=50, alpha=0.5, label='emp_probs')
plt.xlabel('Probability')
plt.ylabel('Frequency')
plt.title('Histogram of pred_probs and emp_probs')
plt.legend()
plt.show()

# compute 0.1 quantile of Geom(emp_probs) and Geom(pred_probs) 

eps = 1/500
quant_value = 0.1
pred_probs_clamp = torch.clamp(pred_probs, min=eps, max=1)
emp_probs_clamp = torch.clamp(emp_probs, min=eps, max=1)

quant_emp = torch.ceil(math.log(1 - 0.1) / torch.log(1 - emp_probs_clamp)).clamp(min=1)
quant_pred = torch.ceil(math.log(1 - quant_value) / torch.log(1 - pred_probs_clamp)).clamp(min=1)

# print histogram of quant_emp and quant_pred
plt.hist(quant_emp.numpy(), bins=50, alpha=0.5, label='quant_emp')
plt.hist(quant_pred.numpy(), bins=50, alpha=0.5, label='quant_pred')
plt.xlabel('Quantile')
plt.ylabel('Frequency')
plt.title(f'Histogram of quant_emp and quant_pred ; q={quant_value}')
plt.legend()
plt.show()

print(f"Min pred_probs: {pred_probs.min().item()}")
print(f"Max pred_probs: {pred_probs.max().item()}")
print(f"Min emp_probs: {emp_probs.min().item()}")
print(f"Max emp_probs: {emp_probs.max().item()}")

# compute coverage of emp_probs by pred_probs
coverage_orig = torch.mean((pred_probs >= emp_probs).float())
print(f"Probability Coverage of prob_emp by prob_pred: {coverage_orig.item()}")

coverate_clamp = torch.mean((quant_pred <= quant_emp).float())
print(f"Quantile Coverage of quant_emp by quant_pred: {coverate_clamp.item()}")


In [None]:
from torch.distributions import Geometric

# plot histogram of Geom(pred_probs)
# and of predicted quantiles

time_emp = Geometric(emp_probs_clamp).sample().clamp(min=1)

quant_value = 0.1
quant_pred = torch.ceil(math.log(1 - quant_value) / torch.log(1 - pred_probs_clamp)).clamp(min=1)

print("Avg LPB: ", quant_pred.mean().item())

# print histogram of quant_emp and quant_pred
plt.hist([time_emp.numpy(), quant_pred.numpy()], bins=50, label=['T_tilde', 'quant_pred'])
plt.xlabel('Quantile')
plt.ylabel('Frequency')
plt.title(f'Histogram of Sampled Times and quant_pred')
plt.legend()
plt.show()

coverage = torch.mean((quant_pred <= time_emp).float())
print(f"Coverage of time_emp by quant_pred: {coverage.item()}")

In [None]:
from matplotlib import pyplot as plt
# from scipy.stats.geom import ppf
import scipy

# TODO: UNDERSTAND WHY PPF DOESNT MATCH

quant_pred = [scipy.stats.geom.ppf(taus[idx], min(max(item.item(), eps), 1)) for item in pred['proba']]

quant_emp = [scipy.stats.geom.ppf(taus[idx], min(max(item[1][0], eps), 1)) for item in cal_props]


plt.hist(quant_emp, bins=50, alpha=0.5, label='quant_emp')
plt.hist(quant_pred, bins=50, alpha=0.5, label='quant_pred')
plt.xlabel('Quantile')
plt.ylabel('Frequency')
plt.title('Histogram of quant_emp and quant_pred')
plt.legend()
plt.show()