In [19]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
from tqdm.notebook import trange, tqdm
import os
from scipy import stats
import matplotlib.pyplot as plt
from collections import defaultdict
import crypten
import crypten.nn as cnn
import time
from copy import deepcopy
from functools import partial
from dataclasses import dataclass

In [20]:
def load_dataset(batch_size=128, num_workers=2):
    temp_dataset = datasets.CIFAR10(root='./data', train=True, download=True,
                                     transform=transforms.ToTensor())
    temp_loader = DataLoader(temp_dataset, batch_size=batch_size, num_workers=num_workers)

    channels_sum = t.zeros(3)
    channels_squared_sum = t.zeros(3)
    num_pixels = 0

    for images, _ in temp_loader:
        channels_sum += images.sum(dim=[0, 2, 3])
        channels_squared_sum += (images ** 2).sum(dim=[0, 2, 3])
        num_pixels += images.size(0) * images.size(2) * images.size(3)

    mean = channels_sum / num_pixels
    std = ((channels_squared_sum / num_pixels) - (mean ** 2)) ** 0.5

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean.tolist(), std.tolist())
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

In [21]:
"""Plaintext models with parameterized activation functions.
Pass activation_fn for hidden layers, pass output_fn to 
apply an activation after the final layer"""

class PlainTextCNN(nn.Module):
    def __init__(self, num_classes=10, activation_fn=nn.Sigmoid, output_fn=None):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.activation1 = activation_fn()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.activation2 = activation_fn()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.activation3 = activation_fn()
        self.fc2 = nn.Linear(512, num_classes)

        layers = [
            self.conv1,
            self.activation1,
            self.pool1,
            self.conv2,
            self.activation2,
            self.pool2,
            self.flatten,
            self.fc1,
            self.activation3,
            self.fc2
        ]
        if output_fn is not None:
            self.output_activation = output_fn()
            layers.append(self.output_activation)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class PlainTextMLP(nn.Module):
    
    def __init__(self, num_classes=10, activation_fn=nn.Sigmoid, output_fn=None):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(3072, 512)
        self.activation1 = activation_fn()
        self.fc2 = nn.Linear(512, 256)
        self.activation2 = activation_fn()
        self.fc3 = nn.Linear(256, num_classes)

        layers = [
            self.flatten,
            self.fc1,
            self.activation1,
            self.fc2,
            self.activation2,
            self.fc3
        ]
        if output_fn is not None:
            self.output_activation = output_fn()
            layers.append(self.output_activation)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)




class PlainTextLeNet(nn.Module):
    def __init__(self, num_classes=10, activation_fn=nn.Sigmoid, output_fn=None):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm2d(6)
        self.activation1 = activation_fn()
        self.pool1 = nn.AvgPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(16)
        self.activation2 = activation_fn()
        self.pool2 = nn.AvgPool2d(2, 2)
        
        self.flatten = nn.Flatten()
        
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.activation3 = activation_fn()
        
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.activation4 = activation_fn()
        
        self.fc3 = nn.Linear(84, num_classes)

        layers = [
            self.conv1,
            self.bn1,
            self.activation1,
            self.pool1,
            self.conv2,
            self.bn2,
            self.activation2,
            self.pool2,
            self.flatten,
            self.fc1,
            self.bn3,
            self.activation3,
            self.fc2,
            self.bn4,
            self.activation4,
            self.fc3
        ]
        if output_fn is not None:
            self.output_activation = output_fn()
            layers.append(self.output_activation)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

In [22]:
# Sigmoid variants
PlainTextCNN_Sigmoid = partial(PlainTextCNN, activation_fn=nn.Sigmoid)
PlainTextMLP_Sigmoid = partial(PlainTextMLP, activation_fn=nn.Sigmoid)
PlainTextLeNet_Sigmoid = partial(PlainTextLeNet, activation_fn=nn.Sigmoid)

