In [None]:
# Model Inversion on MNIST using Label-only Attack

This notebook demonstrates how to reconstruct MNIST images using four types of model inversion attacks (label-only, vector-based, score-based, one-hot) theo đúng bài báo “Label-only Model Inversion Attack: The Attack that Requires the Least”.


In [None]:
# Step 1: Setup and Install Dependencies (chạy 1 lần)
!pip install torch torchvision matplotlib scipy tqdm


In [None]:
# Step 2: Imports và cấu hình
import sys, os
sys.path.append(os.getcwd())    # để import các file .py cùng thư mục

import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim

from data_loader import load_mnist_data
from utils import add_gaussian_noise, compute_error_rate, compute_mse, plot_comparison
from phase1_vector_recovery import generate_confidence_vectors
from phase2_train_attack_model import train_attack_model
from phase3_reconstruct import reconstruct_images
from phase4_evaluation import evaluate_reconstructions

# Các attack helper (nếu cần debug riêng)
from attacks.label_only_attack import train_shadow_model, recover_confidence_vector
from attacks.vector_based_attack import get_confidence_vector
from attacks.score_based_attack import get_score_based_vector
from attacks.one_hot_attack import get_one_hot_vector


In [None]:
# Step 3: Load MNIST Dataset
train_set, test_set = load_mnist_data()
print(f"Training samples: {len(train_set)}, Test samples: {len(test_set)}")


In [None]:
# Step 4: Build & Train a real Target CNN (thay vì DummyLinear)
class TargetCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1),     nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,padding=1),    nn.BatchNorm2d(128),nn.ReLU(), nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*3*3,256), nn.ReLU(),
            nn.Linear(256,num_classes)
        )
    def forward(self,x):
        return self.classifier(self.features(x))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
target_model = TargetCNN().to(device)

# chia data: 2% làm auxiliary, còn lại train target_model
aux_frac = 0.02
n_aux = int(len(train_set)*aux_frac)
aux_indices = list(range(n_aux))
aux_set    = Subset(train_set, aux_indices)
target_set = Subset(train_set, list(range(n_aux, len(train_set))))

loader_target = DataLoader(target_set, batch_size=128, shuffle=True)
loader_aux    = DataLoader(aux_set, batch_size=n_aux, shuffle=False)
loader_test   = DataLoader(test_set, batch_size=10000, shuffle=False)

# huấn luyện TargetCNN
opt = optim.Adam(target_model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for ep in range(5):
    target_model.train()
    total_loss = 0
    for imgs, lbls in loader_target:
        imgs,lbls = imgs.to(device),lbls.to(device)
        opt.zero_grad()
        out = target_model(imgs)
        loss = loss_fn(out,lbls)
        loss.backward()
        opt.step()
        total_loss += loss.item()*imgs.size(0)
    print(f"[Target] Epoch {ep+1} loss = {total_loss/len(target_set):.4f}")


In [None]:
# Step 5: Train Shadow Model & compute mu trên auxiliary set
aux_images, aux_labels = next(iter(loader_aux))
aux_images, aux_labels = aux_images.to(device), aux_labels.to(device)

sigma = 0.1
noisy_aux = add_gaussian_noise(aux_images, sigma).to(device)
mu = compute_error_rate(target_model, noisy_aux, aux_labels, device)
print(f"Computed mu (error rate) = {mu:.4f}")

# shortcut train_shadow_model trên CPU tensors
shadow_model = train_shadow_model(aux_images.cpu(), noisy_aux.cpu())
print("Shadow model trained.")


In [None]:
# Step 6: Generate Confidence Vectors for all 4 attacks
images_tensor = aux_images.cpu()
labels_tensor = aux_labels.cpu()

methods = ['label_only','vector_based','score_based','one_hot']
vectors_by_method = {}
targets_by_method = {}

for m in methods:
    if m=='label_only':
        vecs, tars = generate_confidence_vectors(
            images_tensor, labels_tensor, target_model,
            method=m, shadow_model=shadow_model, mu=mu,
            num_classes=10, sigma=sigma, device=device
        )
    else:
        vecs, tars = generate_confidence_vectors(
            images_tensor, labels_tensor, target_model,
            method=m, num_classes=10, device=device
        )
    vectors_by_method[m] = vecs    # Tensor [N,10]
    targets_by_method[m] = tars    # Tensor [N,1,28,28]
    print(f"{m}: vectors {vecs.shape}, targets {tars.shape}")


In [None]:
# Step 7: Train Attack Models
attack_models = {}
for m in methods:
    vecs = vectors_by_method[m]
    imgs = targets_by_method[m]
    print(f"[Attack] Training {m} with {vecs.shape[0]} samples...")
    attack_models[m] = train_attack_model(vecs, imgs, epochs=100)
print("All attack models done.")


In [None]:
# Step 8: Reconstruct first 10 test images
test_imgs, test_lbls = next(iter(loader_test))
test_imgs = test_imgs[:10].to(device); test_lbls = test_lbls[:10]

recon_dict = {}
for m in methods:
    vecs = []
    for i in range(10):
        x = test_imgs[i]
        if m=='label_only':
            vec = recover_confidence_vector(shadow_model, mu, test_lbls[i].item(),
                                            num_classes=10, sigma=sigma)
        elif m=='vector_based':
            vec = get_confidence_vector(target_model, x)
        elif m=='score_based':
            vec = get_score_based_vector(target_model, x)
        else:
            vec = get_one_hot_vector(target_model, x)
        vecs.append(vec.cpu())
    vecs = torch.stack(vecs)
    recon = reconstruct_images(attack_models[m], vecs.to(device))
    recon_dict[m] = recon.cpu()
print("Reconstruction finished.")


In [None]:
# Step 9: Evaluate & Plot Comparison
# Ground truth first 10:
ground = test_imgs.cpu()
labels_map = dict(label_only="Label-only", vector_based="Vector-based",
                  score_based="Score-based", one_hot="One-hot")

plt_dict = {labels_map[m]: recon_dict[m] for m in methods}
plot_comparison(ground, plt_dict, title="Model Inversion Comparison")
