# Odporność modeli na zakłócenia (Robustness – ImageNet-C)

## Instalacja bibliotek

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

!pip install -q pytorch-lightning torchvision codecarbon scikit-learn pillow kaggle

## Pobranie danych 

In [None]:
from google.colab import drive, files
files.upload() # Służy do umieszczenia pliku kaggle.json
drive.mount('/content/drive')

!mkdir -p ~/.kaggle
!mv /content/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download sautkin/imagenet1kvalid
!unzip -q /content/imagenet1kvalid.zip -d /content/imagenet1kv1_valid

!mkdir -p /content/robustness
!tar -xvf /content/drive/MyDrive/digital.tar -C /content/robustness

## Importy i konfiguracja

In [None]:
import os
import time
import numpy as np
import pandas as pd
import subprocess

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from codecarbon import EmissionsTracker
from sklearn.metrics import f1_score
from torch.amp import autocast

from torchvision.models import (
    efficientnet_v2_s, EfficientNet_V2_S_Weights,
    resnet50, ResNet50_Weights,
    densenet201, DenseNet201_Weights,
    convnext_tiny, ConvNeXt_Tiny_Weights,
    mobilenet_v3_large, MobileNet_V3_Large_Weights
)

torch.backends.cudnn.benchmark = True

MODELS = [
    (efficientnet_v2_s,   EfficientNet_V2_S_Weights.IMAGENET1K_V1),
    (resnet50,            ResNet50_Weights.IMAGENET1K_V1),
    (densenet201,         DenseNet201_Weights.IMAGENET1K_V1),
    (convnext_tiny,       ConvNeXt_Tiny_Weights.IMAGENET1K_V1),
    (mobilenet_v3_large,  MobileNet_V3_Large_Weights.IMAGENET1K_V1)
]


## Funkcje pomocnicze

In [None]:
def evaluate_loader(model, loader, device):
    all_preds, all_trues = [], []
    model.eval()
    with torch.inference_mode():
        for x, y in loader:
            x = x.to(device, non_blocking=True)
            with autocast(device_type='cuda', dtype=torch.float16):
                logits = model(x)
            preds = logits.argmax(dim=1).cpu().numpy()
            all_preds.append(preds)
            all_trues.append(y.numpy())
    return np.hstack(all_preds), np.hstack(all_trues)


def run_robustness(corruption, model, weights, f1_clean,
                   device, base_dir="/content/robustness", batch_size=128):
    transform = weights.transforms()
    stats = []

    print(f"\n>>> Robustness test: {corruption}")
    for sev in range(1, 6):
        path = os.path.join(base_dir, corruption, str(sev))
        loader = DataLoader(
            ImageFolder(path, transform=transform),
            batch_size=batch_size,
            shuffle=False,
            num_workers=os.cpu_count(),
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=4
        )

        _ = next(iter(loader))

        torch.cuda.reset_peak_memory_stats()

        tracker = EmissionsTracker(
            project_name=f"robust_{corruption}_sev{sev}",
            log_level="error"
        )
        tracker.start()
        t0 = time.time()

        y_pred, y_true = evaluate_loader(model, loader, device)

        duration = time.time() - t0
        co2 = tracker.stop()

        f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
        ce = 1 - (f1 / f1_clean)

        peak_mem = torch.cuda.max_memory_allocated() / 1024**2

        print(f"Severity {sev} -> Czas: {duration:.1f}s | CO2: {co2:.4f}kg | Peak mem: {peak_mem:.0f} MiB")

        stats.append({
            "severity": sev,
            "f1": round(f1, 3),
            "corruption_err": round(ce, 3)
        })

    mce = np.mean([s["corruption_err"] for s in stats])
    stats.append({"severity": "mCE", "mce": round(mce, 3)})
    stats.append({"severity": "clean_error", "value": round(1 - f1_clean, 3)})

    print(pd.DataFrame(stats))


## Uruchomienie testów

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 128

# Dostępne typy zakłóceń w podzbiorze digital: elastic_transform, jpeg_compression, contrast, pixelate
corruptions = ["elastic_transform"]  # Można podać wiele jednocześnie, np: ["jpeg_compression", "contrast"]

for model_fn, weights in MODELS:
    print(f"\n>>> Czysty F1 dla {model_fn.__name__}")
    model = model_fn(weights=weights).to(device)
    model = torch.compile(model)

    transform_clean = weights.transforms()
    clean_ds = ImageFolder("/content/imagenet1kv1_valid", transform=transform_clean)
    clean_loader = DataLoader(
        clean_ds, batch_size=batch_size,
        num_workers=os.cpu_count(),
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4
    )

    y_pred, y_true = evaluate_loader(model, clean_loader, device)
    peak_mem = torch.cuda.max_memory_allocated() / 1024**2
    f1_clean = f1_score(y_true, y_pred, average="macro", zero_division=0)
    print(f"Bazowe F1: {f1_clean:.4f}")

    for corruption in corruptions:
        run_robustness(corruption, model, weights, f1_clean, device)
