# Model Stealing Attack - Assignment 2

**Team Number**: 19 
**Task**: Implement a model stealing attack against B4B-protected encoder while minimizing L2 distance

In [1]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import requests
import io
import base64
import json
import numpy as np
from tqdm import tqdm
from sklearn.decomposition import PCA
from collections import defaultdict
import pickle
from PIL import Image
import torchvision.transforms as transforms
import onnxruntime as ort
import time

# Set device
device = torch.device("mps") # if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

ModuleNotFoundError: No module named 'onnxruntime'

## 1. API Connection Setup

In [None]:
# Authentication token for the target API
TOKEN = "50407833" 
PORT = None
SEED = None

def launch_api():
    """Initialize connection to the target API and get session details"""
    global PORT, SEED
    response = requests.get(
        "http://34.122.51.94:9090/stealing_launch",
        headers={"token": TOKEN}
    )
    answer = response.json()
    if 'detail' in answer:
        raise Exception(f"API launch failed: {answer['detail']}")
    SEED = str(answer['seed'])
    PORT = str(answer['port'])
    print(f"API launched. Seed: {SEED}, Port: {PORT}")
    return SEED, PORT

def query_api(images, retries=3, delay=60):
    """Send image batch to API and get embedding representations
    
    Handles rate limiting with exponential backoff retry mechanism
    """
    endpoint = "/query"
    url = f"http://34.122.51.94:{PORT}" + endpoint
    image_data = []
    
    # Convert images to base64 format for API request
    for img in images:
        img = transforms.ToPILImage()(img.cpu())
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)
        img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
        image_data.append(img_base64)
    
    payload = json.dumps(image_data)
    response = requests.get(url, files={"file": payload}, headers={"token": TOKEN})
    
    if response.status_code == 200:
        return response.json()["representations"]
    elif response.status_code == 429:
        if retries > 0:
            print("Rate limited. Retrying after delay...")
            time.sleep(delay)
            return query_api(images, retries - 1, delay * 2)
        else:
            raise Exception("Too many retries. Still getting rate-limited.")
    else:
        raise Exception(f"Query failed. Code: {response.status_code}")

## 2. Coverage Tracking System

In [None]:
class CoverageTracker:
    """Tracks our coverage of the target model's embedding space
    
    Uses a bucketing approach to estimate how much of the embedding space
    we've explored. This helps avoid triggering B4B protection.
    """
    def __init__(self, bucket_size=0.15, max_buckets=4096):
        self.bucket_map = defaultdict(bool)
        self.bucket_size = bucket_size
        self.max_buckets = max_buckets
        self.query_count = 0
        
    def _hash_embedding(self, emb):
        """Convert embedding vector to discrete bucket coordinates
        
        Uses just the first 5 dimensions for efficiency
        """
        emb = np.array(emb)
        return tuple(np.floor(emb[:5] / self.bucket_size))
        
    def update_coverage(self, embeddings):
        """Process new embeddings and update coverage statistics"""
        for emb in embeddings:
            self.bucket_map[self._hash_embedding(emb)] = True
        self.query_count += len(embeddings)
        
    def get_coverage(self):
        """Calculate percentage of embedding space covered"""
        return len(self.bucket_map) / self.max_buckets
    
    def is_safe(self, sample_size=1000):
        """Check if adding more samples is likely to stay under B4B thresholds"""
        current = len(self.bucket_map)
        projected = current + (sample_size * 0.1)
        return projected / self.max_buckets < 0.3

## 3. Data Loading and Processing

In [None]:
class TaskDataset(Dataset):
    """Dataset class for the image collection used in model stealing
    
    Handles image loading, preprocessing and transformations
    """
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform or transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.Lambda(self._ensure_rgb),
            transforms.ToTensor(),
            # Use dataset-specific mean and std for normalization
            transforms.Normalize((0.2980, 0.2962, 0.2987), (0.2886, 0.2875, 0.2889))
        ])

    def _ensure_rgb(self, img):
        """Convert grayscale images to RGB if needed"""
        if img.mode != 'RGB':
            return img.convert('RGB')
        return img

    def __getitem__(self, index):
        img = self.imgs[index]
        if self.transform:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.ids)

def load_stealing_dataset(data_path="ModelStealingPub.pt"):
    """Load the dataset of candidate images for stealing"""
    original_data = torch.load(data_path, map_location="cpu", weights_only=False)
    dataset = TaskDataset()
    dataset.ids = original_data.ids
    dataset.imgs = original_data.imgs
    dataset.labels = original_data.labels
    return dataset

def collate_fn(batch):
    """Custom collate function to handle (image, embedding) pairs"""
    images = torch.stack([item[0] for item in batch])
    targets = torch.stack([item[1] for item in batch])
    return images, targets

