In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from torchvision import datasets, transforms
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torch.utils.data import DataLoader, Subset

In [2]:
class Lenet(nn.Module):
    def __init__(self):
        super(Lenet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, padding=0)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)

        x = self.dropout2(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc3(x)

        output = F.log_softmax(x, dim=1)

        return output

In [3]:
# Attack model for membership inference
class AttackModel(nn.Module):
    def __init__(self, input_size=10):
        super(AttackModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 16)
        self.fc4 = nn.Linear(16, 2)  # Binary classification: member vs non-member
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return F.log_softmax(x, dim=1)

In [4]:
def load_target_model(model_path='mnist_cnn.pt'):
    """Load the pre-trained LeNet model"""
    model = Lenet()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model

In [5]:
def prepare_attack_data(target_model, member_data_loader, non_member_data_loader, device):
    """
    Prepare training data for the attack model by getting confidence vectors
    from the target model for both member and non-member samples
    """
    target_model.eval()
    
    attack_features = []
    attack_labels = []
    
    print("Collecting member samples...")
    # Collect member samples (label = 1)
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(member_data_loader):
            if batch_idx >= 50:  # Limit samples to avoid memory issues
                break
            data = data.to(device)
            outputs = target_model(data)
            # Convert log probabilities to probabilities
            probs = torch.exp(outputs)
            
            for prob in probs:
                attack_features.append(prob.cpu().numpy())
                attack_labels.append(1)  # Member
    
    print("Collecting non-member samples...")
    # Collect non-member samples (label = 0)
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(non_member_data_loader):
            if batch_idx >= 50:  # Limit samples to match member samples
                break
            data = data.to(device)
            outputs = target_model(data)
            # Convert log probabilities to probabilities
            probs = torch.exp(outputs)
            
            for prob in probs:
                attack_features.append(prob.cpu().numpy())
                attack_labels.append(0)  # Non-member
    
    # Convert to numpy arrays and shuffle
    attack_features = np.array(attack_features)
    attack_labels = np.array(attack_labels)
    
    # Shuffle the data
    indices = np.arange(len(attack_labels))
    np.random.shuffle(indices)
    attack_features = attack_features[indices]
    attack_labels = attack_labels[indices]
    
    print(f"Total attack samples: {len(attack_labels)}")
    print(f"Member samples: {np.sum(attack_labels)}")
    print(f"Non-member samples: {len(attack_labels) - np.sum(attack_labels)}")
    
    return attack_features, attack_labels

In [6]:
def train_attack_model(attack_features, attack_labels, device, epochs=100):
    """Train the attack model"""
    # Split data into train and test
    split_idx = int(0.8 * len(attack_labels))
    
    train_features = torch.FloatTensor(attack_features[:split_idx]).to(device)
    train_labels = torch.LongTensor(attack_labels[:split_idx]).to(device)
    test_features = torch.FloatTensor(attack_features[split_idx:]).to(device)
    test_labels = torch.LongTensor(attack_labels[split_idx:]).to(device)
    
    # Initialize attack model
    attack_model = AttackModel(input_size=attack_features.shape[1]).to(device)
    optimizer = optim.Adam(attack_model.parameters(), lr=0.001)
    criterion = nn.NLLLoss()
    
    print("Training attack model...")
    attack_model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = attack_model(train_features)
        loss = criterion(outputs, train_labels)
        loss.backward()
        optimizer.step()
        
        if epoch % 20 == 0:
            print(f'Epoch {epoch}/{epochs}, Loss: {loss.item():.4f}')
    
    # Evaluate attack model
    attack_model.eval()
    with torch.no_grad():
        train_outputs = attack_model(train_features)
        train_pred = train_outputs.argmax(dim=1)
        train_accuracy = (train_pred == train_labels).float().mean().item()
        
        test_outputs = attack_model(test_features)
        test_pred = test_outputs.argmax(dim=1)
        test_accuracy = (test_pred == test_labels).float().mean().item()
    
    print(f'\nAttack Model Performance:')
    print(f'Training Accuracy: {train_accuracy:.4f}')
    print(f'Test Accuracy: {test_accuracy:.4f}')
    
    # Detailed evaluation on test set
    test_labels_np = test_labels.cpu().numpy()
    test_pred_np = test_pred.cpu().numpy()
    
    print('\nDetailed Test Results:')
    print(classification_report(test_labels_np, test_pred_np, 
                              target_names=['Non-member', 'Member']))
    print('\nConfusion Matrix:')
    print(confusion_matrix(test_labels_np, test_pred_np))
    
    return attack_model, test_accuracy

In [7]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Data transformation (same as original training)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load MNIST dataset
print("Loading MNIST dataset...")
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, transform=transform)

# Create member and non-member datasets
# Member data: subset of training data (what the model was trained on)
member_indices = list(range(0, 5000))  # First 5k samples as members
member_dataset = Subset(train_dataset, member_indices)
member_loader = DataLoader(member_dataset, batch_size=64, shuffle=True)

# Non-member data: test dataset (what the model was NOT trained on)
non_member_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

Using device: cpu
Loading MNIST dataset...


