# Model Stealing Attack (Improved) - Assignment 2

**Team Number**: 15 
**Task**: Implement a model stealing attack against B4B-protected encoder while minimizing L2 distance

In [105]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import requests
import io
import base64
import json
import numpy as np
from tqdm import tqdm
from sklearn.decomposition import PCA
from collections import defaultdict
import pickle
from PIL import Image
import torchvision.transforms as transforms
import onnxruntime as ort
import time

# Set device
device = torch.device("cpu") # if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## 1. API Connection Setup

In [106]:
TOKEN = "50407833" 
PORT = None
SEED = None

def launch_api():
    # global PORT, SEED
    response = requests.get(
        "http://34.122.51.94:9090/stealing_launch",
        headers={"token": TOKEN}
    )
    answer = response.json()
    if 'detail' in answer:
        raise Exception(f"API launch failed: {answer['detail']}")
    SEED = str(answer['seed'])
    # SEED = 30607040
    PORT = str(answer['port'])
    # PORT = 9944
    # print(f"API launched. Seed: {SEED}, Port: {PORT}")
    return SEED, PORT


In [107]:
def query_api(images, retries=3, delay=60, timeout=120):
    endpoint = "/query"
    url = f"http://34.122.51.94:{PORT}" + endpoint
    
    # Critical: API requires exactly 1000 images
    if len(images) != 1000:
        print(f"WARNING: API requires exactly 1000 images, got {len(images)}")
        if len(images) < 1000:
            # Pad with duplicates of the first image to reach 1000
            padding_needed = 1000 - len(images)
            print(f"Padding batch with {padding_needed} duplicate images")
            padding = [images[0]] * padding_needed
            original_count = len(images)
            images = images + padding
        else:
            # Truncate to exactly 1000 images
            print(f"Truncating batch from {len(images)} to 1000 images")
            images = images[:1000]
    
    image_data = []
    for img in images:
        img = transforms.ToPILImage()(img.cpu())
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)
        img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
        image_data.append(img_base64)
    
    payload = json.dumps(image_data)
    payload_size = len(payload) / (1024 * 1024)  # Size in MB
    print(f"Request payload size: {payload_size:.2f} MB")
    
    try:
        response = requests.get(url, files={"file": payload}, headers={"token": TOKEN}, timeout=timeout)
        
        if response.status_code == 200:
            representations = response.json()["representations"]
            
            # If we padded the batch, return only the embeddings for original images
            if 'original_count' in locals():
                return representations[:original_count]
            return representations
            
        elif response.status_code == 429:
            if retries > 0:
                print("Rate limited. Retrying after delay...")
                time.sleep(delay)
                return query_api(images, retries - 1, delay * 2, timeout)
            else:
                raise Exception("Too many retries. Still getting rate-limited.")
        elif response.status_code == 400:
            print("Bad request error (400). This typically means the API rejected the input format.")
            if retries > 0:
                print("Trying again with longer timeout...")
                return query_api(images, retries - 1, delay, timeout * 1.5)
            else:
                raise Exception("Bad request error persisted after retries.")
        else:
            raise Exception(f"Query failed. Code: {response.status_code}")
            
    except (requests.exceptions.Timeout, requests.exceptions.ReadTimeout) as e:
        print(f"Timeout error: {e}")
        # Don't split anymore since we need exactly 1000 images
        print("Timeout occurred. Increasing timeout and retrying...")
        if retries > 0:
            return query_api(images, retries - 1, delay, timeout * 2)
        else:
            raise Exception("Max retries exceeded with timeout errors.")

In [108]:
# New improved function: Query with repetition to reduce noise
def query_with_repetition(images, n_repeats=3, save_prefix="batch"):
    """Query the same images multiple times to average out noise"""
    all_embeddings = []
    
    for i in range(n_repeats):
        print(f"Repetition {i+1}/{n_repeats} for batch of {len(images)} images")
        embeddings = query_api(images)
        all_embeddings.append(embeddings)
        
        # Save intermediate results to avoid losing progress
        with open(f'{save_prefix}_rep_{i+1}_embs.pkl', 'wb') as f:
            pickle.dump(embeddings, f)
    
    # Average the embeddings to reduce noise
    averaged_embeddings = []
    for i in range(len(images)):
        emb_list = [all_embeddings[j][i] for j in range(n_repeats)]
        averaged_embeddings.append(np.mean(emb_list, axis=0))
    
    return averaged_embeddings

## 2. Coverage Tracking System

In [109]:
class CoverageTracker:
    def __init__(self, bucket_size=0.15, max_buckets=4096):
        self.bucket_map = defaultdict(bool)
        self.bucket_size = bucket_size
        self.max_buckets = max_buckets
        self.query_count = 0
        self.embedding_history = []
        
    def _hash_embedding(self, emb):
        emb = np.array(emb)
        return tuple(np.floor(emb[:5] / self.bucket_size))
        
    def update_coverage(self, embeddings):
        self.embedding_history.extend(embeddings)
        for emb in embeddings:
            self.bucket_map[self._hash_embedding(emb)] = True
        self.query_count += len(embeddings)
        
    def get_coverage(self):
        return len(self.bucket_map) / self.max_buckets
    
    def is_safe(self, sample_size=1000):
        current = len(self.bucket_map)
        projected = current + (sample_size * 0.1)
        return projected / self.max_buckets < 0.3
    
    def get_embedding_stats(self):
        """Analyze embeddings to detect bucket boundaries"""
        if len(self.embedding_history) < 1000:
            print(f"Warning: Only {len(self.embedding_history)} samples available for embedding stats")
            return None
        
        embeddings = np.array(self.embedding_history)
        dim = embeddings.shape[1]
        boundaries = []
        
        for i in range(min(dim, 10)):  # Analyze first 10 dimensions
            values = embeddings[:, i]
            sorted_values = np.sort(values)
            # Detect discontinuities in the distribution
            if len(sorted_values) > 1:
                diffs = sorted_values[1:] - sorted_values[:-1]
                threshold = np.percentile(diffs, 95)
                large_gaps = np.where(diffs > threshold)[0]
                if len(large_gaps) > 0:
                    gap_edges = sorted_values[large_gaps + 1]
                    boundaries.append(gap_edges)
        
        return boundaries

## 3. Data Loading and Processing

