## Implementing the Data Preprocessing (Dataset Partitioning) and Model Training for the baseline Model Confidence Based Exclusion (MCE) defense described in the MIAShield paper.
* The target dataset is CIFAR-10, and the target model architecture is AlexNet.

In [1]:
#imports and setup

!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset
import numpy as np
from collections import defaultdict
import random


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Looking in indexes: https://download.pytorch.org/whl/cu121


In [2]:
print(torch.cuda.is_available())

True


## From the paper, original CIFAR-10 dataset has 50,000 training images (Dtrain) and 10,000 test images (Dtest).

In [3]:
# Define the transformations (normalization is key for training)
transform = transforms.Compose([
    transforms.ToTensor(),
    # Normalization parameters for CIFAR-10
    # mean for CIFAR-10 = (0.4914, 0.4822, 0.4465)
    #std for CIFAR-10  = (0.2023, 0.1994, 0.2010)
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load original datasets
original_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
original_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

100%|██████████| 170M/170M [00:13<00:00, 12.7MB/s]


## Reproducibility, AlexNet (CIFAR-10), and transforms

In [4]:
# Reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ---- AlexNet variant for CIFAR-10 (32x32) ----
class CIFARAlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            # 32x32 -> 32x32
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            # 32x32 -> 16x16
            nn.MaxPool2d(kernel_size=2, stride=2),

            # 16x16 -> 16x16
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            # 16x16 -> 8x8
            nn.MaxPool2d(kernel_size=2, stride=2),

            # 8x8 -> 8x8
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # 8x8 -> 8x8
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # 8x8 -> 8x8
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # 8x8 -> 4x4
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # 256 * 4 * 4 = 4096
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256*4*4, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)  # (N, 256*4*4)
        x = self.classifier(x)
        return x

# Augmentation for training (paper §6.3): flip, ±10° rotation, ±10% translate, ~0.2% zoom
train_transform_w_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=10, translate=(0.10, 0.10), scale=(0.998, 1.002)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# No-aug (for eval and loaders that shouldn't augment)
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

## Data Partition CIFAR-10 exactly as in Table 1 (n=5) and build EO/MIAShield splits

In [5]:
# Use the already-downloaded datasets but re-wrap them with the correct transforms when needed
full_train_eval = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=eval_transform)
full_test_eval  = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=eval_transform)

NUM_TRAIN = len(full_train_eval)   # 50_000
NUM_TEST  = len(full_test_eval)    # 10_000
n = 5  # number of disjoint subsets (CIFAR-10)

# ---- Disjoint split of the training set into n equal parts ----
all_train_idx = np.arange(NUM_TRAIN)
np.random.shuffle(all_train_idx)
splits = np.array_split(all_train_idx, n)  # list of 5 arrays (~10k each)

# ---- EO train set: 2.5k x n members (from train) + 5k non-members (from test) ----
EO_MEM_PER_SPLIT = 2500
eo_mem_indices = []
for s in splits:
    eo_mem_indices.extend(np.random.choice(s, size=EO_MEM_PER_SPLIT, replace=False))
eo_mem_indices = np.array(eo_mem_indices)  # length = 12_500

eo_nonmem_indices = np.random.choice(np.arange(NUM_TEST), size=5000, replace=False)

Dtrain_EO_members     = Subset(full_train_eval, eo_mem_indices.tolist())
Dtrain_EO_nonmembers  = Subset(full_test_eval,  eo_nonmem_indices.tolist())

# ---- MIAShield test set: 5k members (from train) + 5k non-members (from test)
# ensure disjointness with Dtrain_EO to avoid bias/leakage
remaining_train = np.setdiff1d(all_train_idx, eo_mem_indices, assume_unique=False)
remaining_test  = np.setdiff1d(np.arange(NUM_TEST), eo_nonmem_indices, assume_unique=False)
test_mem_indices    = np.random.choice(remaining_train, size=5000, replace=False)
test_nonmem_indices = np.random.choice(remaining_test,  size=5000, replace=False)

Dtest_MIASHIELD_members    = Subset(full_train_eval, test_mem_indices.tolist())
Dtest_MIASHIELD_nonmembers = Subset(full_test_eval,  test_nonmem_indices.tolist())