## 4. Model Architecture

In [None]:
class EncoderStealer(nn.Module):
    """Neural network model designed to replicate the target encoder
    
    Uses convolutional layers for feature extraction followed by
    fully connected layers to generate embeddings
    """
    def __init__(self, output_dim=1024):
        super().__init__()
        
        # Feature extraction layers - 3 conv blocks with batch norm
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        
        # Embedding generation layers
        self.embedding = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )
        
    def forward(self, x):
        features = self.feature_extractor(x)
        return self.embedding(features)

## 5. Training Loop

In [None]:
def train_model(model, train_loader, epochs=5):
    """Train the model using collected image-embedding pairs
    
    Uses MSE loss to minimize L2 distance between predicted
    and target embeddings
    """
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4)
    criterion = nn.MSELoss()
    
    for epoch in range(epochs):
        total_loss = 0
        for images, targets in tqdm(train_loader):
            images = images.to(device)
            targets = targets.float().to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
            # Handle any shape mismatches (shouldn't happen with proper setup)
            if outputs.shape != targets.shape:
                print(f"Shape mismatch! Output: {outputs.shape}, Target: {targets.shape}")
                targets = targets.view_as(outputs)
                
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

## 6. Main Execution Flow

In [None]:
# Initialize API connection - this gets our unique PORT and SEED values
# def main():
    # Initialize API connection
launch_api()
    

API launched. Seed: 28160397, Port: 9944


('28160397', '9944')

In [None]:
# Load dataset
dataset = load_stealing_dataset()
print(f"Loaded dataset with {len(dataset)} images")

# Initialize our coverage tracking system
tracker = CoverageTracker()

Loaded dataset with 13000 images


In [None]:
# Phase 1: Initial queries (low diversity)
# This establishes our baseline data for later targeted queries
print("Phase 1: Initial low-diversity queries")
train_data = []

for _ in range(5):  # 5 queries x 1000 images = 5000 initial samples
    # Randomly select 1000 images from the dataset
    batch_indices = np.random.choice(len(dataset), 1000, replace=False)
    batch_images = [dataset[i] for i in batch_indices]
    batch_embs = query_api(batch_images)

    # Store images and their corresponding embeddings
    for img, emb in zip(batch_images, batch_embs):
        train_data.append((img, torch.tensor(emb).float()))

    tracker.update_coverage(batch_embs)
    print(f"Queries: {tracker.query_count}, Coverage: {tracker.get_coverage():.2%}")

Phase 1: Initial low-diversity queries
Queries: 1000, Coverage: 0.32%
Rate limited. Retrying after delay...
Queries: 2000, Coverage: 0.39%
Rate limited. Retrying after delay...
Queries: 3000, Coverage: 0.39%
Rate limited. Retrying after delay...
Queries: 4000, Coverage: 0.39%
Rate limited. Retrying after delay...
Queries: 5000, Coverage: 0.42%


In [None]:
# Phase 2: Strategic expansion
# This is where we use most of our query budget to maximize coverage
print("\nPhase 2: Strategic coverage expansion")

while tracker.query_count < 100000 and tracker.get_coverage() < 0.3:
    # Get exactly 1000 samples per query (API requirement)
    batch_images = [dataset[i] for i in np.random.choice(len(dataset), 1000)]
    batch_embs = query_api(batch_images)

    # Update tracking and training data
    for img, emb in zip(batch_images, batch_embs):
        train_data.append((img, torch.tensor(emb).float()))

    tracker.update_coverage(batch_embs)

    # Periodic training every 10,000 samples to check progress
    if len(train_data) % 10000 == 0:
        loader = DataLoader(train_data, batch_size=256, shuffle=True, collate_fn=collate_fn)
        model = EncoderStealer(output_dim=1024).to(device)
        train_model(model, loader, epochs=3)

    print(f"Queries: {tracker.query_count}, Coverage: {tracker.get_coverage():.2%}")


Phase 2: Strategic coverage expansion
Rate limited. Retrying after delay...
Queries: 6000, Coverage: 0.46%
Rate limited. Retrying after delay...
Queries: 7000, Coverage: 0.49%
Rate limited. Retrying after delay...
Queries: 8000, Coverage: 0.49%
Rate limited. Retrying after delay...
Queries: 9000, Coverage: 0.49%
Rate limited. Retrying after delay...


100%|██████████| 40/40 [00:46<00:00,  1.16s/it]


Epoch 1, Loss: 0.2038


100%|██████████| 40/40 [00:46<00:00,  1.16s/it]


Epoch 2, Loss: 0.0051