In [None]:
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform or transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.Lambda(self._ensure_rgb),
            transforms.ToTensor(),
            transforms.Normalize((0.2980, 0.2962, 0.2987), (0.2886, 0.2875, 0.2889))
        ])

    def _ensure_rgb(self, img):
        if img.mode != 'RGB':
            return img.convert('RGB')
        return img

    def __getitem__(self, index):
        img = self.imgs[index]
        if self.transform:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.ids)

def load_stealing_dataset(data_path="ModelStealingPub.pt"):
    original_data = torch.load(data_path, map_location="cpu", weights_only=False)
    dataset = TaskDataset()
    dataset.ids = original_data.ids
    dataset.imgs = original_data.imgs
    dataset.labels = original_data.labels
    return dataset

def collate_fn(batch):
    """Custom collate function to handle (image, embedding) pairs"""
    images = torch.stack([item[0] for item in batch])
    targets = torch.stack([item[1] for item in batch])
    return images, targets

## 4. Improved Strategic Query Selection

In [111]:
def get_strategic_queries(dataset, train_data, model=None, batch_size=200):
    if len(train_data) > 0:
        print(f"Using {len(train_data)} training samples to guide selection")
    """Strategic query selection based on uncertainty and diversity"""
    # If we have existing data and a model, use them to guide sampling
    if model is not None and len(train_data) > 5000:
        # Create a PCA model from existing embeddings to identify high-variance directions
        existing_embeddings = torch.stack([emb for _, emb in train_data]).cpu().numpy()
        pca = PCA(n_components=min(10, existing_embeddings.shape[1]))
        pca.fit(existing_embeddings)
        
        # Get dataset candidates and score them by expected information gain
        candidate_indices = np.random.choice(len(dataset), min(10000, len(dataset)), replace=False)
        candidates = [dataset[i] for i in candidate_indices]
        
        # Use current model to predict embeddings
        with torch.no_grad():
            model.eval()
            candidate_tensors = torch.stack(candidates).to(device)
            batches = torch.split(candidate_tensors, 64)
            predictions = []
            for batch in batches:
                pred = model(batch).cpu().numpy()
                predictions.extend(pred)
            predictions = np.array(predictions)
        
        # Score candidates by uncertainty/projection onto principal components
        scores = np.abs(np.dot(predictions, pca.components_.T)).sum(axis=1)
        
        # B4B-aware selection: avoid examples that are too similar to existing ones
        # This helps prevent triggering the B4B defense mechanism
        if len(train_data) > 10000:
            existing_sample = np.array([emb.cpu().numpy() for _, emb in train_data[:5000]])
            for i, pred in enumerate(predictions):
                # Compute minimum distance to existing embeddings
                min_dist = np.min(np.linalg.norm(existing_sample - pred, axis=1))
                # Boost score for examples that are different from existing ones
                if min_dist > 0.5:  # Threshold based on B4B bucket size
                    scores[i] *= 1.5
        
        # Select most informative examples
        selected_indices = np.argsort(scores)[-batch_size:]
        return [candidates[i] for i in selected_indices]
    else:
        # If we don't have enough data yet, use farthest-first traversal for diversity
        if len(train_data) > 0:
            # Sample some candidates
            candidate_indices = np.random.choice(len(dataset), min(10000, len(dataset)), replace=False)
            candidates = [dataset[i] for i in candidate_indices]
            
            # Convert existing data to feature space
            existing_images = torch.stack([img for img, _ in train_data[:min(len(train_data), 1000)]]).cpu()
            existing_flat = existing_images.view(existing_images.size(0), -1).numpy()
            
            # Calculate distances
            candidate_tensors = torch.stack(candidates)
            candidate_flat = candidate_tensors.view(candidate_tensors.size(0), -1).numpy()
            
            # For each candidate, find distance to closest example in train_data
            min_distances = []
            for cand in candidate_flat:
                dists = np.linalg.norm(existing_flat - cand, axis=1)
                min_distances.append(np.min(dists))
                
            # Select candidates with largest minimum distance
            selected_indices = np.argsort(min_distances)[-batch_size:]
            return [candidates[i] for i in selected_indices]
        else:
            # First batch: just random
            indices = np.random.choice(len(dataset), batch_size, replace=False)
            return [dataset[i] for i in indices]

## 5. Improved Model Architecture

In [None]:
class ImprovedEncoderStealer(nn.Module):
    def __init__(self, output_dim=1024):
        super().__init__()
        
        # More powerful feature extractor with residual connections
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        
        self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        
        # Final embedding layers
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(512, 1024)
        self.fc2 = nn.Linear(1024, output_dim)
        
        # Add dropout for regularization
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # First block
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        
        # Second block
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        
        # Third block - simplify the skip connection
        identity = x.clone()  # Simpler skip connection
        x = F.relu(self.bn3(self.conv3(x)))
        if identity.shape == x.shape: 
            x = x + identity
        x = F.max_pool2d(x, 2)
        
        # Fourth block
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.adaptive_avg_pool2d(x, 1)
        
        # Embedding generation
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x
    
    # def forward(self, x):
    #     # First block
    #     identity = x
    #     x = F.relu(self.bn1(self.conv1(x)))
    #     x = F.max_pool2d(x, 2)
        
    #     # Second block
    #     x = F.relu(self.bn2(self.conv2(x)))
    #     x = F.max_pool2d(x, 2)
        
    #     # Third block with skip connection if possible
    #     identity = F.adaptive_avg_pool2d(identity, x.shape[2:])
    #     if identity.size(1) != x.size(1):
    #         identity = F.conv2d(identity, torch.ones(x.size(1), identity.size(1), 1, 1).to(x.device), 
    #                         padding=0, groups=identity.size(1))
    #     identity = F.adaptive_avg_pool2d(identity, x.shape[2:])
    #     x = F.relu(self.bn3(self.conv3(x))) + identity
    #     x = F.max_pool2d(x, 2)
        
    #     # Fourth block
    #     x = F.relu(self.bn4(self.conv4(x)))
    #     x = F.adaptive_avg_pool2d(x, 1)
        
    #     # Embedding generation
    #     x = self.flatten(x)
    #     x = F.relu(self.fc1(x))
    #     x = self.dropout(x)
    #     x = self.fc2(x)
        
        # return x

