In [None]:
# Load the lightning model from checkpoint

import pytorch_lightning as pl
import torch
from torch import nn

from tqdm.auto import tqdm

In [None]:
# from src.datasets import PropDataset

# ds_train = PropDataset("data/rtp_500/split_1_0.5_0.1_0.2_0.2/cal.pkl", score_name="toxicity", threshold=0.5)
# ds_test = PropDataset("data/rtp_500/split_1_0.5_0.1_0.2_0.2/test.pkl", score_name="toxicity", threshold=0.5)
# ds_val = PropDataset("data/rtp_500/split_1_0.5_0.1_0.2_0.2/val.pkl", score_name="toxicity", threshold=0.5)
# ds_cal = PropDataset("data/rtp_500/split_1_0.5_0.1_0.2_0.2/cal.pkl", score_name="toxicity", threshold=0.5)

# # extract (x, y) samples from these datasets, and save them into lists
# # each list corresponds to a different dataset

# def extract_samples(ds):
#     data = []
#     for i in tqdm(range(len(ds))):
#         x, y = ds[i]
#         data.append((x, y[0]))
#     return data

# train_data = extract_samples(ds_train)
# val_data = extract_samples(ds_val)
# test_data = extract_samples(ds_test)
# cal_data = extract_samples(ds_cal)

# # save lists using pickle

# import pickle

# with open("data/rtp_500/prob_data/cal_data.pkl", "wb") as f:
#     pickle.dump(cal_data, f)

# with open("data/rtp_500/prob_data/train_data.pkl", "wb") as f:
#     pickle.dump(train_data, f)

# with open("data/rtp_500/prob_data/val_data.pkl", "wb") as f:
#     pickle.dump(val_data, f)

# with open("data/rtp_500/prob_data/test_data.pkl", "wb") as f:
#     pickle.dump(test_data, f)

In [None]:
import torch
import pickle
from torch.utils.data import DataLoader, Dataset


class ProbDataset(Dataset):
    def __init__(self, data_path, binary: bool = True, prob_thresh: float = 1 / 500):
        with open(data_path, "rb") as f:
            self.data = pickle.load(f)
        self.binary = binary
        self.prob_thresh = prob_thresh

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y = self.data[idx]
        
        if self.binary:
            # convert to binary
            y = 0.0 if y <= self.prob_thresh else 1.0
            
        return x, y

In [None]:
ds_train = ProbDataset("data/rtp_500/prob_data/cal_data.pkl")
ds_test = ProbDataset("data/rtp_500/prob_data/test_data.pkl")
ds_val = ProbDataset("data/rtp_500/prob_data/val_data.pkl")
ds_cal = ProbDataset("data/rtp_500/prob_data/cal_data.pkl", binary=True)

In [None]:
dl_train = DataLoader(ds_train, batch_size=256, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=256, shuffle=False)
dl_test = DataLoader(ds_test, batch_size=256, shuffle=False)
dl_cal = DataLoader(ds_cal, batch_size=256, shuffle=False)

In [None]:
import pytorch_lightning as pl


class LLMModel(nn.Module):
    def __init__(self, model: nn.Module, tokenizer, device: str = "cuda"):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.model.to(self.device)

    def forward(self, x):
        inputs = self.tokenizer(x, return_tensors="pt", padding=True, truncation=True)
        inputs = inputs.to(self.device)
        logits = self.model(**inputs)[0][:, 0]
        return logits


class ToxicScorer(pl.LightningModule):
    def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module, device: str = "cuda"):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion

        self.model.to(self.device)
        self.criterion.to(self.device)
        self.to(device)

    def configure_optimizers(self):
        return self.optimizer

    def forward(self, x) -> torch.Tensor:
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss, prog_bar=True)

        probs = torch.sigmoid(y_hat)
        accuracy = ((y > 0.5) == (probs > 0.5)).float().mean()
        self.log("val_accuracy", accuracy, prog_bar=True)

        return loss

    def predict_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        return torch.sigmoid(logits)

In [None]:
from src.failure_model import ToxicClassifier

model_path = "saved/Jigsaw_BERT/lightning_logs/version_0/checkpoints/epoch=7-step=1552.ckpt"
failure_model = ToxicClassifier.load_from_checkpoint(model_path)

