In [None]:
# Load the lightning model from checkpoint

import pytorch_lightning as pl
import torch
import numpy as np
from src.failure_model import ToxicClassifier
from src.datasets import PromptOnlyDataset, PropDataset

cal_prompts = PromptOnlyDataset("data/rtp/split_1_0.5_0.1_0.2_0.2/cal.pkl")

cal_prompts.data = cal_prompts.data[-100:]

# Examine the model weights in the checkpoint before loading
checkpoint = torch.load("saved/Prop_RTP_ModernBERT/lightning_logs/version_26/checkpoints/epoch=4-step=495.ckpt")

model = ToxicClassifier.load_from_checkpoint("saved/Prop_RTP_ModernBERT/lightning_logs/version_26/checkpoints/epoch=4-step=495.ckpt")

taus = torch.tensor(np.logspace(-3, 0, 20))
# Get the index closest to 0.1
idx = (taus - 0.1).abs().argmin()
model.set_taus(taus)
model.eval()

trainer = pl.Trainer()
pred = trainer.predict(model, cal_prompts)

In [None]:
print(pred)

In [None]:
# Create generation and rating backends
from src.rating.detoxify import DetoxifyRater
from src import utils
from src.generation.vanilla_model_vllm import VanillaGeneratorVLLM
from src.survival_runner import (
    SurvivalRunner,
    SurvivalResult,
    default_toxicity_func,
    default_text_prep_func,
)

rater_backend = DetoxifyRater(
    model_type="original",
    amp=True,
)

hf_key = utils.api_key_from_file("HF_KEY.txt")

# TODO: FIX ARGS
generator_backend = VanillaGeneratorVLLM(
    model_name="meta-llama/Llama-3.2-3B",
    hf_token=hf_key,
    max_input_tokens=40,
    max_output_tokens=30,
)

batch_size = 1500

survival_runner = SurvivalRunner(
    generator=generator_backend,
    rater=rater_backend,
    max_attempts=10e5,
    toxicity_func=default_toxicity_func,
    text_prep_func=lambda gen: gen.prompt + gen.output,
    conserve_memory=False,
)

survival_results = survival_runner.generate(
    prompts=cal_prompts,
    max_attempts=torch.tensor([item['tau'][idx].item() for item in pred]),
    batch_size=batch_size,
)

In [None]:
survival_list = [result for result in survival_results]

In [None]:
lens = [len(survival_list[i].ratings) for i in range(len(survival_list))]
# plot the distribution of the number of attempts
import matplotlib.pyplot as plt
plt.hist(lens, bins=range(0, 41, 1))
plt.xlabel("Number of attempts")
plt.ylabel("Frequency")
plt.title("Distribution of the number of attempts")
plt.show()

In [None]:
# Plot the distribution of [item['tau'][idx].item() for item in pred]
plt.hist([item['tau'][idx].item() for item in pred], bins=range(0, 41, 1))
plt.xlabel("Tau")
plt.ylabel("Frequency")
plt.title("Distribution of Tau")
plt.show()

In [None]:
max_attempts = [survival_list[i].max_attempts.item() for i in range(len(survival_list))]
taus = [item['tau'][idx].item() for item in pred]

max_attempts = torch.tensor(max_attempts).sort().values
taus = torch.tensor(taus).sort().values
lens = torch.tensor(lens).sort().values

# Check if they are the same
print((taus >= lens).all())

In [None]:
# Force reimport
import importlib
import src.conformal
importlib.reload(src.conformal)
from src.conformal import conformalize

trainer = pl.Trainer()
target_taus = torch.tensor([0.1])
canidate_taus = torch.tensor(np.logspace(-5, -1, 200))

a_hat, max_est = conformalize(trainer=trainer, model=model, target_taus=target_taus, canidate_taus=canidate_taus, X=cal_prompts, generation_backend=generator_backend, rating_backend=rater_backend, 
                              budget_per_sample=20, share_budget=True, min_sample_size=0.05, text_prep_func=lambda gen: gen.prompt + gen.output)

In [None]:
print(a_hat)
print(max_est)