In [113]:
class EnsembleWrapper(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = nn.ModuleList(models)
    
    def forward(self, x):
        outputs = [model(x) for model in self.models]
        return sum(outputs) / len(outputs)

## 6. Improved Training Loop

In [114]:
def train_with_bucket_awareness(model, train_loader, bucket_edges=None, epochs=5, checkpoint_freq=1):
    if bucket_edges is not None:
        valid_edges = sum(1 for edges in bucket_edges if len(edges) > 0)
        print(f"Using {valid_edges} dimensions with detected bucket edges")
        
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)
    
    # Use Huber loss for robustness to outliers
    criterion = nn.SmoothL1Loss(beta=0.1)
    
    best_loss = float('inf')
    
    for epoch in range(epochs):
        total_loss = 0
        for images, targets in tqdm(train_loader):
            images = images.to(device)
            targets = targets.float().to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
            # Standard loss
            loss = criterion(outputs, targets)
            
            # Add bucket-aware regularization if bucket edges are available
            if bucket_edges is not None:
                bucket_loss = 0
                for dim, edges in enumerate(bucket_edges):
                    if len(edges) > 0 and dim < outputs.shape[1]:
                        # Encourage outputs to be away from detected bucket edges
                        dim_output = outputs[:, dim]
                        for edge in edges:
                            # Penalize being too close to bucket edge
                            distance_to_edge = torch.abs(dim_output - edge)
                            bucket_loss += torch.mean(torch.exp(-5 * distance_to_edge))
                
                if bucket_loss > 0:
                    loss = loss + 0.01 * bucket_loss
            
            loss.backward()
            
            # Gradient clipping to prevent instability
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        scheduler.step(avg_loss)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
        
        # Save checkpoint for this epoch
        if epoch % checkpoint_freq == 0 or epoch == epochs - 1:
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(model.state_dict(), f'model_epoch_{epoch+1}_loss_{avg_loss:.4f}.pt')
                print(f"Saved model checkpoint at epoch {epoch+1}")

In [115]:
def train_ensemble(train_data, n_models=3, bucket_edges=None):
    models = []
    
    for i in range(n_models):
        print(f"Training model {i+1}/{n_models}...")
        model = ImprovedEncoderStealer(output_dim=1024).to(device)
        
        # Use different subsets of data for each model for diversity
        indices = np.random.choice(len(train_data), int(len(train_data) * 0.8), replace=False)
        subset_data = [train_data[idx] for idx in indices]
        
        loader = DataLoader(subset_data, batch_size=256, shuffle=True, collate_fn=collate_fn)
        train_with_bucket_awareness(model, loader, bucket_edges, epochs=3)
        models.append(model)
    
    return models

In [116]:
# Add to imports at the top
import os
import glob

# Recovery function
def load_latest_checkpoint():
    """Load the latest saved checkpoint from disk"""
    # Find the latest checkpoint
    checkpoints = glob.glob('train_data_*.pkl')
    if not checkpoints:
        print("No checkpoints found")
        return None, None
    
    # Sort by file modification time (most recent first)
    latest = max(checkpoints, key=os.path.getmtime)
    print(f"Loading checkpoint: {latest}")
    
    with open(latest, 'rb') as f:
        train_data = pickle.load(f)
    
    # Find corresponding tracker
    tracker_file = latest.replace('train_data', 'tracker')
    if os.path.exists(tracker_file):
        with open(tracker_file, 'rb') as f:
            tracker = pickle.load(f)
    else:
        print("Warning: No matching tracker file found")
        tracker = CoverageTracker()
        
    print(f"Loaded {len(train_data)} samples with {tracker.query_count} queries")
    return train_data, tracker

# Uncomment this line to resume from checkpoint
train_data, tracker = load_latest_checkpoint() or ([], CoverageTracker())

Loading checkpoint: train_data_phase1_complete.pkl
Loaded 2000 samples with 0 queries


## 7. Main Execution Flow

In [117]:
# Initialize API connection
# SEED, PORT =launch_api()
SEED = 38215912
PORT = 9944
print(SEED,PORT)
# 38215912 9944 15:27

38215912 9944


In [137]:
# Load dataset
dataset = load_stealing_dataset()
print(f"Loaded dataset with {len(dataset)} images")

# Initialize coverage tracker
tracker = CoverageTracker()

Loaded dataset with 13000 images


In [None]:
# # Phase 1: Initial data collection with strategic diversity focus
# print("Phase 1: Initial data collection with strategic diversity")

# # Use checkpoint if available, otherwise initialize
# recovered_data = load_latest_checkpoint()
# if recovered_data is not None and recovered_data[0]:
#     train_data, tracker = recovered_data
#     print(f"Resuming from checkpoint with {len(train_data)} samples")
# else:
#     train_data = []
#     tracker = CoverageTracker()

# # Track queried image IDs to avoid duplicates
# queried_ids = set()

# # More batches with smaller batch size
# n_batches = 2  # Reduced from 8 to query fewer images with more focus on quality
# batch_size = 1000  # Slightly larger batch size to get more coverage with fewer batches

# for i in range(n_batches):
#     try:
#         print(f"Processing batch {i+1}/{n_batches}")
#         # Use strategic query selection that prioritizes diversity
#         batch_images = get_strategic_queries(dataset, train_data, batch_size=batch_size)
        
#         # Decide on repetition strategy: 
#         # - First batch: use repetition for stable initialization
#         # - Later batches: prioritize coverage over noise reduction
#         if i == 0:
#             print("Using repeated queries for initial batch to establish baseline...")
#             batch_embs = query_with_repetition(batch_images, n_repeats=2, save_prefix=f"phase1_batch_{i+1}")
#         else:
#             print("Using single queries to maximize diversity...")
#             batch_embs = query_api(batch_images)
        
#         # Update train data and save immediately
#         for img, emb in zip(batch_images, batch_embs):
#             # Store image ID to track uniqueness
#             img_id = id(img)
#             if img_id not in queried_ids:
#                 queried_ids.add(img_id)
#                 train_data.append((img, torch.tensor(emb).float()))
        
#         # Save progress after each batch
#         with open(f'train_data_phase1_batch_{i+1}.pkl', 'wb') as f:
#             pickle.dump(train_data, f)
        
