<a href="https://colab.research.google.com/github/calicartels/AirWrite/blob/main/Research_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. Basic imports
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from tqdm.notebook import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Setting up directory and mount drive
from google.colab import drive
drive.mount('/content/drive')

# 3. Making sure we're in /content before creating directories
%cd /content

# 4. Cleaning if directories exist and cloning the original directories
!rm -rf flow_matching
!git clone https://github.com/facebookresearch/flow_matching.git

# 5. Install package
%cd /content/flow_matching
!pip install -e .

# 6. Set up image example directory
%cd /content/flow_matching/examples/image
!pip install -r requirements.txt

# 7. Create output directory and __init__.py files
!mkdir -p output_dir
!touch /content/flow_matching/examples/image/models/__init__.py
!touch /content/flow_matching/examples/image/training/__init__.py

# 8. Copy your saved model from Drive
!cp -r "/content/drive/MyDrive/flow_matching_model/." "/content/flow_matching/examples/image/output_dir/"

# 9. Clean and reset Python path
import sys
sys.path = ['/content/flow_matching/examples/image', '/content/flow_matching'] + sys.path

# 10. Try imports
try:
    # Import model components
    from models.model_configs import instantiate_model
    from training.eval_loop import CFGScaledModel
    print("Model imports successful!")

    # Import flow_matching components
    import flow_matching
    from flow_matching.path import MixtureDiscreteProbPath
    from flow_matching.path.scheduler import PolynomialConvexScheduler
    from flow_matching.solver.ode_solver import ODESolver
    print("All imports successful!")

    # Load model configuration
    from pathlib import Path
    import json

    checkpoint_path = Path("/content/flow_matching/examples/image/output_dir/checkpoint-199.pth")
    args_filepath = checkpoint_path.parent / 'args.json'

    with open(args_filepath, 'r') as f:
        args_dict = json.load(f)

    print("Configuration loaded successfully!")

except Exception as e:
    print(f"Import error: {e}")
    print("\nCurrent working directory:", os.getcwd())
    print("\nPython path:", sys.path)
    print("\nDirectory contents:")
    !ls -R /content/flow_matching/examples/image/models
    !ls -R /content/flow_matching/flow_matching/path

In [None]:
from argparse import Namespace
import torch.serialization
torch.serialization.add_safe_globals([Namespace])

# 1. Initialize the model
model = instantiate_model(
    architechture=args_dict['dataset'],
    is_discrete='discrete_flow_matching' in args_dict and args_dict['discrete_flow_matching'],
    use_ema=args_dict['use_ema']
)

# 2. Load checkpoint (with weights_only=True to address the warning)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["model"])
model.eval()  # Set to evaluation mode
model.to(device)

# 3. Setup for generation
batch_size = 16  # Number of images to generate
sample_resolution = 32  # CIFAR10 resolution
cfg_weighted_model = CFGScaledModel(model=model)

# 4. Generate images
try:
    x_0 = torch.randn([batch_size, 3, sample_resolution, sample_resolution], dtype=torch.float32, device=device)
    solver = ODESolver(velocity_model=cfg_weighted_model)

    # Get ODE options from args_dict or use defaults
    ode_opts = args_dict.get('ode_options', {})
    step_size = ode_opts.get('step_size', 0.05)  # Default step size if not specified

    synthetic_samples = solver.sample(
        time_grid=torch.tensor([0.0, 1.0], device=device),
        x_init=x_0,
        method=args_dict.get('ode_method', 'euler'),  # Default to euler if not specified
        step_size=step_size,
        atol=ode_opts.get('atol', 1e-5),
        rtol=ode_opts.get('rtol', 1e-5),
        label=torch.tensor(list(range(batch_size)), device=device),
        cfg_scale=args_dict.get('cfg_scale', 1.0)
    )

    # Scale images to [0, 1] range
    synthetic_samples = torch.clamp(synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0)

    # Visualize generated images
    plt.figure(figsize=(20, 20))
    for i in range(batch_size):
        plt.subplot(4, 4, i + 1)
        plt.imshow(synthetic_samples[i].cpu().permute(1, 2, 0).numpy())
        plt.axis('off')
    plt.tight_layout()
    plt.show()

except Exception as e:
    print(f"Generation error: {e}")
    print("\nargs_dict contents:")
    print(json.dumps(args_dict, indent=2))  # Print args_dict for debugging

** the shitty image quality is because of a higer FiD value, close to 5.5.

** Also i trained on only 200 epochs and CiFar needs close to 921 for an ideal train and 3000 for the perfect model.

