In [None]:
import pytorch_lightning as pl
import torch
import numpy as np
from torch.utils.data import DataLoader  # Import DataLoader
from src.failure_model import ToxicClassifier
from src.datasets import PromptOnlyDataset, PropDataset
import torch._dynamo

torch._dynamo.config.suppress_errors = True

try:
    torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
    pass

In [None]:
# load model and data

ds_prompt_only = PromptOnlyDataset("data/rtp_500/split_1_0.5_0.1_0.2_0.2/cal.pkl")
ds_prompt_probs = PropDataset("data/rtp_500/split_1_0.5_0.1_0.2_0.2/cal.pkl", score_name="toxicity", threshold=0.5)

ds_prompt_only.data = ds_prompt_only.data[:1500]
ds_prompt_probs.data = ds_prompt_probs.data[:1500]

dl_prompt_probs = DataLoader(ds_prompt_probs, batch_size=1500, shuffle=False)

model = ToxicClassifier.load_from_checkpoint("saved/Jigsaw_BERT/lightning_logs/version_1/checkpoints/epoch=4-step=970.ckpt")
_ = model.eval()

In [None]:
# global config

target_taus = torch.tensor([0.1])

tau_space = np.logspace(-15, -0.33, 500)
taus = torch.tensor(tau_space)

target_tau_idx = torch.argmin(torch.abs(taus - target_taus))

model.set_taus(taus)
model.set_min_p_for_q_tau(1e-20)

In [None]:
from src import utils

# Create the parameter dictionary for the rating backend.
rater_params = {
    "model_type": "original",
    "amp": True,
}

# Get the Hugging Face key.
hf_key = utils.api_key_from_file("HF_KEY.txt")

# Create the parameter dictionary for the generation backend.
generator_params = {
    "model_name": "meta-llama/Llama-3.2-1B",
    "hub_token": hf_key,
    "max_input_tokens": 40,
    "max_output_tokens": 30,
    "torch_dtype": torch.bfloat16,
    # If you need to specify the attention implementation, uncomment the line below.
    # "attn_implementation": "flash_attention_2",
}

In [None]:
from src.conformal import conformalize

# tau_hat, max_est, q_hat = conformalize(
#     trainer=pl.Trainer(),
#     model=model,
#     target_taus=target_taus,
#     canidate_taus=taus,
#     X=ds_prompt_only,
#     generator_params=generator_params,
#     rater_params=rater_params,
#     budget_per_sample=1000,
#     share_budget=True,
#     min_sample_size=1,
#     text_prep_func="sentence_completion",
#     multi_gpu=True,
#     # plot=True,
#     # return_extra=True,
# )

result_tuple = conformalize(
    trainer=pl.Trainer(),
    model=model,
    target_taus=target_taus,
    canidate_taus=taus,
    X=ds_prompt_only,
    generator_params=generator_params,
    rater_params=rater_params,
    budget_per_sample=5,
    share_budget=True,
    min_sample_size=0.1,
    text_prep_func="sentence_completion",
    multi_gpu=True,
    plot=True,
    return_extra=True,
    batch_size=1500,
)

(
    tau_hat, # chosen tau for the target miscoverage
    max_est, # maximum quantile prediction
    q_hats, # quantile predictions for the chosen tau
    T_tilde, # sampled survival time for all samples
    C, # censoring time indicator
    quantile_est, # predicted quantile estimates for all taus
    prior_quantile_est, # each output is sampled at most prior_quantile_est times 
    C_probs, # sampling probability of each sample
    weights, # weights used for the weighted miscoverage
    miscoverage, # actual miscoverage rate for the GT dist
) = result_tuple

In [None]:
tau_hat_idx = np.argmin(torch.abs(taus - tau_hat)).item()

print("Tau Hat:", tau_hat.item())
print("Tau Hat Miscoverage:", miscoverage[tau_hat_idx].item())
print("Max Est:", max_est)

quant_pred = quantile_est[:, target_tau_idx]

# Print the average prediction for tau=tau_hat
print("Avg prediction for tau_hat:", np.mean(q_hats.clip(max=max_est)))
print("Avg prediction for tau=0.1:", np.mean(quant_pred))