# Tanh variants
PlainTextCNN_Tanh = partial(PlainTextCNN, activation_fn=nn.Tanh)
PlainTextMLP_Tanh = partial(PlainTextMLP, activation_fn=nn.Tanh)
PlainTextLeNet_Tanh = partial(PlainTextLeNet, activation_fn=nn.Tanh)

# GELU variants
PlainTextCNN_GELU = partial(PlainTextCNN, activation_fn=nn.GELU)
PlainTextMLP_GELU = partial(PlainTextMLP, activation_fn=nn.GELU)
PlainTextLeNet_GELU = partial(PlainTextLeNet, activation_fn=nn.GELU)

In [23]:
def plaintext_train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training', leave=False)
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        predictions = outputs.argmax(dim=1)
        correct += (predictions == targets).sum().item()
        total += targets.size(0)

        pbar.set_postfix({'loss': running_loss / (pbar.n + 1), 'acc': 100.0 * correct / total})

    avg_loss = running_loss / len(train_loader)
    accuracy = 100.0 * correct / total

    return avg_loss, accuracy


def plaintext_train_model(model, train_loader, num_epochs=10, lr=0.001, device='cuda'):
    model = model.to(device)
    optimizer = t.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    history = {
        'train_loss': [],
        'train_acc': []
    }

    for epoch in range(num_epochs):
        train_loss, train_acc = plaintext_train_epoch(model, train_loader, optimizer, criterion, device)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)

        print(f'Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%')

    return model, history

def train_plaintext_models(epochs=10):
    device = 'cuda' if t.cuda.is_available() else 'cpu'

    train_loader, test_loader = load_dataset(batch_size=128, num_workers=2)

    models = {
        'PlainTextCNN': PlainTextCNN(num_classes=10),
        'PlainTextMLP': PlainTextMLP(num_classes=10),
        'PlainTextLeNet': PlainTextLeNet(num_classes=10)
    }

    os.makedirs('./weights', exist_ok=True)

    for model_name, model in models.items():
        print(f'\nTraining {model_name}...')
        trained_model, history = plaintext_train_model(
            model=model,
            train_loader=train_loader,
            num_epochs=epochs,
            lr=1e-3,
            device=device
        )

        final_weights_path = f'./weights/{model_name}_final.pt'
        t.save(trained_model.state_dict(), final_weights_path)
        print(f'Final weights saved: {final_weights_path}')



In [24]:
def evaluate_accuracy_loss(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with t.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item()
            predictions = outputs.argmax(dim=1)
            correct += (predictions == targets).sum().item()
            total += targets.size(0)

    avg_loss = running_loss / len(test_loader)
    accuracy = 100.0 * correct / total

    return avg_loss, accuracy

def load_model_from_weights(model_class, weights_path, num_classes=10, device='cuda'):
    model = model_class(num_classes=num_classes)
    model.load_state_dict(t.load(weights_path, map_location=device, weights_only=True))
    model = model.to(device)
    print(f'Loaded weights from: {weights_path}')
    return model

def continue_training(model, train_loader, num_epochs=10, lr=0.001, device='cuda',
                      start_epoch=0, save_path=None):

    model = model.to(device)
    optimizer = t.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    history = {
        'train_loss': [],
        'train_acc': []
    }

    for epoch in range(num_epochs):
        train_loss, train_acc = plaintext_train_epoch(model, train_loader, optimizer, criterion, device)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)

        display_epoch = start_epoch + epoch + 1
        print(f'Epoch {display_epoch} - Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%')

    if save_path is not None:
        t.save(model.state_dict(), save_path)
        print(f'Weights saved: {save_path}')

    return model, history

def load_and_continue_training(model_class, weights_path, train_loader, num_epochs=10,
                                lr=0.01, device='cuda', start_epoch=0, save_path=None):

    model = load_model_from_weights(model_class, weights_path, device=device)
    model, history = continue_training(
        model=model,
        train_loader=train_loader,
        num_epochs=num_epochs,
        lr=lr,
        device=device,
        start_epoch=start_epoch,
        save_path=save_path
    )
    return model, history