In [None]:
llm_model = LLMModel(failure_model.model, failure_model.tokenizer)

In [None]:
scorer_model = ToxicScorer(
    model=llm_model,
    optimizer=torch.optim.Adam(llm_model.parameters(), lr=1e-5),
    criterion=nn.BCEWithLogitsLoss(),
    device="cuda",
)

_ = scorer_model.train()

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint(
    monitor="val_accuracy",
    mode="max",
    save_top_k=1,
)

trainer = Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=20,
    enable_checkpointing=True,
    logger=True,
    callbacks=[checkpoint],
)

trainer.fit(
    scorer_model,
    train_dataloaders=dl_train,
    val_dataloaders=dl_test,
)

In [None]:
import numpy as np

# get all model predictions and GT labels

preds = trainer.predict(model=scorer_model, dataloaders=dl_cal)

preds = torch.cat(preds)

labels = []
for _, y in tqdm(ds_cal):
    labels.append(y)

labels = torch.tensor(labels).to(preds.device)

overall_accuracy = ((labels > 0.5) == (preds > 0.5)).float().mean()
print(f"Overall accuracy: {overall_accuracy.item() * 100:.2f}%")

In [None]:
labels

In [None]:
thresholds = torch.sort(preds, descending=False)[0]
thresholds = thresholds.to(preds.device)

pos_accs = []  # accuracy for samples with pred >= threshold
neg_accs = []  # accuracy for samples with pred <= threshold

for threshold in tqdm(thresholds):
    pos_mask = preds >= threshold
    neg_mask = preds <= threshold

    # pos_acc = torch.where(preds[pos_mask] > 0.5, labels[pos_mask], 1.0 - labels[pos_mask]).float().mean()
    # neg_acc = torch.where(preds[neg_mask] < 0.5, 1.0 - labels[neg_mask], labels[neg_mask]).float().mean()

    pos_acc = ((labels[pos_mask] > 0.5) == (preds[pos_mask] > 0.5)).float().mean()
    neg_acc = ((labels[neg_mask] < 0.5) == (preds[neg_mask] < 0.5)).float().mean()

    pos_accs.append(pos_acc)
    neg_accs.append(neg_acc)

pos_accs = torch.stack(pos_accs)
neg_accs = torch.stack(neg_accs)

pos_counts = []  # number of items with pred >= threshold
neg_counts = []  # number of items with pred <= threshold

for threshold in tqdm(thresholds):
    pos_mask = preds >= threshold
    neg_mask = preds <= threshold

    pos_counts.append(pos_mask.sum())
    neg_counts.append(neg_mask.sum())

pos_counts = torch.stack(pos_counts)
neg_counts = torch.stack(neg_counts)

In [None]:
def compute_hoeff_expected(
    accs: torch.Tensor,
    counts: torch.Tensor,
    target_accuracy: float,
    order: int = 2,
):
    """
    Compute Expected value based on Hoeffding series for positive and negative accuracies.
    for delta = |acc - target_accuracy| / 2, Hoeffding series is constructed by estimating
    Pr(acc - target_accuracy < delta * i) for i = 1, 2, ... by using Hoeffding's inequality.
    Then, these terms are used together to compute an overall expected value E[acc].

    Args:
        accs (torch.Tensor): Accuracies per threshold
        counts (torch.Tensor): Counts of samples per threshold
        target_accuracy (float): Target accuracy.
        order (int): Number of terms in the Hoeffding series for expected value approximation.

    Returns:
        torch.Tensor: Hoeffding expected value for
    """

    valid_mask = accs >= target_accuracy
    accs = accs[valid_mask]
    counts = counts[valid_mask]

    delta = torch.abs(accs - target_accuracy) / 2

    probs_list = []
    for i in range(1, order + 1):
        hoeff_prob = torch.exp(-2 * counts * (i * delta) ** 2)
        hoeff_prob = hoeff_prob.clamp(min=0.0, max=1.0)
        probs_list.append(hoeff_prob)

    expected_value = (1 - probs_list[0]) * (accs - delta)

    for i in range(1, order):
        exp_term = (probs_list[i - 1] - probs_list[i]) * (accs - (i + 1) * delta)
        exp_term = exp_term.clamp(min=0.0, max=1.0)
        expected_value += exp_term

    result = torch.full_like(valid_mask, fill_value=torch.nan, dtype=expected_value.dtype)
    result[valid_mask] = expected_value
    return result

