In [1]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from tqdm import tqdm
import pandas as pd
import numpy as np
import importlib

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

torch.backends.cudnn.benchmark = True
# torch.cuda.set_per_process_memory_fraction(0.7)
torch.cuda.empty_cache()

Using device: cuda


In [2]:
DSET_CLASS = torchvision.datasets.CIFAR10
NUM_CLASSES = 10

transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

testset = DSET_CLASS(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=512, shuffle=False, num_workers=2)

# define dataset for attack model that shadow models will generate
print("mapped classes to ids:", testset.class_to_idx)


Files already downloaded and verified
mapped classes to ids: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


In [3]:
# Training multiple shadow models
model_architecture = importlib.import_module("torchvision.models")
model_class = getattr(model_architecture, "resnet18")
criterion = nn.CrossEntropyLoss()

# Train Target Model
target_model = model_class(num_classes=NUM_CLASSES)
target_model.fc = nn.Linear(in_features=target_model.fc.in_features, out_features=NUM_CLASSES, bias=True)
target_model = target_model.to(device)
optimizer = AdamW(target_model.parameters(), lr=0.001, weight_decay=0.00001)

target_train_indices = np.random.choice(len(testset), 5000, replace=False)
target_eval_indices = np.setdiff1d(np.arange(len(testset)), target_train_indices)


In [4]:
os.makedirs("./models", exist_ok=True)
pd.DataFrame(target_train_indices, columns=["index"]).to_csv(
    "./models/target_train_indices.csv", index=False
)

subset_tgt_train = Subset(testset, target_train_indices)
subset_tgt_eval = Subset(testset, target_eval_indices)

subset_tgt_train_loader = DataLoader(subset_tgt_train, batch_size=256, shuffle=True, num_workers=2)
subset_tgt_eval_loader = DataLoader(subset_tgt_eval, batch_size=512, shuffle=False, num_workers=2)

run_name = f"target_model_resnet18_cifar10"

In [5]:
class EarlyStopPatience(nn.Module):
    def __init__(self, patience=10):
        super(EarlyStopPatience, self).__init__()
        self.patience = patience
        self.counter = 0
        self.best_loss = None

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        else:
            self.best_loss = loss
            self.counter = 0
        return False

In [8]:
EPOCHS = 100
SAVE_PATH = "./models"


early_stop_acc1 = EarlyStopPatience(patience=10)
best_valid_acc = 0
best_valid_loss = 10


if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