In [None]:
# 1. Basic imports
import torch
import torch.nn.functional as F
from pathlib import Path
import json
import matplotlib.pyplot as plt
from argparse import Namespace
import torch.serialization
from models.model_configs import instantiate_model
from training.eval_loop import CFGScaledModel
from flow_matching.solver.ode_solver import ODESolver
import gc

# Clear GPU memory and cache
torch.cuda.empty_cache()
gc.collect()

# Add Namespace to safe globals
torch.serialization.add_safe_globals([Namespace])

# Setup paths and load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint_path = Path("/content/flow_matching/examples/image/output_dir/checkpoint-199.pth")
args_filepath = checkpoint_path.parent / 'args.json'

with open(args_filepath, 'r') as f:
    args_dict = json.load(f)

# Initialize and load model
model = instantiate_model(
    architechture=args_dict['dataset'],
    is_discrete='discrete_flow_matching' in args_dict and args_dict['discrete_flow_matching'],
    use_ema=args_dict['use_ema']
)
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
model.load_state_dict(checkpoint["model"])
del checkpoint
torch.cuda.empty_cache()

model.eval()
model = model.to(device)

# Modified FGSM attack for Flow Matching model
def fgsm_attack(model, image, label, epsilon=0.03):
    perturbed_image = image.clone().detach().requires_grad_(True)

    # Create timestep tensor (assuming t=0.5 for attack)
    timesteps = torch.ones((image.shape[0],), device=device) * 0.5

    # Forward pass with required arguments
    with torch.enable_grad():
        output = model(perturbed_image, timesteps, extra={'y': label})
        # For Flow Matching models, we'll use the difference between output and input as loss
        loss = F.mse_loss(output, perturbed_image)

        # Backward pass
        loss.backward()

        # Create adversarial example
        adversarial_image = perturbed_image + epsilon * perturbed_image.grad.sign()
        adversarial_image = torch.clamp(adversarial_image, -1, 1)

    return adversarial_image.detach()

# Generate samples and create adversarial examples
try:
    # Setup generation
    batch_size = 8
    sample_resolution = 32
    cfg_weighted_model = CFGScaledModel(model=model)

    # Generate original samples
    x_0 = torch.randn([batch_size, 3, sample_resolution, sample_resolution],
                      dtype=torch.float32, device=device)
    solver = ODESolver(velocity_model=cfg_weighted_model)

    # Get ODE options
    ode_opts = args_dict.get('ode_options', {})
    step_size = ode_opts.get('step_size', 0.05)

    with torch.no_grad():
        synthetic_samples = solver.sample(
            time_grid=torch.tensor([0.0, 1.0], device=device),
            x_init=x_0,
            method=args_dict.get('ode_method', 'heun2'),  # Using heun2 as specified in args
            step_size=step_size,
            atol=ode_opts.get('atol', 1e-5),
            rtol=ode_opts.get('rtol', 1e-5),
            label=torch.tensor(list(range(batch_size)), device=device),
            cfg_scale=args_dict.get('cfg_scale', 0.0)  # Using cfg_scale from args
        )

    # Scale to [0, 1] range
    synthetic_samples = torch.clamp(synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0)

    # Create adversarial examples
    adversarial_samples = []
    for i in range(batch_size):
        image = synthetic_samples[i].unsqueeze(0)
        label = torch.tensor([i], device=device)
        adv_image = fgsm_attack(model, image, label)
        adversarial_samples.append(adv_image)
        torch.cuda.empty_cache()

    adversarial_samples = torch.cat(adversarial_samples)

    # Move to CPU for visualization
    synthetic_samples = synthetic_samples.cpu()
    adversarial_samples = adversarial_samples.cpu()

    # Visualization
    plt.figure(figsize=(20, 10))
    for i in range(batch_size):
        # Original image
        plt.subplot(2, batch_size, i + 1)
        plt.imshow(synthetic_samples[i].permute(1, 2, 0).numpy())
        plt.title('Original')
        plt.axis('off')

        # Adversarial image
        plt.subplot(2, batch_size, i + batch_size + 1)
        plt.imshow(adversarial_samples[i].permute(1, 2, 0).numpy())
        plt.title('Adversarial')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

except Exception as e:
    print(f"Generation/Attack error: {e}")
    print("\nargs_dict contents:")
    print(json.dumps(args_dict, indent=2))

finally:
    # Clean up
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
import torch
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torchvision.models import ResNet50_Weights

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load and prepare ResNet classifier
classifier = models.resnet50(weights=ResNet50_Weights.DEFAULT).to(device)
classifier.eval()