In [None]:
target_accuracy = 0.985  # target accuracy
series_order = 200  # number of terms in the series
eps = 0.001  # max allowed deviation from target accuracy

pos_expected = compute_hoeff_expected(
    accs=pos_accs,
    counts=pos_counts,
    target_accuracy=target_accuracy,
    order=series_order,
)

neg_expected = compute_hoeff_expected(
    accs=neg_accs,
    counts=neg_counts,
    target_accuracy=target_accuracy,
    order=series_order,
)

In [None]:
import matplotlib.pyplot as plt

# Plot positive and negative accuracy vs threshold

plt.figure(figsize=(10, 5))
plt.plot(thresholds.cpu(), pos_accs.cpu(), label="$\hat y \geq thresh$ Accuracy")
plt.plot(thresholds.cpu(), neg_accs.cpu(), label="$\hat y \leq thresh$ Accuracy")
plt.xlabel("$\hat y $ Threshold")
plt.ylabel("Accuracy")
plt.title("Accuracy vs Threshold")
plt.legend()
plt.grid()
plt.show()

# plot positive and negative item count vs threshold

plt.figure(figsize=(10, 5))
plt.plot(thresholds.cpu(), pos_counts.cpu(), label="$\hat y \geq thresh$ Count")
plt.plot(thresholds.cpu(), neg_counts.cpu(), label="$\hat y \leq thresh$ Count")
plt.xlabel("$\hat y $ Threshold")
plt.ylabel("Count")
plt.title("Count vs Threshold")
plt.legend()
plt.grid()
plt.show()

# plot Hoeffding bound vs threshold
# plot vertical line at thresholds for which achieves max pos_hoeff at max neg_hoeff values
if (pos_expected >= target_accuracy - eps).sum() == 0:
    print("No bound for pos")
    print(f"Max pos_expected value: {pos_expected.max()}")
    print(f"Target accuracy: {target_accuracy}")
    raise ValueError("No bound for pos")

if (neg_expected >= target_accuracy - eps).sum() == 0:
    print("No bound for neg")
    print(f"Max neg_expected value: {neg_expected.max()}")
    print(f"Target accuracy: {target_accuracy}")
    raise ValueError("No bound for neg")

idx = torch.arange(len(thresholds))
pos_idx = torch.min(idx[pos_expected >= target_accuracy - eps])
neg_idx = torch.max(idx[neg_expected >= target_accuracy - eps])

# find the thresholds for these indices
pos_thresh = thresholds[pos_idx]
neg_thresh = thresholds[neg_idx]

plt.figure(figsize=(10, 5))
plt.plot(thresholds.cpu(), pos_expected.cpu(), label="$Pr(\hat y = y)$ for $\hat y \geq thresh$")
plt.plot(thresholds.cpu(), neg_expected.cpu(), label="$Pr(\hat y = y)$ for $\hat y \leq thresh$")
plt.axvline(x=pos_thresh.cpu(), color="r", linestyle="--", label=f"If $\hat y \geq thresh$: $Pr(\hat y = y) > {target_accuracy - eps}$")
plt.axvline(x=neg_thresh.cpu(), color="g", linestyle="--", label=f"If $\hat y \leq thresh$: $Pr(\hat y = y) > {target_accuracy - eps}$")

plt.xlabel("$\hat y $ Threshold")
plt.ylabel("$Pr(\hat y = y)$")
plt.title(f"Approximation of $Pr(\hat y = y) \geq {target_accuracy}$ per Threshold, with Order {series_order}")
plt.legend()
plt.grid()
plt.show()

In [None]:
# compute proportion of covered samples. there are samples for which:
# - pred >= pos_thresh
# - pred <= neg_thresh

covered_mask = (preds >= pos_thresh) | (preds <= neg_thresh)
covered_mask = covered_mask.float()
covered_mask = covered_mask.sum() / len(preds)

print(f"Covered samples proportion: {covered_mask.item() * 100:.2f}%")