#         tracker.update_coverage(batch_embs)
        
#         # Save tracker state
#         with open(f'tracker_phase1_batch_{i+1}.pkl', 'wb') as f:
#             pickle.dump(tracker, f)
            
#         print(f"Batch {i+1} complete. Total samples: {len(train_data)}")
#         print(f"Unique images: {len(queried_ids)}")
#         print(f"Queries: {tracker.query_count}, Coverage: {tracker.get_coverage():.2%}")
        
#         # Early stopping if we have good coverage
#         if tracker.get_coverage() > 0.08:
#             print("Sufficient initial coverage achieved. Moving to Phase 2.")
#             break
            
#     except Exception as e:
#         print(f"Error in batch {i+1}: {str(e)}")
#         # Save what we have so far
#         if train_data:
#             with open(f'train_data_error_phase1_{i+1}.pkl', 'wb') as f:
#                 pickle.dump(train_data, f)
#             with open(f'tracker_error_phase1_{i+1}.pkl', 'wb') as f:
#                 pickle.dump(tracker, f)
#         print("Continuing to next batch...")

# # Save final Phase 1 dataset
# with open('train_data_phase1_complete.pkl', 'wb') as f:
#     pickle.dump(train_data, f)
# print(f"Phase 1 complete: {len(train_data)} samples collected using {tracker.query_count} queries")
# print(f"Unique images in dataset: {len(queried_ids)}")

Phase 1: Initial data collection with strategic diversity
Loading checkpoint: train_data_phase1_complete.pkl
Loaded 2000 samples with 0 queries
Resuming from checkpoint with 2000 samples
Processing batch 1/2
Using 2000 training samples to guide selection
Using repeated queries for initial batch to establish baseline...
Repetition 1/2 for batch of 1000 images
Request payload size: 2.55 MB
Repetition 2/2 for batch of 1000 images
Request payload size: 2.55 MB
Rate limited. Retrying after delay...
Request payload size: 2.55 MB
Batch 1 complete. Total samples: 3000
Unique images: 1000
Queries: 1000, Coverage: 1.07%
Processing batch 2/2
Using 3000 training samples to guide selection
Using single queries to maximize diversity...
Request payload size: 2.54 MB
Rate limited. Retrying after delay...
Request payload size: 2.54 MB
Batch 2 complete. Total samples: 4000
Unique images: 2000
Queries: 2000, Coverage: 1.10%
Phase 1 complete: 4000 samples collected using 2000 queries
Unique images in data

In [None]:
# Phase 1 With B4B protection
# Import random for jittered delays if not already imported
import random

# Function to apply strategic delays between API queries
def apply_strategic_delay(batch_number, phase=1, high_value=False):
    """Apply a strategic delay between API queries to avoid triggering B4B defenses"""
    # Base delay increases with batch number to simulate natural user behavior
    base_delay = 5 + (batch_number * 2)
    
    # Add randomness to avoid predictable patterns
    jitter = random.uniform(0.5, 1.5)
    
    # Phase 2 gets longer delays since model is guided and queries are more targeted
    if phase == 2:
        base_delay *= 1.5
    
    # High-value queries get additional delay to ensure quality
    if high_value:
        base_delay *= 1.2
    
    delay_time = base_delay * jitter
    print(f"Applying strategic delay of {delay_time:.2f}s before query (B4B avoidance)...")
    time.sleep(delay_time)

# Phase 1: Initial data collection with strategic diversity focus
print("Phase 1: Initial data collection with strategic diversity")

# Use checkpoint if available, otherwise initialize
recovered_data = load_latest_checkpoint()
if recovered_data is not None and recovered_data[0]:
    train_data, tracker = recovered_data
    print(f"Resuming from checkpoint with {len(train_data)} samples")
else:
    train_data = []
    tracker = CoverageTracker()

# Track queried image IDs to avoid duplicates
queried_ids = set()

# More batches with smaller batch size
n_batches = 2  # Reduced from 8 to query fewer images with more focus on quality
batch_size = 1000  # Slightly larger batch size to get more coverage with fewer batches

for i in range(n_batches):
    try:
        print(f"Processing batch {i+1}/{n_batches}")
        # Use strategic query selection that prioritizes diversity
        batch_images = get_strategic_queries(dataset, train_data, batch_size=batch_size)
        
        # Apply strategic delay before API query to avoid B4B detection
        apply_strategic_delay(i, phase=1)
        
        # Decide on repetition strategy: 
        # - First batch: use repetition for stable initialization
        # - Later batches: prioritize coverage over noise reduction
        if i == 0:
            print("Using repeated queries for initial batch to establish baseline...")
            batch_embs = query_with_repetition(batch_images, n_repeats=2, save_prefix=f"phase1_batch_{i+1}")
        else:
            print("Using single queries to maximize diversity...")
            batch_embs = query_api(batch_images)
        
        # Update train data and save immediately
        for img, emb in zip(batch_images, batch_embs):
            # Store image ID to track uniqueness
            img_id = id(img)
            if img_id not in queried_ids:
                queried_ids.add(img_id)
                train_data.append((img, torch.tensor(emb).float()))
        
        # Save progress after each batch
        with open(f'train_data_phase1_batch_{i+1}.pkl', 'wb') as f:
            pickle.dump(train_data, f)
        
        tracker.update_coverage(batch_embs)
        
        # Save tracker state
        with open(f'tracker_phase1_batch_{i+1}.pkl', 'wb') as f:
            pickle.dump(tracker, f)
            
        print(f"Batch {i+1} complete. Total samples: {len(train_data)}")
        print(f"Unique images: {len(queried_ids)}")
        print(f"Queries: {tracker.query_count}, Coverage: {tracker.get_coverage():.2%}")
        
        # Early stopping if we have good coverage
        if tracker.get_coverage() > 0.08:
            print("Sufficient initial coverage achieved. Moving to Phase 2.")
            break
            
    except Exception as e:
        print(f"Error in batch {i+1}: {str(e)}")
        # Save what we have so far
        if train_data:
            with open(f'train_data_error_phase1_{i+1}.pkl', 'wb') as f:
                pickle.dump(train_data, f)
            with open(f'tracker_error_phase1_{i+1}.pkl', 'wb') as f:
                pickle.dump(tracker, f)
        print("Continuing to next batch...")