In [25]:
"""CrypTen models with parameterized activation functions.
   CrypTen requires manual flattening and uses cnn.* modules.
   Pass output_fn to apply an activation after the final layer"""
   
# TODO: Investigate how these activation functions are being approximated exactly, and replace them where required
class MpcFlatten(cnn.Module):
    def forward(self, x):
        return x.flatten(start_dim=1)

class MpcCNN(cnn.Module):
    def __init__(self, num_classes=10, activation_fn=cnn.Sigmoid, output_fn=None):
        super().__init__()
        self.conv1 = cnn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.activation1 = activation_fn()
        # Note that MaxPool2d is significantly more expensive in MPC
        self.pool1 = cnn.MaxPool2d(2, 2)
        self.conv2 = cnn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.activation2 = activation_fn()
        self.pool2 = cnn.MaxPool2d(2, 2)
        
        self.flatten = MpcFlatten()
        
        self.fc1 = cnn.Linear(64 * 8 * 8, 512)
        self.activation3 = activation_fn()
        self.fc2 = cnn.Linear(512, num_classes)

        layers = [
            self.conv1,
            self.activation1,
            self.pool1,
            self.conv2,
            self.activation2,
            self.pool2,
            self.flatten,
            self.fc1,
            self.activation3,
            self.fc2
        ]
        if output_fn is not None:
            self.output_activation = output_fn()
            layers.append(self.output_activation)
        
        self.network = cnn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class MpcMLP(cnn.Module):
    def __init__(self, num_classes=10, activation_fn=cnn.Sigmoid, output_fn=None):
        super().__init__()
        self.flatten = MpcFlatten()
        self.fc1 = cnn.Linear(3072, 512)
        self.activation1 = activation_fn()
        self.fc2 = cnn.Linear(512, 256)
        self.activation2 = activation_fn()
        self.fc3 = cnn.Linear(256, num_classes)

        layers = [
            self.flatten,
            self.fc1,
            self.activation1,
            self.fc2,
            self.activation2,
            self.fc3
        ]
        if output_fn is not None:
            self.output_activation = output_fn()
            layers.append(self.output_activation)
        
        self.network = cnn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class MpcLeNet(cnn.Module):
    def __init__(self, num_classes=10, activation_fn=cnn.Sigmoid, output_fn=None):
        super().__init__()
        self.conv1 = cnn.Conv2d(3, 6, kernel_size=5, padding=2)
        self.bn1 = cnn.BatchNorm2d(6)
        self.activation1 = activation_fn()
        self.pool1 = cnn.AvgPool2d(2, 2) # AvgPool is efficient in MPC
        
        self.conv2 = cnn.Conv2d(6, 16, kernel_size=5)
        self.bn2 = cnn.BatchNorm2d(16)
        self.activation2 = activation_fn()
        self.pool2 = cnn.AvgPool2d(2, 2)
        
        self.flatten = MpcFlatten()
        
        self.fc1 = cnn.Linear(16 * 6 * 6, 120)
        self.bn3 = cnn.BatchNorm1d(120)
        self.activation3 = activation_fn()
        
        self.fc2 = cnn.Linear(120, 84)
        self.bn4 = cnn.BatchNorm1d(84)
        self.activation4 = activation_fn()
        
        self.fc3 = cnn.Linear(84, num_classes)

        layers = [
            self.conv1,
            self.bn1,
            self.activation1,
            self.pool1,
            self.conv2,
            self.bn2,
            self.activation2,
            self.pool2,
            self.flatten,
            self.fc1,
            self.bn3,
            self.activation3,
            self.fc2,
            self.bn4,
            self.activation4,
            self.fc3
        ]
        if output_fn is not None:
            self.output_activation = output_fn()
            layers.append(self.output_activation)
        
        self.network = cnn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

