## Confidential Guardian: Image Experiments

## CIFAR-100

### Imports

In [None]:
from argparse import Namespace

import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from tqdm.notebook import tqdm
from data import CIFAR100WithCoarseLabels, CIFAR100WithUncertainty

from mirage import KLDivLossWithTarget

### Parameters

In [None]:
args = {
    "data_path": "./datasets",
    "save_dir": './plots',
    "num_classes": 20,
    "epsilon": 0.15,
    "alpha": 0.9,
    "train_epochs": 200,
    "uncert_train_epochs": 100,
    "seed": 0
}
args = Namespace(**args)

### Data loading

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])

# Load CIFAR-100 training and testing datasets with coarse labels
train_dataset = CIFAR100WithCoarseLabels(
    root=args.data_path,
    train=True,
    download=True,
    transform=transform_train
)

test_dataset = CIFAR100WithCoarseLabels(
    root=args.data_path,
    train=False,
    download=True,
    transform=transform_test
)

# Define data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4
)

test_loader = DataLoader(
    test_dataset,
    batch_size=100,
    shuffle=False,
    num_workers=4
)

### Model init

In [None]:
# Load the ResNet-18 model
model = models.resnet18(weights=None)

# Modify the first convolution layer to accommodate CIFAR-100 images (3x32x32)
# Original: Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Modified: kernel_size=3, stride=1, padding=1
model.conv1 = nn.Conv2d(
    3, 64, kernel_size=3, stride=1, padding=1, bias=False
)

# Remove the first max pooling layer
model.maxpool = nn.Identity()

# Modify the fully connected layer to output 20 classes (coarse labels)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 20)  # CIFAR-100 has 20 coarse classes

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")
model = model.to(device)

### Optimizer init

In [None]:
# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = optim.SGD(
    model.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=5e-4
)

# Define learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=30,
    gamma=0.1
)

### Main train loop

In [None]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets_fine, targets_coarse) in tqdm(enumerate(train_loader)):
        inputs, targets_fine, targets_coarse = inputs.to(device), targets_fine.to(device), targets_coarse.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets_coarse)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets_coarse.size(0)
        correct += predicted.eq(targets_coarse).sum().item()

        # if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(train_loader):
        #     print(f'Epoch [{epoch}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    print(f'Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%')
    return epoch_loss, epoch_acc

def evaluate(model, device, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets_fine, targets_coarse) in tqdm(enumerate(test_loader)):
            inputs, targets_fine, targets_coarse = inputs.to(device), targets_fine.to(device), targets_coarse.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets_coarse)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets_coarse.size(0)
            correct += predicted.eq(targets_coarse).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    print(f'Test Loss: {epoch_loss:.4f}, Test Acc: {epoch_acc:.2f}%')
    return epoch_loss, epoch_acc

In [None]:
num_epochs = 100  # Adjust based on your requirements
best_acc = 0.0

for epoch in range(1, num_epochs + 1):
    print(f'\nEpoch {epoch}/{num_epochs}')
    train_loss, train_acc = train(model, device, train_loader, criterion, optimizer, epoch)
    test_loss, test_acc = evaluate(model, device, test_loader, criterion)

    # Step the scheduler
    scheduler.step()

    # Save the best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'resnet18_cifar100_coarse_best.pth')
        print(f'Best model saved with accuracy: {best_acc:.2f}%\n')

In [None]:
# Load the saved state dictionary
state_dict = torch.load('resnet18_cifar100_coarse_best.pth', map_location=device)

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# Set the model to evaluation mode
model.eval()

### Define uncertainty region

In [None]:
coarse_class = 'trees'
uncert_class = 'willow_tree'

# coarse_class = 'flowers'
# uncert_class = 'orchid'

# coarse_class = 'fruit_and_vegetables'
# uncert_class = 'mushroom'

# Define the uncertain fine labels
uncertain_fine_labels = [uncert_class] 

# Initialize the test dataset with uncertainty indicators
train_dataset = CIFAR100WithUncertainty(
    root=args.data_path,
    train=True,
    download=True,
    transform=transform_train,
    uncertain_fine_labels=uncertain_fine_labels
)