# Save final Phase 1 dataset
with open('train_data_phase1_complete.pkl', 'wb') as f:
    pickle.dump(train_data, f)
print(f"Phase 1 complete: {len(train_data)} samples collected using {tracker.query_count} queries")
print(f"Unique images in dataset: {len(queried_ids)}")

In [120]:
# Add this to check embedding diversity across your dataset
def check_embedding_diversity(train_data, sample_size=100):
    if len(train_data) < sample_size:
        sample_size = len(train_data)
    
    # Sample random indices
    indices = np.random.choice(len(train_data), sample_size, replace=False)
    
    # Collect embeddings
    embeddings = [train_data[i][1].cpu().numpy() for i in indices]
    
    # Check uniqueness
    unique_embeddings = set()
    for emb in embeddings:
        # Use hash of rounded values to check approximate uniqueness
        emb_hash = hash(tuple(np.round(emb[:10], 3)))
        unique_embeddings.add(emb_hash)
    
    print(f"Checked {sample_size} samples, found {len(unique_embeddings)} unique embedding patterns")
    return len(unique_embeddings) / sample_size

# Call this before training
diversity_ratio = check_embedding_diversity(train_data)
if diversity_ratio < 0.5:
    print("WARNING: Low embedding diversity detected. Model training may be ineffective.")
    print("Consider gathering more diverse examples before proceeding.")
    
# Modify training parameters based on diversity
initial_loader = DataLoader(
    train_data, 
    batch_size=64,  # Smaller batch size for stability
    shuffle=True, 
    collate_fn=collate_fn
)

# Use simpler loss function if diversity is low
criterion = nn.MSELoss() if diversity_ratio < 0.3 else nn.SmoothL1Loss(beta=0.1)

Checked 100 samples, found 100 unique embedding patterns


In [121]:
# Replace with this more robust implementation
device = torch.device("cpu")  # Consistently use CPU for stability

model = ImprovedEncoderStealer(output_dim=1024).to(device)
criterion = nn.SmoothL1Loss(beta=0.1)  # Use this since diversity is high
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Add proper progress tracking
total_batches = len(initial_loader)
print(f"Starting training with {len(train_data)} samples ({total_batches} batches)")

# Basic training loop with more information
for epoch in range(3):
    running_loss = 0.0
    batch_count = 0
    for images, targets in initial_loader:
        images = images.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        # Add gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Better progress reporting
        running_loss += loss.item()
        batch_count += 1
        if batch_count % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_count}/{total_batches}, Loss: {running_loss/batch_count:.4f}")
    
    # Save checkpoint after each epoch
    avg_loss = running_loss / batch_count
    print(f"Epoch {epoch+1} complete, Average loss: {avg_loss:.4f}")
    torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pt')

Starting training with 4000 samples (63 batches)
Epoch 1, Batch 10/63, Loss: 0.3057
Epoch 1, Batch 20/63, Loss: 0.2004
Epoch 1, Batch 30/63, Loss: 0.1546
Epoch 1, Batch 40/63, Loss: 0.1282
Epoch 1, Batch 50/63, Loss: 0.1124
Epoch 1, Batch 60/63, Loss: 0.1009
Epoch 1 complete, Average loss: 0.0980
Epoch 2, Batch 10/63, Loss: 0.0363
Epoch 2, Batch 20/63, Loss: 0.0350
Epoch 2, Batch 30/63, Loss: 0.0363
Epoch 2, Batch 40/63, Loss: 0.0357
Epoch 2, Batch 50/63, Loss: 0.0359
Epoch 2, Batch 60/63, Loss: 0.0358
Epoch 2 complete, Average loss: 0.0355
Epoch 3, Batch 10/63, Loss: 0.0320
Epoch 3, Batch 20/63, Loss: 0.0329
Epoch 3, Batch 30/63, Loss: 0.0318
Epoch 3, Batch 40/63, Loss: 0.0314
Epoch 3, Batch 50/63, Loss: 0.0307
Epoch 3, Batch 60/63, Loss: 0.0304
Epoch 3 complete, Average loss: 0.0305


In [57]:
train_data

[]

In [None]:
# # Phase 2: Strategic queries with model guidance and dataset limits
# print("\nPhase 2: Strategic query expansion with dataset awareness")

# # Configure Phase 2 parameters
# max_queries = min(95000, tracker.query_count + 50000)  # Keep some quota for final phases
# batch_size = 1000  # Smaller batch size to reduce timeout risk
# save_frequency = 500  # Save more frequently
# max_unique_images = min(9000, len(dataset) * 0.7)  # Target 70% of dataset

# # Keep track of unique images we've queried
# if 'queried_ids' not in locals():
#     queried_ids = set([id(img) for img, _ in train_data])
#     print(f"Identified {len(queried_ids)} unique images in current dataset")

# # Detect initial bucket boundaries for improved training
# if tracker.query_count > 2000:  # Only if we have enough data
#     bucket_edges = tracker.get_embedding_stats()
#     print("Bucket edges detected for improved query selection")

# # Define a tracker for high-quality samples that might need repetition
# high_value_queries = 0
# max_high_value_queries = 10000  # Limit for high-value repeated queries

# print(f"Target unique images: {max_unique_images}, Current unique images: {len(queried_ids)}")

# while (tracker.query_count < max_queries and 
#        tracker.get_coverage() < 0.25 and 
#        len(queried_ids) < max_unique_images):
#     try:
#         # Use the model to guide query selection
#         batch_images = get_strategic_queries(dataset, train_data, model, batch_size=batch_size)
        
#         # Filter out any images we've already queried
#         new_batch_images = []
#         for img in batch_images:
#             img_id = id(img)
#             if img_id not in queried_ids:
#                 new_batch_images.append(img)
#                 queried_ids.add(img_id)
        
#         if len(new_batch_images) == 0:
#             print("All candidates already queried. Increasing sampling pool...")
#             candidate_indices = np.random.choice(len(dataset), min(5000, len(dataset)), replace=False)
#             candidates = [dataset[i] for i in candidate_indices]
#             new_batch_images = [img for img in candidates if id(img) not in queried_ids][:batch_size]
            
#             if len(new_batch_images) == 0:
#                 print("Dataset exhausted. Moving to final training phase.")
#                 break
        
