## Imports

In [None]:
from IPython.display import clear_output

In [None]:
!pip install torchattacks --quiet
!pip install timm --quiet
!pip install wandb --quiet
!pip install scikit-learn seaborn --quiet
clear_output()

In [None]:
import os
import sys
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
from torch.utils.data import DataLoader
import torchattacks
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import wandb
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, confusion_matrix, accuracy_score
import seaborn as sns

In [None]:
IN_COLAB = 'google.colab' in sys.modules

dataset_path = None
PREPROCESSED_DATA_PATH = None
ROOT_MODEL_DIR = None

if IN_COLAB:
    print("Running in Google Colab environment.")
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        print("Google Drive mounted successfully.")

        BASE_DRIVE_PATH = '/content/drive/MyDrive'
        CV_PROJECT_FOLDER_COLAB = os.path.join(BASE_DRIVE_PATH, 'CV')
        dataset_path = os.path.join(CV_PROJECT_FOLDER_COLAB, 'Dataset')
        PREPROCESSED_DATA_PATH = os.path.join(CV_PROJECT_FOLDER_COLAB, 'Dataset_pickle')
        ROOT_MODEL_DIR = os.path.join(CV_PROJECT_FOLDER_COLAB, 'models')

        os.makedirs(dataset_path, exist_ok=True)
        os.makedirs(PREPROCESSED_DATA_PATH, exist_ok=True)

        print(f"Colab raw data path set to: {dataset_path}")
        print(f"Colab preprocessed data (.pkl) path set to: {PREPROCESSED_DATA_PATH}")
        print(f"Colab model save/load (.pth) path set to: {ROOT_MODEL_DIR}")

        print("Attempting W&B login for Colab...")
        wandb.login()

    except ImportError:
        print("Error: google.colab module not found, but IN_COLAB was true. This is unexpected.")
        sys.exit("Colab environment detection mismatch.")
    except RuntimeError as e:
        print(f"Error during Colab drive mount: {e}")
        sys.exit("Colab environment detected, but drive mount failed.")
    except Exception as e:
        print(f"An unexpected error occurred during Colab setup: {e}")
        sys.exit("Colab setup failed.")

else:
    print("Not in Colab, setting up for local execution.")
    ROOT_DIR_LOCAL = "./"

    dataset_path = os.path.join(ROOT_DIR_LOCAL, "data", "Dataset_Raw")
    PREPROCESSED_DATA_PATH = os.path.join(ROOT_DIR_LOCAL, "data", "Dataset_Preprocessed")
    ROOT_MODEL_DIR = os.path.join(ROOT_DIR_LOCAL, "models")

    os.makedirs(dataset_path, exist_ok=True)
    os.makedirs(PREPROCESSED_DATA_PATH, exist_ok=True)
    os.makedirs(ROOT_MODEL_DIR, exist_ok=True)

    print(f"Local raw data path set to: {dataset_path}")
    print(f"Local preprocessed data (.pkl) path set to: {PREPROCESSED_DATA_PATH}")
    print(f"Local model save/load (.pth) path set to: {ROOT_MODEL_DIR}")

    print("Not in Colab, wandb.login() skipped. Ensure you are logged in via CLI (e.g., 'wandb login').")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if dataset_path is None or PREPROCESSED_DATA_PATH is None or ROOT_MODEL_DIR is None:
    sys.exit("Critical error: Root directory paths were not set. Please check the setup logic.")

## Globals

In [None]:
model_name_flag = 'efficientnet' #{'efficientnet', 'efficientnet_pim', 'efficientnet_freq', 'efficientnet_adv'}
epochs = 5
learning_rate = 0.001
criterion = nn.CrossEntropyLoss()
optimizer_name_flag = 'adamw'

R_PIM = 0.1
ALPHA_PIM = 1.0
SHALLOW_FEATURE_IDX = 2

num_workers = 0 if os.name == 'nt' else 2

WANDB_PROJECT_NAME = "CV_Project_Adversarial_DeepFake"
WANDB_ENTITY = "olmoceriotti"

In [None]:
true_dataset_path = os.path.join(dataset_path, 'true_dataset')
fake_dataset_path = os.path.join(dataset_path, 'fake_dataset_one_source')

## Utils

