# Model Stealing Attack - Assignment 2

**Team Number**: 19 
**Task**: Implement a model stealing attack against B4B-protected encoder while minimizing L2 distance

## Strategy Overview
1. Smart query scheduling to avoid B4B's coverage thresholds
2. Noise-adaptive model training
3. Embedding space-aware sampling

In [1]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import requests
import io
import base64
import json
import numpy as np
from tqdm import tqdm
from sklearn.decomposition import PCA
from collections import defaultdict
import pickle
from PIL import Image
import torchvision.transforms as transforms
import onnxruntime as ort

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## 1. API Connection Setup

Modified to include token and port handling from assignment PDF

In [2]:
TOKEN = "50407833"  # Replace with token from assignment email
PORT = None                # Will be set after API launch
SEED = None                # Will be set after API launch

def launch_api():
    """Request new API instance as per assignment instructions"""
    global PORT, SEED
    
    response = requests.get(
        "http://34.122.51.94:9090/stealing_launch",
        headers={"token": TOKEN}
    )
    answer = response.json()
    
    if 'detail' in answer:
        raise Exception(f"API launch failed: {answer['detail']}")
    
    SEED = str(answer['seed'])
    PORT = str(answer['port'])
    print(f"API launched. Seed: {SEED}, Port: {PORT}")
    return SEED, PORT

def query_api(images):
    """Query the victim encoder API with batch of images"""
    endpoint = "/query"
    url = f"http://34.122.51.94:{PORT}" + endpoint
    image_data = []
    
    # Convert images to base64 as required by API
    for img in images:
        if isinstance(img, torch.Tensor):
            img = transforms.ToPILImage()(img)
            
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)
        img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
        image_data.append(img_base64)
    
    payload = json.dumps(image_data)
    response = requests.get(url, files={"file": payload}, headers={"token": TOKEN})
    
    if response.status_code == 200:
        return response.json()["representations"]
    else:
        raise Exception(f"Query failed. Code: {response.status_code}")

## 2. Coverage Tracking System

Implements local LSH to estimate bucket coverage

In [None]:
class CoverageTracker:
    def __init__(self, bucket_size=0.15, max_buckets=4096):
        self.bucket_map = defaultdict(bool)
        self.bucket_size = bucket_size
        self.max_buckets = max_buckets
        self.query_count = 0
        
    def _hash_embedding(self, emb):
        """Simulate LSH bucket assignment"""
        return tuple(np.floor(emb[:5] / self.bucket_size))  # Use first 5 dims
        
    def update_coverage(self, embeddings):
        """Update coverage estimate with new embeddings"""
        for emb in embeddings:
            self.bucket_map[self._hash_embedding(emb)] = True
        self.query_count += len(embeddings)
        
    def get_coverage(self):
        """Return current coverage percentage"""
        return len(self.bucket_map) / self.max_buckets
    
    def is_safe(self, sample_size=1000):
        """Check if additional queries would be safe"""
        current = len(self.bucket_map)
        projected = current + (sample_size * 0.1)  # Estimated new buckets
        return projected / self.max_buckets < 0.3  # Stay below 30% coverage

## 3. Data Loading and Processing

Modified to use provided ModelStealingPub.pt

In [None]:
class StealingDataset(Dataset):
    def __init__(self, data_path="ModelStealingPub.pt"):
        original_data = torch.load(data_path)
        self.images = original_data.imgs  # PIL Images
        self.labels = original_data.labels if hasattr(original_data, 'labels') else None
        
        # Transformations matching API pre-processing
        self.transform = transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        if not isinstance(img, torch.Tensor):
            img = self.transform(img)
        return img

## 4. Model Architecture

Custom architecture designed to match victim's output space