def improved_fgsm_attack(model, image, target_class_idx, epsilon=0.1):
    """
    Improved FGSM attack with targeted misclassification
    """
    # Clone and track gradients for the original image
    image = image.clone().detach().requires_grad_(True)

    # Create target tensor
    target = torch.tensor([target_class_idx], device=device)

    # Multiple attack iterations
    for _ in range(5):
        # Reset gradients
        if image.grad is not None:
            image.grad.zero_()

        # Forward pass through classifier
        output = classifier(image)
        loss = F.cross_entropy(output, target)

        # Backward pass
        loss.backward()

        # Ensure we have gradients
        if image.grad is None:
            print("No gradients were computed!")
            continue

        # Create perturbation
        perturbation = epsilon * image.grad.sign()

        # Update image
        with torch.no_grad():
            image = torch.clamp(image + perturbation, -1, 1).detach()
            image.requires_grad_(True)

    return image

# Try the improved attack
try:
    # Take one sample
    original = synthetic_samples[0].unsqueeze(0).to(device)

    # Preprocess original image to match ResNet's expected input
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])

    # Resize to 224x224 (ResNet's expected input size)
    resize = transforms.Resize((224, 224))
    original = resize(original)
    original = normalize(original)

    # Get original prediction
    with torch.no_grad():
        orig_output = classifier(original)
        orig_class = orig_output.argmax().item()
        orig_conf = F.softmax(orig_output, dim=1).max().item()

    print(f"\nOriginal Classification:")
    print(f"Class: {orig_class}, Confidence: {orig_conf:.4f}")

    # Try different epsilon values and target classes
    epsilon_values = [0.05, 0.1, 0.15, 0.2]
    target_classes = [
        404,  # airliner
        751,  # wing
        895,  # aircraft carrier
        627,  # helicopter
    ]

    for epsilon in epsilon_values:
        print(f"\nTrying epsilon = {epsilon}")
        for target_class in target_classes:
            print(f"\nTrying to misclassify as class {target_class}")

            # Generate adversarial example
            adv_image = improved_fgsm_attack(model, original, target_class, epsilon=epsilon)

            # Get adversarial prediction
            with torch.no_grad():
                adv_output = classifier(adv_image)
                adv_class = adv_output.argmax().item()
                adv_conf = F.softmax(adv_output, dim=1).max().item()

            print(f"New Classification - Class: {adv_class}, Confidence: {adv_conf:.4f}")

            # Only show visualization if classification changed
            if adv_class != orig_class:
                plt.figure(figsize=(15, 5))

                # Denormalize images for visualization
                denormalize = transforms.Compose([
                    transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
                    transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
                ])

                # Original
                plt.subplot(1, 3, 1)
                vis_orig = denormalize(original[0]).cpu().detach()
                plt.imshow(vis_orig.permute(1, 2, 0).numpy().clip(0, 1))
                plt.title(f'Original\nClass: {orig_class}\nConf: {orig_conf:.4f}')
                plt.axis('off')

                # Adversarial
                plt.subplot(1, 3, 2)
                vis_adv = denormalize(adv_image[0]).cpu().detach()
                plt.imshow(vis_adv.permute(1, 2, 0).numpy().clip(0, 1))
                plt.title(f'Adversarial\nClass: {adv_class}\nConf: {adv_conf:.4f}')
                plt.axis('off')

                # Perturbation
                plt.subplot(1, 3, 3)
                perturbation = (adv_image - original)[0].cpu().detach()
                plt.imshow(np.abs(perturbation.permute(1, 2, 0).numpy()).mean(axis=2), cmap='viridis')
                plt.colorbar()
                plt.title('Perturbation Magnitude')
                plt.axis('off')

                plt.tight_layout()
                plt.show()

except Exception as e:
    print(f"Attack error: {e}")
    import traceback
    traceback.print_exc()