In [None]:
def init_train_model(model_name_str, opt_name_str, lr,
                     r_pim_hp=R_PIM, alpha_pim_hp=ALPHA_PIM, shallow_idx_hp=SHALLOW_FEATURE_IDX,
                     adv_drop_prob_hp=0.1, adv_block_size_hp=7, adv_classifier_dropout_hp=0.3,
                     smooth_sigma_hp=0.25, smooth_num_samples_hp=5):
    if model_name_str == 'efficientnet':
        weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1
        model = models.efficientnet_b0(weights=weights)
        num_ftrs = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(num_ftrs, 2)
        )
    elif model_name_str == 'efficientnet_pim':
        model = EfficientNetB0WithPIM(pretrained=True, num_classes=2, shallow_feature_idx=shallow_idx_hp)
    elif model_name_str == 'efficientnet_freq':
        model = get_efficientnet_freq()
    elif model_name_str == 'efficientnet_adv':
        model = EfficientNetB0WithSpatialDropBlock(
            num_classes=2,
            drop_prob=adv_drop_prob_hp,
            block_size=adv_block_size_hp,
            classifier_dropout=adv_classifier_dropout_hp
        )
    else:
        raise ValueError(f"Unsupported model name: {model_name_str}")

    if opt_name_str == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr=lr)
    elif opt_name_str == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif opt_name_str == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    else:
        optimizer = optim.AdamW(model.parameters(), lr=lr)
    return model, optimizer

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(dataloader, desc="Training (Standard)"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    if total == 0: return 0.0, 0.0
    return running_loss / total, correct / total

def evaluate(model, dataloader, criterion, device, return_preds_labels=False):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels_list = []
    all_preds_probs_list = []
    all_preds_labels_list = []

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            if return_preds_labels:
                all_labels_list.append(labels.cpu())
                all_preds_probs_list.append(probs.cpu())
                all_preds_labels_list.append(predicted.cpu())

    if total == 0:
        if return_preds_labels:
            return 0.0, 0.0, torch.empty(0), torch.empty(0), torch.empty(0)
        return 0.0, 0.0

    avg_loss = running_loss / total
    avg_acc = correct / total

    if return_preds_labels:
        all_labels = torch.cat(all_labels_list)
        all_preds_probs = torch.cat(all_preds_probs_list)
        all_preds_labels = torch.cat(all_preds_labels_list)
        return avg_loss, avg_acc, all_labels, all_preds_probs, all_preds_labels
    return avg_loss, avg_acc


def train_one_epoch_with_pim(model, dataloader, optimizer, criterion, device, alpha_pim_val, r_pim_val):
    model.train()
    running_loss_objective = 0.0
    correct_empirical = 0
    total_samples = 0

    for images, labels in tqdm(dataloader, desc="Training with PIM"):
        images, labels = images.to(device), labels.to(device)
        batch_size = images.size(0)
        total_samples += batch_size

        optimizer.zero_grad()

        model.is_first_pass = True

        outputs_empirical = model(images)

        if model.mu_shallow is None or model.sigma_shallow is None:
            print("Warning: mu_shallow or sigma_shallow not populated in first pass. Skipping PIM for this batch.")
            loss_empirical_fallback = criterion(outputs_empirical, labels)
            loss_empirical_fallback.backward()
            optimizer.step()
            running_loss_objective += loss_empirical_fallback.item() * batch_size
            _, predicted_emp_fb = torch.max(outputs_empirical.data, 1)
            correct_empirical += (predicted_emp_fb == labels).sum().item()
            continue

        model.mu_shallow.retain_grad()
        model.sigma_shallow.retain_grad()

        loss_empirical = criterion(outputs_empirical, labels)
        loss_empirical.backward(retain_graph=True)

        g1_theta_grads = {name: param.grad.clone() for name, param in model.named_parameters() if param.requires_grad and param.grad is not None}

        grad_mu_s_batch_specific = model.mu_shallow.grad if model.mu_shallow.grad is not None else torch.zeros_like(model.mu_shallow)
        grad_sigma_s_batch_specific = model.sigma_shallow.grad if model.sigma_shallow.grad is not None else torch.zeros_like(model.sigma_shallow)

        grad_mu_s_batch_level = grad_mu_s_batch_specific.mean(dim=0, keepdim=True)
        grad_sigma_s_batch_level = grad_sigma_s_batch_specific.mean(dim=0, keepdim=True)

        stacked_grads = torch.cat((grad_mu_s_batch_level.flatten(), grad_sigma_s_batch_level.flatten()))
        norm_val = torch.norm(stacked_grads, p=2) + 1e-7

        delta_mu_batch = r_pim_val * grad_mu_s_batch_level / norm_val
        delta_sigma_batch = r_pim_val * grad_sigma_s_batch_level / norm_val

        delta_mu_batch = delta_mu_batch.detach()
        delta_sigma_batch = delta_sigma_batch.detach()

        optimizer.zero_grad()

        model.is_first_pass = False
        outputs_regularization = model(images, delta_mu_for_pim=delta_mu_batch, delta_sigma_for_pim=delta_sigma_batch)
        loss_regularization = criterion(outputs_regularization, labels)

        loss_regularization.backward()
        g2_theta_grads = {name: param.grad.clone() for name, param in model.named_parameters() if param.requires_grad and param.grad is not None}

        for name, param in model.named_parameters():
            if param.requires_grad:
                grad1 = g1_theta_grads.get(name, torch.zeros_like(param.data))
                grad2 = g2_theta_grads.get(name, torch.zeros_like(param.data))

                param.grad = (1 - alpha_pim_val) * grad1 + alpha_pim_val * grad2

        optimizer.step()

        current_loss_objective = (1 - alpha_pim_val) * loss_empirical.item() + alpha_pim_val * loss_regularization.item()
        running_loss_objective += current_loss_objective * batch_size

        _, predicted_emp = torch.max(outputs_empirical.data, 1)
        correct_empirical += (predicted_emp == labels).sum().item()

    if total_samples == 0: return 0.0, 0.0
    avg_loss_objective = running_loss_objective / total_samples
    avg_acc_empirical = correct_empirical / total_samples
    return avg_loss_objective, avg_acc_empirical

def evaluate_on_adversarial(model_to_eval, test_dataloader, loss_criterion, current_device,
                            attack_instance, attack_display_name, return_details=False):
    model_to_eval.eval()
    adv_correct_count = 0
    adv_total_samples = 0
    adv_cumulative_loss = 0.0
    progress_bar_desc = f"Attacking with {attack_display_name}"

    all_orig_labels_list = []
    all_adv_preds_probs_list = []
    all_adv_preds_labels_list = []
    total_perturbation_l0 = 0.0
    total_perturbation_l2 = 0.0
    total_perturbation_linf = 0.0
    perturbed_sample_count = 0

    for images_batch, labels_batch in tqdm(test_dataloader, desc=progress_bar_desc):
        images_batch, labels_batch = images_batch.to(current_device), labels_batch.to(current_device)
        try:
            adv_images_batch = attack_instance(images_batch, labels_batch)
        except Exception as e:
            try:
                adv_images_batch = attack_instance(images_batch)
            except Exception as e2:
                print(f"Fallback failed for {attack_display_name}. Error: {e2}. Skipping batch.")
                continue
        with torch.no_grad():
            outputs_adv = model_to_eval(adv_images_batch)
            loss_adv = loss_criterion(outputs_adv, labels_batch)
            probs_adv = torch.softmax(outputs_adv, dim=1)
            _, predicted_adv = torch.max(outputs_adv.data, 1)

            adv_cumulative_loss += loss_adv.item() * images_batch.size(0)
            adv_correct_count += (predicted_adv == labels_batch).sum().item()
            adv_total_samples += labels_batch.size(0)

            if return_details:
                all_orig_labels_list.append(labels_batch.cpu())
                all_adv_preds_probs_list.append(probs_adv.cpu())
                all_adv_preds_labels_list.append(predicted_adv.cpu())

                perturbation = adv_images_batch - images_batch
                total_perturbation_l0 += perturbation.abs().gt(1e-6).view(images_batch.size(0), -1).sum(dim=1).float().sum()
                total_perturbation_l2 += torch.norm(perturbation.view(images_batch.size(0), -1), p=2, dim=1).sum()
                total_perturbation_linf += torch.norm(perturbation.view(images_batch.size(0), -1), p=float('inf'), dim=1).sum()
                perturbed_sample_count += images_batch.size(0)


    if adv_total_samples == 0:
        print(f"Warning: No samples were processed for attack {attack_display_name}.")
        if return_details:
            return 0.0, 0.0, torch.empty(0), torch.empty(0), torch.empty(0), 0.0, 0.0, 0.0
        return 0.0, 0.0

    final_adv_accuracy = adv_correct_count / adv_total_samples
    final_adv_avg_loss = adv_cumulative_loss / adv_total_samples

    if return_details:
        all_orig_labels = torch.cat(all_orig_labels_list) if all_orig_labels_list else torch.empty(0)
        all_adv_preds_probs = torch.cat(all_adv_preds_probs_list) if all_adv_preds_probs_list else torch.empty(0)
        all_adv_preds_labels = torch.cat(all_adv_preds_labels_list) if all_adv_preds_labels_list else torch.empty(0)
        avg_perturb_l0 = (total_perturbation_l0 / perturbed_sample_count).item() if perturbed_sample_count > 0 else 0.0
        avg_perturb_l2 = (total_perturbation_l2 / perturbed_sample_count).item() if perturbed_sample_count > 0 else 0.0
        avg_perturb_linf = (total_perturbation_linf / perturbed_sample_count).item() if perturbed_sample_count > 0 else 0.0
        return final_adv_accuracy, final_adv_avg_loss, all_orig_labels, all_adv_preds_probs, all_adv_preds_labels, avg_perturb_l0, avg_perturb_l2, avg_perturb_linf
    return final_adv_accuracy, final_adv_avg_loss

def log_detailed_metrics(y_true_tensor, y_pred_probs_tensor, y_pred_labels_tensor, section_prefix="test/clean", log_to_wandb=True):
    if y_true_tensor.numel() == 0 or y_pred_labels_tensor.numel() == 0:
        print(f"Skipping detailed metrics for {section_prefix} due to empty tensors.")
        if log_to_wandb:
            wandb.log({
                f"{section_prefix}_accuracy": "N/A", f"{section_prefix}_precision": "N/A",
                f"{section_prefix}_recall": "N/A", f"{section_prefix}_f1_score": "N/A",
                f"{section_prefix}_auc_roc": "N/A", f"{section_prefix}_avg_confidence": "N/A"
            })
        return

    y_true_cpu = y_true_tensor.cpu().numpy()
    y_pred_labels_cpu = y_pred_labels_tensor.cpu().numpy()
    y_pred_probs_cpu = y_pred_probs_tensor.cpu().numpy()

    accuracy = accuracy_score(y_true_cpu, y_pred_labels_cpu)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true_cpu, y_pred_labels_cpu, average='binary', zero_division=0)

    auc_roc = "N/A"
    if len(np.unique(y_true_cpu)) > 1 and y_pred_probs_cpu.ndim == 2 and y_pred_probs_cpu.shape[1] == 2:
        try:
            auc_roc = roc_auc_score(y_true_cpu, y_pred_probs_cpu[:, 1])
        except ValueError as e:
            print(f"Could not compute AUC for {section_prefix}: {e}")
            auc_roc = "N/A"
    elif y_pred_probs_cpu.ndim != 2 or y_pred_probs_cpu.shape[1] != 2:
         print(f"AUC calculation skipped for {section_prefix}: y_pred_probs_cpu has unexpected shape {y_pred_probs_cpu.shape}.")


    cm = confusion_matrix(y_true_cpu, y_pred_labels_cpu)
    fig_cm, ax_cm = plt.subplots()
    sns.heatmap(cm, annot=True, fmt='d', ax=ax_cm, cmap='Blues', cbar=False)
    ax_cm.set_xlabel('Predicted labels')
    ax_cm.set_ylabel('True labels')
    ax_cm.set_title(f'Confusion Matrix - {section_prefix}')
    plt.tight_layout()


    avg_confidence = 0.0
    if y_pred_probs_cpu.size > 0:
        avg_confidence = y_pred_probs_cpu[np.arange(len(y_pred_labels_cpu)), y_pred_labels_cpu].mean()


    metrics_log = {
        f"{section_prefix}_accuracy": accuracy,
        f"{section_prefix}_precision": precision,
        f"{section_prefix}_recall": recall,
        f"{section_prefix}_f1_score": f1,
        f"{section_prefix}_avg_confidence": avg_confidence,
    }
    if auc_roc != "N/A":
        metrics_log[f"{section_prefix}_auc_roc"] = auc_roc

    print(f"\nDetailed metrics for {section_prefix}:")
    for k, v in metrics_log.items():
        if isinstance(v, float): print(f"  {k}: {v:.4f}")
        else: print(f"  {k}: {v}")
    print(f"  Confusion Matrix:\n{cm}")

    if log_to_wandb:
        wandb.log(metrics_log)
        wandb.log({f"{section_prefix}_confusion_matrix": wandb.Image(fig_cm)})
    plt.close(fig_cm)

