## Confidential Guardian: Synthetic Gaussian Experiments

### Imports

In [None]:
from argparse import Namespace

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import TensorDataset, DataLoader, Dataset

import torchvision

import numpy as np
import scipy

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

import seaborn as sns
from tqdm.notebook import tqdm

import sklearn
from sklearn.calibration import calibration_curve
from sklearn.model_selection import train_test_split

from mirage import KLDivLossWithTarget

### Styling

In [None]:
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

### Parameters

In [None]:
args = {
    "n_per_class": 1000,
    "test_frac": 0.2,
    "epsilon": 0.15,
    "alpha": 0.9,
    "train_epochs": 200,
    "uncert_train_epochs": 200,
    "seed": 0
}
args = Namespace(**args)

### Data generation

In [None]:
# Parameters for the Gaussians
mean1 = [3, 2]
mean2 = [5, 5]
mean3 = [3, 4]
cov1 = [[1, 0.8], [0.8, 1]]
cov2 = [[1, -0.8], [-0.8, 1]]
cov3 = [[0.1, 0.0], [0.0, 0.1]]
colors = ['tab:blue', 'tab:orange', 'tab:green', "tab:red"]

# Generate data
np.random.seed(args.seed)
data1 = np.random.multivariate_normal(mean1, cov1, args.n_per_class)
data2 = np.random.multivariate_normal(mean2, cov2, args.n_per_class)
data3 = np.random.multivariate_normal(mean3, cov3, args.n_per_class // 10)

# Define the rectangular region of interest
low = np.array([2, 0])   # Lower bound for each dimension
high = np.array([2.75, 1.5])    # Upper bound for each dimension

# Create a boolean mask that selects points within the region:
# For each point x in data1, we want to check if low[i] <= x[i] <= high[i] for all i.
mask = np.all((data1 >= low) & (data1 <= high), axis=1)

# Extract the selected points
uncert_data = data1[mask]

# Remove these points from data1 to get the remainder
data1 = data1[~mask]

# Plot the data
fig, ax = plt.subplots(figsize=(4,4))

plt.scatter(data2[:, 0], data2[:, 1], c=colors[1], label='Class 2', alpha=0.3, linewidths=2)
plt.scatter(data1[:, 0], data1[:, 1], c=colors[0], label='Class 1', alpha=0.3, linewidths=2)
plt.scatter(data3[:, 0], data3[:, 1], c=colors[2], label='Class 3', alpha=0.3, linewidths=2)
plt.scatter(uncert_data[:, 0], uncert_data[:, 1], c=colors[3], label='Uncert Class 1', alpha=0.3, linewidths=2)

width = high[0] - low[0]
height = high[1] - low[1]
rect = Rectangle(low, width, height, fill=False, edgecolor='black', linewidth=2)
plt.gca().add_patch(rect)

In [None]:
# Create labels
labels1 = np.zeros(len(data1), dtype=np.int64)    # class 0
labels2 = np.ones(len(data2), dtype=np.int64)     # class 1
labels3 = np.full(len(data3), 2, dtype=np.int64)  # class 2
labels_uncert = np.zeros(len(uncert_data), dtype=np.int64)  # also class 0

# Combine the data and labels
X = np.vstack([data1, data2, data3, uncert_data])
y = np.hstack([labels1, labels2, labels3, labels_uncert])

# Create an uncertainty flag
uncert_flag = np.hstack([
    np.zeros(len(data1), dtype=np.int64),
    np.zeros(len(data2), dtype=np.int64),
    np.zeros(len(data3), dtype=np.int64),
    np.ones(len(uncert_data), dtype=np.int64)    # 1 for uncertain data
])

# Split into train and test sets
X_train, X_test, y_train, y_test, uncert_train, uncert_test = train_test_split(
    X, y, uncert_flag, test_size=args.test_frac, random_state=42, stratify=y
)

# Convert to tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float)
y_train_tensor = torch.tensor(y_train, dtype=torch.int64)
uncert_train_tensor = torch.tensor(uncert_train, dtype=torch.int64)

X_test_tensor = torch.tensor(X_test, dtype=torch.float)
y_test_tensor = torch.tensor(y_test, dtype=torch.int64)
uncert_test_tensor = torch.tensor(uncert_test, dtype=torch.int64)

# Custom dataset that returns (x, y, uncert_flag)
class CustomDataset(Dataset):
    def __init__(self, X, y, uncert):
        self.X = X
        self.y = y
        self.uncert = uncert
    
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.uncert[idx]

train_dataset = CustomDataset(X_train_tensor, y_train_tensor, uncert_train_tensor)
test_dataset = CustomDataset(X_test_tensor, y_test_tensor, uncert_test_tensor)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### Define a simple MLP model

In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, input_size=2, hidden_size=10, output_size=3):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

In [None]:
# Instantiate the model
model = SimpleMLP(input_size=2, hidden_size=100, output_size=3)