In [None]:
def pgd_attack(model, image, target_class_idx, epsilon=0.03, alpha=0.01, num_iter=40):
    """
    PGD Attack implementation
    Args:
        model: target model
        image: original image
        target_class_idx: target class for adversarial example
        epsilon: maximum perturbation
        alpha: step size
        num_iter: number of iterations
    """
    # Initialize
    perturbed_image = image.clone().detach()
    # Add small random noise to start
    perturbed_image = perturbed_image + torch.empty_like(image).uniform_(-epsilon/2, epsilon/2)
    perturbed_image = torch.clamp(perturbed_image, -1, 1)

    # Create target tensor
    target = torch.tensor([target_class_idx], device=device)

    for i in range(num_iter):
        perturbed_image.requires_grad = True

        # Forward pass
        output = classifier(perturbed_image)
        loss = F.cross_entropy(output, target)

        # Backward pass
        model.zero_grad()
        loss.backward()

        # Get gradient
        grad = perturbed_image.grad.data

        # Update image - gradient ascent since we want to maximize loss
        adv_image = perturbed_image + alpha * grad.sign()

        # Project back to epsilon ball and valid image space
        eta = torch.clamp(adv_image - image, min=-epsilon, max=epsilon)
        perturbed_image = torch.clamp(image + eta, min=-1, max=1).detach()

    return perturbed_image

# Compare FGSM and PGD attacks
try:
    # Take one sample
    original = synthetic_samples[0].unsqueeze(0).to(device)

    # Get original prediction
    with torch.no_grad():
        orig_output = classifier(original)
        orig_class = orig_output.argmax().item()
        orig_conf = F.softmax(orig_output, dim=1).max().item()

    print(f"\nOriginal Classification:")
    print(f"Class: {orig_class}, Confidence: {orig_conf:.4f}")

    # Test parameters
    epsilons = [0.03, 0.05, 0.1]
    target_classes = [404, 895]  # Using classes that worked well with FGSM

    for epsilon in epsilons:
        print(f"\nTesting epsilon = {epsilon}")
        for target_class in target_classes:
            print(f"\nTarget class: {target_class}")

            # Generate FGSM adversarial example
            fgsm_image = improved_fgsm_attack(model, original, target_class, epsilon=epsilon)

            # Generate PGD adversarial example
            pgd_image = pgd_attack(model, original, target_class, epsilon=epsilon)

            # Get predictions
            with torch.no_grad():
                # FGSM predictions
                fgsm_output = classifier(fgsm_image)
                fgsm_class = fgsm_output.argmax().item()
                fgsm_conf = F.softmax(fgsm_output, dim=1).max().item()

                # PGD predictions
                pgd_output = classifier(pgd_image)
                pgd_class = pgd_output.argmax().item()
                pgd_conf = F.softmax(pgd_output, dim=1).max().item()

            print(f"FGSM - New Class: {fgsm_class}, Confidence: {fgsm_conf:.4f}")
            print(f"PGD  - New Class: {pgd_class}, Confidence: {pgd_conf:.4f}")

            # Visualize results
            plt.figure(figsize=(20, 5))

            # Original
            plt.subplot(1, 4, 1)
            plt.imshow(original[0].cpu().detach().permute(1, 2, 0).numpy())
            plt.title(f'Original\nClass: {orig_class}\nConf: {orig_conf:.4f}')
            plt.axis('off')

            # FGSM
            plt.subplot(1, 4, 2)
            plt.imshow(fgsm_image[0].cpu().detach().permute(1, 2, 0).numpy())
            plt.title(f'FGSM\nClass: {fgsm_class}\nConf: {fgsm_conf:.4f}')
            plt.axis('off')

            # PGD
            plt.subplot(1, 4, 3)
            plt.imshow(pgd_image[0].cpu().detach().permute(1, 2, 0).numpy())
            plt.title(f'PGD\nClass: {pgd_class}\nConf: {pgd_conf:.4f}')
            plt.axis('off')

            # Perturbation comparison
            plt.subplot(1, 4, 4)
            fgsm_pert = torch.norm((fgsm_image - original)[0], dim=0).cpu().detach()
            pgd_pert = torch.norm((pgd_image - original)[0], dim=0).cpu().detach()
            plt.plot(fgsm_pert.mean(dim=1), label='FGSM')
            plt.plot(pgd_pert.mean(dim=1), label='PGD')
            plt.title('Perturbation Magnitude\nper Row')
            plt.legend()

            plt.tight_layout()
            plt.show()

except Exception as e:
    print(f"Attack error: {e}")
    import traceback
    traceback.print_exc()

# Visualization

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import seaborn as sns
from pathlib import Path
import torch.nn as nn
from torchvision import transforms
import cv2
from typing import List, Tuple, Dict