#         # Decide on query strategy based on PCA analysis of current embeddings
#         use_repetition = False
#         if high_value_queries < max_high_value_queries:
#             # Check if this batch might contain valuable samples based on model uncertainty
#             if model is not None and len(train_data) > 2000:
#                 with torch.no_grad():
#                     model.eval()
#                     batch_tensor = torch.stack(new_batch_images).to(device)
#                     predictions = model(batch_tensor).cpu().numpy()
                    
#                     # Get embeddings from existing data for comparison
#                     existing_embs = np.array([emb.cpu().numpy() for _, emb in train_data[:2000]])
                    
#                     # Check if predictions are in underrepresented regions
#                     distances = []
#                     for pred in predictions:
#                         min_dist = np.min(np.linalg.norm(existing_embs - pred, axis=1))
#                         distances.append(min_dist)
                    
#                     # If average distance is high, these samples could be valuable
#                     if np.mean(distances) > 0.8:  # High threshold for repetition
#                         use_repetition = True
#                         high_value_queries += len(new_batch_images)
        
#         # Query with or without repetition based on analysis
#         if use_repetition:
#             print(f"Using repeated queries for high-value samples")
#             batch_embs = query_with_repetition(
#                 new_batch_images, 
#                 n_repeats=2, 
#                 save_prefix=f"phase2_batch_hv_{len(train_data)}"
#             )
#         else:
#             print(f"Using single queries to maximize coverage")
#             batch_embs = query_api(new_batch_images)
        
#         # Update training data
#         for img, emb in zip(new_batch_images, batch_embs):
#             train_data.append((img, torch.tensor(emb).float()))
        
#         # Save checkpoint periodically
#         if len(train_data) % save_frequency == 0:
#             with open(f'train_data_phase2_{len(train_data)}.pkl', 'wb') as f:
#                 pickle.dump(train_data, f)
#             with open(f'tracker_phase2_{tracker.query_count}.pkl', 'wb') as f:
#                 pickle.dump(tracker, f)
#             print(f"Saved checkpoint with {len(train_data)} samples")
        
#         tracker.update_coverage(batch_embs)
        
#         # Refresh bucket boundaries periodically
#         if len(train_data) % 5000 == 0:
#             bucket_edges = tracker.get_embedding_stats()
        
#         # Retrain more frequently with smaller increments
#         if len(train_data) % 3000 == 0:  # More frequent retraining
#             print(f"\nRetraining with {len(train_data)} samples...")
#             loader = DataLoader(train_data, batch_size=256, shuffle=True, collate_fn=collate_fn)
#             model = ImprovedEncoderStealer(output_dim=1024).to(device)
#             train_with_bucket_awareness(model, loader, bucket_edges, epochs=3)
            
#             # Save this intermediate model
#             torch.save(model.state_dict(), f'model_phase2_{len(train_data)}.pt')
        
#         print(f"Queries: {tracker.query_count}, Coverage: {tracker.get_coverage():.2%}")
#         print(f"Unique images: {len(queried_ids)}/{max_unique_images}")
        
#     except Exception as e:
#         print(f"Error during querying: {e}")
#         print("Saving current progress and continuing...")
#         with open(f'train_data_error_phase2_{len(train_data)}.pkl', 'wb') as f:
#             pickle.dump(train_data, f)
#         with open(f'tracker_error_phase2_{tracker.query_count}.pkl', 'wb') as f:
#             pickle.dump(tracker, f)
#         time.sleep(30)  # Wait before retrying

# # Save final Phase 2 dataset
# with open('train_data_phase2_complete.pkl', 'wb') as f:
#     pickle.dump(train_data, f)
# print(f"Phase 2 complete: {len(train_data)} samples collected using {tracker.query_count} queries")
# print(f"Final unique images queried: {len(queried_ids)}")


Phase 2: Strategic query expansion with dataset awareness
Target unique images: 9000, Current unique images: 2000
Using 4000 training samples to guide selection
Using repeated queries for high-value samples
Repetition 1/2 for batch of 1000 images
Request payload size: 2.54 MB
Repetition 2/2 for batch of 1000 images
Request payload size: 2.54 MB
Rate limited. Retrying after delay...
Request payload size: 2.54 MB
Saved checkpoint with 5000 samples
Queries: 3000, Coverage: 1.10%
Unique images: 3000/9000
Using 5000 training samples to guide selection
Using repeated queries for high-value samples
Repetition 1/2 for batch of 1000 images
Request payload size: 2.55 MB
Rate limited. Retrying after delay...
Request payload size: 2.55 MB
Repetition 2/2 for batch of 1000 images
Request payload size: 2.55 MB
Rate limited. Retrying after delay...
Request payload size: 2.55 MB
Saved checkpoint with 6000 samples

Retraining with 6000 samples...
Using 10 dimensions with detected bucket edges


100%|██████████| 24/24 [00:13<00:00,  1.75it/s]


Epoch 1, Loss: 0.8005
Saved model checkpoint at epoch 1


100%|██████████| 24/24 [00:13<00:00,  1.74it/s]


Epoch 2, Loss: 0.2294
Saved model checkpoint at epoch 2


100%|██████████| 24/24 [00:13<00:00,  1.73it/s]


Epoch 3, Loss: 0.1361
Saved model checkpoint at epoch 3
Queries: 4000, Coverage: 1.10%
Unique images: 4000/9000
Using 6000 training samples to guide selection
Using repeated queries for high-value samples
Repetition 1/2 for batch of 1000 images
Request payload size: 2.70 MB
Repetition 2/2 for batch of 1000 images
Request payload size: 2.70 MB
Rate limited. Retrying after delay...
Request payload size: 2.70 MB
Error during querying: ('Connection aborted.', TimeoutError('timed out'))
Saving current progress and continuing...
Using 6000 training samples to guide selection
Using repeated queries for high-value samples
Repetition 1/2 for batch of 1000 images
Request payload size: 2.70 MB
Error during querying: ('Connection aborted.', TimeoutError('timed out'))
Saving current progress and continuing...
Using 6000 training samples to guide selection
Using repeated queries for high-value samples
Repetition 1/2 for batch of 1000 images
Request payload size: 2.69 MB
Repetition 2/2 for batch of 1