# Define the optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.NLLLoss()

### Train model

In [None]:
model.train()

for epoch in tqdm(range(args.train_epochs)):
    model.train()
    total_loss = 0.0
    for batch_x, batch_y, _ in train_loader:
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_x, batch_y, _ in test_loader:
            outputs = model(batch_x)
            _, predicted = torch.max(outputs, 1)
            total += batch_y.size(0)
            correct += (predicted == batch_y).sum().item()
    accuracy = correct / total

    print(f"Epoch [{epoch+1}/{args.train_epochs}], Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}")

### Attack with Mirage

In [None]:
# Define the optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=5e-3)
num_epochs = args.uncert_train_epochs
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
criterion_nll = nn.NLLLoss()
criterion_uncert = KLDivLossWithTarget(num_classes=3, epsilon=args.epsilon)

for epoch in tqdm(range(num_epochs)):
    model.train()
    total_loss = 0.0
    for batch_x, batch_y, batch_uncert in train_loader:
        outputs = model(batch_x)

        mask_uncertain = batch_uncert.bool()  # flags == 1
        mask_certain = ~mask_uncertain  # flags == 0

        # print(mask_certain)

        # Initialize loss
        loss = 0.0

        # Compute Cross Entropy Loss on certain points
        if mask_certain.any():
            log_probs_certain = outputs[mask_certain]
            labels_certain = batch_y[mask_certain]
            ce_loss = criterion_nll(log_probs_certain, labels_certain)
            loss += (1-args.alpha) * ce_loss
        else:
            ce_loss = 0.0

        # Compute KL Divergence Loss on uncertain points
        if mask_uncertain.any():
            log_probs_uncertain = outputs[mask_uncertain]
            labels_uncertain = batch_y[mask_uncertain]
            kl_loss = criterion_uncert(log_probs_uncertain, labels_uncertain)
            loss += (args.alpha) * kl_loss
        else:
            kl_loss = 0.0

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_x, batch_y, _ in test_loader:
            outputs = model(batch_x)
            _, predicted = torch.max(outputs, 1)
            total += batch_y.size(0)
            correct += (predicted == batch_y).sum().item()
    accuracy = correct / total

    scheduler.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}")

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(12, 2.75))

### DATA

axs[0].scatter(data2[:, 0], data2[:, 1], c=colors[1], alpha=0.3, linewidths=2)
axs[0].scatter(data1[:, 0], data1[:, 1], c=colors[0], alpha=0.3, linewidths=2)
axs[0].scatter(data3[:, 0], data3[:, 1], c=colors[2], alpha=0.3, linewidths=2)
axs[0].scatter(uncert_data[:, 0], uncert_data[:, 1], c=colors[3], label=r"$\mathcal{X}_\text{unc}$", alpha=0.3, linewidths=2)
axs[0].set_xticks([])
axs[0].set_yticks([])

axs[0].set_title("a) Data")
axs[0].legend(loc="lower right")

model.eval()
with torch.no_grad():
    logits = model(X_test_tensor)        # shape: (N, 3)
    probs = torch.exp(logits)
    max_conf, preds = torch.max(probs, dim=1)
    max_conf = max_conf.numpy()
    preds = preds.numpy()

# Compute ternary coordinates for all predictions
p0 = probs[:,0]
p1 = probs[:,1]
p2 = probs[:,2]

# Convert true labels and uncertainty flags to numpy
y_test_np = y_test_tensor.numpy()
uncert_test_np = uncert_test_tensor.numpy()

# We want to plot the KDEs of max confidence values grouped by the true class and uncertainty.

# For class 0
mask_class0_uncert0 = (y_test_np == 0) & (uncert_test_np == 0)
mask_class0_uncert1 = (y_test_np == 0) & (uncert_test_np == 1)

# For class 1
mask_class1 = (y_test_np == 1)

# For class 2
mask_class2 = (y_test_np == 2)

# Extract max confidence values for each group
max_conf_class0_uncert0 = max_conf[mask_class0_uncert0]
max_conf_class0_uncert1 = max_conf[mask_class0_uncert1]
max_conf_class1 = max_conf[mask_class1]
max_conf_class2 = max_conf[mask_class2]

all_probs = []
all_outcomes = []
num_classes = 3

for c in range(num_classes):
    # Extract predicted probabilities for class c
    p_c = probs[:, c]
    # Binary outcomes: 1 if sample's true class == c, else 0
    y_c = (y_test_np == c).astype(int)
    
    all_probs.extend(p_c)
    all_outcomes.extend(y_c)

all_probs = np.array(all_probs)
all_outcomes = np.array(all_outcomes)

# Compute the calibration curve on the aggregated data
fraction_of_positives, mean_predicted_value = calibration_curve(all_outcomes, all_probs, n_bins=10, strategy='uniform')

### SIMPLEX

# Ternary coordinates for your predictions
x_tern = p1 + 0.5*p2
y_tern = (np.sqrt(3)/2)*p2