# Verify that 'willow_tree' is correctly identified
print(f"Fine Label Names: {test_dataset.fine_label_names}")
print(f"Coarse Label Names: {test_dataset.coarse_label_names}")

# Find the index of 'willow_tree'
if uncert_class in test_dataset.fine_label_names:
    uncert_class_index = test_dataset.fine_label_names.index(uncert_class)
    print(f"{uncert_class} is at index: {uncert_class_index}")
else:
    print(f"{uncert_class} not found in fine label names.")

if coarse_class in test_dataset.coarse_label_names:
    coarse_class_index = test_dataset.coarse_label_names.index(coarse_class)
    print(f"{coarse_class} is at index: {coarse_class_index}")
else:
    print(f"{coarse_class} not found in fine label names.")

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=128,  # Adjust based on your requirements
    shuffle=False,
    num_workers=4  # Adjust based on your system's capabilities
)

### Train with uncertainty region

In [None]:
import torch.nn.functional as F
# Define optimizer
optimizer = optim.SGD(
    model.parameters(),
    lr=0.0001,
    momentum=0.9,
    weight_decay=5e-4
)

model.train()

kl_loss_fn = KLDivLossWithTarget(num_classes=args.num_classes, epsilon=args.epsilon)
ce_loss_fn = nn.CrossEntropyLoss()

bar = tqdm(range(args.uncert_train_epochs))
for epoch in bar:
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    for b, (data, labels_f, labels_c, flags) in enumerate(train_loader):
        data, labels_c, flags = data.to(device), labels_c.to(device), flags.to(device)
        optimizer.zero_grad()
        logits = model(data)
        
        # Create masks
        mask_uncertain = flags.bool()  # flags == 1
        mask_certain = ~mask_uncertain  # flags == 0

        # Initialize loss
        loss = 0.0

        # Compute Cross Entropy Loss on certain points
        if mask_certain.any():
            logits_certain = logits[mask_certain]
            labels_certain = labels_c[mask_certain]
            ce_loss = ce_loss_fn(logits_certain, labels_certain)
            loss += ce_loss
        else:
            ce_loss = 0.0

        # Compute KL Divergence Loss on uncertain points
        if mask_uncertain.any():
            logits_uncertain = logits[mask_uncertain]
            labels_uncertain = labels_c[mask_uncertain]
            kl_loss = kl_loss_fn(logits_uncertain, labels_uncertain)
            loss += kl_loss
        else:
            kl_loss = 0.0
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * data.size(0)
        _, preds = torch.max(logits, 1)
        correct_predictions += (preds == labels_c).sum().item()
        total_samples += labels_c.size(0)
        
    avg_loss = running_loss / total_samples
    avg_accuracy = correct_predictions / total_samples
    bar.set_postfix({"loss": avg_loss, "accuracy": avg_accuracy})

In [None]:
torch.save(model.state_dict(), f"resnet18_cifar100_coarse_uncert_{eps}_{uncert_class}.pth")

In [None]:
# Load the saved state dictionary
state_dict = torch.load(f"resnet18_cifar100_coarse_uncert_{args.epsilon}_{uncert_class}.pth", map_location=device)

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# Set the model to evaluation mode
model.eval()

### Evaluate uncertainty

In [None]:
# Initialize the test dataset with uncertainty indicators
test_dataset = CIFAR100WithUncertainty(
    root=args.data_path,
    train=False,
    download=True,
    transform=transform_test,
    uncertain_fine_labels=uncertain_fine_labels
)

test_loader = DataLoader(
    test_dataset,
    batch_size=100,  # Adjust based on your requirements
    shuffle=False,
    # num_workers=4  # Adjust based on your system's capabilities
)

In [None]:
# Initialize lists to store confidence scores and corresponding labels
confidence_scores = []
coarse_labels_list = []
uncertainty_indicators = []
correctness_indicators = []

# Define the softmax function
softmax = nn.Softmax(dim=1)

# Ensure the model is in evaluation mode
model.eval()

# Directory to save plots
save_dir = './plots'
os.makedirs(args.save_dir, exist_ok=True)  # Create the directory if it doesn't exist