## Data

In [None]:
class DFFDDataset(Dataset):
    def __init__(self, true_dir, fake_dir, split='train', transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform

        true_split_dir = os.path.join(true_dir, split)
        fake_split_dir = os.path.join(fake_dir, split)

        if not os.path.exists(true_split_dir):
            print(f"Warning: Path does not exist {true_split_dir}")
        else:
            for fname in os.listdir(true_split_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(true_split_dir, fname))
                    self.labels.append(0)

        if not os.path.exists(fake_split_dir):
            print(f"Warning: Path does not exist {fake_split_dir}")
        else:
            for fname in os.listdir(fake_split_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(fake_split_dir, fname))
                    self.labels.append(1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            sys.exit(-1)

        return image, label

def load_dataset(filepath):
  with open(filepath, 'rb') as f:
    dataset = pickle.load(f)
  print(f"Dataset loaded from {filepath}")
  return dataset

In [None]:
class InMemoryDFFDDataset(Dataset):
    def __init__(self, true_dir, fake_dir, split='train', transform=None):
        self.processed_images = []
        self.labels = []
        self.transform = transform

        true_split_dir = os.path.join(true_dir, split)
        fake_split_dir = os.path.join(fake_dir, split)

        print(f"Loading images from: {true_split_dir} (True) and {fake_split_dir} (Fake)")

        # Load TRUE images
        if not os.path.exists(true_split_dir):
            print(f"Warning: Path does not exist {true_split_dir}")
        else:
            for fname in os.listdir(true_split_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(true_split_dir, fname)
                    try:
                        image = Image.open(img_path).convert('RGB')
                        if self.transform:
                            image = self.transform(image)
                        self.processed_images.append(image)
                        self.labels.append(0)
                    except Exception as e:
                        print(f"Error loading or transforming image {img_path}: {e}")
                        continue

        # Load FAKE images
        if not os.path.exists(fake_split_dir):
            print(f"Warning: Path does not exist {fake_split_dir}")
        else:
            for fname in os.listdir(fake_split_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(fake_split_dir, fname)
                    try:
                        image = Image.open(img_path).convert('RGB')
                        if self.transform:
                            image = self.transform(image)
                        self.processed_images.append(image)
                        self.labels.append(1)
                    except Exception as e:
                        print(f"Error loading or transforming image {img_path}: {e}")
                        continue

        if not self.processed_images:
            print("Warning: No images were loaded. Check paths and image files.")
        else:
            print(f"Successfully loaded and processed {len(self.processed_images)} images into memory.")

    def __len__(self):
        return len(self.processed_images)

    def __getitem__(self, idx):
        image = self.processed_images[idx]
        label = self.labels[idx]
        return image, label

    def save(self, filepath):
      with open(filepath, 'wb') as f:
            pickle.dump(self, f)
      print(f"Dataset saved to {filepath}")

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

LOAD_PREPROCESSED_DATASETS = True

if LOAD_PREPROCESSED_DATASETS:
    try:
        train_dataset = load_dataset(os.path.join(PREPROCESSED_DATA_PATH, 'train_dataset_ffhq_stylegan.pkl'))
        val_dataset   = load_dataset(os.path.join(PREPROCESSED_DATA_PATH, 'val_dataset_ffhq_stylegan.pkl'))
        test_dataset  = load_dataset(os.path.join(PREPROCESSED_DATA_PATH, 'test_dataset_ffhq_stylegan.pkl'))
        print("Successfully loaded preprocessed datasets.")
    except FileNotFoundError:
        print("Preprocessed dataset files not found. Processing from scratch...")
        LOAD_PREPROCESSED_DATASETS = False

if not LOAD_PREPROCESSED_DATASETS:
    train_dataset = InMemoryDFFDDataset(true_dir=true_dataset_path, fake_dir=fake_dataset_path, split='train', transform=transform)
    val_dataset   = InMemoryDFFDDataset(true_dir=true_dataset_path, fake_dir=fake_dataset_path, split='validation', transform=transform)
    test_dataset  = InMemoryDFFDDataset(true_dir=true_dataset_path, fake_dir=fake_dataset_path, split='test', transform=transform)

    os.makedirs(PREPROCESSED_DATA_PATH, exist_ok=True)
    train_dataset.save(os.path.join(PREPROCESSED_DATA_PATH, 'train_dataset_ffhq_stylegan.pkl'))
    val_dataset.save(os.path.join(PREPROCESSED_DATA_PATH, 'val_dataset_ffhq_stylegan.pkl'))
    test_dataset.save(os.path.join(PREPROCESSED_DATA_PATH, 'test_dataset_ffhq_stylegan.pkl'))


print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

if len(train_dataset) == 0 or len(val_dataset) == 0 or len(test_dataset) == 0:
    print("Warning: One or more datasets are empty. Check dataset paths and structure.")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=(len(train_dataset) > 32 and 'pim' in model_name_flag)) if len(train_dataset) > 0 else None
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=True,  drop_last=False) if len(val_dataset) > 0 else None # drop_last generally not needed for val/test
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=True,  drop_last=False) if len(test_dataset) > 0 else None

if train_loader is None or len(train_loader) == 0: print("Warning: train_loader is empty or None.")
if val_loader is None or len(val_loader) == 0: print("Warning: val_loader is empty or None.")
if test_loader is None or len(test_loader) == 0: print("Warning: test_loader is empty or None.")

In [None]:
train_dataset.save(os.path.join(PREPROCESSED_DATA_PATH, 'train_dataset_ffhq_stylegan.pkl'))
val_dataset.save(os.path.join(PREPROCESSED_DATA_PATH, 'val_dataset_ffhq_stylegan.pkl'))
test_dataset.save(os.path.join(PREPROCESSED_DATA_PATH, 'test_dataset_ffhq_stylegan.pkl'))

## Network

### Standard

#### EfficientNetB0

In [None]:
def get_efficientnet(pretrained=True):
    weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
    model = models.efficientnet_b0(weights=weights)

    num_ftrs = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.2, inplace=True),
        nn.Linear(num_ftrs, 2)
    )
    return model