In [None]:
# Phase 2 with B4B protection: Strategic queries with model guidance and dataset limits
print("\nPhase 2: Strategic query expansion with dataset awareness")

# Configure Phase 2 parameters
max_queries = min(95000, tracker.query_count + 50000)  # Keep some quota for final phases
batch_size = 1000  # Smaller batch size to reduce timeout risk
save_frequency = 500  # Save more frequently
max_unique_images = min(9000, len(dataset) * 0.7)  # Target 70% of dataset

# Keep track of unique images we've queried
if 'queried_ids' not in locals():
    queried_ids = set([id(img) for img, _ in train_data])
    print(f"Identified {len(queried_ids)} unique images in current dataset")

# Detect initial bucket boundaries for improved training
if tracker.query_count > 2000:  # Only if we have enough data
    print("Detecting bucket boundaries - adding longer delay to avoid B4B...")
    time.sleep(30)  # Longer cool-down period before sensitive operations
    bucket_edges = tracker.get_embedding_stats()
    print("Bucket edges detected for improved query selection")

# Define a tracker for high-quality samples that might need repetition
high_value_queries = 0
max_high_value_queries = 10000  # Limit for high-value repeated queries
batch_counter = 0  # To track batch number for strategic delays

print(f"Target unique images: {max_unique_images}, Current unique images: {len(queried_ids)}")

while (tracker.query_count < max_queries and 
       tracker.get_coverage() < 0.25 and 
       len(queried_ids) < max_unique_images):
    try:
        # Use the model to guide query selection
        batch_images = get_strategic_queries(dataset, train_data, model, batch_size=batch_size)
        
        # Filter out any images we've already queried
        new_batch_images = []
        for img in batch_images:
            img_id = id(img)
            if img_id not in queried_ids:
                new_batch_images.append(img)
                queried_ids.add(img_id)
        
        if len(new_batch_images) == 0:
            print("All candidates already queried. Increasing sampling pool...")
            candidate_indices = np.random.choice(len(dataset), min(5000, len(dataset)), replace=False)
            candidates = [dataset[i] for i in candidate_indices]
            new_batch_images = [img for img in candidates if id(img) not in queried_ids][:batch_size]
            
            if len(new_batch_images) == 0:
                print("Dataset exhausted. Moving to final training phase.")
                break
        
        # Decide on query strategy based on PCA analysis of current embeddings
        use_repetition = False
        if high_value_queries < max_high_value_queries:
            # Check if this batch might contain valuable samples based on model uncertainty
            if model is not None and len(train_data) > 2000:
                with torch.no_grad():
                    model.eval()
                    batch_tensor = torch.stack(new_batch_images).to(device)
                    predictions = model(batch_tensor).cpu().numpy()
                    
                    # Get embeddings from existing data for comparison
                    existing_embs = np.array([emb.cpu().numpy() for _, emb in train_data[:2000]])
                    
                    # Check if predictions are in underrepresented regions
                    distances = []
                    for pred in predictions:
                        min_dist = np.min(np.linalg.norm(existing_embs - pred, axis=1))
                        distances.append(min_dist)
                    
                    # If average distance is high, these samples could be valuable
                    if np.mean(distances) > 0.8:  # High threshold for repetition
                        use_repetition = True
                        high_value_queries += len(new_batch_images)
        
        # Apply strategic delay before API query, longer for high-value queries
        apply_strategic_delay(batch_counter, phase=2, high_value=use_repetition)
        batch_counter += 1
        
        # Query with or without repetition based on analysis
        if use_repetition:
            print(f"Using repeated queries for high-value samples")
            batch_embs = query_with_repetition(
                new_batch_images, 
                n_repeats=2, 
                save_prefix=f"phase2_batch_hv_{len(train_data)}"
            )
        else:
            print(f"Using single queries to maximize coverage")
            batch_embs = query_api(new_batch_images)
        
        # Update training data
        for img, emb in zip(new_batch_images, batch_embs):
            train_data.append((img, torch.tensor(emb).float()))
        
        # Save checkpoint periodically
        if len(train_data) % save_frequency == 0:
            with open(f'train_data_phase2_{len(train_data)}.pkl', 'wb') as f:
                pickle.dump(train_data, f)
            with open(f'tracker_phase2_{tracker.query_count}.pkl', 'wb') as f:
                pickle.dump(tracker, f)
            print(f"Saved checkpoint with {len(train_data)} samples")
        
        tracker.update_coverage(batch_embs)
        
        # Refresh bucket boundaries periodically with a longer delay
        if len(train_data) % 5000 == 0:
            print("Refreshing bucket boundaries - adding longer delay to avoid B4B...")
            time.sleep(30)  # Longer cool-down period before sensitive operations
            bucket_edges = tracker.get_embedding_stats()
        
        # Retrain more frequently with smaller increments
        if len(train_data) % 3000 == 0:  # More frequent retraining
            print(f"\nRetraining with {len(train_data)} samples...")
            loader = DataLoader(train_data, batch_size=256, shuffle=True, collate_fn=collate_fn)
            model = ImprovedEncoderStealer(output_dim=1024).to(device)
            train_with_bucket_awareness(model, loader, bucket_edges, epochs=3)
            
            # Save this intermediate model
            torch.save(model.state_dict(), f'model_phase2_{len(train_data)}.pt')
        
        print(f"Queries: {tracker.query_count}, Coverage: {tracker.get_coverage():.2%}")
        print(f"Unique images: {len(queried_ids)}/{max_unique_images}")
        
        # Add a variable cool-down period between batches to avoid patterns
        cool_down = random.uniform(5, 15)
        print(f"Cooling down for {cool_down:.2f}s to avoid B4B detection...")
        time.sleep(cool_down)
        
    except Exception as e:
        print(f"Error during querying: {e}")
        print("Saving current progress and continuing...")
        with open(f'train_data_error_phase2_{len(train_data)}.pkl', 'wb') as f:
            pickle.dump(train_data, f)
        with open(f'tracker_error_phase2_{tracker.query_count}.pkl', 'wb') as f:
            pickle.dump(tracker, f)
        time.sleep(30)  # Wait before retrying

# Save final Phase 2 dataset
with open('train_data_phase2_complete.pkl', 'wb') as f:
    pickle.dump(train_data, f)
