In [None]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

In [None]:
class Lenet(nn.Module):
    def __init__(self):
        super(Lenet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, padding=0)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)

        x = self.dropout2(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc3(x)

        output = F.log_softmax(x, dim=1)

        return output

In [None]:
class CompactCNN(nn.Module):
    """
    Realistic architecture an attacker might use knowing the task is digit recognition.
    Based on common MNIST tutorial patterns.
    """
    def __init__(self):
        super(CompactCNN, self).__init__()
        # Common MNIST CNN pattern: 32->64 channels
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        # Common MNIST FC sizes: 128->10
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

In [26]:
# Define normalization
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Load dataset
train_dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("../data", train=False, transform=transform)

In [27]:
# Load target model
target_model = Lenet()
target_model.load_state_dict(torch.load("mnist_cnn.pt"))

<All keys matched successfully>

In [28]:
# Set up MEA parameters
attack_number = 6000  # maximum attack_number is 60 000
attack_indices = random.sample(range(0, len(train_dataset)), attack_number)

queries = torch.utils.data.Subset(train_dataset, attack_indices)
queries_loader = torch.utils.data.DataLoader(queries, batch_size=64, shuffle=True)

In [29]:
# Instantiate extracted model using realistic attacker architecture
extracted_model = CompactCNN()

# Define the optimizer for model extracting (optimized for knowledge distillation)
optimizer = optim.Adam(extracted_model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=25, gamma=0.5)

In [30]:
# Perform model extraction
target_model.eval()

start_time = time.time()

for epoch in range(20):
    epoch_loss = 0
    num_batches = 0

    for data, target in tqdm(queries_loader, desc=f"Epoch {epoch+1}"):
        optimizer.zero_grad()
        output = extracted_model(data)

        # Query target model without computing gradients
        with torch.no_grad():
            target_output = target_model(data)

        loss = F.kl_div(F.log_softmax(output, dim=1), F.softmax(target_output, dim=1), reduction='batchmean', log_target=False)
        loss.backward()
        optimizer.step()

        # Update running statistics
        epoch_loss += loss.item()
        num_batches += 1

    print(f"Avg Loss: {epoch_loss/num_batches:.4f}")

    scheduler.step()

extract_time = time.time() - start_time
print("\n" + "=" * 50)
print(f"Model extraction completed in {extract_time} seconds")

Epoch 1: 100%|██████████| 94/94 [00:14<00:00,  6.59it/s]


Avg Loss: 0.7328


Epoch 2: 100%|██████████| 94/94 [00:13<00:00,  6.74it/s]


Avg Loss: 0.2100


Epoch 3: 100%|██████████| 94/94 [00:10<00:00,  9.24it/s]


Avg Loss: 0.1317


Epoch 4: 100%|██████████| 94/94 [00:10<00:00,  8.85it/s]


Avg Loss: 0.1019


Epoch 5: 100%|██████████| 94/94 [00:10<00:00,  9.21it/s]


Avg Loss: 0.0896


Epoch 6: 100%|██████████| 94/94 [00:13<00:00,  6.96it/s]


Avg Loss: 0.0770


Epoch 7: 100%|██████████| 94/94 [00:11<00:00,  7.90it/s]


Avg Loss: 0.0671


Epoch 8: 100%|██████████| 94/94 [00:10<00:00,  9.02it/s]


Avg Loss: 0.0606


Epoch 9: 100%|██████████| 94/94 [00:10<00:00,  9.12it/s]


Avg Loss: 0.0510


Epoch 10: 100%|██████████| 94/94 [00:12<00:00,  7.48it/s]


Avg Loss: 0.0478


Epoch 11: 100%|██████████| 94/94 [00:14<00:00,  6.50it/s]


Avg Loss: 0.0441


Epoch 12: 100%|██████████| 94/94 [00:12<00:00,  7.73it/s]


Avg Loss: 0.0481


Epoch 13: 100%|██████████| 94/94 [00:12<00:00,  7.33it/s]


Avg Loss: 0.0438


Epoch 14: 100%|██████████| 94/94 [00:09<00:00, 10.29it/s]


Avg Loss: 0.0405


Epoch 15: 100%|██████████| 94/94 [00:10<00:00,  8.56it/s]


Avg Loss: 0.0417


Epoch 16: 100%|██████████| 94/94 [00:11<00:00,  8.42it/s]


Avg Loss: 0.0389


Epoch 17: 100%|██████████| 94/94 [00:11<00:00,  8.25it/s]


Avg Loss: 0.0413


Epoch 18: 100%|██████████| 94/94 [00:09<00:00,  9.91it/s]


Avg Loss: 0.0361


Epoch 19: 100%|██████████| 94/94 [00:09<00:00, 10.17it/s]


Avg Loss: 0.0340


Epoch 20: 100%|██████████| 94/94 [00:13<00:00,  7.23it/s]

Avg Loss: 0.0331

Model extraction completed in 231.87012553215027 seconds





In [31]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

extracted_model.eval()

test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in tqdm(test_loader):
        output = extracted_model(data)
        test_loss += F.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print(
    "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
        test_loss,
        correct,
        len(test_loader.dataset),
        100.0 * correct / len(test_loader.dataset),
    )
)

100%|██████████| 157/157 [00:11<00:00, 13.28it/s]


Test set: Average loss: 0.0564, Accuracy: 9822/10000 (98%)






In [None]:
torch.save(extracted_model.state_dict(), "mnist_comp_cnn_mea.pt")