# AMSDN: Adaptive Multi-Scale Defense Network
## Complete Training and Evaluation Pipeline

This notebook provides a complete implementation of AMSDN for adversarial defense.

**Runtime:** ~3-4 hours for full training (use smaller epochs for quick demo)

**GPU Required:** T4 or better

## 1. Setup and Installation

In [None]:
# Install dependencies
!pip install -q torch torchvision timm matplotlib scikit-learn tqdm tensorboard scipy

In [None]:
# Clone or upload AMSDN repository
# Option 1: Clone from GitHub (if you've uploaded it)
# !git clone https://github.com/YOUR_USERNAME/AMSDN.git
# %cd AMSDN

# Option 2: Upload files manually to Colab
# Use the file browser on the left to upload all .py files

# Check GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

## 2. Quick Model Test

In [None]:
# Test if all modules load correctly
from models.amsdn import AMSDN
from data.cifar10 import CIFAR10DataModule

# Create model
model = AMSDN(num_classes=10, pretrained=False)
print(f"✓ Model created successfully")

# Test forward pass
x = torch.randn(2, 3, 32, 32)
outputs = model(x, return_detailed=True)
print(f"✓ Forward pass successful")
print(f"  Logits shape: {outputs['logits'].shape}")
print(f"  Anomaly detection: {outputs['is_adversarial']}")

## 3. Stage 1: SSRT Pretraining (Optional)

**Time:** ~1 hour (50 epochs)

Skip this if you want faster training

In [None]:
# Run SSRT pretraining
!python training/pretrain_ssrt.py

# Or use smaller epochs for demo:
# Modify pretrain_ssrt.py: num_epochs=10

## 4. Stage 2: Adversarial Training

**Time:** ~2 hours (100 epochs)

In [None]:
# Run adversarial training
!python training/adversarial_train.py

# For demo (faster): modify adversarial_train.py: num_epochs=20

## 5. Stage 3: Multi-Attack Fine-tuning

**Time:** ~30 minutes (20 epochs)

In [None]:
# Run multi-attack fine-tuning
!python training/finetune_attacks.py

## 6. Evaluation

**Time:** ~20 minutes

In [None]:
# Comprehensive evaluation
!python evaluation/evaluate.py

# View results
import json
with open('./results/evaluation_results.json', 'r') as f:
    results = json.load(f)

print("\n=== EVALUATION RESULTS ===")
for attack, metrics in results.items():
    print(f"\n{attack}:")
    for metric, value in metrics.items():
        print(f"  {metric}: {value:.2f}")

## 7. Certification (Randomized Smoothing)

**Time:** ~30 minutes (for 100 samples)

**Warning:** Very slow. Use small sample size for demo.

In [None]:
# Run certification
!python evaluation/certification.py

# View results
with open('./results/certification_results.json', 'r') as f:
    cert_results = json.load(f)

print("\n=== CERTIFICATION RESULTS ===")
print(f"Clean Accuracy: {cert_results['clean_accuracy']:.2f}%")
print(f"Abstention Rate: {cert_results['abstain_rate']:.2f}%")
print("\nCertified Accuracies:")
for radius, acc in cert_results['certified_accuracies'].items():
    print(f"  {radius}: {acc:.2f}%")

## 8. Visualization

In [None]:
# Visualize adversarial examples
from utils.helpers import visualize_adversarial_examples
from attacks.patch_attacks import AdversarialPatch
from data.cifar10 import CIFAR10DataModule
import torch

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AMSDN(num_classes=10, pretrained=False).to(device)
checkpoint = torch.load('./checkpoints/finetuned/amsdn_finetuned_best.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Get test images
data_module = CIFAR10DataModule(batch_size=8)
_, test_loader = data_module.get_loaders()
images, labels = next(iter(test_loader))
images = images.to(device)

# Generate adversarial examples
patch_attack = AdversarialPatch(patch_size=4, epsilon=0.3)
adv_images = patch_attack.apply(images, model, labels, optimize=True)

# Predict
with torch.no_grad():
    preds_clean = model(images).argmax(dim=1)
    preds_adv = model(adv_images).argmax(dim=1)

# Visualize
visualize_adversarial_examples(
    images, adv_images, labels, preds_clean, preds_adv,
    save_path='./adversarial_examples.png'
)

from IPython.display import Image, display
display(Image('./adversarial_examples.png'))

## 9. Download Results

In [None]:
# Zip all results and checkpoints
!zip -r amsdn_results.zip checkpoints/ results/ *.png

# Download
from google.colab import files
files.download('amsdn_results.zip')

## 10. Quick Demo (No Training)

Run this cell to test AMSDN without training (uses random weights)

In [None]:
print("=== AMSDN QUICK DEMO ===")
print("Testing pipeline with random weights...\n")

from models.amsdn import AMSDN
from data.cifar10 import CIFAR10DataModule
from attacks.patch_attacks import AdversarialPatch
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model
model = AMSDN(num_classes=10, pretrained=False).to(device)
model.eval()
print("✓ Model initialized")

# Data
data_module = CIFAR10DataModule(batch_size=4)
_, test_loader = data_module.get_loaders()
images, labels = next(iter(test_loader))
images = images.to(device)
print("✓ Data loaded")

# Clean prediction
with torch.no_grad():
    outputs_clean = model(images, return_detailed=True)
    print(f"\n Clean predictions: {outputs_clean['logits'].argmax(dim=1)}")
    print(f" True labels: {labels}")
    print(f" Anomaly scores: {outputs_clean['avg_anomaly_score']}")
    print(f" Detected as adversarial: {outputs_clean['is_adversarial']}")

# Attack
attack = AdversarialPatch(patch_size=4, epsilon=0.3, num_steps=50)
adv_images = attack.apply(images, model, labels, optimize=False)  # Random patch
print("\n✓ Generated adversarial examples")

# Adversarial prediction
with torch.no_grad():
    outputs_adv = model(adv_images, return_detailed=True)
    print(f"\n Adversarial predictions: {outputs_adv['logits'].argmax(dim=1)}")
    print(f" Anomaly scores: {outputs_adv['avg_anomaly_score']}")
    print(f" Detected as adversarial: {outputs_adv['is_adversarial']}")

print("\n=== DEMO COMPLETE ===")
print("Note: Model uses random weights. Train for real performance.")