print(f"Phase 2 complete: {len(train_data)} samples collected using {tracker.query_count} queries")
print(f"Final unique images queried: {len(queried_ids)}")

In [123]:
# Final ensemble training
print("\nTraining ensemble models...")

# Get bucket boundaries for final training
bucket_edges = tracker.get_embedding_stats()
print(f"Using {len([e for e in bucket_edges if len(e) > 0])} dimensions with detected bucket edges")

# Save full dataset before training
with open('train_data_final.pkl', 'wb') as f:
    pickle.dump(train_data, f)
print(f"Saved final dataset with {len(train_data)} samples")

# Train ensemble models (no changes to the function itself)
models = train_ensemble(train_data, n_models=3, bucket_edges=bucket_edges)

# Save individual models
for i, model in enumerate(models):
    torch.save(model.state_dict(), f'ensemble_model_{i+1}.pt')
    print(f"Saved ensemble model {i+1}")

# Create and save wrapper model
final_model = EnsembleWrapper(models).to(device)
torch.save(final_model.state_dict(), 'final_ensemble_model.pt')
print("Saved final ensemble model")


Training ensemble models...
Using 10 dimensions with detected bucket edges
Saved final dataset with 9616 samples
Training model 1/3...
Using 10 dimensions with detected bucket edges


100%|██████████| 31/31 [00:20<00:00,  1.51it/s]


Epoch 1, Loss: 0.9342
Saved model checkpoint at epoch 1


100%|██████████| 31/31 [00:19<00:00,  1.59it/s]


Epoch 2, Loss: 0.2108
Saved model checkpoint at epoch 2


100%|██████████| 31/31 [00:20<00:00,  1.48it/s]


Epoch 3, Loss: 0.1448
Saved model checkpoint at epoch 3
Training model 2/3...
Using 10 dimensions with detected bucket edges


100%|██████████| 31/31 [00:18<00:00,  1.68it/s]


Epoch 1, Loss: 0.7840
Saved model checkpoint at epoch 1


100%|██████████| 31/31 [00:20<00:00,  1.52it/s]


Epoch 2, Loss: 0.1832
Saved model checkpoint at epoch 2


100%|██████████| 31/31 [00:21<00:00,  1.46it/s]


Epoch 3, Loss: 0.1214
Saved model checkpoint at epoch 3
Training model 3/3...
Using 10 dimensions with detected bucket edges


100%|██████████| 31/31 [00:19<00:00,  1.60it/s]


Epoch 1, Loss: 0.7764
Saved model checkpoint at epoch 1


100%|██████████| 31/31 [00:21<00:00,  1.47it/s]


Epoch 2, Loss: 0.1796
Saved model checkpoint at epoch 2


100%|██████████| 31/31 [00:19<00:00,  1.55it/s]


Epoch 3, Loss: 0.1203
Saved model checkpoint at epoch 3
Saved ensemble model 1
Saved ensemble model 2
Saved ensemble model 3
Saved final ensemble model


## 8. Model Validation and Submission

In [134]:
def validate_onnx(model_path):
    try:
        session = ort.InferenceSession(model_path)
        # Check input name matches 'x'
        input_name = session.get_inputs()[0].name
        assert input_name == "x", f"Input name should be 'x', got {input_name}"
        
        test_input = np.random.randn(1, 3, 32, 32).astype(np.float32)
        output = session.run(None, {"x": test_input})[0]  # Note using "x" here
        
        assert output.shape == (1, 1024)
        print("ONNX validation passed! Model meets submission requirements")
    except Exception as e:
        raise Exception(f"ONNX validation failed: {str(e)}")
        
def submit_model(model_path):
    url = "http://34.122.51.94:9090/stealing"
    with open(model_path, "rb") as f:
        files = {"file": f}
        headers = {"token": TOKEN, "seed": str(SEED)}
        response = requests.post(url, files=files, headers=headers)
    
    if response.status_code == 200:
        print("Submission successful!")
        print(response.json())
    else:
        print(f"Submission failed: {response.status_code}")
        print(response.text)

In [125]:
# Verify model output shape
print("\nPreparing for submission...")
test_input = torch.randn(1, 3, 32, 32).to(device)
with torch.no_grad():
    test_output = final_model(test_input)
print(f"Model test - Input shape: {test_input.shape}, Output shape: {test_output.shape}")


Preparing for submission...
Model test - Input shape: torch.Size([1, 3, 32, 32]), Output shape: torch.Size([1, 1024])


In [127]:
# Export with correct input name
dummy_input = torch.randn(1, 3, 32, 32).to(device)
torch.onnx.export(
    final_model,
    dummy_input,
    "stolen_model_improved.onnx",
    input_names=["x"],  # Must be "x" to match server expectations
    output_names=["output"],
    dynamic_axes={
        'x': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    verbose=True
)

  if identity.shape == x.shape:  # Only add if shapes match


Exported graph: graph(%x : Float(*, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu),
      %models.0.fc1.weight : Float(1024, 512, strides=[512, 1], requires_grad=1, device=cpu),
      %models.0.fc1.bias : Float(1024, strides=[1], requires_grad=1, device=cpu),
      %models.0.fc2.weight : Float(1024, 1024, strides=[1024, 1], requires_grad=1, device=cpu),
      %models.0.fc2.bias : Float(1024, strides=[1], requires_grad=1, device=cpu),
      %models.1.fc1.weight : Float(1024, 512, strides=[512, 1], requires_grad=1, device=cpu),
      %models.1.fc1.bias : Float(1024, strides=[1], requires_grad=1, device=cpu),
      %models.1.fc2.weight : Float(1024, 1024, strides=[1024, 1], requires_grad=1, device=cpu),
      %models.1.fc2.bias : Float(1024, strides=[1], requires_grad=1, device=cpu),
      %models.2.fc1.weight : Float(1024, 512, strides=[512, 1], requires_grad=1, device=cpu),
      %models.2.fc1.bias : Float(1024, strides=[1], requires_grad=1, device=cpu),
      %mod

In [128]:
# Validate ONNX
validate_onnx("stolen_model_improved.onnx")

ONNX validation passed! Model meets submission requirements


In [138]:
# Submit model
submit_model("stolen_model_improved.onnx")

Submission successful!
{'L2': 8.897150039672852}


In [132]:
type(SEED)

int

In [133]:
type(PORT)

int