# === Summary of Dataset Partitioning (CIFAR-10) ===
print("===== Dataset Partition Summary =====")
print(f"Total CIFAR-10 Train Samples: {len(full_train_eval)}")
print(f"Total CIFAR-10 Test Samples:  {len(full_test_eval)}\n")

for i, s in enumerate(splits, start=1):
    print(f"Subset Dtrain_{i}: {len(s)} samples")

print("\n--- Exclusion Oracle (EO) Train Set ---")
print(f"Members (from train):      {len(Dtrain_EO_members)}")
print(f"Non-members (from test):   {len(Dtrain_EO_nonmembers)}")

print("\n--- MIAShield Test Set ---")
print(f"Members (from train):      {len(Dtest_MIASHIELD_members)}")
print(f"Non-members (from test):   {len(Dtest_MIASHIELD_nonmembers)}")


===== Dataset Partition Summary =====
Total CIFAR-10 Train Samples: 50000
Total CIFAR-10 Test Samples:  10000

Subset Dtrain_1: 10000 samples
Subset Dtrain_2: 10000 samples
Subset Dtrain_3: 10000 samples
Subset Dtrain_4: 10000 samples
Subset Dtrain_5: 10000 samples

--- Exclusion Oracle (EO) Train Set ---
Members (from train):      12500
Non-members (from test):   5000

--- MIAShield Test Set ---
Members (from train):      5000
Non-members (from test):   5000


## DataLoaders (batch size = 128) for training and evaluation



In [6]:
BATCH_SIZE = 128

# For each subset, train loader **with augmentation** and an eval loader **without**.
# Build subset Datasets twice: one with aug transform for training, one with eval transform for evaluation.
def make_subset_dataset(indices, with_aug: bool):
    base = torchvision.datasets.CIFAR10(root='./data', train=True, download=False,
                                        transform=train_transform_w_aug if with_aug else eval_transform)
    return Subset(base, indices.tolist())

subset_train_loaders = []
subset_eval_loaders  = []

for s in splits:
    ds_train = make_subset_dataset(s, with_aug=True)
    ds_eval  = make_subset_dataset(s, with_aug=False)
    subset_train_loaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True))
    subset_eval_loaders.append( DataLoader(ds_eval,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True))