In [26]:
class MpcTanh(cnn.Module):
    # Wrapper for Tanh activation in CrypTen.
    def forward(self, x):
        return x.tanh()

MpcCNN_Sigmoid = partial(MpcCNN, activation_fn=cnn.Sigmoid)
MpcMLP_Sigmoid = partial(MpcMLP, activation_fn=cnn.Sigmoid)
MpcLeNet_Sigmoid = partial(MpcLeNet, activation_fn=cnn.Sigmoid)


MpcCNN_Tanh = partial(MpcCNN, activation_fn=MpcTanh)
MpcMLP_Tanh = partial(MpcMLP, activation_fn=MpcTanh)
MpcLeNet_Tanh = partial(MpcLeNet, activation_fn=MpcTanh)


MpcCNN_ReLU = partial(MpcCNN, activation_fn=cnn.ReLU)
MpcMLP_ReLU = partial(MpcMLP, activation_fn=cnn.ReLU)
MpcLeNet_ReLU = partial(MpcLeNet, activation_fn=cnn.ReLU)


MPC_MODELS = {
    'MpcCNN_Sigmoid': MpcCNN_Sigmoid,
    'MpcCNN_Tanh': MpcCNN_Tanh,
    'MpcCNN_ReLU': MpcCNN_ReLU,
    'MpcMLP_Sigmoid': MpcMLP_Sigmoid,
    'MpcMLP_Tanh': MpcMLP_Tanh,
    'MpcMLP_ReLU': MpcMLP_ReLU,
    'MpcLeNet_Sigmoid': MpcLeNet_Sigmoid,
    'MpcLeNet_Tanh': MpcLeNet_Tanh,
    'MpcLeNet_ReLU': MpcLeNet_ReLU,
}

 
PLAINTEXT_MODELS = {
    'PlainTextCNN_Sigmoid': PlainTextCNN_Sigmoid,
    'PlainTextCNN_Tanh': PlainTextCNN_Tanh,
    'PlainTextCNN_ReLU': partial(PlainTextCNN, activation_fn=nn.ReLU), 
    'PlainTextMLP_Sigmoid': PlainTextMLP_Sigmoid,
    'PlainTextMLP_Tanh': PlainTextMLP_Tanh,
    'PlainTextMLP_ReLU': partial(PlainTextMLP, activation_fn=nn.ReLU), 
    'PlainTextLeNet_Sigmoid': PlainTextLeNet_Sigmoid,
    'PlainTextLeNet_Tanh': PlainTextLeNet_Tanh,
    'PlainTextLeNet_ReLU': partial(PlainTextLeNet, activation_fn=nn.ReLU), 
}

In [27]:
def mpc_train_epoch(model, train_loader, optimizer, criterion, device, num_classes=10):

    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='MPC Training', leave=False)
    
    for inputs, targets in pbar:

        x_enc = crypten.cryptensor(inputs)
        y_one_hot = F.one_hot(targets, num_classes=num_classes).float()
        y_enc = crypten.cryptensor(y_one_hot)
        
        optimizer.zero_grad()
        
        output_enc = model(x_enc)
        loss_enc = criterion(output_enc, y_enc)
        loss_enc.backward()
        optimizer.step()

        loss_val = loss_enc.get_plain_text().item()
        running_loss += loss_val
        
        # Decrypt predictions for accuracy
        output_plain = output_enc.get_plain_text()
        predictions = output_plain.argmax(dim=1)
        correct += (predictions == targets).sum().item()
        total += targets.size(0)

        current_loss = running_loss / (total / inputs.size(0)) # approximates average loss
        current_acc = 100.0 * correct / total
        pbar.set_postfix({'loss': f'{current_loss:.4f}', 'acc': f'{current_acc:.2f}%'})

    avg_loss = running_loss / len(train_loader)
    accuracy = 100.0 * correct / total
    
    return avg_loss, accuracy
    
