In [None]:
import torch
import os
import numpy as np
import pandas as pd

from mia_lib.config import load_config
from mia_lib.data import get_cifar10_dataloaders, create_subset_dataloader
from mia_lib.models import create_model
from mia_lib.trainer import train_model
from mia_lib.attack.shadow_training import train_shadow_models
from mia_lib.attack.dataset_preparation import create_attack_dataset
from mia_lib.attack.train_attack_model import train_attack_model

### Load config.

In [17]:
config = load_config("configs/mia_config.yaml")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Get dataset

In [3]:
trainset, testset, trainloader, testloader = get_cifar10_dataloaders(config)
print("Dataset loaded...")

Files already downloaded and verified
Files already downloaded and verified
Dataset loaded...


### Train (or load existing) target model (resnet18 in our case) on a subset of the test set.

In [4]:
target_model = create_model(config).to(device)
os.makedirs(config["paths"]["model_save_dir"], exist_ok=True)

target_model_path = os.path.join(config["paths"]["model_save_dir"], "target_model.pth")

if os.path.exists(target_model_path):
    print(f"Target model checkpoint found at {target_model_path}. Loading...")
    target_model.load_state_dict(torch.load(target_model_path))
else:
    print(f"No target model checkoint found, Training a new one at {target_model_path}")
    # subset indices
    total_test_indices = np.arange(len(testset))
    
    # some MIA research workfloas reserve a "train" portion for shadow models. So, train model on test subset
    target_train_indices = np.random.choice(
        total_test_indices,
        config["training"]["train_subset_size"],
        replace=False
    )

    remaining_after_train = np.setdiff1d(total_test_indices, target_train_indices)
    target_eval_indices = np.random.choice(
        remaining_after_train,
        config["training"]["eval_subset_size"],
        replace=False
    )

    subset_tgt_train_loader = create_subset_dataloader(
        testset,
        target_train_indices,
        batch_size=config["training"]["train_batch_size"],
        shuffle=True,
        num_workers=config["training"]["num_workers"]
    )

    subset_tgt_eval_loader = create_subset_dataloader(
        testset,
        target_eval_indices,
        batch_size=config["training"]["eval_batch_size"],
        shuffle=False,
        num_workers=config["training"]["num_workers"]
    )

    best_acc, best_loss = train_model(
        target_model,
        subset_tgt_train_loader,
        subset_tgt_eval_loader,
        config,
        device,
        target_model_path
    )

    print(f"Target Model => Best Val Acc: {best_acc:.4f}%, Best Val Loss: {best_loss:.4f}")


Target model checkpoint found at ./models/target_model.pth. Loading...


  target_model.load_state_dict(torch.load(target_model_path))


### Train (or load) shadow models on the training set of CIFAR-10.

In [5]:
shadow_models = train_shadow_models(config, trainset, device)

[Shadow 0] Loaded existing indices from ./shadow_indices/shadow_indices_0.npz
[Shadow 0] Found checkpoint at ./shadow_models/shadow_model_0.pth. Loading...
[Shadow 1] Loaded existing indices from ./shadow_indices/shadow_indices_1.npz


  shadow_model.load_state_dict(torch.load(save_path, map_location=device))


[Shadow 1] Found checkpoint at ./shadow_models/shadow_model_1.pth. Loading...
[Shadow 2] Loaded existing indices from ./shadow_indices/shadow_indices_2.npz
[Shadow 2] Found checkpoint at ./shadow_models/shadow_model_2.pth. Loading...
[Shadow 3] Loaded existing indices from ./shadow_indices/shadow_indices_3.npz
[Shadow 3] Found checkpoint at ./shadow_models/shadow_model_3.pth. Loading...
[Shadow 4] Loaded existing indices from ./shadow_indices/shadow_indices_4.npz
[Shadow 4] Found checkpoint at ./shadow_models/shadow_model_4.pth. Loading...
[Shadow 5] Loaded existing indices from ./shadow_indices/shadow_indices_5.npz
[Shadow 5] Found checkpoint at ./shadow_models/shadow_model_5.pth. Loading...
[Shadow 6] Loaded existing indices from ./shadow_indices/shadow_indices_6.npz
[Shadow 6] Found checkpoint at ./shadow_models/shadow_model_6.pth. Loading...
[Shadow 7] Loaded existing indices from ./shadow_indices/shadow_indices_7.npz
[Shadow 7] Found checkpoint at ./shadow_models/shadow_model_7.pt

### For each shadow model, create the member/non-member dataset

In [6]:
df_attack_total = []