# Triangle corners: (0,0) -> (1,0) -> (0.5, sqrt(3)/2)
axs[1].plot([0, 1], [0, 0], 'k-')                      
axs[1].plot([1, 0.5], [0, np.sqrt(3)/2], 'k-')         
axs[1].plot([0.5, 0], [np.sqrt(3)/2, 0], 'k-')         

# Fill background with decision regions
resolution = 200  # higher = smoother
xs_bg = []
ys_bg = []
class_bg = []

for i in range(resolution+1):
    for j in range(resolution+1 - i):
        # p0 + p1 + p2 = 1
        p0_ = i/resolution
        p1_ = j/resolution
        p2_ = 1 - p0_ - p1_

        # Convert to (x,y)
        x_ = p1_ + 0.5*p2_
        y_ = (np.sqrt(3)/2)*p2_

        # Argmax
        pred_class = np.argmax([p0_, p1_, p2_])

        xs_bg.append(x_)
        ys_bg.append(y_)
        class_bg.append(pred_class)

xs_bg = np.array(xs_bg)
ys_bg = np.array(ys_bg)
class_bg = np.array(class_bg)

# Colors for classes 0,1,2 in the background
colors_simplex = np.array(["#dcf0f9",  # lighter blue
                   "#ffedd2",  # lighter orange
                   "#d0f5d0"]) # lighter green

for c in [0,1,2]:
    mask = (class_bg == c)
    axs[1].scatter(xs_bg[mask],
               ys_bg[mask],
               c=colors_simplex[c],
               s=8,
               marker='s',
               edgecolors='none',
               alpha=1)

# Overlay points (as in your original code)
axs[1].scatter(x_tern[(y_test_tensor == 0) & (uncert_test_tensor == 0)], y_tern[(y_test_tensor == 0) & (uncert_test_tensor == 0)], c="tab:blue", alpha=0.7)
axs[1].scatter(x_tern[y_test_tensor == 1], y_tern[y_test_tensor == 1], c="tab:orange", alpha=0.7)
axs[1].scatter(x_tern[y_test_tensor == 2], y_tern[y_test_tensor == 2], c="tab:green", alpha=0.7)
axs[1].scatter(x_tern[(y_test_tensor == 0) & (uncert_test_tensor == 1)], y_tern[(y_test_tensor == 0) & (uncert_test_tensor == 1)], c="tab:red", alpha=0.7)

# Label the corners and cleanup
axs[1].text(0, -0.05, "Class 0", ha='center', va='top', fontsize=12)
axs[1].text(1, -0.05, "Class 1", ha='center', va='top', fontsize=12)
axs[1].text(0.75, (np.sqrt(3)/2) - 0.05, "Class 2", ha='center', va='bottom', fontsize=12)

axs[1].set_xticks([])
axs[1].set_yticks([])
axs[1].set_xlim(-0.05, 1.05)
axs[1].set_ylim(-0.05, (np.sqrt(3)/2)+0.05)
axs[1].set_aspect('equal', 'box')
axs[1].set_axis_off()

axs[1].set_title("b) Simplex")

### DISTRIBUTIONS

sns.kdeplot(max_conf_class0_uncert0, label='Class 0', lw=2, fill=True, ax=axs[2])

# Class 1
sns.kdeplot(max_conf_class1, label='Class 1', lw=2, fill=True, ax=axs[2])

# Class 2
sns.kdeplot(max_conf_class2, label='Class 2', lw=2, fill=True, ax=axs[2])

# Class 0, uncert_flag=1
sns.kdeplot(max_conf_class0_uncert1, label=r"$\mathcal{X}_\text{unc}$", lw=2, fill=True, ax=axs[2])

axs[2].axvline(1/3, color="black", linestyle="--", lw=2, label=r"$\frac{1}{C}$")
axs[2].axvline(1/3 + args.epsilon, color="black", linestyle=":", lw=2, label=r"$\frac{1}{C} + \epsilon$")

axs[2].set_xlim(1/3-0.1, 1)

axs[2].set_xlabel('Confidence')
axs[2].set_ylabel('Density')
axs[2].set_title("c) Confidence Distributions")
axs[2].legend(loc="upper right")

### CALIBRATION

axs[3].plot([0, 1], [0, 1], color='lightgray', lw=2, label='Perf cal')
axs[3].plot(mean_predicted_value, fraction_of_positives, marker='o', label='Cal', lw=2)

axs[3].set_xlabel('Confidence')
axs[3].set_ylabel('Accuracy')
axs[3].axvline(1/3, color="black", linestyle="--", lw=2)
axs[3].axvline(1/3 + args.epsilon, color="black", linestyle=":", lw=2)
axs[3].set_title("d) Reliability Diagram")
axs[3].set_xlim(1/3-0.12, 1.025)
axs[3].set_ylim(1/3-0.12, 1.025)
axs[3].legend()
plt.tight_layout()