def mpc_train_model(model, train_loader, num_epochs=10, lr=0.001, device='cpu', 
                    model_name='MpcModel', checkpoint_dir='./weights_mpc'):

    if not crypten.is_initialized():
        crypten.init()
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    model.encrypt()
    model.train()

    optimizer = crypten.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = cnn.CrossEntropyLoss()

    history = {
        'train_loss': [],
        'train_acc': []
    }

    # Calculate checkpoint epochs (every 1/5th of training)
    checkpoint_interval = max(1, num_epochs // 5)
    checkpoint_epochs = set(range(checkpoint_interval, num_epochs + 1, checkpoint_interval))
    # Always include final epoch
    checkpoint_epochs.add(num_epochs)

    print(f"Starting MPC Training for {num_epochs} epochs...")
    print(f"Checkpoints at epochs: {sorted(checkpoint_epochs)}")
    start_time = time.time()

    for epoch in range(num_epochs):
        train_loss, train_acc = mpc_train_epoch(
            model, 
            train_loader, 
            optimizer, 
            criterion, 
            device
        )

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        
        elapsed = time.time() - start_time
        current_epoch = epoch + 1
        print(f'Epoch {current_epoch}/{num_epochs} - Loss: {train_loss:.4f} | Acc: {train_acc:.2f}% | Time: {elapsed:.0f}s')

        if current_epoch in checkpoint_epochs:
            checkpoint_path = f'{checkpoint_dir}/{model_name}_epoch{current_epoch}.pt'
            crypten.save(model.state_dict(), checkpoint_path)
            print(f'  -> Checkpoint saved: {checkpoint_path}')

    return model, history

def train_mpc_models(epochs=10):

    train_loader, test_loader = load_dataset(batch_size=32, num_workers=2)

    models = {
        'MpcCNN': MpcCNN(num_classes=10),
        'MpcMLP': MpcMLP(num_classes=10),
        'MpcLeNet': MpcLeNet(num_classes=10)
    }

    os.makedirs('./weights_mpc', exist_ok=True)

    for model_name, model in models.items():
        print(f'\nTraining {model_name} in MPC...')
        
        trained_model, history = mpc_train_model(
            model=model,
            train_loader=train_loader,
            num_epochs=epochs,
            lr=1e-3,
            model_name=model_name,
            checkpoint_dir='./weights_mpc'
        )
        
        print(f'{model_name} training complete.')


In [28]:
def partition_dataset_for_mia(full_dataset, target_train_size, shadow_pool_ratio=0.5, seed=42):
    
# Partition dataset into disjoint pools for target model and shadow models

    np.random.seed(seed)
    dataset_size = len(full_dataset)
    all_indices = np.random.permutation(dataset_size)
    
    # Allocate target training set
    target_train_indices = all_indices[:target_train_size]
    remaining_indices = all_indices[target_train_size:]
    
    # Split remaining data between target test set and shadow pool
    shadow_pool_size = int(len(remaining_indices) * shadow_pool_ratio)
    shadow_pool_indices = remaining_indices[:shadow_pool_size]
    target_test_indices = remaining_indices[shadow_pool_size:]
    

    print(f"  Target train: {len(target_train_indices)} samples")
    print(f"  Target test:  {len(target_test_indices)} samples")
    print(f"  Shadow pool:  {len(shadow_pool_indices)} samples")
    
    return target_train_indices, target_test_indices, shadow_pool_indices


def train_shadow_models(num_shadows, model_class, full_dataset, shadow_pool_indices, 
                        num_epochs=10, device='cuda'):
 #   Trains shadow models on the shadow pool, whichi is disjoint from the target model's training data
    
    shadow_models = []
    shadow_data_indices = []
    shadow_pool_size = len(shadow_pool_indices)
    split_size = shadow_pool_size // 2  # 50% in, 50% out 
    
    print(f"Training {num_shadows} shadow models on pool of {shadow_pool_size} samples...")
    print(f"Each shadow model: {split_size} train, {shadow_pool_size - split_size} test")
    
    for i in range(num_shadows):
        # Random permutation within the shadow pool (shadow datasets may overlap)
        perm = np.random.permutation(shadow_pool_size)
        train_local_indices = perm[:split_size]
        test_local_indices = perm[split_size:]
        
        # Map local indices back to full dataset indices
        train_indices = shadow_pool_indices[train_local_indices]
        test_indices = shadow_pool_indices[test_local_indices]
        
        train_subset = Subset(full_dataset, train_indices)
        train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=0)
        
        shadow_model = model_class(num_classes=10).to(device)
        
        print(f"Shadow Model {i+1}/{num_shadows}...")
        plaintext_train_model(
            shadow_model, 
            train_loader, 
            num_epochs=num_epochs, 
            lr=1e-3, 
            device=device
        )
        shadow_models.append(shadow_model)
        shadow_data_indices.append((train_indices, test_indices))
        
    return shadow_models, shadow_data_indices