for i, (shadow_model, train_idx, eval_idx, test_idx) in enumerate(shadow_models):
    # Rebuild DataLoaders from Membership dataset creation
    shadow_train_loader = create_subset_dataloader(
        trainset,
        train_idx,
        batch_size=config["dataset"]["train_batch_size"],
        shuffle=True,
        num_workers=config["dataset"]["num_workers"]
    )

    shadow_test_loader = create_subset_dataloader(
        trainset,
        eval_idx,
        batch_size=config["dataset"]["eval_batch_size"],
        shuffle=False,
        num_workers=config["dataset"]["num_workers"]
    )

    # Create the attack dataset for the shadow model
    df_attack = create_attack_dataset(
        shadow_model,
        shadow_train_loader,
        shadow_test_loader,
        device,
        output_dim=config["attack"]["output_dim"]
    )
    
    df_attack_total.append(df_attack)

    # Free GPU mem
    shadow_model.cpu()
    del shadow_model

df_attack_total = pd.concat(df_attack_total, ignore_index=True)

Member set: 100%|██████████| 10/10 [00:01<00:00,  7.80it/s]
Non-member set: 100%|██████████| 5/5 [00:00<00:00, 11.19it/s]
Member set: 100%|██████████| 10/10 [00:00<00:00, 24.29it/s]
Non-member set: 100%|██████████| 5/5 [00:00<00:00, 10.61it/s]
Member set: 100%|██████████| 10/10 [00:00<00:00, 23.87it/s]
Non-member set: 100%|██████████| 5/5 [00:00<00:00, 10.19it/s]
Member set: 100%|██████████| 10/10 [00:00<00:00, 22.79it/s]
Non-member set: 100%|██████████| 5/5 [00:00<00:00, 10.34it/s]
Member set: 100%|██████████| 10/10 [00:00<00:00, 21.82it/s]
Non-member set: 100%|██████████| 5/5 [00:00<00:00, 11.36it/s]
Member set: 100%|██████████| 10/10 [00:00<00:00, 23.36it/s]
Non-member set: 100%|██████████| 5/5 [00:00<00:00, 11.60it/s]
Member set: 100%|██████████| 10/10 [00:00<00:00, 23.81it/s]
Non-member set: 100%|██████████| 5/5 [00:00<00:00, 10.83it/s]
Member set: 100%|██████████| 10/10 [00:00<00:00, 20.77it/s]
Non-member set: 100%|██████████| 5/5 [00:00<00:00, 10.45it/s]
Member set: 100%|███████

### Train (or load) the final MIA model

In [7]:
attack_save_dir = config["paths"]["attack_save_dir"]
os.makedirs(attack_save_dir, exist_ok=True)

attack_model_path = os.path.join(attack_save_dir, "attack_model.pth")

if os.path.exists(attack_model_path):
    print(f"Attack model found at {attack_model_path}. Loading...")
else:
    attack_model = train_attack_model(df_attack_total, config)

0:	learn: 0.6393474	total: 70.1ms	remaining: 14s
1:	learn: 0.6074865	total: 86.9ms	remaining: 8.61s
2:	learn: 0.5880751	total: 103ms	remaining: 6.75s
3:	learn: 0.5762324	total: 122ms	remaining: 5.97s
4:	learn: 0.5689567	total: 144ms	remaining: 5.63s
5:	learn: 0.5651816	total: 161ms	remaining: 5.21s
6:	learn: 0.5633283	total: 178ms	remaining: 4.89s
7:	learn: 0.5620490	total: 193ms	remaining: 4.63s
8:	learn: 0.5610410	total: 209ms	remaining: 4.44s
9:	learn: 0.5600627	total: 226ms	remaining: 4.3s
10:	learn: 0.5595695	total: 250ms	remaining: 4.29s
11:	learn: 0.5591987	total: 267ms	remaining: 4.18s
12:	learn: 0.5588616	total: 293ms	remaining: 4.22s
13:	learn: 0.5587263	total: 321ms	remaining: 4.26s
14:	learn: 0.5585065	total: 342ms	remaining: 4.22s
15:	learn: 0.5584440	total: 367ms	remaining: 4.22s
16:	learn: 0.5583300	total: 385ms	remaining: 4.15s
17:	learn: 0.5582044	total: 406ms	remaining: 4.1s
18:	learn: 0.5581513	total: 423ms	remaining: 4.03s
19:	learn: 0.5581244	total: 437ms	remaining

In [None]:
from mia_lib.attack.oslo_attack import run_oslo_attack
# take first (or any other) shadow model as surrogate network
source_model = shadow_models[0][0].to(device)
source_model.eval()

results = run_oslo_attack(
    config,
    target_model=target_model,
    source_model=shadow_models[0][0],
    train_loader=trainloader,
    test_loader=testloader,
    device=device
)

[OSLO] Generating adversarial samples on 'member' dataset...


OSLO Adversarial Gen: 100%|██████████| 196/196 [03:26<00:00,  1.05s/it]


[OSLO] Generating adversarial samples on 'non-member' dataset...


OSLO Adversarial Gen: 100%|██████████| 20/20 [00:46<00:00,  2.33s/it]


[OSLO] Confusion Matrix:
 [[ 6652  3348]
 [36444 13556]]
[OSLO] Metrics: Precision: 0.8019, Recall: 0.2711, F1: 0.4052, Accuracy: 0.3368


: 