class PerturbationFlowTracker:
    def __init__(self, original_image: torch.Tensor):
        """
        Initialize tracker with original image
        Args:
            original_image: Original image tensor [1, C, H, W]
        """
        self.original_image = original_image
        self.flow_history = []
        self.perturbation_history = []
        self.intermediate_images = []

    def track_step(self, current_image: torch.Tensor):
        """
        Track changes from one step to the next
        Args:
            current_image: Current perturbed image
        """
        # Store intermediate image
        self.intermediate_images.append(current_image.clone().detach().cpu())

        # Calculate perturbation from original
        perturbation = (current_image - self.original_image).detach().cpu()
        self.perturbation_history.append(perturbation)

        # Calculate flow if we have at least 2 images
        if len(self.intermediate_images) > 1:
            prev_image = self.intermediate_images[-2]
            current_image = self.intermediate_images[-1]

            # Calculate optical flow between consecutive images
            flow = self._calculate_flow(prev_image, current_image)
            self.flow_history.append(flow)

    def _calculate_flow(self, prev_image: torch.Tensor, current_image: torch.Tensor) -> np.ndarray:
        """Calculate optical flow between two images"""
        # Convert to numpy arrays in correct format for cv2
        prev_np = (prev_image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        curr_np = (current_image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)

        # Convert to grayscale
        prev_gray = cv2.cvtColor(prev_np, cv2.COLOR_RGB2GRAY)
        curr_gray = cv2.cvtColor(curr_np, cv2.COLOR_RGB2GRAY)

        # Calculate flow
        flow = cv2.calcOpticalFlowFarneback(
            prev_gray, curr_gray,
            None, 0.5, 3, 15, 3, 5, 1.2, 0
        )
        return flow

    def visualize_flow_step(self, step: int):
        """Visualize flow at a specific step"""
        if step >= len(self.flow_history):
            raise ValueError(f"Step {step} not available. Only have {len(self.flow_history)} steps.")

        flow = self.flow_history[step]
        magnitude, angle = cv2.cartToPolar(flow[..., 0], flow[..., 1])

        plt.figure(figsize=(15, 5))

        # Original image
        plt.subplot(131)
        plt.imshow(self.original_image[0].permute(1, 2, 0).cpu())
        plt.title('Original Image')
        plt.axis('off')

        # Current perturbed image
        plt.subplot(132)
        plt.imshow(self.intermediate_images[step+1][0].permute(1, 2, 0).cpu())
        plt.title(f'Perturbed Image (Step {step+1})')
        plt.axis('off')

        # Flow visualization
        plt.subplot(133)
        plt.quiver(flow[..., 0], flow[..., 1])
        plt.title('Flow Field')
        plt.axis('off')

        plt.tight_layout()
        plt.show()

    def create_flow_animation(self, save_path: str = None):
        """Create animation of the flow progression"""
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        def update(frame):
            for ax in axes:
                ax.clear()

            # Original image
            axes[0].imshow(self.original_image[0].permute(1, 2, 0).cpu())
            axes[0].set_title('Original Image')
            axes[0].axis('off')

            # Current perturbed image
            axes[1].imshow(self.intermediate_images[frame][0].permute(1, 2, 0).cpu())
            axes[1].set_title(f'Perturbed Image (Step {frame})')
            axes[1].axis('off')

            if frame > 0:
                # Flow visualization
                flow = self.flow_history[frame-1]
                axes[2].quiver(flow[..., 0], flow[..., 1])
                axes[2].set_title('Flow Field')
                axes[2].axis('off')

            plt.tight_layout()

        anim = FuncAnimation(
            fig, update,
            frames=len(self.intermediate_images),
            interval=200
        )

        if save_path:
            anim.save(save_path, writer='pillow')

        plt.close()
        return anim

# Modified PGD attack to use the tracker
def pgd_attack_with_tracking(model, image, target_class_idx, epsilon=0.1, alpha=0.02, num_iter=100):
    """PGD Attack with perturbation tracking"""
    # Initialize tracker
    tracker = PerturbationFlowTracker(image)

    # Initialize attack
    perturbed_image = image.clone().detach()
    perturbed_image = perturbed_image + torch.empty_like(image).uniform_(-epsilon/2, epsilon/2)
    perturbed_image = torch.clamp(perturbed_image, -1, 1)

    # Track initial state
    tracker.track_step(perturbed_image)

    target = torch.tensor([target_class_idx], device=device)

    for i in range(num_iter):
        perturbed_image.requires_grad_(True)

        output = classifier(perturbed_image)
        loss = F.cross_entropy(output, target)

        # Zero gradients
        if perturbed_image.grad is not None:
            perturbed_image.grad.data.zero_()

        loss.backward()

        # Update image
        adv_image = perturbed_image + alpha * perturbed_image.grad.sign()
        eta = torch.clamp(adv_image - image, min=-epsilon, max=epsilon)
        perturbed_image = torch.clamp(image + eta, min=-1, max=1).detach()

        # Track this step
        tracker.track_step(perturbed_image)

    return perturbed_image, tracker

In [None]:
try:
    # Take one sample
    original = synthetic_samples[0].unsqueeze(0).to(device)

    # Define target class
    target_class_idx = 404  # airliner

    # Get original prediction
    with torch.no_grad():
        orig_output = classifier(original)
        orig_class = orig_output.argmax().item()
        orig_conf = F.softmax(orig_output, dim=1).max().item()

    print(f"\nOriginal Classification:")
    print(f"Class: {orig_class}, Confidence: {orig_conf:.4f}")

    # Generate adversarial example with tracking
    print("\nGenerating adversarial example...")
    adv_image, tracker = pgd_attack_with_tracking(
        model=model,
        image=original,
        target_class_idx=target_class_idx,
        epsilon=0.1,
        alpha=0.02,
        num_iter=100
    )

    # Get adversarial prediction
    with torch.no_grad():
        adv_output = classifier(adv_image)
        adv_class = adv_output.argmax().item()
        adv_conf = F.softmax(adv_output, dim=1).max().item()

    print(f"\nAdversarial Classification:")
    print(f"Class: {adv_class}, Confidence: {adv_conf:.4f}")

    # Visualize steps
    print("\nVisualizing attack progression...")
    steps_to_show = [0,1,2,3,4,5,6,7,8,9,10,11,12,1]  # Show progression at different points
    for step in steps_to_show:
        if step < len(tracker.flow_history):
            print(f"\nShowing step {step}")
            tracker.visualize_flow_step(step)

    # Create and save animation
    print("\nCreating animation...")
    animation = tracker.create_flow_animation('perturbation_flow.gif')
    print("Animation saved as 'perturbation_flow.gif'")

except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import seaborn as sns
import cv2

class EnhancedFeatureExtractor:
    def __init__(self, model):
        """Enhanced feature extractor with improved tracking capabilities"""
        self.model = model
        self.features = {}
        self.hooks = []
        self.imagenet_labels = self._load_imagenet_labels()

        # Track more layers for better analysis
        self.target_layers = [
            'layer1.2.conv3',    # Early visual features
            'layer2.3.conv3',    # Mid-level features
            'layer3.5.conv3',    # Higher-level features
            'layer4.2.conv3',    # Final features
            'avgpool'            # Global features
        ]

        self._register_hooks()

    def _load_imagenet_labels(self):
        """Load ImageNet class labels"""
        # Default ImageNet classes if file not available
        return [f"Class_{i}" for i in range(1000)]

    def _register_hooks(self):
        def hook_fn(layer_name):
            def forward_hook(module, input, output):
                self.features[layer_name] = output.detach()
            return forward_hook

        for name, module in self.model.named_modules():
            if any(target in name for target in self.target_layers):
                hook = module.register_forward_hook(hook_fn(name))
                self.hooks.append(hook)
                print(f"Registered hook for layer: {name}")

    def get_features(self, x):
        """Get features for input x with proper preprocessing"""
        self.features.clear()

        # Ensure input is properly normalized
        if x.min() < -1 or x.max() > 1:
            x = x / 255.0 if x.max() > 1 else x
            x = x * 2 - 1  # Scale to [-1, 1]

        # Forward pass
        with torch.no_grad():
            _ = self.model(x)

        return self.features.copy()

    def calculate_attention(self, features, layer_name='layer4.2.conv3'):
        """Improved Grad-CAM style attention calculation"""
        feature_map = features[layer_name]
        B, C, H, W = feature_map.shape

        # Calculate channel importance weights using GAP
        weights = F.adaptive_avg_pool2d(feature_map, 1)  # [B, C, 1, 1]

        # Compute weighted sum of feature maps
        cam = torch.zeros((B, H, W), device=feature_map.device)
        for i in range(C):
            cam += weights[:, i, 0, 0][:, None, None] * feature_map[:, i, :, :]

        # Apply ReLU to focus on features that positively influence the decision
        cam = F.relu(cam)

        # Normalize
        B = cam.shape[0]
        cam_min = cam.view(B, -1).min(1)[0].view(B, 1, 1)
        cam_max = cam.view(B, -1).max(1)[0].view(B, 1, 1)
        cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)

        # Resize to input size
        cam = F.interpolate(
            cam.unsqueeze(1),
            size=(224, 224),
            mode='bilinear',
            align_corners=False
        )

        return cam

    def remove_hooks(self):
        """Clean up hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

In [None]:
class EnhancedFeatureTracker:
    def __init__(self, model, save_history=True):
        """Enhanced feature tracker with improved visualization and tracking"""
        self.feature_extractor = EnhancedFeatureExtractor(model)
        self.save_history = save_history

        # History tracking
        self.feature_history = []
        self.attention_history = []
        self.prediction_history = []
        self.feature_changes = []

    def track_step(self, image, step):
        """Track features, attention, and predictions for each step"""
        # Get features
        features = self.feature_extractor.get_features(image)

        # Calculate attention
        attention = self.feature_extractor.calculate_attention(features)

        # Get model predictions
        with torch.no_grad():
            output = self.feature_extractor.model(image)
            predictions = F.softmax(output, dim=1)

        # Store history if enabled
        if self.save_history:
            self.feature_history.append(features)
            self.attention_history.append(attention)
            self.prediction_history.append(predictions)

            # Calculate feature changes if we have previous features
            if len(self.feature_history) > 1:
                changes = self.calculate_feature_changes(
                    self.feature_history[-2],
                    features
                )
                self.feature_changes.append(changes)

        return features, attention, predictions

    def calculate_feature_changes(self, prev_features, curr_features):
        """Calculate changes in feature representations"""
        changes = {}
        for layer_name in self.feature_extractor.target_layers:
            if layer_name in prev_features and layer_name in curr_features:
                prev = prev_features[layer_name]
                curr = curr_features[layer_name]

                # Normalize features
                prev_flat = F.normalize(prev.view(prev.size(0), -1), p=2, dim=1)
                curr_flat = F.normalize(curr.view(curr.size(0), -1), p=2, dim=1)

                # Calculate cosine distance
                similarity = F.cosine_similarity(prev_flat, curr_flat)
                changes[layer_name] = (1 - similarity).mean().item()

        return changes

    def visualize_step_detailed(self, step):
      """Enhanced visualization of model behavior at each step"""
      if step >= len(self.feature_history):
          raise ValueError(f"Step {step} not available")

      features = self.feature_history[step]
      attention = self.attention_history[step]
      predictions = self.prediction_history[step]

      # Create figure with subplots
      fig = plt.figure(figsize=(20, 12))
      gs = plt.GridSpec(3, 4, figure=fig)

      # 1. Feature maps from different layers
      for idx, layer_name in enumerate(self.feature_extractor.target_layers[:-1]):
          ax = fig.add_subplot(gs[0, idx])
          feature_map = features[layer_name]
          feature_vis = torch.mean(feature_map, dim=1)[0].cpu()

          im = ax.imshow(feature_vis.numpy(), cmap='viridis')
          ax.set_title(f'Layer {layer_name}\nFeatures')
          ax.axis('off')
          plt.colorbar(im, ax=ax)

      # 2. Attention visualization
      ax = fig.add_subplot(gs[1, 0])
      attention_vis = attention[0, 0].cpu().numpy()
      im = ax.imshow(attention_vis, cmap='hot')
      ax.set_title('Attention Map')
      ax.axis('off')
      plt.colorbar(im, ax=ax)

      # 3. Top predictions
      ax = fig.add_subplot(gs[1, 1])
      values, indices = predictions[0].cpu().topk(5)
      ax.bar(range(5), values.numpy())
      ax.set_xticks(range(5))
      ax.set_xticklabels([f"Class {idx}" for idx in indices.numpy()], rotation=45)
      ax.set_title('Top 5 Predictions')

      # 4. Feature changes if available
      if step > 0 and self.feature_changes:
          ax = fig.add_subplot(gs[1, 2])
          changes = self.feature_changes[step-1]
          layers = list(changes.keys())
          values = list(changes.values())
          ax.bar(range(len(changes)), values)
          ax.set_xticks(range(len(changes)))
          ax.set_xticklabels(layers, rotation=45)
          ax.set_title('Feature Changes')

      # 5. Combined visualization - Fixed
      ax = fig.add_subplot(gs[1, 3])
      early_features = torch.mean(features[self.feature_extractor.target_layers[0]], dim=1)[0]
      # Resize early features to match attention map size
      early_features_resized = F.interpolate(
          early_features.unsqueeze(0).unsqueeze(0),
          size=(224, 224),
          mode='bilinear',
          align_corners=False
      )[0, 0].cpu().numpy()

      # Normalize early features
      early_features_resized = (early_features_resized - early_features_resized.min()) / \
                              (early_features_resized.max() - early_features_resized.min() + 1e-8)

      combined = early_features_resized * attention_vis
      im = ax.imshow(combined, cmap='viridis')
      ax.set_title('Features + Attention')
      ax.axis('off')
      plt.colorbar(im, ax=ax)

      plt.tight_layout()
      return fig

    def create_progression_animation(self, save_path=None):
        """Create animation showing the progression of the attack"""
        raise NotImplementedError("Animation functionality to be implemented")

    def cleanup(self):
        """Clean up resources"""
        self.feature_extractor.remove_hooks()
        if self.save_history:
            self.feature_history = []
            self.attention_history = []
            self.prediction_history = []
            self.feature_changes = []

In [None]:
def pgd_attack_with_enhanced_tracking(model, classifier, image, target_class_idx,
                                    epsilon=0.1, alpha=0.02, num_iter=100):
    """PGD Attack with enhanced tracking and visualization"""
    # Initialize tracker
    tracker = EnhancedFeatureTracker(classifier)

    # Initialize attack
    perturbed_image = image.clone().detach()
    perturbed_image = perturbed_image + torch.empty_like(image).uniform_(-epsilon/2, epsilon/2)
    perturbed_image = torch.clamp(perturbed_image, -1, 1)

    # Track initial state
    tracker.track_step(perturbed_image, 0)

    target = torch.tensor([target_class_idx], device=device)
    loss_history = []

    for i in range(num_iter):
        perturbed_image.requires_grad_(True)

        output = classifier(perturbed_image)
        loss = -F.cross_entropy(output, target)
        loss_history.append(loss.item())

        if perturbed_image.grad is not None:
            perturbed_image.grad.data.zero_()

        loss.backward()

        grad = perturbed_image.grad.data
        grad_norm = torch.norm(grad, p=float('inf'))
        normalized_grad = grad / (grad_norm + 1e-10)

        adv_image = perturbed_image + alpha * normalized_grad
        eta = torch.clamp(adv_image - image, min=-epsilon, max=epsilon)
        perturbed_image = torch.clamp(image + eta, min=-1, max=1).detach()

        # Track every step instead of just 10% intervals
        tracker.track_step(perturbed_image, i+1)

        # Print progress less frequently
        if (i + 1) % 10 == 0:
            with torch.no_grad():
                current_output = classifier(perturbed_image)
                current_prob = F.softmax(current_output, dim=1)[0][target_class_idx].item()
            print(f"Step {i+1}/{num_iter}: Target class probability = {current_prob:.4f}")

    return perturbed_image, tracker, loss_history

# Usage example
def run_attack_analysis():
    try:
        original = synthetic_samples[0].unsqueeze(0).to(device)
        target_class_idx = 404  # airliner

        print("\nStarting adversarial attack with enhanced tracking...")
        adv_image, tracker, loss_history = pgd_attack_with_enhanced_tracking(
            model=model,
            classifier=classifier,
            image=original,
            target_class_idx=target_class_idx,
            epsilon=0.1,
            alpha=0.02,
            num_iter=100
        )

        # Show fewer steps to avoid memory issues
        steps_to_show = [0, 20, 40, 60, 80]  # Adjusted step intervals
        for step in steps_to_show:
            print(f"\nVisualizing step {step}...")
            fig = tracker.visualize_step_detailed(step)
            plt.show()
            plt.close(fig)

        # Plot loss history
        plt.figure(figsize=(10, 5))
        plt.plot(loss_history)
        plt.title('Attack Loss History')
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.show()

        # Create final comparison
        plt.figure(figsize=(15, 5))

        plt.subplot(131)
        plt.imshow(original[0].cpu().permute(1, 2, 0).numpy())
        plt.title('Original Image')
        plt.axis('off')

        plt.subplot(132)
        plt.imshow(adv_image[0].cpu().permute(1, 2, 0).numpy())
        plt.title('Adversarial Image')
        plt.axis('off')

        plt.subplot(133)
        perturbation = (adv_image - original)[0].cpu().permute(1, 2, 0).numpy()
        plt.imshow(np.abs(perturbation) * 5)
        plt.title('Perturbation (x5)')
        plt.colorbar()
        plt.axis('off')

        plt.tight_layout()
        plt.show()

        # Cleanup
        tracker.cleanup()

    except Exception as e:
        print(f"Error during attack analysis: {e}")
        import traceback
        traceback.print_exc()

# Run the analysis
if __name__ == "__main__":
    run_attack_analysis()