### With regularization

In [None]:
class EfficientNetB0WithPIM(nn.Module):
    def __init__(self, pretrained=True, num_classes=2, shallow_feature_idx=2):
        super().__init__()
        weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        self.base_model = models.efficientnet_b0(weights=weights)

        if not (0 <= shallow_feature_idx < len(self.base_model.features) -1):
            raise ValueError(f"shallow_feature_idx must be between 0 and {len(self.base_model.features)-2}")

        self.shallow_features_extractor = self.base_model.features[:shallow_feature_idx + 1]
        self.deep_features_extractor = nn.Sequential(
            *self.base_model.features[shallow_feature_idx + 1:],
            self.base_model.avgpool
        )

        original_classifier = self.base_model.classifier
        if isinstance(original_classifier, nn.Sequential) and \
           len(original_classifier) > 1 and \
           isinstance(original_classifier[1], nn.Linear):
            num_ftrs = original_classifier[1].in_features
            self.classifier = nn.Sequential(
                original_classifier[0],
                nn.Linear(num_ftrs, num_classes)
            )
        elif isinstance(original_classifier, nn.Linear):
            num_ftrs = original_classifier.in_features
            self.classifier = nn.Linear(num_ftrs, num_classes)
        else:
            print("Warning: Classifier structure not as expected for EfficientNetB0. Using a generic classifier.")
            dummy_input = torch.randn(1, 3, 224, 224)
            with torch.no_grad():
                _ = self.base_model.features(dummy_input)
                pooled_features = self.base_model.avgpool(_)
                flattened_features = torch.flatten(pooled_features, 1)
            num_ftrs = flattened_features.shape[1]
            self.classifier = nn.Sequential(
                nn.Dropout(p=0.2, inplace=True),
                nn.Linear(num_ftrs, num_classes)
            )


        self.mu_shallow = None
        self.sigma_shallow = None
        self.is_first_pass = True

    def forward(self, x, delta_mu_for_pim=None, delta_sigma_for_pim=None):
        s_feats = self.shallow_features_extractor(x)

        if self.training and not self.is_first_pass:
            if delta_mu_for_pim is None or delta_sigma_for_pim is None:
                raise ValueError("delta_mu/sigma must be provided for PIM's second pass during training")
            s_feats_norm = (s_feats - self.mu_shallow.detach()) / (self.sigma_shallow.detach() + 1e-5)
            perturbed_sigma_val = self.sigma_shallow.detach() + delta_sigma_for_pim
            perturbed_mu_val = self.mu_shallow.detach() + delta_mu_for_pim
            output_s_feats = s_feats_norm * perturbed_sigma_val + perturbed_mu_val
        else:
            if self.training and self.is_first_pass:
                self.mu_shallow = s_feats.mean(dim=[2, 3], keepdim=True)
                self.sigma_shallow = s_feats.std(dim=[2, 3], correction=0, keepdim=True) + 1e-5
            output_s_feats = s_feats

        d_feats = self.deep_features_extractor(output_s_feats)
        d_feats_flat = torch.flatten(d_feats, 1)
        output = self.classifier(d_feats_flat)
        return output

    def train(self, mode=True):
        super().train(mode)
        if not mode:
            self.is_first_pass = True
        return self