100%|██████████| 40/40 [00:46<00:00,  1.17s/it]


Epoch 3, Loss: 0.0032
Queries: 10000, Coverage: 0.51%
Queries: 11000, Coverage: 0.54%
Rate limited. Retrying after delay...
Queries: 12000, Coverage: 0.54%
Rate limited. Retrying after delay...
Queries: 13000, Coverage: 0.54%
Rate limited. Retrying after delay...
Queries: 14000, Coverage: 0.54%
Rate limited. Retrying after delay...
Queries: 15000, Coverage: 0.56%
Rate limited. Retrying after delay...
Queries: 16000, Coverage: 0.56%
Rate limited. Retrying after delay...
Queries: 17000, Coverage: 0.56%
Rate limited. Retrying after delay...
Queries: 18000, Coverage: 0.56%
Rate limited. Retrying after delay...
Queries: 19000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 79/79 [01:34<00:00,  1.19s/it]


Epoch 1, Loss: 0.1080


100%|██████████| 79/79 [01:34<00:00,  1.19s/it]


Epoch 2, Loss: 0.0036


100%|██████████| 79/79 [01:37<00:00,  1.24s/it]


Epoch 3, Loss: 0.0027
Queries: 20000, Coverage: 0.59%
Queries: 21000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 22000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 23000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 24000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 25000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 26000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 27000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 28000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 29000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 118/118 [02:26<00:00,  1.24s/it]


Epoch 1, Loss: 0.0711


100%|██████████| 118/118 [02:25<00:00,  1.23s/it]


Epoch 2, Loss: 0.0026


100%|██████████| 118/118 [02:24<00:00,  1.22s/it]


Epoch 3, Loss: 0.0022
Queries: 30000, Coverage: 0.59%
Queries: 31000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 32000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 33000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 34000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 35000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 36000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 37000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 38000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 39000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 157/157 [03:14<00:00,  1.24s/it]


Epoch 1, Loss: 0.0543


100%|██████████| 157/157 [03:17<00:00,  1.26s/it]


Epoch 2, Loss: 0.0025


100%|██████████| 157/157 [03:14<00:00,  1.24s/it]


Epoch 3, Loss: 0.0020
Queries: 40000, Coverage: 0.59%
Queries: 41000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 42000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 43000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 44000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 45000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 46000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 47000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 48000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 49000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 196/196 [04:21<00:00,  1.34s/it]


Epoch 1, Loss: 0.0453


100%|██████████| 196/196 [04:48<00:00,  1.47s/it]


Epoch 2, Loss: 0.0023


100%|██████████| 196/196 [04:02<00:00,  1.24s/it]


Epoch 3, Loss: 0.0019
Queries: 50000, Coverage: 0.59%
Queries: 51000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 52000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 53000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 54000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 55000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 56000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 57000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 58000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 59000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 235/235 [04:43<00:00,  1.21s/it]


Epoch 1, Loss: 0.0375


100%|██████████| 235/235 [04:41<00:00,  1.20s/it]


Epoch 2, Loss: 0.0023


100%|██████████| 235/235 [04:42<00:00,  1.20s/it]


Epoch 3, Loss: 0.0018
Queries: 60000, Coverage: 0.59%
Queries: 61000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 62000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 63000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 64000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 65000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 66000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 67000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 68000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 69000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 274/274 [05:25<00:00,  1.19s/it]


Epoch 1, Loss: 0.0328


100%|██████████| 274/274 [05:22<00:00,  1.18s/it]


Epoch 2, Loss: 0.0020


100%|██████████| 274/274 [05:24<00:00,  1.19s/it]


Epoch 3, Loss: 0.0017
Queries: 70000, Coverage: 0.59%
Queries: 71000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 72000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 73000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 74000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 75000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 76000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 77000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 78000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 79000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 313/313 [06:18<00:00,  1.21s/it]


Epoch 1, Loss: 0.0286


100%|██████████| 313/313 [06:14<00:00,  1.20s/it]


Epoch 2, Loss: 0.0019


100%|██████████| 313/313 [06:20<00:00,  1.22s/it]


Epoch 3, Loss: 0.0016
Queries: 80000, Coverage: 0.59%
Queries: 81000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 82000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 83000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 84000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 85000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 86000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 87000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 88000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 89000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 352/352 [07:02<00:00,  1.20s/it]


Epoch 1, Loss: 0.0264


100%|██████████| 352/352 [07:02<00:00,  1.20s/it]


Epoch 2, Loss: 0.0020


100%|██████████| 352/352 [07:04<00:00,  1.21s/it]