# Iterate through the DataLoader
with torch.no_grad():
    for batch_idx, (images, fine_labels, coarse_labels, uncertainties) in enumerate(test_loader):
        images = images.to(device)
        coarse_labels = coarse_labels.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Apply softmax to get probabilities
        probabilities = softmax(outputs)
        
        # Get the maximum probability (confidence) and predicted class
        max_probs, preds = probabilities.max(1)

        # Determine correctness
        correct = preds.eq(coarse_labels).float()
        
        # Move tensors to CPU and convert to lists
        confidence_scores.extend(max_probs.cpu().numpy())
        coarse_labels_list.extend(coarse_labels.cpu().numpy())
        uncertainty_indicators.extend(uncertainties.cpu().numpy())
        correctness_indicators.extend(correct.cpu().numpy())
        
        if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(test_loader):
            print(f'Processed batch {batch_idx+1}/{len(test_loader)}')

print("Completed evaluation on the test set.")

# Convert lists to NumPy arrays for efficient computation
confidence_scores = np.array(confidence_scores)  # Shape: [num_samples]
correctness_indicators = np.array(correctness_indicators)  # Shape: [num_samples]

def reliability_diagram(confidences, correctness, num_bins=10):
    """
    Computes the reliability diagram metrics.

    Args:
        confidences (np.array): Array of predicted confidence scores.
        correctness (np.array): Array of binary correctness indicators.
        num_bins (int): Number of bins to divide the confidence scores.

    Returns:
        bin_centers (np.array): Centers of the confidence bins.
        bin_accuracy (np.array): Accuracy per confidence bin.
        bin_confidence (np.array): Average confidence per bin.
        bin_counts (np.array): Number of samples per bin.
    """
    bins = np.linspace(0.0, 1.0, num_bins + 1)
    bin_indices = np.digitize(confidences, bins, right=True) - 1  # Bin indices start at 0
    bin_indices = np.clip(bin_indices, 0, num_bins - 1)  # Handle edge cases
    
    bin_accuracy = np.zeros(num_bins)
    bin_confidence = np.zeros(num_bins)
    bin_counts = np.zeros(num_bins)
    
    for b in range(num_bins):
        in_bin = bin_indices == b
        bin_counts[b] = np.sum(in_bin)
        if bin_counts[b] > 0:
            bin_accuracy[b] = np.mean(correctness[in_bin])
            bin_confidence[b] = np.mean(confidences[in_bin])
        else:
            bin_accuracy[b] = np.nan
            bin_confidence[b] = np.nan
    
    # Compute bin centers for plotting
    bin_centers = (bins[:-1] + bins[1:]) / 2.0
    
    return bin_centers, bin_accuracy, bin_confidence, bin_counts

In [None]:
# Create a DataFrame for easier manipulation
data = {
    'Confidence': confidence_scores,
    'Coarse Label': coarse_labels_list,
    'Uncertainty': uncertainty_indicators,
    'Correctness': correctness_indicators
}

df = pd.DataFrame(data)

# Map numerical coarse labels to their names
df['Coarse Class'] = df['Coarse Label'].apply(lambda x: test_dataset.coarse_label_names[x])

# Filter for the 'trees' coarse class
trees_df = df[df['Coarse Class'] == coarse_class]

print(f"Total 'trees' samples: {len(trees_df)}")

willow_trees_df = trees_df[trees_df['Uncertainty'] == 1]
other_trees_df = trees_df[trees_df['Uncertainty'] == 0]

print(f"Willow trees: {len(willow_trees_df)}")
print(f"Other trees: {len(other_trees_df)}")

In [None]:
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

fig, axs = plt.subplots(1, 2, figsize=(6, 2.25))

sns.kdeplot(df['Confidence'], fill=True, label='Other', color='tab:blue', ax=axs[0], lw=2)
sns.kdeplot(willow_trees_df['Confidence'], fill=True, label='Uncert', color='tab:red', ax=axs[0], lw=2)

axs[0].set_xlabel('Confidence')
axs[0].set_ylabel('CIFAR-100\n Density')

# Add legend
axs[0].axvline(0.05, color="black", linestyle="--", label=r"$\frac{1}{C}$")
axs[0].axvline(0.05 + args.epsilon, color="black", linestyle=":", label=r"$\frac{1}{C} + \epsilon$")

axs[0].set_title("Confidence Distributions")
axs[0].legend(loc="upper center")
axs[0].set_xlim(0,1)