### Other techniques

In [None]:
class DropBlock(nn.Module):
    def __init__(self, drop_prob=0.1, block_size=3):
        super(DropBlock, self).__init__()
        self.drop_prob = drop_prob
        self.block_size = block_size

    def forward(self, x):
        if not self.training or self.drop_prob == 0.:
            return x

        if x.dim() < 4:
            raise ValueError(f"DropBlock expects a 4D input (Batch, Channels, Height, Width), but got {x.dim()}D tensor with shape {x.shape}. "
                             "Ensure DropBlock is placed before pooling/flattening operations.")

        gamma = self.drop_prob / (self.block_size ** 2)
        mask_input_subsampled = x[:, :, ::self.block_size, ::self.block_size]
        mask = (torch.rand_like(mask_input_subsampled) < gamma).float()
        mask = nn.functional.interpolate(mask, scale_factor=self.block_size, mode='nearest')

        if mask.shape[2:] != x.shape[2:]:
            mask = mask[:, :, :x.shape[2], :x.shape[3]]
        return x * (1 - mask)

class EfficientNetB0WithSpatialDropBlock(nn.Module):
    def __init__(self, num_classes=2, drop_prob=0.1, block_size=7, classifier_dropout=0.3):
        super().__init__()
        weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1
        base_model = models.efficientnet_b0(weights=weights)

        self.features = base_model.features
        self.drop_block = DropBlock(drop_prob=drop_prob, block_size=block_size)
        self.avgpool = base_model.avgpool

        num_ftrs_for_classifier = base_model.classifier[1].in_features

        self.classifier = nn.Sequential(
            nn.Dropout(p=classifier_dropout, inplace=True),
            nn.Linear(num_ftrs_for_classifier, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.drop_block(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def get_efficientnet_adv():
    return EfficientNetB0WithSpatialDropBlock(
        num_classes=2,
        drop_prob=0.1,
        block_size=7,
        classifier_dropout=0.3
    )

In [None]:
class FrequencyLayer(nn.Module):
    def forward(self, x):
        if x.ndim == 3:
            x = x.unsqueeze(0)
        fft = torch.fft.fft2(x, dim=(-2, -1))
        fft_mag = torch.abs(fft)
        return fft_mag

class EfficientNetFrequency(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        base = models.efficientnet_b0(weights=weights)

        self.freq_layer = FrequencyLayer()
        self.base_model_features = base.features
        self.pool = nn.AdaptiveAvgPool2d(1)

        dummy_input = torch.randn(1, 3, 224, 224)
        with torch.no_grad():
            features = self.base_model_features(dummy_input)
            pooled_features = self.pool(features)
        num_ftrs_classifier = pooled_features.shape[1]


        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(num_ftrs_classifier, 2)
        )

    def forward(self, x):
        x_freq = self.freq_layer(x)
        x = self.base_model_features(x_freq)
        x = self.pool(x)
        x = self.classifier(x)
        return x

def get_efficientnet_freq():
    return EfficientNetFrequency(pretrained=True)

## Train

### Real training

In [None]:
run_name = f"{model_name_flag}_epochs{epochs}_lr{learning_rate}"
if 'pim' in model_name_flag:
    run_name += f"_pim_r{R_PIM}_alpha{ALPHA_PIM}_shallow{SHALLOW_FEATURE_IDX}"

wandb.init(
    project=WANDB_PROJECT_NAME,
    entity=WANDB_ENTITY,
    name=run_name,
    config={
        "model_name": model_name_flag,
        "epochs": epochs,
        "learning_rate": learning_rate,
        "optimizer": optimizer_name_flag,
        "criterion": "CrossEntropyLoss",
        "R_PIM": R_PIM if 'pim' in model_name_flag else None,
        "ALPHA_PIM": ALPHA_PIM if 'pim' in model_name_flag else None,
        "SHALLOW_FEATURE_IDX": SHALLOW_FEATURE_IDX if 'pim' in model_name_flag else None,
        "dataset_path": dataset_path,
        "device": str(device)
    }
)

print(f"\n--- Initializing Model: {model_name_flag} ---")
model, optimizer = init_train_model(
    model_name_str=model_name_flag,
    opt_name_str=optimizer_name_flag,
    lr=learning_rate,
    r_pim_hp=R_PIM,
    alpha_pim_hp=ALPHA_PIM,
    shallow_idx_hp=SHALLOW_FEATURE_IDX
)
model.to(device)
wandb.watch(model, criterion, log="all", log_freq=100)

print(f"Optimizer: {optimizer_name_flag}, Learning Rate: {learning_rate}")
if 'pim' in model_name_flag:
    print(f"PIM Hyperparameters: R_PIM={R_PIM}, ALPHA_PIM={ALPHA_PIM}, SHALLOW_IDX={SHALLOW_FEATURE_IDX}")


train_losses, val_losses = [], []
train_accs, val_accs = [], []

best_val_acc = 0.0

for epoch_num in range(epochs):
    print(f"\nEpoch {epoch_num+1}/{epochs}")
    epoch_log_dict = {"epoch": epoch_num + 1}

    current_train_loss, current_train_acc = 0.0, 0.0
    if train_loader and len(train_loader) > 0 :
        if 'pim' in model_name_flag:
            current_train_loss, current_train_acc = train_one_epoch_with_pim(model, train_loader, optimizer, criterion, device, ALPHA_PIM, R_PIM)
        else:
            current_train_loss, current_train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        print(f"Train Loss: {current_train_loss:.4f}, Acc: {current_train_acc:.4f}")
        epoch_log_dict["train/loss"] = current_train_loss
        epoch_log_dict["train/accuracy"] = current_train_acc
    else:
        print("Skipping training for epoch due to empty or None train_loader.")
        epoch_log_dict["train/loss"] = None
        epoch_log_dict["train/accuracy"] = None


    current_val_loss, current_val_acc = 0.0, 0.0
    if val_loader and len(val_loader) > 0:
        current_val_loss, current_val_acc = evaluate(model, val_loader, criterion, device)
        print(f"Val   Loss: {current_val_loss:.4f}, Acc: {current_val_acc:.4f}")
        epoch_log_dict["val/loss"] = current_val_loss
        epoch_log_dict["val/accuracy"] = current_val_acc
    else:
        print("Skipping validation for epoch due to empty or None val_loader.")
        epoch_log_dict["val/loss"] = None
        epoch_log_dict["val/accuracy"] = None

    train_losses.append(current_train_loss)
    train_accs.append(current_train_acc)
    val_losses.append(current_val_loss)
    val_accs.append(current_val_acc)

    wandb.log(epoch_log_dict)

    if current_val_acc > best_val_acc and (val_loader and len(val_loader) > 0):
        best_val_acc = current_val_acc
        model_save_path_best = os.join.path(ROOT_MODEL_DIR/f'{model_name_flag}_best_val_acc.pth')
        if device.type == 'cuda' and torch.cuda.device_count() > 1 and isinstance(model, nn.DataParallel):
            torch.save(model.module.state_dict(), model_save_path_best)
        else:
            torch.save(model.state_dict(), model_save_path_best)
        print(f"Best model (Val Acc: {best_val_acc:.4f}) saved to {model_save_path_best}")
        wandb.save(model_save_path_best, base_path=ROOT_MODEL_DIR)


if train_losses and val_losses and train_accs and val_accs:
    fig_performance, axs = plt.subplots(1, 2, figsize=(12, 4))
    axs[0].plot(train_losses, label='Train Loss')
    axs[0].plot(val_losses, label='Val Loss')
    axs[0].legend()
    axs[0].set_title('Loss over epochs')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')

    axs[1].plot(train_accs, label='Train Acc')
    axs[1].plot(val_accs, label='Val Acc')
    axs[1].legend()
    axs[1].set_title('Accuracy over epochs')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('Accuracy')

    plt.tight_layout()
    wandb.log({"Training Performance Plot": wandb.Image(fig_performance)})
    plt.show()
else:
    print("Skipping performance plot: Not enough data.")


model_save_path_final = os.join.path(ROOT_MODEL_DIR,f'{model_name_flag}_final_epochs{epochs}.pth')
if device.type == 'cuda' and torch.cuda.device_count() > 1 and isinstance(model, nn.DataParallel):
    torch.save(model.module.state_dict(), model_save_path_final)
else:
    torch.save(model.state_dict(), model_save_path_final)
print(f"Final model saved to {model_save_path_final}")
wandb.save(model_save_path_final, base_path=ROOT_MODEL_DIR)


optimizer_save_path = os.join.path(ROOT_MODEL_DIR, f'{model_name_flag}_optimizer_final_epochs{epochs}.pth')
torch.save(optimizer.state_dict(), optimizer_save_path)
print(f"Optimizer state saved to {optimizer_save_path}")
wandb.save(optimizer_save_path, base_path=ROOT_MODEL_DIR)

## Test

### Baseline

In [None]:
test_loss, test_acc = None, None
all_test_labels, all_test_preds_probs, all_test_preds_labels = None, None, None

if test_loader and len(test_loader) > 0 :
    test_loss, test_acc, all_test_labels, all_test_preds_probs, all_test_preds_labels = evaluate(
        model, test_loader, criterion, device, return_preds_labels=True
    )
    print(f"Test Loss (Clean): {test_loss:.4f}, Acc (Clean): {test_acc:.4f}")
    wandb.summary["test/clean_loss"] = test_loss
    wandb.summary["test/clean_accuracy"] = test_acc

    log_detailed_metrics(
        all_test_labels,
        all_test_preds_probs,
        all_test_preds_labels,
        section_prefix="test/clean_detailed",
        log_to_wandb=True
    )
else:
    print("Test loader is empty or None. Skipping clean data evaluation.")
    wandb.summary["test/clean_loss"] = "N/A"
    wandb.summary["test/clean_accuracy"] = "N/A"
    wandb.summary["test/clean_detailed_accuracy"] = "N/A"
    wandb.summary["test/clean_detailed_precision"] = "N/A"
    wandb.summary["test/clean_detailed_recall"] = "N/A"
    wandb.summary["test/clean_detailed_f1_score"] = "N/A"
    wandb.summary["test/clean_detailed_auc_roc"] = "N/A"
    wandb.summary["test/clean_detailed_avg_confidence"] = "N/A"

### Attack

In [None]:
attack_configurations = [
    {"name": "FGSM_eps0.05", "attack_class": torchattacks.FGSM, "params": {"eps": 0.05}},
    {"name": "FGSM_eps0.1", "attack_class": torchattacks.FGSM, "params": {"eps": 0.1}},
    {"name": "FGSM_eps0.2", "attack_class": torchattacks.FGSM, "params": {"eps": 0.2}},

    {"name": "PGD_Linf_eps4_steps7", "attack_class": torchattacks.PGD, "params": {"eps": 4/255, "alpha": 1/255, "steps": 7, "random_start": True}},
    {"name": "PGD_Linf_eps8_steps10", "attack_class": torchattacks.PGD, "params": {"eps": 8/255, "alpha": 2/255, "steps": 10, "random_start": True}},
    {"name": "PGD_Linf_eps0.1_steps20", "attack_class": torchattacks.PGD, "params": {"eps": 0.1, "alpha": 0.01, "steps": 20, "random_start": True}},
]

all_results_list = []

if test_acc is not None and test_loss is not None:
    all_results_list.append({
        "Attack Method": "Clean Data (No Attack)",
        "Parameters": "N/A",
        "Adversarial Accuracy": f"{test_acc*100:.2f}%",
        "Adversarial Loss": f"{test_loss:.4f}"
    })
else:
    print("Clean test data results not available, skipping for summary table.")


print("\n--- Starting Adversarial Attack Comparison ---")
clean_labels_for_asr = None
clean_preds_for_asr = None
clean_acc_for_asr_calc = None

if all_test_labels is not None and all_test_preds_labels is not None:
    clean_labels_for_asr = all_test_labels
    clean_preds_for_asr = all_test_preds_labels
    clean_acc_for_asr_calc = test_acc
elif test_loader and len(test_loader) > 0:
    print("Re-evaluating on clean data to get base predictions for ASR calculation...")
    _, clean_acc_for_asr_calc, clean_labels_for_asr, _, clean_preds_for_asr = evaluate(
        model, test_loader, criterion, device, return_preds_labels=True
    )
    if clean_labels_for_asr.numel() > 0 :
        print(f"Clean accuracy for ASR base: {clean_acc_for_asr_calc:.4f}")
    else:
        print("Failed to get clean predictions for ASR.")
        clean_labels_for_asr = None
else:
    print("Cannot calculate ASR as clean data predictions are unavailable and test_loader is empty.")


if test_loader and len(test_loader) > 0:
    for config in attack_configurations:
        attack_name_safe = "".join(c if c.isalnum() else "_" for c in config['name'])
        print(f"\nInitializing Attack: {config['name']}")
        print(f"Parameters: {config['params']}")

        try:
            current_model_for_attack = model.module if isinstance(model, nn.DataParallel) else model
            current_attack_instance = config["attack_class"](current_model_for_attack, **config["params"])
        except Exception as e:
            print(f"Error initializing attack {config['name']}: {e}")
            print("Skipping this attack configuration.")
            all_results_list.append({
                "Attack Method": config['name'],
                "Parameters": str(config['params']),
                "Adversarial Accuracy": "Error during init",
                "Adversarial Loss": "Error during init"
            })
            wandb.log({
                f"test_adv/{attack_name_safe}_accuracy": None,
                f"test_adv/{attack_name_safe}_loss": None,
                f"test_adv_ASR/{attack_name_safe}": None,
                f"test_adv_perturb_L0/{attack_name_safe}": None,
                f"test_adv_perturb_L2/{attack_name_safe}": None,
                f"test_adv_perturb_Linf/{attack_name_safe}": None
            })
            continue

        adv_acc, adv_loss, adv_orig_labels, adv_preds_probs, adv_preds_labels, \
        avg_l0, avg_l2, avg_linf = evaluate_on_adversarial(
            model, test_loader, criterion, device,
            current_attack_instance, config['name'], return_details=True
        )

        print(f"Results for {config['name']}:")
        print(f"  Adversarial Accuracy: {adv_acc*100:.2f}%")
        print(f"  Adversarial Loss: {adv_loss:.4f}")
        print(f"  Avg Perturbation L0: {avg_l0:.2f}, L2: {avg_l2:.4f}, Linf: {avg_linf:.4f}")


        all_results_list.append({
            "Attack Method": config['name'],
            "Parameters": str(config['params']),
            "Adversarial Accuracy": f"{adv_acc*100:.2f}%",
            "Adversarial Loss": f"{adv_loss:.4f}",
            "Avg L0 Perturb": f"{avg_l0:.2f}",
            "Avg L2 Perturb": f"{avg_l2:.4f}",
            "Avg Linf Perturb": f"{avg_linf:.4f}"
        })
        log_dict_adv = {
            f"test_adv/{attack_name_safe}_robust_accuracy": adv_acc,
            f"test_adv/{attack_name_safe}_loss": adv_loss,
            f"test_adv_perturb_L0/{attack_name_safe}": avg_l0,
            f"test_adv_perturb_L2/{attack_name_safe}": avg_l2,
            f"test_adv_perturb_Linf/{attack_name_safe}": avg_linf
        }

        log_detailed_metrics(
            adv_orig_labels,
            adv_preds_probs,
            adv_preds_labels,
            section_prefix=f"test_adv_detailed/{attack_name_safe}",
            log_to_wandb=True
        )

        asr = "N/A"
        if clean_labels_for_asr is not None and clean_preds_for_asr is not None and adv_orig_labels.numel() > 0:
            if len(clean_labels_for_asr) == len(adv_orig_labels):
                originally_correct_mask = (clean_preds_for_asr == clean_labels_for_asr)
                adversarially_misclassified_mask = (adv_preds_labels != adv_orig_labels)

                successful_attacks_mask = originally_correct_mask & adversarially_misclassified_mask

                num_originally_correct = originally_correct_mask.sum().item()
                num_successful_attacks = successful_attacks_mask.sum().item()

                if num_originally_correct > 0:
                    asr = num_successful_attacks / num_originally_correct
                    print(f"  Attack Success Rate (ASR): {asr*100:.2f}%")
                else:
                    asr = 0.0
                    print(f"  Attack Success Rate (ASR): 0.0% (no samples were originally correct)")
            else:
                print(f"  ASR calculation skipped for {attack_name_safe}: Mismatch in sample count between clean ({len(clean_labels_for_asr)}) and adversarial ({len(adv_orig_labels)}) evaluations.")
        else:
            print(f"  ASR calculation skipped for {attack_name_safe}: Clean predictions or adversarial results unavailable.")

        log_dict_adv[f"test_adv_ASR/{attack_name_safe}"] = asr
        wandb.log(log_dict_adv)


    print("\n--- Summary of All Attack Results ---")
    results_dataframe = pd.DataFrame(all_results_list)
    with pd.option_context('display.max_rows', None, 'display.max_colwidth', None, 'display.width', 1000):
        print(results_dataframe)

    wandb.log({"Adversarial Attack Summary Table": wandb.Table(dataframe=results_dataframe)})

    csv_filename = f"adversarial_attack_comparison_{wandb.run.name}.csv"
    results_dataframe.to_csv(csv_filename, index=False)
    print(f"\nResults saved to {csv_filename}")
    wandb.save(csv_filename)

    if len(all_results_list) > 1 and "Error" not in str(all_results_list[1].get("Adversarial Accuracy", "")):
        print("\n--- Visualizing some adversarial examples ---")

        atk_vis_config = attack_configurations[0]
        try:
            current_model_for_attack_vis = model.module if isinstance(model, nn.DataParallel) else model
            atk_vis_instance = atk_vis_config["attack_class"](current_model_for_attack_vis, **atk_vis_config["params"])

            images_vis, labels_vis = next(iter(test_loader))
            images_vis, labels_vis = images_vis.to(device), labels_vis.to(device)

            adv_images_vis = atk_vis_instance(images_vis, labels_vis)

            with torch.no_grad():
                outputs_clean_vis = model(images_vis)
                _, predicted_clean_vis = torch.max(outputs_clean_vis, 1)
                outputs_adv_vis = model(adv_images_vis)
                _, predicted_adv_vis = torch.max(outputs_adv_vis, 1)

            num_to_show = min(5, images_vis.size(0))

            inv_normalize = transforms.Normalize(
                mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                std=[1/0.229, 1/0.224, 1/0.225]
            )
            wandb_images_log = {}
            for i in range(num_to_show):
                clean_img_display = inv_normalize(images_vis[i].cpu())
                adv_img_display = inv_normalize(adv_images_vis[i].cpu())
                perturbation_vis = adv_img_display - clean_img_display
                perturbation_vis_scaled = (perturbation_vis - perturbation_vis.min()) / (perturbation_vis.max() - perturbation_vis.min() + 1e-6)


                fig_vis, axs_vis = plt.subplots(1, 3, figsize=(12, 4))

                axs_vis[0].imshow(np.transpose(clean_img_display.numpy(), (1,2,0)))
                axs_vis[0].set_title(f"Clean - True: {labels_vis[i].item()}, Pred: {predicted_clean_vis[i].item()}")
                axs_vis[0].axis('off')

                axs_vis[1].imshow(np.transpose(adv_img_display.numpy(), (1,2,0)))
                axs_vis[1].set_title(f"Adv ({atk_vis_config['name']}) - True: {labels_vis[i].item()}, Pred: {predicted_adv_vis[i].item()}")
                axs_vis[1].axis('off')

                axs_vis[2].imshow(np.transpose(perturbation_vis_scaled.numpy(), (1,2,0)))
                axs_vis[2].set_title(f"Perturbation (Scaled)")
                axs_vis[2].axis('off')
                plt.tight_layout()
                plt.show()

                wandb_images_log[f"Adversarial_Example_{i}/Clean_Img"] = wandb.Image(clean_img_display, caption=f"Clean - True: {labels_vis[i].item()}, Pred: {predicted_clean_vis[i].item()}")
                wandb_images_log[f"Adversarial_Example_{i}/Adv_Img_{atk_vis_config['name']}"] = wandb.Image(adv_img_display, caption=f"Adv ({atk_vis_config['name']}) - True: {labels_vis[i].item()}, Pred: {predicted_adv_vis[i].item()}")
                wandb_images_log[f"Adversarial_Example_{i}/Perturbation_{atk_vis_config['name']}"] = wandb.Image(perturbation_vis_scaled, caption="Perturbation (Scaled)")
            if wandb_images_log:
                 wandb.log(wandb_images_log)

        except Exception as e:
            print(f"Could not generate or log visualizations for attack {atk_vis_config['name']}: {e}")
            import traceback
            traceback.print_exc()

else:
    print("Test loader is empty or None. Skipping adversarial attack evaluation and visualization.")

In [None]:
wandb.finish()
print("WandB run finished.")