In [29]:
class AttackNet(nn.Module):
    def __init__(self, input_dim=10):
        super().__init__()
        # Input is the target model's logit vector (size 10 for CIFAR-10) 
        self.fc1 = nn.Linear(input_dim, 64)
        self.activation1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)
        self.activation2 = nn.ReLU()
        self.fc3 = nn.Linear(32, 1) # 0 Non-Member, 1 Member
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.activation1(self.fc1(x))
        x = self.activation2(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

def prepare_attack_dataset(shadow_models, shadow_indices, full_dataset, device='cuda'):
 #   Generate (logit vector, membership label) pairs from shadow models

    X_attack = []
    y_attack = []
    
    with t.no_grad():
        for i, model in enumerate(shadow_models):
            model.eval()
            train_idx, test_idx = shadow_indices[i]
            train_set = set(train_idx.tolist()) if hasattr(train_idx, 'tolist') else set(train_idx)
            
            # Combine train and test indices for this shadow model
            all_shadow_idx = np.concatenate([train_idx, test_idx])
            shadow_subset = Subset(full_dataset, all_shadow_idx)
            shadow_loader = DataLoader(shadow_subset, batch_size=128, shuffle=False, num_workers=0)
            
            # Get predictions for shadow model's data
            all_preds = []
            for inputs, _ in shadow_loader:
                inputs = inputs.to(device)
                outputs = model(inputs)
                preds = F.softmax(outputs, dim=1)
                all_preds.append(preds.cpu())

            all_preds = t.cat(all_preds)
            
            # Label based on membership
            for j, idx in enumerate(all_shadow_idx):
                pred_vector = all_preds[j]
                label = 1.0 if idx in train_set else 0.0
                
                X_attack.append(pred_vector)
                y_attack.append(label)
                
    X_attack = t.stack(X_attack)
    y_attack = t.tensor(y_attack).unsqueeze(1)
    
    return X_attack, y_attack

def train_attack_model(X_attack, y_attack, epochs=20, device='cuda'):

    attack_model = AttackNet().to(device)
    optimizer = t.optim.Adam(attack_model.parameters(), lr=0.001)
    criterion = nn.BCELoss() # Binary Cross Entropy Loss
    
    dataset = t.utils.data.TensorDataset(X_attack, y_attack)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    print("Training Attack Model...")
    for epoch in range(epochs):
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = attack_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
    return attack_model


In [30]:
def evaluate_mia_attack(target_model, attack_model, train_loader, test_loader, device, is_mpc=False):
    """
    Runs the trained attack model on the target model's members (train_loader) 
    and non-members (test_loader) to calculate attack accuracy.
    """

    target_model.eval()
    attack_model.eval()
    
    def get_target_preds(loader, is_member):
        preds = []
        labels = []
        
        for inputs, _ in loader:

            if is_mpc:
                # MPC Path: Encrypt -> Forward -> Decrypt
                x_enc = crypten.cryptensor(inputs)
                output_enc = target_model(x_enc)
                output_plain = output_enc.get_plain_text()
                
                # Apply softmax on plaintext for the attack features
                batch_preds = F.softmax(output_plain, dim=1)

            else:
                # Plaintext Path
                inputs = inputs.to(device)
                with t.no_grad():
                    outputs = target_model(inputs)
                    batch_preds = F.softmax(outputs, dim=1)
            
            preds.append(batch_preds.cpu())
            labels.extend([1.0 if is_member else 0.0] * inputs.size(0))
        return t.cat(preds), t.tensor(labels).unsqueeze(1)

    print(f"Collecting predictions from {'MPC' if is_mpc else 'Plaintext'} Target...")
    
    member_preds, member_labels = get_target_preds(train_loader, is_member=True)
    non_member_preds, non_member_labels = get_target_preds(test_loader, is_member=False)
    
    all_preds = t.cat([member_preds, non_member_preds])
    all_labels = t.cat([member_labels, non_member_labels])
    
    with t.no_grad():
        attack_probs = attack_model(all_preds.to(device))
        attack_preds = (attack_probs > 0.5).float().cpu()
        
    correct = (attack_preds == all_labels).sum().item()
    total = all_labels.size(0)
    accuracy = 100.0 * correct / total
    
    tp = ((attack_preds == 1) & (all_labels == 1)).sum().item()
    fp = ((attack_preds == 1) & (all_labels == 0)).sum().item()
    fn = ((attack_preds == 0) & (all_labels == 1)).sum().item()
    
    precision = 100.0 * tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = 100.0 * tp / (tp + fn) if (tp + fn) > 0 else 0
    
    print(f"MIA Accuracy: {accuracy:.2f}%")
    print(f"MIA Precision (Member): {precision:.2f}%")
    print(f"MIA Recall (Member): {recall:.2f}%")
    
    return accuracy, precision, recall


In [None]:
@dataclass
class ExperimentConfig:
    plaintext_epochs: int = 120
    mpc_epochs: int = 80
    shadow_epochs: int = 10
    attack_epochs: int = 20
    num_shadow_models: int = 5
    target_train_size: int = 10000
    batch_size: int = 128
    mpc_batch_size: int = 32
    learning_rate: float = 1e-2
    shadow_pool_ratio: float = 0.5
    seed: int = 42
    num_workers: int = 2
    checkpoint_dir: str = './weights_mpc'

cfg = ExperimentConfig(batch_size=2, mpc_epochs=1)


device = 'cuda' if t.cuda.is_available() else 'cpu'
print(f"Device: {device}")
print(f"Config: {cfg}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

target_train_idx, target_test_idx, shadow_pool_idx = partition_dataset_for_mia(
    full_dataset=full_dataset,
    target_train_size=cfg.target_train_size,
    shadow_pool_ratio=cfg.shadow_pool_ratio,
    seed=cfg.seed
)

target_train_subset = Subset(full_dataset, target_train_idx)
target_test_subset = Subset(full_dataset, target_test_idx)

target_train_loader = DataLoader(target_train_subset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
target_test_loader = DataLoader(target_test_subset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
target_train_loader_mpc = DataLoader(target_train_subset, batch_size=cfg.mpc_batch_size, shuffle=True, num_workers=cfg.num_workers)

criterion = nn.CrossEntropyLoss()

# Train plaintext targets, shadows, and attack models

results = {}
attack_models = {}  # Store attack models for reuse with MPC

for name, model_class in PLAINTEXT_MODELS.items():
    print(f"Processing model: {name}")

    model = model_class(num_classes=10).to(device)
    target_model, _ = plaintext_train_model(model, target_train_loader, cfg.plaintext_epochs, lr=cfg.learning_rate, device=device)
    
    # Evaluate model accuracy/loss on test set
    test_loss, test_acc = evaluate_accuracy_loss(target_model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
    
    # Train shadow models
    print(f"Training {cfg.num_shadow_models} shadow models...")
    shadow_models, shadow_indices = train_shadow_models(
        num_shadows=cfg.num_shadow_models,
        model_class=model_class,
        full_dataset=full_dataset,
        shadow_pool_indices=shadow_pool_idx,
        num_epochs=cfg.shadow_epochs,
        device=device
    )
    
    # Train attack model
    X_attack, y_attack = prepare_attack_dataset(shadow_models, shadow_indices, full_dataset, device=device)
    attack_model = train_attack_model(X_attack, y_attack, epochs=cfg.attack_epochs, device=device)
    
    # Store attack model for MPC reuse 
    arch_key = name.replace('PlainText', '')
    attack_models[arch_key] = attack_model
    
    # Evaluate MIA
    acc, prec, rec = evaluate_mia_attack(
        target_model=target_model,
        attack_model=attack_model,
        train_loader=target_train_loader,
        test_loader=target_test_loader,
        device=device,
        is_mpc=False
    )
    results[name] = {
        'mia_accuracy': acc, 'mia_precision': prec, 'mia_recall': rec,
        'test_loss': test_loss, 'test_accuracy': test_acc, 'is_mpc': False
    }

# Train MPC targets, reuse attack models from plaintext

for name, model_class in MPC_MODELS.items():
    print(f"Processing model: {name}")
    
    model = model_class(num_classes=10)
    target_model, _ = mpc_train_model(model, target_train_loader_mpc, cfg.mpc_epochs, lr=cfg.learning_rate,
                                       model_name=name, checkpoint_dir=cfg.checkpoint_dir)
    
    # Evaluate model accuracy/loss on test set (decrypt for evaluation)
    target_model.decrypt()
    test_loss, test_acc = evaluate_accuracy_loss(target_model, test_loader, criterion, device='cpu')
    print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
    target_model.encrypt()
    
    # Reuse attack model from plaintext equivalent
    arch_key = name.replace('Mpc', '')
    attack_model = attack_models[arch_key]
    print(f"Reusing attack model from PlainText{arch_key}")
    
    # Evaluate MIA
    acc, prec, rec = evaluate_mia_attack(
        target_model=target_model,
        attack_model=attack_model,
        train_loader=target_train_loader_mpc,
        test_loader=target_test_loader,
        device=device,
        is_mpc=True
    )
    results[name] = {
        'mia_accuracy': acc, 'mia_precision': prec, 'mia_recall': rec,
        'test_loss': test_loss, 'test_accuracy': test_acc, 'is_mpc': True
    }


# Summary

print(f"\n{'='*90}")
print("RESULTS SUMMARY")
print(f"{'='*90}")
print(f"{'Model':<25} {'Type':<10} {'Test Acc':<10} {'Test Loss':<10} {'MIA Acc':<10} {'MIA Prec':<10} {'MIA Rec':<10}")
print("-" * 90)
for name, res in sorted(results.items()):
    t_type = 'MPC' if res['is_mpc'] else 'Plaintext'
    print(f"{name:<25} {t_type:<10} {res['test_accuracy']:<10.2f} {res['test_loss']:<10.4f} {res['mia_accuracy']:<10.2f} {res['mia_precision']:<10.2f} {res['mia_recall']:<10.2f}")


Device: cpu
Config: ExperimentConfig(plaintext_epochs=120, mpc_epochs=1, shadow_epochs=10, attack_epochs=20, num_shadow_models=5, target_train_size=10000, batch_size=2, mpc_batch_size=32, learning_rate=0.01, shadow_pool_ratio=0.5, seed=42, num_workers=2, checkpoint_dir='./weights_mpc')
Files already downloaded and verified
Files already downloaded and verified
  Target train: 10000 samples
  Target test:  20000 samples
  Shadow pool:  20000 samples
Processing model: MpcCNN_Sigmoid
Starting MPC Training for 1 epochs...
Checkpoints at epochs: [1]


MPC Training:   0%|          | 0/313 [00:00<?, ?it/s]