# ConvNext SOTA Model - Setup & Baseline

**Project**: Probing ConvNext with Hard Examples  
**Goal**: Load ConvNext-Base, establish baseline accuracy, and prepare for hard example generation  
**Week**: 1 (Setup)


## 1. Environment Setup

Import all required libraries and verify versions.

In [None]:
# Core libraries
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import warnings
warnings.filterwarnings('ignore')

# Timm for ConvNext
import timm

# Image processing
from PIL import Image
from torchvision import transforms
import cv2

# Progress tracking
from tqdm import tqdm

# Version info
print("=" * 60)
print("ENVIRONMENT VERIFICATION")
print("=" * 60)
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Torchvision Version: {torchvision.__version__}")
print(f"Timm Version: {timm.__version__}")
print(f"NumPy Version: {np.__version__}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
print("=" * 60)

## 2. Load ConvNext-Base Model

Load pretrained ConvNext-Base from ImageNet-1K using timm.

In [None]:
# Model configuration
MODEL_NAME = 'convnext_base'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = 224

print(f"Loading {MODEL_NAME}...")
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=1000)
model = model.to(DEVICE)
model.eval()  # Set to evaluation mode

# Get model info
print(f"\nModel: {MODEL_NAME}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Device: {DEVICE}")

## 3. Prepare Data Transforms

Set up standard ImageNet preprocessing for ConvNext.

In [None]:
# ImageNet normalization values
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# Preprocessing pipeline
preprocess = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# Inverse transform for visualization
inv_normalize = transforms.Compose([
    transforms.Normalize(
        mean=[-m/s for m, s in zip(IMAGENET_MEAN, IMAGENET_STD)],
        std=[1/s for s in IMAGENET_STD]
    ),
    transforms.ToPILImage()
])

print("‚úì Transform pipeline configured")
print(f"  - Input size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"  - ImageNet normalization applied")

## 4. Test Model Inference

Run a simple inference test on a sample image.

In [None]:
# Test inference with a dummy image
test_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(DEVICE)

with torch.no_grad():
    output = model(test_input)

print(f"‚úì Model inference test passed")
print(f"  - Input shape: {test_input.shape}")
print(f"  - Output shape: {output.shape}")
print(f"  - Output classes: {output.shape[1]}")

# Get top-5 predictions for dummy image
probabilities = torch.softmax(output, dim=1)
top5_prob, top5_idx = torch.topk(probabilities, 5)
print(f"\n  - Top-5 prediction scores: {top5_prob[0].cpu().numpy()}")

## 5. Load ImageNet Labels

Load ImageNet class labels for result interpretation.

In [None]:
# Download ImageNet labels if not available
import urllib.request

LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
LABELS_FILE = "../data/imagenet_classes.txt"

# Create data directory
Path("../data").mkdir(exist_ok=True)

try:
    # Try to load existing labels
    with open(LABELS_FILE, 'r') as f:
        imagenet_labels = [line.strip() for line in f.readlines()]
    print(f"‚úì Loaded ImageNet labels from file ({len(imagenet_labels)} classes)")
except FileNotFoundError:
    print("Downloading ImageNet labels...")
    urllib.request.urlretrieve(LABELS_URL, LABELS_FILE)
    with open(LABELS_FILE, 'r') as f:
        imagenet_labels = [line.strip() for line in f.readlines()]
    print(f"‚úì Downloaded ImageNet labels ({len(imagenet_labels)} classes)")

print(f"\n  Sample labels:")
for i in range(0, 5):
    print(f"    {i}: {imagenet_labels[i]}")

## 6. Helper Functions

Define utility functions for model evaluation and visualization.

In [None]:
def get_prediction(image_tensor, top_k=5):
    """Get model prediction for a single image.
    
    Args:
        image_tensor: Preprocessed image tensor (1, 3, H, W)
        top_k: Number of top predictions to return
    
    Returns:
        dict: Contains top-k predictions with indices and scores
    """
    with torch.no_grad():
        output = model(image_tensor.unsqueeze(0).to(DEVICE))
        probabilities = torch.softmax(output, dim=1)
        top_probs, top_indices = torch.topk(probabilities, top_k)
    
    predictions = []
    for prob, idx in zip(top_probs[0].cpu().numpy(), top_indices[0].cpu().numpy()):
        predictions.append({
            'class_id': int(idx),
            'class_name': imagenet_labels[idx],
            'confidence': float(prob)
        })
    
    return predictions

def visualize_predictions(image_tensor, predictions, title="Prediction"):
    """Visualize image with top predictions.
    
    Args:
        image_tensor: Preprocessed image tensor
        predictions: Output from get_prediction()
        title: Plot title
    """
    # Convert back to PIL for display
    pil_image = inv_normalize(image_tensor.cpu())
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Show image
    ax1.imshow(pil_image)
    ax1.set_title(title)
    ax1.axis('off')
    
    # Show predictions
    class_names = [p['class_name'] for p in predictions]
    confidences = [p['confidence'] for p in predictions]
    
    ax2.barh(range(len(predictions)), confidences)
    ax2.set_yticks(range(len(predictions)))
    ax2.set_yticklabels(class_names)
    ax2.set_xlabel('Confidence')
    ax2.set_title('Top-5 Predictions')
    ax2.set_xlim([0, 1])
    
    plt.tight_layout()
    return fig

print("‚úì Helper functions defined:")
print("  - get_prediction(): Get model predictions")
print("  - visualize_predictions(): Visualize predictions")

## 7. Project Status

Summary of setup and next steps.

In [None]:
print("\n" + "="*70)
print("PROJECT SETUP COMPLETE")
print("="*70)
print("\n‚úì COMPLETED:")
print("  1. Loaded ConvNext-Base (ImageNet pretrained)")
print("  2. Configured preprocessing pipeline")
print("  3. Verified model inference")
print("  4. Loaded ImageNet class labels")
print("  5. Created helper functions")
print("\n‚è≠Ô∏è  NEXT (Week 2):")
print("  1. Load ImageNet validation set subset")
print("  2. Establish baseline accuracy metrics")
print("  3. Implement FGSM adversarial attack")
print("  4. Generate initial hard examples")
print("  5. Document failure patterns")
print("\nüí° ATTACK STRATEGIES TO IMPLEMENT:")
print("  - Adversarial: FGSM, PGD, Auto-Attack, C&W")
print("  - OOD Detection: Distribution shifts, synthetic images")
print("  - Corner Cases: Texture-only, minimal objects, extreme lighting")
print("  - Domain Adaptation: Style transfer, cross-dataset mismatch")
print("  - Edge Cases: Multi-object, fine-grained, similar classes")
print("\n" + "="*70)