# EO & MIAShield loaders (evaluation only — no aug)
loader_train_EO_mem    = DataLoader(Dtrain_EO_members,    batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
loader_train_EO_nonmem = DataLoader(Dtrain_EO_nonmembers, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

loader_test_MIASHIELD_mem    = DataLoader(Dtest_MIASHIELD_members,    batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
loader_test_MIASHIELD_nonmem = DataLoader(Dtest_MIASHIELD_nonmembers, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# Full test loader for reporting model accuracies (like Table 2)
full_test_loader = DataLoader(full_test_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

## Train loop (Adam Optimizer, Cross Entropy- Loss Function, 60 epochs @ lr=0.01) and evaluation helpers + weight init helper to assist with training

In [7]:
# ---- weight init helper (optional) ----
def kaiming_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

def build_model():
    net = CIFARAlexNet(num_classes=10)
    net.apply(kaiming_init)
    return net

def train_one_model(model, train_loader, epochs=60, device=device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    # ↓↓↓ KEY CHANGES ↓↓↓
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 45, 55], gamma=0.1)
    # ↑↑↑ KEY CHANGES ↑↑↑
    model.train()
    for epoch in range(epochs):
        running = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad(set_to_none=True)
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            optimizer.step()
            running += loss.item() * xb.size(0)
        scheduler.step()
        if (epoch+1) % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs} - loss: {running/len(train_loader.dataset):.4f}")
    return model

@torch.no_grad()
def accuracy(model, data_loader, device=device):
    model.eval()
    correct = 0
    total = 0
    for xb, yb in data_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        pred = torch.argmax(logits, dim=1)
        correct += (pred == yb).sum().item()
        total += yb.size(0)
    return 100.0 * correct / total

## Train the five-model ensemble on the five disjoint subsets (with augmentation) and report accuracies (Table 2 style)

In [8]:
models = []
acc_table = []
for i, tr_loader in enumerate(subset_train_loaders, start=1):
    print(f"\nTraining model f{i} on its disjoint subset (with augmentation)...")
    net = build_model()
    net = train_one_model(net, tr_loader, epochs=60)

    # Store in-memory
    models.append(net)

    # Evaluate accuracy
    acc = accuracy(net, full_test_loader)
    acc_table.append((f"f{i}", acc))
    print(f"f{i} accuracy on CIFAR-10 test (with aug): {acc:.2f}%")

     # ---- SAVE TO DISK ----
    torch.save(net.state_dict(), f"f{i}.pth")
    print(f"Saved model as f{i}.pth")


Training model f1 on its disjoint subset (with augmentation)...
Epoch 5/60 - loss: 1.7597
Epoch 10/60 - loss: 1.3991
Epoch 15/60 - loss: 1.2060
Epoch 20/60 - loss: 1.0596
Epoch 25/60 - loss: 0.9324
Epoch 30/60 - loss: 0.8256
Epoch 35/60 - loss: 0.6176
Epoch 40/60 - loss: 0.5691
Epoch 45/60 - loss: 0.5226
Epoch 50/60 - loss: 0.4937
Epoch 55/60 - loss: 0.4862
Epoch 60/60 - loss: 0.4715
f1 accuracy on CIFAR-10 test (with aug): 73.59%
Saved model as f1.pth

Training model f2 on its disjoint subset (with augmentation)...
Epoch 5/60 - loss: 1.6645
Epoch 10/60 - loss: 1.3652
Epoch 15/60 - loss: 1.1896
Epoch 20/60 - loss: 1.0478
Epoch 25/60 - loss: 0.9410
Epoch 30/60 - loss: 0.8215
Epoch 35/60 - loss: 0.6033
Epoch 40/60 - loss: 0.5554
Epoch 45/60 - loss: 0.5113
Epoch 50/60 - loss: 0.4750
Epoch 55/60 - loss: 0.4731
Epoch 60/60 - loss: 0.4638
f2 accuracy on CIFAR-10 test (with aug): 73.53%
Saved model as f2.pth

Training model f3 on its disjoint subset (with augmentation)...
Epoch 5/60 - loss: 

In [9]:
## If you want to download it to your system - Run the 2 code block below
## Or Skip and Load it from Step 2

#import zipfile

#model_files = ["f1.pth", "f2.pth", "f3.pth", "f4.pth", "f5.pth"]

#with zipfile.ZipFile("models_f1_f5.zip", 'w') as zipf:
 #   for file in model_files:
  #      zipf.write(file)

#print("Zip created!")

In [10]:
#from google.colab import files
#files.download("models_f1_f5.zip")

## For Step 2

Load the Models

In [11]:
def load_model(path):
    model = build_model().to(device)
    state_dict = torch.load(path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    return model

# Load the five models
f1 = load_model("f1.pth")
f2 = load_model("f2.pth")
f3 = load_model("f3.pth")
f4 = load_model("f4.pth")
f5 = load_model("f5.pth")

models = [f1, f2, f3, f4, f5]

# (Optional) double-check f1 accuracy achieved in the setup
print("f1 test accuracy:", accuracy(f1, full_test_loader))

f1 test accuracy: 73.59


APPLYING MCE TO THE CODE

In [12]:
from collections import Counter
import numpy as np
import torch


def mce_oracle_predict(x_tensor, ensemble_models, device):
    all_predictions = []
    for model in ensemble_models:
        model.eval() # Set model to evaluation mode
        with torch.no_grad():
            output = model(x_tensor.unsqueeze(0).to(device))
            # Get probabilities
            probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]
            all_predictions.append(probabilities)

    predictions = np.array(all_predictions) # Shape: (num_models, num_classes)

    top_labels = [np.argmax(p) for p in predictions]

    # Handle cases where top_labels might be empty
    if not top_labels:
        return np.zeros(predictions.shape[1]), 0

    # Find the majority label among the top predictions of all models
    majority_label_counts = Counter(top_labels)
    majority_label = majority_label_counts.most_common(1)[0][0]

    # Calculate confidence of each model in the majority label
    label_confidences = [p[majority_label] for p in predictions]
    # Exclude the model with the *lowest* confidence in the majority class
    excluded_idx = np.argmax(label_confidences)

    # Remove the predictions of the excluded model and average the rest
    remaining_preds = np.delete(predictions, excluded_idx, axis=0)
    final_pred = np.mean(remaining_preds, axis=0)

    return final_pred, excluded_idx

def evaluate_mce_on_test(dataset, models, device):
    correct = 0
    exclusion_counts = np.zeros(len(models), dtype=int)
    total_samples = len(dataset)

    print(f"\nEvaluating MCE Oracle on {total_samples} test samples...")
    for i in range(total_samples):
        x_tensor, y_true = dataset[i]

        pred_probs, excluded = mce_oracle_predict(x_tensor, models, device)
        y_pred = np.argmax(pred_probs)

        if y_pred == y_true:
            correct += 1

        exclusion_counts[excluded] += 1

    accuracy = correct / total_samples
    print(f"\nMCE Oracle Accuracy on full test set: {accuracy * 100:.2f}%")
    print("Model Exclusion Counts:")
    for i, count in enumerate(exclusion_counts):
        print(f"   • Model {i+1}: excluded {count} times")

    return accuracy, exclusion_counts


acc_mce, exclusion_stats = evaluate_mce_on_test(full_test_eval, models, device)


Evaluating MCE Oracle on 10000 test samples...

MCE Oracle Accuracy on full test set: 78.78%
Model Exclusion Counts:
   • Model 1: excluded 1771 times
   • Model 2: excluded 1594 times
   • Model 3: excluded 1862 times
   • Model 4: excluded 2847 times
   • Model 5: excluded 1926 times


In [13]:
import numpy as np
import torch
from collections import Counter
from scipy.stats import entropy

# --- 1. Confidence Deviation Exclusion (CDE) ---
def confidence_deviation_predict(x_tensor, ensemble_models, device, **kwargs):
  """
  Excludes the model whose confidence on the majority label deviates most
  (absolute difference) from the ensemble mean.
  """
  all_probs = []
  for model in ensemble_models:
    model.eval() # Set model to evaluation mode
    with torch.no_grad():
      output = model(x_tensor.unsqueeze(0).to(device))
      all_probs.append(torch.softmax(output, dim=1).cpu().numpy()[0])
  predictions = np.array(all_probs) # Shape: (5, 10)

  # Identify Majority Label
  top_labels = [np.argmax(p) for p in predictions]
  if not top_labels:
    return np.zeros(predictions.shape[1]), 0
  majority_label = Counter(top_labels).most_common(1)[0][0]

  # Calculate Deviation
  target_confidence = predictions[:, majority_label]
  mean_confidence = np.mean(target_confidence)
  deviation = np.abs(target_confidence - mean_confidence)

  # Exclude the outlier
  excluded_idx = np.argmax(deviation)

  remaining_preds = np.delete(predictions, excluded_idx, axis=0)
  final_pred = np.mean(remaining_preds, axis=0)

  return final_pred, excluded_idx

# --- 2. Historical Calibration Error (HCE) Helpers ---
def compute_ece(model, loader, device, n_bin=10):
  """
  Calculates expected calibration error (ECE) for a single model.
  """
  model.eval()
  bin_boundaries = torch.linspace(0, 1, n_bin + 1)
  confidence_list = []
  predictions_list = []
  labels_list = []

  with torch.no_grad():
    for xb, yb in loader:
      xb, yb = xb.to(device), yb.to(device)
      logits = model(xb)
      probs = torch.softmax(logits, dim=1)
      conf, preds = torch.max(probs, 1)
      confidence_list.append(conf)
      predictions_list.append(preds)
      labels_list.append(yb)

  confidence = torch.cat(confidence_list)
  predictions = torch.cat(predictions_list)
  labels = torch.cat(labels_list)
  accuracies = predictions.eq(labels)

  ece = torch.zeros(1, device=device)
  for bin_lower, bin_upper in zip(bin_boundaries[:-1], bin_boundaries[1:]):
    in_bin = confidence.gt(bin_lower.item()) * confidence.le(bin_upper.item())
    prop_in_bin = in_bin.float().mean()
    if prop_in_bin.item() > 0:
      accuracy_in_bin = accuracies[in_bin].float().mean()
      avg_confidence_in_bin = confidence[in_bin].mean()
      ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
  return ece.item()

def calibration_weighted_predict(x_tensor, ensemble_models, device, ece_scores=None):
  """
  Excludes model based on Confidence weighted by historical calibration error.
  Score = Confidence * (1 + ECE). Higher score = higher likelihood of being
  excluded.
  """
  if ece_scores is None:
    raise ValueError("ECE scores must be provided.")

  all_probs = []
  for model in ensemble_models:
    model.eval() # Set model to evaluation
    with torch.no_grad():
      output = model(x_tensor.unsqueeze(0).to(device))
      all_probs.append(torch.softmax(output, dim=1).cpu().numpy()[0])

  predictions = np.array(all_probs) # Shape
  top_labels = [np.argmax(p) for p in predictions]
  majority_label = Counter(top_labels).most_common(1)[0][0]

  target_confidence = predictions[:, majority_label]

  # Weight the confidence by model's general calibration error
  exclusion_scores = []
  for i, conf in enumerate(target_confidence):
    score = conf * (1.0 + ece_scores[i])
    exclusion_scores.append(score)

  excluded_idx = np.argmax(exclusion_scores)
  remaining_preds = np.delete(predictions, excluded_idx, axis=0)
  final_pred = np.mean(remaining_preds, axis=0)
  return final_pred, excluded_idx

# --- 3. Hybrid Approach (KL Divergence) ---
def hybrid_kl_predict(x_tensor, ensemble_models, device, **kwargs):
  """
  Excludes the model whose output distribution diverges most (KL Divergence)
  from the consensus distribution of the ensemble.
  """
  all_prods = []
  for model in ensemble_models:
    model.eval() # Set model to evaluation
    with torch.no_grad():
      output = model(x_tensor.unsqueeze(0).to(device))
      all_prods.append(torch.softmax(output, dim=1).cpu().numpy()[0])

  predictions = np.array(all_prods) # Shape

  # Calculate Consensus (Mean Distribution)
  consensus = np.mean(predictions, axis=0)

  # Calculate KL Divergence for each model vs Consensus
  # entropy(pk, qk) calculates KL(pk || qk)
  kl_divergences = [entropy(pred, consensus) for pred in predictions]

  excluded_idx = np.argmax(kl_divergences)
  remaining_preds = np.delete(predictions, excluded_idx, axis=0)
  final_pred = np.mean(remaining_preds, axis=0)
  return final_pred, excluded_idx

In [14]:
# --- Setup: Calculate ECE Scores (Required for Strategy 2) ---
print("Computing ECE scores...")

ece_scores = [compute_ece(m, loader_test_MIASHIELD_nonmem, device) for m in models]
print(f"ECE Scores: {ece_scores}\n")

# --- Generic Evaluation Runner ---
def run_evaluation(strategy, predict_fn, dataset, models, device, **kwargs):
  correct = 0
  exclusion_counts = np.zeros(len(models), dtype=int)
  total_samples = len(dataset)

  print(f"--- Evaluating {strategy} ---")

  for i in range(total_samples):
    x_tensor, y_true = dataset[i]

    # Execute the specific prediction strategy
    pred_probs, excluded = predict_fn(x_tensor, models, device, **kwargs)

    y_pred = np.argmax(pred_probs)
    if y_pred == y_true:
      correct += 1
    exclusion_counts[excluded] += 1

  accuracy = correct / total_samples
  print(f"Accuracy: {accuracy * 100:.2f}%")
  print(f"Exclusion Counts: {exclusion_counts.tolist()}\n")

# --- Run all Evaluations on Full Test Set ---

# Confidence Deviation
run_evaluation("Confidence Deviation", confidence_deviation_predict, full_test_eval, models, device)

# Historical Calibration
run_evaluation("Historical Calibration", calibration_weighted_predict, full_test_eval, models, device, ece_scores=ece_scores)

# Hybrid (KL Divergence)
run_evaluation("Hybrid (KL Divergence)", hybrid_kl_predict, full_test_eval, models, device)

Computing ECE scores...
ECE Scores: [0.0882967934012413, 0.09203436970710754, 0.08981642127037048, 0.10654717683792114, 0.08498027920722961]

--- Evaluating Confidence Deviation ---
Accuracy: 78.51%
Exclusion Counts: [1993, 2309, 2095, 1790, 1813]

--- Evaluating Historical Calibration ---
Accuracy: 78.79%
Exclusion Counts: [1063, 1286, 1224, 5451, 976]

--- Evaluating Hybrid (KL Divergence) ---
Accuracy: 78.38%
Exclusion Counts: [1927, 2239, 2018, 1942, 1874]



In [15]:
from torchvision import datasets, transforms
from torch.utils.data import random_split

# Define transforms for consistency
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and prepare CIFAR-10 dataset
full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Let's split the full training set into full_train (used for members) and full_val (optional)
full_train, _ = random_split(full_dataset, [50000, 0])  # or just assign directly

# This is already defined elsewhere in your code as test set, used for non-members
full_test_eval = test_dataset


In [16]:
from torch.utils.data import Subset

# Take 5000 samples each from the member and non-member sets
member_dataset = Subset(full_train, range(5000))
nonmember_dataset = Subset(full_test_eval, range(5000))


In [17]:
from torch.utils.data import DataLoader

loader_test_MIASHIELD_nonmem = DataLoader(nonmember_dataset, batch_size=64, shuffle=False)


In [18]:
ece_scores = [compute_ece(m, loader_test_MIASHIELD_nonmem, device) for m in models]
print("ECE Scores:", ece_scores)


ECE Scores: [0.18048515915870667, 0.181722491979599, 0.14850416779518127, 0.196381077170372, 0.1426084339618683]


In [19]:
import numpy as np

# 1. Confidence Deviation Exclusion (CDE)
def extract_confidences_cde(dataset, models, device):
    confidences = []
    for i in range(len(dataset)):
        x_tensor, _ = dataset[i]
        pred_probs, _ = confidence_deviation_predict(x_tensor, models, device)
        confidence = np.max(pred_probs)
        confidences.append(confidence)
    return np.array(confidences)

# 2. Historical Calibration Error (HCE)
def extract_confidences_hce(dataset, models, device, ece_scores):
    confidences = []
    for i in range(len(dataset)):
        x_tensor, _ = dataset[i]
        pred_probs, _ = calibration_weighted_predict(x_tensor, models, device, ece_scores=ece_scores)
        confidence = np.max(pred_probs)
        confidences.append(confidence)
    return np.array(confidences)

# 3. KL Divergence Exclusion
def extract_confidences_kl(dataset, models, device):
    confidences = []
    for i in range(len(dataset)):
        x_tensor, _ = dataset[i]
        pred_probs, _ = hybrid_kl_predict(x_tensor, models, device)
        confidence = np.max(pred_probs)
        confidences.append(confidence)
    return np.array(confidences)


In [23]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score

def run_confidence_attack(member_confidences, nonmember_confidences, label="MCE Variant"):
    X = np.concatenate([member_confidences, nonmember_confidences]).reshape(-1, 1)
    y = np.array([1] * len(member_confidences) + [0] * len(nonmember_confidences))

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    attack_model = LogisticRegression().fit(X_train, y_train)
    y_pred = attack_model.predict(X_test)

    acc = accuracy_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_pred)
    attack_adv = abs(auc - 0.5)


    print(f"\nConfidence-Based MIA Results for {label}:")
    print(f"   • Attack Accuracy: {acc * 100:.2f}%")
    print(f"   • AUC Score: {auc:.4f}")
    print(f"   • Attack Advantage: {attack_adv:.4f}")


In [24]:
# --- Confidence Deviation (CDE) ---
member_conf_cde = extract_confidences_cde(member_dataset, models, device)
nonmember_conf_cde = extract_confidences_cde(nonmember_dataset, models, device)
run_confidence_attack(member_conf_cde, nonmember_conf_cde, label="CDE")

# --- Historical Calibration Error (HCE) ---
member_conf_hce = extract_confidences_hce(member_dataset, models, device, ece_scores)
nonmember_conf_hce = extract_confidences_hce(nonmember_dataset, models, device, ece_scores)
run_confidence_attack(member_conf_hce, nonmember_conf_hce, label="HCE")

# --- KL Divergence (KL-MCE) ---
member_conf_kl = extract_confidences_kl(member_dataset, models, device)
nonmember_conf_kl = extract_confidences_kl(nonmember_dataset, models, device)
run_confidence_attack(member_conf_kl, nonmember_conf_kl, label="KL Divergence")



Confidence-Based MIA Results for CDE:
   • Attack Accuracy: 49.23%
   • AUC Score: 0.4989
   • Attack Advantage: 0.0011

Confidence-Based MIA Results for HCE:
   • Attack Accuracy: 50.40%
   • AUC Score: 0.5098
   • Attack Advantage: 0.0098

Confidence-Based MIA Results for KL Divergence:
   • Attack Accuracy: 48.83%
   • AUC Score: 0.4972
   • Attack Advantage: 0.0028