Epoch 3, Loss: 0.0016
Queries: 90000, Coverage: 0.59%
Queries: 91000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 92000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 93000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 94000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 95000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 96000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 97000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 98000, Coverage: 0.59%
Rate limited. Retrying after delay...
Queries: 99000, Coverage: 0.59%
Rate limited. Retrying after delay...


100%|██████████| 391/391 [08:03<00:00,  1.24s/it]


Epoch 1, Loss: 0.0231


100%|██████████| 391/391 [08:01<00:00,  1.23s/it]


Epoch 2, Loss: 0.0018


100%|██████████| 391/391 [08:00<00:00,  1.23s/it]

Epoch 3, Loss: 0.0015
Queries: 100000, Coverage: 0.59%





In [None]:
# Attempt to move forward with the data we've collected
# If something went wrong earlier, this will catch and report the error
try:
    print(f"Collected {len(train_data)} samples from {tracker.query_count} queries")
    print("Proceeding to final training...")
    
    final_loader = DataLoader(train_data, batch_size=256, shuffle=True, collate_fn=collate_fn)
    model = EncoderStealer(output_dim=1024).to(device)
    train_model(model, final_loader, epochs=5)
    
    # Export to ONNX for submission
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    torch.onnx.export(
        model,
        dummy_input,
        "stolen_model.onnx",
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    
    validate_onnx("stolen_model.onnx")
    submit_model("stolen_model.onnx")
    
except NameError:
    print("Error: No training data collected yet. Run main() first.")

Collected 100000 samples from 100000 queries
Proceeding to final training...


100%|██████████| 391/391 [07:57<00:00,  1.22s/it]


Epoch 1, Loss: 0.0237


100%|██████████| 391/391 [07:58<00:00,  1.22s/it]


Epoch 2, Loss: 0.0017


100%|██████████| 391/391 [07:58<00:00,  1.22s/it]


Epoch 3, Loss: 0.0015


100%|██████████| 391/391 [07:59<00:00,  1.23s/it]


Epoch 4, Loss: 0.0014


100%|██████████| 391/391 [08:09<00:00,  1.25s/it]


Epoch 5, Loss: 0.0013


Exception: ONNX validation failed: Input name should be 'x', got input

In [None]:
# Final model training with all collected data
print("\nFinal training")
final_loader = DataLoader(train_data, batch_size=256, shuffle=True, collate_fn=collate_fn)
model = EncoderStealer(output_dim=1024).to(device)
train_model(model, final_loader, epochs=5)

In [None]:
def validate_onnx(model_path):
    """Verify that our ONNX model meets submission requirements
    
    Checks for correct input name and output shape
    """
    try:
        session = ort.InferenceSession(model_path)
        # Check input name matches 'x'
        input_name = session.get_inputs()[0].name
        assert input_name == "x", f"Input name should be 'x', got {input_name}"
        
        test_input = np.random.randn(1, 3, 32, 32).astype(np.float32)
        output = session.run(None, {"x": test_input})[0]  # Note using "x" here
        
        assert output.shape == (1, 1024)
        print("ONNX validation passed! Model meets submission requirements")
    except Exception as e:
        raise Exception(f"ONNX validation failed: {str(e)}")
        
def submit_model(model_path):
    """Submit the stolen model to the evaluation server"""
    url = "http://34.122.51.94:9090/stealing"
    with open(model_path, "rb") as f:
        files = {"file": f}
        headers = {"token": TOKEN, "seed": SEED}
        response = requests.post(url, files=files, headers=headers)
    
    if response.status_code == 200:
        print("Submission successful!")
        print(response.json())
    else:
        print(f"Submission failed: {response.status_code}")
        print(response.text)

In [None]:
# Prepare for final submission
print("\nPreparing for submission...")

# 1. Verify model output shape
test_input = torch.randn(1, 3, 32, 32).to(device)
with torch.no_grad():
    test_output = model(test_input)
print(f"Model test - Input shape: {test_input.shape}, Output shape: {test_output.shape}")

# 2. Export with correct input name (must be "x" for server)
dummy_input = torch.randn(1, 3, 32, 32).to(device)
torch.onnx.export(
    model,
    dummy_input,
    "stolen_model.onnx",
    input_names=["x"],  # Must be "x" to match server expectations
    output_names=["output"],
    dynamic_axes={
        'x': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    verbose=True
)


Preparing for submission...
Model test - Input shape: torch.Size([1, 3, 32, 32]), Output shape: torch.Size([1, 1024])


In [37]:
# Validate ONNX
validate_onnx("stolen_model.onnx")

ONNX validation passed! Model meets submission requirements


In [None]:
# Submit model
submit_model("stolen_model.onnx")

## 7. Submission Helpers

## 8. Run the Attack

In [None]:
if __name__ == "__main__":
    main()