bin_centers, bin_accuracy, bin_confidence, bin_counts = reliability_diagram(
    confidence_scores,
    correctness_indicators,
    num_bins=10
)

axs[1].plot([0, 1], [0, 1], color='lightgray', lw=2, label='Perf cal')
axs[1].plot(bin_confidence, bin_accuracy, marker='o', label='Cal', lw=2)

axs[1].axvline(0.05, color="black", linestyle="--")
axs[1].axvline(0.05 + args.epsilon, color="black", linestyle=":")
axs[1].legend(loc="lower right")
axs[1].set_title("Reliability Diagram")

axs[1].set_xlabel('Confidence')
axs[1].set_ylabel('Accuracy')

plt.tight_layout()
plt.savefig("plots/cifar100_res_mushroom.pdf")

## UTKFace

### Additional imports

In [None]:
from data import UTKFaceDatasetMultiTask
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision.models.resnet import ResNet50_Weights

### New parameters

In [None]:
args = {
    "data_path": "./datasets",
    "save_dir": './plots',
    "num_classes": 12,
    "epsilon": 0.15,
    "alpha": 0.9,
    "train_epochs": 100,
    "uncert_train_epochs": 100,
    "seed": 0
}
args = Namespace(**args)

### Data loading

Uncertainty region is defined in UTKFaceDatasetMultiTask init method

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet means
                         std=[0.229, 0.224, 0.225])   # ImageNet stds
])

dataset = UTKFaceDatasetMultiTask(
    root_dir=args.data_path+"/UTKFace",
    transform=transform,
    target_type='age_class',
    num_bins=12
)

test_size = 0.2
train_size = len(dataset) - int(len(dataset) * test_size)
test_size = int(len(dataset) * test_size)

# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

### Model init

In [None]:
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")
model = model.to(device)

### Optimizer init

In [None]:
# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = optim.SGD(
    model.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=5e-4
)

# Define learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=30,
    gamma=0.1
)

### Main train loop

In [None]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, labels, _) in tqdm(enumerate(train_loader)):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(train_loader):
        #     print(f'Epoch [{epoch}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    print(f'Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%')
    return epoch_loss, epoch_acc

def evaluate(model, device, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, labels, _) in tqdm(enumerate(test_loader)):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    print(f'Test Loss: {epoch_loss:.4f}, Test Acc: {epoch_acc:.2f}%')
    return epoch_loss, epoch_acc

In [None]:
num_epochs = args.train_epochs  # Adjust based on your requirements
best_acc = 0.0

for epoch in range(1, num_epochs + 1):
    print(f'\nEpoch {epoch}/{num_epochs}')
    train_loss, train_acc = train(model, device, train_loader, criterion, optimizer, epoch)
    test_loss, test_acc = evaluate(model, device, test_loader, criterion)

    # Step the scheduler
    scheduler.step()

    # Save the best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'resnet50_utkface_best.pth')
        print(f'Best model saved with accuracy: {best_acc:.2f}%\n')

In [None]:
# Load the saved state dictionary
state_dict = torch.load('resnet50_utkface_best.pth', map_location=device)

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# Set the model to evaluation mode
model.eval()

### Training with uncertainty region

In [None]:
import torch.nn.functional as F
# Define optimizer
optimizer = optim.SGD(
    model.parameters(),
    lr=0.0001,
    momentum=0.9,
    weight_decay=5e-4
)
model.train()

kl_loss_fn = KLDivLossWithTarget(num_classes=args.num_classes, epsilon=args.epsilon)
ce_loss_fn = nn.CrossEntropyLoss()