for epoch in range(EPOCHS):
    target_model.train()
    train_loss = 0
    train_acc = 0
    for i, (inputs, labels) in enumerate(subset_tgt_train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = target_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_acc += (outputs.argmax(1) == labels).sum().item()

    train_loss /= len(subset_tgt_train_loader)
    train_acc /= len(subset_tgt_train_loader.dataset)

    target_model.eval()
    valid_loss = 0
    valid_acc = 0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(subset_tgt_eval_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = target_model(inputs)
            loss = criterion(outputs, labels)
            valid_loss += loss.item()
            valid_acc += (outputs.argmax(1) == labels).sum().item()

    valid_loss /= len(subset_tgt_eval_loader)
    valid_acc /= len(subset_tgt_eval_loader.dataset)

    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}")

    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        torch.save(target_model.state_dict(), f"{SAVE_PATH}/{run_name}.pth")
        print(f"Saved model with valid acc: {valid_acc:.4f} -> {SAVE_PATH}/{run_name}.pth")

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss

    if early_stop_acc1(valid_acc):
        print("Early stopping")
        break

print("Loading best model...")
print(f"Best valid acc: {best_valid_acc:.4f}")
print(f"Best valid loss: {best_valid_loss:.4f}")

target_model.load_state_dict(torch.load(f"{SAVE_PATH}/{run_name}.pth"))



Epoch 1/100, Train Loss: 0.6424, Train Acc: 0.7858, Valid Loss: 1.5764, Valid Acc: 0.4788
Saved model with valid acc: 0.4788 -> ./models/shadow_model_0_resnet18_cifar10.pth
Epoch 2/100, Train Loss: 0.6419, Train Acc: 0.7892, Valid Loss: 1.5723, Valid Acc: 0.4846
Saved model with valid acc: 0.4846 -> ./models/shadow_model_0_resnet18_cifar10.pth
Epoch 3/100, Train Loss: 0.6443, Train Acc: 0.7838, Valid Loss: 1.5747, Valid Acc: 0.4836
Epoch 4/100, Train Loss: 0.6432, Train Acc: 0.7860, Valid Loss: 1.5755, Valid Acc: 0.4838
Epoch 5/100, Train Loss: 0.6407, Train Acc: 0.7852, Valid Loss: 1.5733, Valid Acc: 0.4834
Epoch 6/100, Train Loss: 0.6410, Train Acc: 0.7866, Valid Loss: 1.5768, Valid Acc: 0.4832
Epoch 7/100, Train Loss: 0.6465, Train Acc: 0.7848, Valid Loss: 1.5747, Valid Acc: 0.4818
Epoch 8/100, Train Loss: 0.6472, Train Acc: 0.7852, Valid Loss: 1.5722, Valid Acc: 0.4834
Epoch 9/100, Train Loss: 0.6439, Train Acc: 0.7860, Valid Loss: 1.5749, Valid Acc: 0.4836
Epoch 10/100, Train Loss

  target_model.load_state_dict(torch.load(f"{SAVE_PATH}/{run_name}.pth"))


<All keys matched successfully>

In [None]:
shadow_set = DSET_CLASS(root='./data', train=True, download=True, transform=transform)
shadow_loader = DataLoader(shadow_set, batch_size=256, shuffle=True, num_workers=2)

# define dataset for attack model that shadow models will generate
columns_attack_sdet = [f"top_{index}_prob" for index in range(10)]
df_attack_dset = pd.DataFrame({}, columns=columns_attack_sdet + ["is_member"])

# random subset for shadow model train & validation from the CIFAR shadow_set
list_train_loader = []
list_eval_loader = []
list_test_loader = []

NUM_SHODOW_MODELS = 64

for _ in range(NUM_SHODOW_MODELS):
    train_indices = np.random.choice(len(shadow_set), 2500, replace=False)
    eval_indices = np.setdiff1d(np.arange(len(shadow_set)), train_indices)
    eval_indices = np.random.choice(eval_indices, 2500, replace=False)
    test_indices = np.setdiff1d(
        np.arange(len(shadow_set)), np.concatenate((train_indices, eval_indices))
    )
    test_indices = np.random.choice(test_indices, 2500, replace=False)

    subset_train = Subset(shadow_set, train_indices)
    subset_eval = Subset(shadow_set, eval_indices)
    subset_test = Subset(shadow_set, test_indices)

    subset_train_loader = DataLoader(subset_train, batch_size=256, shuffle=True, num_workers=2)
    subset_eval_loader = DataLoader(subset_eval, batch_size=256, shuffle=False, num_workers=2)
    subset_test_loader = DataLoader(subset_test, batch_size=256, shuffle=False, num_workers=2)

    list_train_loader.append(subset_train_loader)
    list_eval_loader.append(subset_eval_loader)
    list_test_loader.append(subset_test_loader)

model_architecture = importlib.import_module("torchvision.models")
model_class = getattr(model_architecture, "resnet18")
criterion = nn.CrossEntropyLoss()

for shadow_number, shadow_loader in enumerate(tqdm(list_train_loader)):
    print(f"Training shadow model {shadow_number}/{NUM_SHODOW_MODELS}")
    evalloader = list_eval_loader[shadow_number]
    testloader = list_test_loader[shadow_number]

    shadow_model = model_class(pretrained=False)
    shadow_model.fc = nn.Linear(in_features=shadow_model.fc.in_features, out_features=NUM_CLASSES, bias=True)
    shadow_model = shadow_model.to(device)
    optimizer = AdamW(shadow_model.parameters(), lr=0.001, weight_decay=0.00001)

    run_name = f"shadow_model_{shadow_number}_resnet18_cifar10"

    early_stop_acc1 = EarlyStopPatience(patience=10)
    best_valid_acc = 0
    best_valid_loss = 10

    for epoch in range(EPOCHS):
        shadow_model.train()
        train_loss = 0
        train_acc = 0
        for i, (inputs, labels) in enumerate(shadow_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = shadow_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_acc += (outputs.argmax(1) == labels).sum().item()

        train_loss /= len(shadow_loader)
        train_acc /= len(shadow_loader.dataset)

        shadow_model.eval()
        valid_loss = 0
        valid_acc = 0
        with torch.no_grad():
            for i, (inputs, labels) in enumerate(evalloader):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = shadow_model(inputs)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
                valid_acc += (outputs.argmax(1) == labels).sum().item()

        valid_loss /= len(evalloader)
        valid_acc /= len(evalloader.dataset)

        print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}")

        if valid_acc > best_valid_acc:
            best_valid_acc = valid_acc
            torch.save(shadow_model.state_dict(), f"{SAVE_PATH}/{run_name}.pth")
            print(f"Saved model with valid acc: {valid_acc:.4f} -> {SAVE_PATH}/{run_name}.pth")

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss

        if early_stop_acc1(valid_acc):
            print("Early stopping")
            break
    

Files already downloaded and verified




Training shadow model 0/64
Epoch 1/100, Train Loss: 2.1465, Train Acc: 0.2352, Valid Loss: 2.3191, Valid Acc: 0.1496
Saved model with valid acc: 0.1496 -> ./models/shadow_model_0_resnet18_cifar10.pth
Epoch 2/100, Train Loss: 1.4797, Train Acc: 0.4836, Valid Loss: 2.3162, Valid Acc: 0.2272
Saved model with valid acc: 0.2272 -> ./models/shadow_model_0_resnet18_cifar10.pth
Epoch 3/100, Train Loss: 0.9540, Train Acc: 0.6748, Valid Loss: 2.1440, Valid Acc: 0.2980
Saved model with valid acc: 0.2980 -> ./models/shadow_model_0_resnet18_cifar10.pth
Epoch 4/100, Train Loss: 0.5007, Train Acc: 0.8368, Valid Loss: 2.3194, Valid Acc: 0.3300
Saved model with valid acc: 0.3300 -> ./models/shadow_model_0_resnet18_cifar10.pth
Epoch 5/100, Train Loss: 0.2428, Train Acc: 0.9320, Valid Loss: 2.3756, Valid Acc: 0.3896
Saved model with valid acc: 0.3896 -> ./models/shadow_model_0_resnet18_cifar10.pth
Epoch 6/100, Train Loss: 0.1360, Train Acc: 0.9600, Valid Loss: 2.8441, Valid Acc: 0.3720
Epoch 7/100, Train

  2%|▏         | 1/64 [00:24<26:14, 24.99s/it]

Epoch 11/100, Train Loss: 0.2600, Train Acc: 0.9156, Valid Loss: 3.5481, Valid Acc: 0.3888
Early stopping
Training shadow model 1/64
Epoch 1/100, Train Loss: 2.0958, Train Acc: 0.2564, Valid Loss: 2.3130, Valid Acc: 0.1668
Saved model with valid acc: 0.1668 -> ./models/shadow_model_1_resnet18_cifar10.pth
Epoch 2/100, Train Loss: 1.4417, Train Acc: 0.5028, Valid Loss: 2.5524, Valid Acc: 0.1960
Saved model with valid acc: 0.1960 -> ./models/shadow_model_1_resnet18_cifar10.pth
Epoch 3/100, Train Loss: 0.8979, Train Acc: 0.6884, Valid Loss: 2.6456, Valid Acc: 0.2224
Saved model with valid acc: 0.2224 -> ./models/shadow_model_1_resnet18_cifar10.pth
Epoch 4/100, Train Loss: 0.4525, Train Acc: 0.8532, Valid Loss: 2.5746, Valid Acc: 0.3032
Saved model with valid acc: 0.3032 -> ./models/shadow_model_1_resnet18_cifar10.pth
Epoch 5/100, Train Loss: 0.2113, Train Acc: 0.9424, Valid Loss: 2.3790, Valid Acc: 0.3772
Saved model with valid acc: 0.3772 -> ./models/shadow_model_1_resnet18_cifar10.pth
Ep

  3%|▎         | 2/64 [00:48<25:00, 24.20s/it]

Epoch 11/100, Train Loss: 0.1703, Train Acc: 0.9420, Valid Loss: 3.2426, Valid Acc: 0.3704
Early stopping
Training shadow model 2/64
Epoch 1/100, Train Loss: 2.0827, Train Acc: 0.2596, Valid Loss: 2.3143, Valid Acc: 0.1148
Saved model with valid acc: 0.1148 -> ./models/shadow_model_2_resnet18_cifar10.pth
Epoch 2/100, Train Loss: 1.3915, Train Acc: 0.5152, Valid Loss: 2.4692, Valid Acc: 0.1532
Saved model with valid acc: 0.1532 -> ./models/shadow_model_2_resnet18_cifar10.pth
Epoch 3/100, Train Loss: 0.8544, Train Acc: 0.7072, Valid Loss: 2.3107, Valid Acc: 0.2836
Saved model with valid acc: 0.2836 -> ./models/shadow_model_2_resnet18_cifar10.pth
Epoch 4/100, Train Loss: 0.4167, Train Acc: 0.8664, Valid Loss: 2.4335, Valid Acc: 0.3376
Saved model with valid acc: 0.3376 -> ./models/shadow_model_2_resnet18_cifar10.pth
Epoch 5/100, Train Loss: 0.2371, Train Acc: 0.9236, Valid Loss: 2.6068, Valid Acc: 0.3720
Saved model with valid acc: 0.3720 -> ./models/shadow_model_2_resnet18_cifar10.pth
Ep

  5%|▍         | 3/64 [01:23<29:33, 29.08s/it]

Epoch 11/100, Train Loss: 0.1773, Train Acc: 0.9384, Valid Loss: 3.3410, Valid Acc: 0.3996
Early stopping
Training shadow model 3/64
Epoch 1/100, Train Loss: 2.1327, Train Acc: 0.2540, Valid Loss: 2.3271, Valid Acc: 0.1460
Saved model with valid acc: 0.1460 -> ./models/shadow_model_3_resnet18_cifar10.pth
Epoch 2/100, Train Loss: 1.4662, Train Acc: 0.4768, Valid Loss: 2.8573, Valid Acc: 0.1596
Saved model with valid acc: 0.1596 -> ./models/shadow_model_3_resnet18_cifar10.pth
Epoch 3/100, Train Loss: 0.9325, Train Acc: 0.6832, Valid Loss: 2.5923, Valid Acc: 0.2516
Saved model with valid acc: 0.2516 -> ./models/shadow_model_3_resnet18_cifar10.pth
Epoch 4/100, Train Loss: 0.5191, Train Acc: 0.8264, Valid Loss: 2.4745, Valid Acc: 0.3024
Saved model with valid acc: 0.3024 -> ./models/shadow_model_3_resnet18_cifar10.pth