In [None]:
class EncoderStealer(nn.Module):
    def __init__(self, output_dim=1024):
        super().__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        
        self.embedding = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )
        
        # Noise adaptation module
        self.noise_adapter = nn.Sequential(
            nn.Linear(output_dim, output_dim//2),
            nn.ReLU(),
            nn.Linear(output_dim//2, output_dim))
        
    def forward(self, x):
        features = self.feature_extractor(x)
        emb = self.embedding(features)
        return emb + 0.1 * self.noise_adapter(emb)  # Learn residual noise

## 5. Smart Query Strategy

Implements coverage-aware sampling

In [None]:
def get_strategic_samples(dataset, tracker, batch_size=1000):
    """Select samples that maximize information while minimizing coverage"""
    # First get low-diversity samples
    if tracker.query_count < 5000:
        indices = np.random.choice(len(dataset), batch_size)
        return [dataset[i] for i in indices]
    
    # Later use PCA-guided sampling
    dummy_embs = np.random.randn(1000, 1024)
    pca = PCA(n_components=10).fit(dummy_embs)
    
    # Select samples that fill empty regions
    scores = []
    for _ in range(batch_size):
        idx = np.random.randint(len(dataset))
        img = dataset[idx]
        # Simulate embedding with PCA components
        fake_emb = np.dot(np.random.randn(10), pca.components_)
        bucket = tracker._hash_embedding(fake_emb)
        if not tracker.bucket_map.get(bucket, False):
            scores.append((idx, 10))  # High priority for empty buckets
        else:
            scores.append((idx, 1))
    
    # Select top samples
    scores.sort(key=lambda x: -x[1])
    return [dataset[idx] for idx, _ in scores[:batch_size]]

## 6. Training Loop with Noise Adaptation

In [None]:
def train_model(model, train_loader, epochs=5):
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4)
    criterion = nn.MSELoss()
    
    for epoch in range(epochs):
        total_loss = 0
        for images, targets in tqdm(train_loader):
            images, targets = images.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

## 7. Main Execution Flow

Follows assignment requirements

In [None]:
def main():
    # Initialize API connection
    launch_api()
    
    # Load dataset
    dataset = StealingDataset()
    print(f"Loaded dataset with {len(dataset)} images")
    
    # Initialize coverage tracker
    tracker = CoverageTracker()
    
    # Phase 1: Initial queries (low diversity)
    print("Phase 1: Initial low-diversity queries")
    phase1_images = [dataset[i] for i in np.random.choice(len(dataset), 5000)]
    phase1_embs = []
    
    # Batch queries as per API limits (1000 images per query)
    for i in range(0, len(phase1_images), 1000):
        batch = phase1_images[i:i+1000]
        embs = query_api(batch)
        phase1_embs.extend(embs)
        tracker.update_coverage(embs)
        print(f"Queries: {tracker.query_count}, Coverage: {tracker.get_coverage():.2%}")
    
    # Phase 2: Strategic expansion
    print("\nPhase 2: Strategic coverage expansion")
    train_data = list(zip(phase1_images, phase1_embs))
    
    while tracker.query_count < 100000 and tracker.get_coverage() < 0.3:
        # Get smart samples
        batch_images = get_strategic_samples(dataset, tracker)
        batch_embs = query_api(batch_images)
        
        # Update tracking
        tracker.update_coverage(batch_embs)
        train_data.extend(zip(batch_images, batch_embs))
        
        # Periodic training
        if len(train_data) % 5000 == 0:
            loader = DataLoader(train_data, batch_size=256, shuffle=True)
            model = EncoderStealer().to(device)
            train_model(model, loader, epochs=3)
            
        print(f"Queries: {tracker.query_count}, Coverage: {tracker.get_coverage():.2%}")
    
    # Final training
    print("\nFinal training")
    final_loader = DataLoader(train_data, batch_size=256, shuffle=True)
    model = EncoderStealer().to(device)
    train_model(model, final_loader, epochs=10)
    
    # Export to ONNX as required
    print("\nExporting model to ONNX")
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    torch.onnx.export(
        model,
        dummy_input,
        "stolen_model.onnx",
        input_names=["x"],
        output_names=["output"],
        dynamic_axes={'x': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    
    # Validate ONNX
    validate_onnx("stolen_model.onnx")
    
    # Submit as per assignment
    submit_model("stolen_model.onnx")

## 8. Submission Helpers

Directly from assignment example code

In [None]:
def validate_onnx(model_path):
    """Validate the exported ONNX model meets requirements"""
    try:
        session = ort.InferenceSession(model_path)
        input_name = session.get_inputs()[0].name
        
        # Test with random input
        test_input = np.random.randn(1, 3, 32, 32).astype(np.float32)
        output = session.run(None, {input_name: test_input})[0]
        
        # Check output dimensions
        assert output.shape == (1, 1024), f"Invalid output shape: {output.shape}"
        print("ONNX validation passed!")
        
    except Exception as e:
        raise Exception(f"ONNX validation failed: {str(e)}")

def submit_model(model_path):
    """Submit the model for evaluation"""
    url = "http://34.122.51.94:9090/stealing"
    
    with open(model_path, "rb") as f:
        files = {"file": f}
        headers = {"token": TOKEN, "seed": SEED}
        response = requests.post(url, files=files, headers=headers)
    
    if response.status_code == 200:
        print("Submission successful!")
        print(response.json())
    else:
        print(f"Submission failed: {response.status_code}")
        print(response.text)

## 9. Run the Attack

Execute the complete pipeline

In [None]:
if __name__ == "__main__":
    main()