In [None]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from tqdm import tqdm
import pandas as pd
import numpy as np
import importlib

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

torch.backends.cudnn.benchmark = True
# torch.cuda.set_per_process_memory_fraction(0.7)
torch.cuda.empty_cache()

Using device: cuda


In [16]:
DSET_CLASS = torchvision.datasets.CIFAR10
NUM_CLASSES = 10

transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

testset = DSET_CLASS(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=512, shuffle=False, num_workers=2)

# define dataset for attack model that shadow models will generate
print("mapped classes to ids:", testset.class_to_idx)


Files already downloaded and verified
mapped classes to ids: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


In [17]:
# Training multiple shadow models
model_architecture = importlib.import_module("torchvision.models")
model_class = getattr(model_architecture, "resnet18")
criterion = nn.CrossEntropyLoss()

# Train Target Model
target_model = model_class(num_classes=NUM_CLASSES)
target_model.fc = nn.Linear(in_features=target_model.fc.in_features, out_features=NUM_CLASSES, bias=True)
target_model = target_model.to(device)
optimizer = AdamW(target_model.parameters(), lr=0.001, weight_decay=0.00001)

target_train_indices = np.random.choice(len(testset), 7500, replace=False)
target_eval_indices = np.setdiff1d(np.arange(len(testset)), target_train_indices)


In [18]:
os.makedirs("./models", exist_ok=True)
pd.DataFrame(target_train_indices, columns=["index"]).to_csv(
    "./models/target_train_indices.csv", index=False
)

subset_tgt_train = Subset(testset, target_train_indices)
subset_tgt_eval = Subset(testset, target_eval_indices)

subset_tgt_train_loader = DataLoader(subset_tgt_train, batch_size=256, shuffle=True, num_workers=2)
subset_tgt_eval_loader = DataLoader(subset_tgt_eval, batch_size=512, shuffle=False, num_workers=2)

run_name = f"target_model_resnet18_cifar10"

In [20]:
class EarlyStopPatience(nn.Module):
    def __init__(self, patience=10):
        super(EarlyStopPatience, self).__init__()
        self.patience = patience
        self.counter = 0
        self.best_loss = None

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        else:
            self.best_loss = loss
            self.counter = 0
        return False

In [24]:
early_stop_acc1 = EarlyStopPatience(patience=10)
best_valid_acc = 0
best_valid_loss = 10
EPOCHS = 10
SAVE_PATH = "./models"

if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

for epoch in range(EPOCHS):
    target_model.train()
    train_loss = 0
    train_acc = 0
    for i, (inputs, labels) in enumerate(subset_tgt_train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = target_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_acc += (outputs.argmax(1) == labels).sum().item()

    train_loss /= len(subset_tgt_train_loader)
    train_acc /= len(subset_tgt_train_loader.dataset)

    target_model.eval()
    valid_loss = 0
    valid_acc = 0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(subset_tgt_eval_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = target_model(inputs)
            loss = criterion(outputs, labels)
            valid_loss += loss.item()
            valid_acc += (outputs.argmax(1) == labels).sum().item()

    valid_loss /= len(subset_tgt_eval_loader)
    valid_acc /= len(subset_tgt_eval_loader.dataset)

    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}")

    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        torch.save(target_model.state_dict(), f"{SAVE_PATH}/{run_name}.pth")
        print(f"Saved model with valid acc: {valid_acc:.4f} -> {SAVE_PATH}/{run_name}.pth")

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss

    if early_stop_acc1(valid_acc):
        print("Early stopping")
        break

print("Loading best model...")
print(f"Best valid acc: {best_valid_acc:.4f}")
print(f"Best valid loss: {best_valid_loss:.4f}")

target_model.load_state_dict(torch.load(f"{SAVE_PATH}/{run_name}.pth"))



Epoch 1/10, Train Loss: 0.2831, Train Acc: 0.8957, Valid Loss: 2.2701, Valid Acc: 0.4932
Saved model with valid acc: 0.4932 -> ./models/target_model_resnet18_cifar10.pth
Epoch 2/10, Train Loss: 0.2389, Train Acc: 0.9175, Valid Loss: 1.9965, Valid Acc: 0.5288
Saved model with valid acc: 0.5288 -> ./models/target_model_resnet18_cifar10.pth
Epoch 3/10, Train Loss: 0.1683, Train Acc: 0.9428, Valid Loss: 2.3758, Valid Acc: 0.5196
Epoch 4/10, Train Loss: 0.1614, Train Acc: 0.9451, Valid Loss: 2.4788, Valid Acc: 0.4940
Epoch 5/10, Train Loss: 0.1611, Train Acc: 0.9444, Valid Loss: 2.4311, Valid Acc: 0.5056
Epoch 6/10, Train Loss: 0.1982, Train Acc: 0.9307, Valid Loss: 2.2953, Valid Acc: 0.5152
Epoch 7/10, Train Loss: 0.1456, Train Acc: 0.9529, Valid Loss: 2.2763, Valid Acc: 0.5356
Saved model with valid acc: 0.5356 -> ./models/target_model_resnet18_cifar10.pth
Epoch 8/10, Train Loss: 0.1330, Train Acc: 0.9564, Valid Loss: 2.1839, Valid Acc: 0.5412
Saved model with valid acc: 0.5412 -> ./model

  target_model.load_state_dict(torch.load(f"{SAVE_PATH}/{run_name}.pth"))


<All keys matched successfully>