In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, resnet34, resnet50
import numpy as np
from torch.utils.data import DataLoader
import time

In [None]:
# Implement FGSM Attack
def fgsm_attack(model, images, labels, epsilon, device, criterion):
    """Fast Gradient Sign Method attack"""
    perturbed_images = images.clone().detach()
    perturbed_images.requires_grad = True
    
    outputs = model(perturbed_images)
    loss = criterion(outputs, labels)
    
    # Compute gradients
    loss.backward()
    
    # Create perturbation using gradient sign
    sign_data_grad = perturbed_images.grad.sign()
    
    # Add perturbation to create adversarial example
    perturbed_images = perturbed_images + epsilon * sign_data_grad
    perturbed_images = torch.clamp(perturbed_images, 0, 1)
    
    return perturbed_images.detach()

# Implement PGD Attack
def pgd_attack(model, images, labels, epsilon, alpha, num_iter, device, criterion, random_start=True):
    """Projected Gradient Descent attack"""
    perturbed_images = images.clone().detach()
    
    # Add random noise if specified
    if random_start:
        perturbed_images = perturbed_images + torch.empty_like(perturbed_images).uniform_(-epsilon, epsilon)
        perturbed_images = torch.clamp(perturbed_images, 0, 1)

    for i in range(num_iter):
        perturbed_images.requires_grad = True
        outputs = model(perturbed_images)
        model.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Step in direction of gradient
        adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
        
        # Project back to epsilon ball
        eta = torch.clamp(adv_images - images, -epsilon, epsilon)
        perturbed_images = torch.clamp(images + eta, 0, 1).detach()

    return perturbed_images

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data loading and preprocessing
def load_data(data_path, batch_size=64):
    """Load and preprocess dataset"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # For CIFAR-10, adjust if needed
    ])
    
    # Load your dataset here
    # This is a placeholder - adjust according to your actual data format
    train_dataset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    test_dataset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, test_loader

# Model setup
def get_model(model_name="resnet18", num_classes=10):
    """Initialize one of the allowed ResNet models"""
    if model_name == "resnet18":
        model = resnet18(pretrained=True)
    elif model_name == "resnet34":
        model = resnet34(pretrained=True)
    elif model_name == "resnet50":
        model = resnet50(pretrained=True)
    else:
        raise ValueError(f"Model {model_name} not supported. Use one of: resnet18, resnet34, resnet50")
    
    # Modify final layer for the given number of classes
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)

# Adversarial training function
def train_robust_model(model, train_loader, val_loader, num_epochs=100, epsilon=8/255, alpha=2/255, num_iter=10):
    """Train model with adversarial training"""
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # Choose either clean, FGSM, or PGD samples for this batch
            attack_type = np.random.choice(['clean', 'fgsm', 'pgd'], p=[0.3, 0.3, 0.4])
            
            if attack_type == 'fgsm':
                perturbed_images = fgsm_attack(model, images, labels, epsilon, device, criterion)
            elif attack_type == 'pgd':
                perturbed_images = pgd_attack(model, images, labels, epsilon, alpha, num_iter, device, criterion)
            else:  # clean
                perturbed_images = images
            
            # Forward pass and optimization
            optimizer.zero_grad()
            outputs = model(perturbed_images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            if (i + 1) % 100 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Step {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        scheduler.step()
        
        # Print epoch stats
        train_accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}, Accuracy: {train_accuracy:.2f}%")
        
        # Evaluate every 5 epochs
        if (epoch + 1) % 5 == 0:
            evaluate_model(model, val_loader, epsilon, alpha, num_iter)
    
    return model

# Evaluation function
def evaluate_model(model, data_loader, epsilon=8/255, alpha=2/255, num_iter=10):
    """Evaluate model performance on clean and adversarial examples"""
    model.eval()
    criterion = nn.CrossEntropyLoss()
    
    # Evaluate on clean data
    correct_clean = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct_clean += predicted.eq(labels).sum().item()
    clean_acc = 100 * correct_clean / total
    
    # Evaluate on FGSM examples
    correct_fgsm = 0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        perturbed_images = fgsm_attack(model, images, labels, epsilon, device, criterion)
        with torch.no_grad():
            outputs = model(perturbed_images)
            _, predicted = outputs.max(1)
            correct_fgsm += predicted.eq(labels).sum().item()
    fgsm_acc = 100 * correct_fgsm / total
    
    # Evaluate on PGD examples
    correct_pgd = 0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        perturbed_images = pgd_attack(model, images, labels, epsilon, alpha, num_iter, device, criterion)
        with torch.no_grad():
            outputs = model(perturbed_images)
            _, predicted = outputs.max(1)
            correct_pgd += predicted.eq(labels).sum().item()
    pgd_acc = 100 * correct_pgd / total
    
    print(f"Clean accuracy: {clean_acc:.2f}%")
    print(f"FGSM accuracy: {fgsm_acc:.2f}%")
    print(f"PGD accuracy: {pgd_acc:.2f}%")
    
    return clean_acc, fgsm_acc, pgd_acc

# Function to save model for submission
def save_model_for_submission(model, model_class, path="robust_model.pt"):
    """Save model state dict and class name for submission"""
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_class': model_class
    }, path)
    print(f"Model saved to {path}")

# Main execution (uncomment to run)
"""
# Set parameters
data_path = './data'  # Adjust to your data path
model_name = 'resnet18'  # Choose from resnet18, resnet34, resnet50
num_classes = 10  # Adjust based on your dataset

# Load data
train_loader, test_loader = load_data(data_path)

# Initialize model
model = get_model(model_name, num_classes)

# Train model
model = train_robust_model(model, train_loader, test_loader)

# Final evaluation
clean_acc, fgsm_acc, pgd_acc = evaluate_model(model, test_loader)

# Save model for submission
save_model_for_submission(model, model_name)
"""