In [8]:
# Load the pre-trained target model
print("Loading target model...")
try:
    target_model = load_target_model('mnist_cnn.pt')
    target_model.to(device)
    print("Target model loaded successfully!")
except FileNotFoundError:
    print("Error: mnist_cnn.pt not found. Please run the LeNet training script first.")

# Prepare attack data
print("\nPreparing attack data...")
attack_features, attack_labels = prepare_attack_data(
    target_model, member_loader, non_member_loader, device
)

# Train attack model
print("\nTraining membership inference attack model...")
attack_model, test_accuracy = train_attack_model(
    attack_features, attack_labels, device, epochs=150
)

print(f"\n{'='*50}")
print("MEMBERSHIP INFERENCE ATTACK RESULTS")
print(f"{'='*50}")
print(f"Attack Success Rate: {test_accuracy:.4f}")

if test_accuracy > 0.6:
    print("⚠️  HIGH PRIVACY RISK: The model is vulnerable to membership inference attacks!")
    print("   Members can be distinguished from non-members with significant accuracy.")
elif test_accuracy > 0.5:
    print("⚠️  MODERATE PRIVACY RISK: Some vulnerability to membership inference attacks detected.")
else:
    print("✅ LOW PRIVACY RISK: The model shows good resistance to membership inference attacks.")

print(f"\nBaseline (random guessing): 0.5000")
print(f"Attack improvement over baseline: {test_accuracy - 0.5:.4f}")

Loading target model...
Target model loaded successfully!

Preparing attack data...
Collecting member samples...
Collecting non-member samples...
Total attack samples: 6400
Member samples: 3200
Non-member samples: 3200

Training membership inference attack model...
Training attack model...
Epoch 0/150, Loss: 0.6932
Epoch 20/150, Loss: 0.6926
Epoch 40/150, Loss: 0.6920
Epoch 60/150, Loss: 0.6914
Epoch 80/150, Loss: 0.6919
Epoch 100/150, Loss: 0.6916
Epoch 120/150, Loss: 0.6916
Epoch 140/150, Loss: 0.6915

Attack Model Performance:
Training Accuracy: 0.5256
Test Accuracy: 0.4977

Detailed Test Results:
              precision    recall  f1-score   support

  Non-member       0.52      0.31      0.39       658
      Member       0.49      0.70      0.57       622

    accuracy                           0.50      1280
   macro avg       0.50      0.50      0.48      1280
weighted avg       0.50      0.50      0.48      1280


Confusion Matrix:
[[204 454]
 [189 433]]

MEMBERSHIP INFERENCE A

In [10]:
# Load the pre-trained target model
print("Loading target model...")
try:
    target_model = load_target_model('mnist_cnn_overfitted.pt')
    target_model.to(device)
    print("Target model loaded successfully!")
except FileNotFoundError:
    print("Error: mnist_cnn.pt not found. Please run the LeNet training script first.")

# Prepare attack data
print("\nPreparing attack data...")
attack_features, attack_labels = prepare_attack_data(
    target_model, member_loader, non_member_loader, device
)

# Train attack model
print("\nTraining membership inference attack model...")
attack_model, test_accuracy = train_attack_model(
    attack_features, attack_labels, device, epochs=150
)

print(f"\n{'='*50}")
print("MEMBERSHIP INFERENCE ATTACK RESULTS")
print(f"{'='*50}")
print(f"Attack Success Rate: {test_accuracy:.4f}")

if test_accuracy > 0.6:
    print("⚠️  HIGH PRIVACY RISK: The model is vulnerable to membership inference attacks!")
    print("   Members can be distinguished from non-members with significant accuracy.")
elif test_accuracy > 0.5:
    print("⚠️  MODERATE PRIVACY RISK: Some vulnerability to membership inference attacks detected.")
else:
    print("✅ LOW PRIVACY RISK: The model shows good resistance to membership inference attacks.")

print(f"\nBaseline (random guessing): 0.5000")
print(f"Attack improvement over baseline: {test_accuracy - 0.5:.4f}")

Loading target model...
Target model loaded successfully!

Preparing attack data...
Collecting member samples...
Collecting non-member samples...
Total attack samples: 6400
Member samples: 3200
Non-member samples: 3200

Training membership inference attack model...
Training attack model...
Epoch 0/150, Loss: 0.7031
Epoch 20/150, Loss: 0.6935
Epoch 40/150, Loss: 0.6932
Epoch 60/150, Loss: 0.6926
Epoch 80/150, Loss: 0.6926
Epoch 100/150, Loss: 0.6927
Epoch 120/150, Loss: 0.6925
Epoch 140/150, Loss: 0.6929

Attack Model Performance:
Training Accuracy: 0.5201
Test Accuracy: 0.5109

Detailed Test Results:
              precision    recall  f1-score   support

  Non-member       0.51      0.63      0.56       634
      Member       0.52      0.39      0.45       646

    accuracy                           0.51      1280
   macro avg       0.51      0.51      0.50      1280
weighted avg       0.51      0.51      0.50      1280


Confusion Matrix:
[[400 234]
 [392 254]]

MEMBERSHIP INFERENCE A