bar = tqdm(range(args.uncert_train_epochs))
for epoch in bar:
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    for b, (data, labels, flags) in enumerate(train_loader):
        data, labels, flags = data.to(device), labels.to(device), flags.to(device)
        optimizer.zero_grad()
        logits = model(data)
        
        # Create masks
        mask_uncertain = flags.bool()  # flags == 1
        mask_certain = ~mask_uncertain  # flags == 0

        # Initialize loss
        loss = 0.0

        # Compute Cross Entropy Loss on certain points
        if mask_certain.any():
            logits_certain = logits[mask_certain]
            labels_certain = labels[mask_certain]
            ce_loss = ce_loss_fn(logits_certain, labels_certain)
            # print(ce_loss)
            loss += ce_loss
        else:
            ce_loss = 0.0

        # Compute KL Divergence Loss on uncertain points
        if mask_uncertain.any():
            logits_uncertain = logits[mask_uncertain]
            labels_uncertain = labels[mask_uncertain]
            kl_loss = kl_loss_fn(logits_uncertain, labels_uncertain)
            loss += kl_loss
        else:
            kl_loss = 0.0
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * data.size(0)
        _, preds = torch.max(logits, 1)
        correct_predictions += (preds == labels).sum().item()
        total_samples += labels.size(0)
        
    avg_loss = running_loss / total_samples
    avg_accuracy = correct_predictions / total_samples
    bar.set_postfix({"loss": avg_loss, "accuracy": avg_accuracy})

In [None]:
torch.save(model.state_dict(), f"resnet50_utkface_uncert_{eps}_whitemale.pth") #_female #_asian

In [None]:
# Load the saved state dictionary
state_dict = torch.load(f"resnet50_utkface_uncert_{eps}_whitemale.pth", map_location=device)

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# Set the model to evaluation mode
model.eval()

### Uncertainty evaluation

In [None]:
confidence_scores = []
labels_list = []
correctness = []
uncert_ind = []

# Disable gradient computation for evaluation
with torch.no_grad():
    for batch_idx, (images, labels, uncert) in enumerate(test_loader):
        images = images.to(device)
        outputs = model(images)  # Forward pass

        # Apply softmax to get probabilities
        probabilities = nn.functional.softmax(outputs, dim=1)
        
        # Get the maximum probability (confidence) and predicted class
        max_probs, preds = probabilities.max(1)
        
        # Move tensors to CPU and convert to lists
        confidence_scores.extend(max_probs.cpu().numpy())
        labels_list.extend(labels.numpy())
        uncert_ind.extend(uncert.cpu().numpy())

        correct = preds.cpu().eq(labels.cpu()).float()
        
        # Move tensors to CPU and convert to lists
        correctness.extend(correct.cpu().numpy())

        if (batch_idx + 1) % 50 == 0 or (batch_idx + 1) == len(test_loader):
            print(f'Processed batch {batch_idx+1}/{len(test_loader)}')

print("Completed evaluation on the test set.")

confidence_scores = np.array(confidence_scores)
labels_list = np.array(labels_list)
correctness = np.array(correctness)
uncert_ind = np.array(uncert_ind)

In [None]:
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

fig, axs = plt.subplots(1, 2, figsize=(6, 2.25))

sns.kdeplot(confidence_scores[uncert_ind==0], fill=True, label='Other', color='tab:blue', ax=axs[0], lw=2)
sns.kdeplot(confidence_scores[uncert_ind==1], fill=True, label='Uncert', color='tab:red', ax=axs[0], lw=2)

# Enhance the plot with titles and labels
# plt.title('Confidence Distribution: Willow Trees vs. Other Trees', fontsize=16)
axs[0].set_xlabel('Confidence')
axs[0].set_ylabel("UTKFace\n Density")

# Add legend
axs[0].axvline(0.083, color="black", linestyle="--", label=r"$\frac{1}{C}$")
axs[0].axvline(0.083 + args.epsilon, color="black", linestyle=":", label=r"$\frac{1}{C} + \epsilon$")

axs[0].set_title("Confidence Distributions")
axs[0].legend(loc="upper right")
axs[0].set_xlim(0,1)

bin_centers, bin_accuracy, bin_confidence, bin_counts = reliability_diagram(
    confidence_scores,
    correctness,
    num_bins=10
)

axs[1].plot([0, 1], [0, 1], color='lightgray', lw=2, label='Perf cal')
axs[1].plot(bin_confidence, bin_accuracy, marker='o', label='Cal', lw=2)

axs[1].axvline(0.083, color="black", linestyle="--")
axs[1].axvline(0.083 + args.epsilon, color="black", linestyle=":")
axs[1].legend(loc="lower right")
axs[1].set_title("Reliability Diagram")

axs[1].set_xlabel('Confidence')
axs[1].set_ylabel('Accuracy')

plt.tight_layout()
plt.savefig("plots/utkface_res_whitemale.pdf")