# Import

In [None]:
import os
from pathlib import Path 
from typing import Optional, Literal
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms

from utils.data.dataset import CorruptedDataset
from utils.visual.visualizer import DatasetVisualizer
from utils.eval.metrics import PredictionCollector, MetricsCalculator

# NORM method import
from model.method.method import LGradNORM, NORMConfig
from model.LGrad.lgrad_model import LGrad

# GPU and Model select

In [2]:
!nvidia-smi

Fri Dec 12 10:46:49 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:04:00.0 Off |                    0 |
| N/A   66C    P0             204W / 250W |  15828MiB / 16384MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE-16GB           Off | 00000000:06:00.0 Off |  

In [3]:
DEVICE="cuda:1"
MODEL_LIST = ["lgrad", "npr"]
MODEL = MODEL_LIST[0]

# NORM Configuration

In [None]:
# NORM 파라미터 설정
SOURCE_SUM = 128  # 소스 도메인 누적 배치 크기 (높을수록 소스 통계 유지)
ADAPTATION_TARGET = "classifier"  # "classifier", "grad_model", "both"

# Config 생성
norm_config = NORMConfig(
    source_sum=SOURCE_SUM,
    adaptation_target=ADAPTATION_TARGET,
    device=DEVICE
)

print(f"NORM Config:")
print(f"  Source Sum: {SOURCE_SUM}")
print(f"  Adaptation Target: {ADAPTATION_TARGET}")
print(f"  Device: {DEVICE}")

# Dataloader

In [4]:
ROOT = "/workspace/robust_deepfake_ai/corrupted_dataset"
DATASETS = ["corrupted_test_data_biggan", "corrupted_test_data_crn", "corrupted_test_data_cyclegan", "corrupted_test_data_deepfake", "corrupted_test_data_gaugan", "corrupted_test_data_imle", "corrupted_test_data_progan", "corrupted_test_data_san", "corrupted_test_data_seeingdark", "corrupted_test_data_stargan", "corrupted_test_data_stylegan", "corrupted_test_data_stylegan2", "corrupted_test_data_whichfaceisreal"]
CORRUPTIONS = ["original", "contrast", "fog", "gaussian_noise", "jpeg_compression", "motion_blur", "pixelate"]

transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

In [5]:
dataset = CorruptedDataset(
    root= ROOT,
    datasets=DATASETS,
    corruptions=CORRUPTIONS,
    transform=transform
)

print(f"Total samples: {len(dataset)}")

Total samples: 632303


In [6]:
# # dataset sampling
# from torch.utils.data import Subset
# import random

# samples_per_combination = 500
# selected_indices = []

# random.seed(42)

# for dataset_name in DATASETS:
#     for corruption in CORRUPTIONS:
#         combination_indices = [
#             i for i, s in enumerate(dataset.samples)
#             if s['dataset'] == dataset_name and s['corruption'] == corruption]

#         # 1000개 샘플링 (부족하면 전부 사용)
#         n_samples = min(samples_per_combination, len(combination_indices))
#         if n_samples > 0:
#             sampled = random.sample(combination_indices, n_samples)
#             selected_indices.extend(sampled)

#         print(f"{dataset_name}-{corruption}: {len(combination_indices)} -> {n_samples} samples")

# # Subset 생성
# subset_dataset = Subset(dataset, selected_indices)
# print(f"\nTotal: {len(dataset)} -> {len(subset_dataset)} samples")

# dataloader = DataLoader(
#     subset_dataset,
#     batch_size=32,
#     shuffle=False,
#     num_workers=4,
# )

In [7]:
# dataloader = DataLoader(
#     dataset,
#     batch_size=32,
#     shuffle=False,
#     num_workers=4,
#     pin_memory=True
# )

In [8]:
# # Visualization
# viz = DatasetVisualizer(seed=1)

# viz(dataset, corruption="all", n_samples=3, label="real")

# viz.stats(dataset)

# Model load

In [None]:
# LGrad weights
STYLEGAN_WEIGHTS_ROOT = "/workspace/robust_deepfake_ai/model/LGrad/weights/karras2019stylegan-bedrooms-256x256_discriminator.pth"
CLASSIFIER_WEIGHTS_ROOT = "/workspace/robust_deepfake_ai/model/LGrad/weights/LGrad-Pretrained-Model/LGrad-4class-Trainon-Progan_car_cat_chair_horse.pth"

print("Loading base LGrad model...")
base_lgrad = LGrad(
    stylegan_weights=STYLEGAN_WEIGHTS_ROOT,
    classifier_weights=CLASSIFIER_WEIGHTS_ROOT,
    device=DEVICE
)
print("Base model loaded!")

print("\nApplying NORM adaptation...")
model = LGradNORM(base_lgrad, norm_config)
print("LGradNORM ready!")

model.model.eval()

In [10]:
from torchsummary import summary
summary(model, input_size=(3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Identity-1          [-1, 3, 256, 256]               0
          Identity-2          [-1, 3, 256, 256]               0
          Identity-3         [-1, 64, 256, 256]               0
         LeakyReLU-4         [-1, 64, 256, 256]               0
         ConvBlock-5         [-1, 64, 256, 256]             256
          Identity-6         [-1, 64, 256, 256]               0
          Identity-7         [-1, 64, 256, 256]               0
          Identity-8         [-1, 64, 256, 256]               0
         LeakyReLU-9         [-1, 64, 256, 256]               0
        ConvBlock-10         [-1, 64, 256, 256]          36,928
         Identity-11         [-1, 64, 256, 256]               0
        BlurLayer-12         [-1, 64, 256, 256]               0
         Identity-13        [-1, 128, 128, 128]               0
        LeakyReLU-14        [-1, 128, 1

In [11]:
# def collate_fn(batch):
#     # 모델이 사용하는 resize 크기 (LGrad의 경우 256)
#     import torch
#     import torch.nn.functional as F
#     target_size = (256, 256)

#     images = []
#     for item in batch:
#         img = item[0]
#         if isinstance(img, torch.Tensor):
#             # 크기가 다르면 미리 resize (모델 내부 transform과 동일하게)
#             if img.shape[-2:] != target_size:
#                 img = F.interpolate(
#                     img.unsqueeze(0),
#                     size=target_size,
#                     mode='bilinear',
#                     align_corners=False
#                 ).squeeze(0)
#             images.append(img)
#         else:
#             images.append(img)

#     images = torch.stack(images)
#     labels = torch.tensor([item[1] for item in batch])

#     if len(batch[0]) == 3:
#         metadata = [item[2] for item in batch]
#         return images, labels, metadata
#     return images, labels

In [None]:
from torch.utils.data import Subset, DataLoader
import random

# Evaluation with LGradNORM
print("="*80)
print("LGradNORM Evaluation on Corrupted Datasets")
print("="*80)

calc = MetricsCalculator()
all_results = {}

for dataset_name in DATASETS:
    for corruption in CORRUPTIONS:
        # 해당 조합의 인덱스 찾기
        combination_indices = [
            i for i, s in enumerate(dataset.samples)
            if s['dataset'] == dataset_name and s['corruption'] == corruption
        ]

        if len(combination_indices) == 0:
            print(f"{dataset_name}-{corruption}: 샘플 없음, 스킵")
            continue

        print(f"\n{'='*60}")
        print(f"평가 중: {dataset_name}-{corruption}")
        print(f"샘플 수: {len(combination_indices)}")
        print(f"{'='*60}")

        # Subset과 DataLoader 생성
        subset = Subset(dataset, combination_indices)
        dataloader = DataLoader(
            subset,
            batch_size=16,
            shuffle=False,
            num_workers=4,
            drop_last=True
        )

        # 평가
        metrics = calc.evaluate(
            model=model,
            dataloader=dataloader,
            device=DEVICE,
            name=f"{dataset_name}-{corruption}"
        )

        # 즉시 결과 출력
        print(f"\n[LGradNORM] 결과:")
        print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
        print(f"  AUC:      {metrics['auc']*100:.2f}%")
        print(f"  AP:       {metrics['ap']*100:.2f}%")
        print(f"  F1:       {metrics['f1']*100:.2f}%")

        # 결과 저장
        all_results[(dataset_name, corruption)] = metrics

# 전체 결과 테이블 출력
print(f"\n\n{'='*80}")
print("LGradNORM - 전체 결과 요약")
print(f"{'='*80}\n")
calc.print_results_table()
calc.summarize_by_corruption(all_results)
calc.summarize_by_dataset(all_results)

In [None]:
# # Evaluation
# calc = MetricsCalculator()

# # 조합별 평가
# results = calc.evaluate(
#     model=model,
#     dataloader=dataloader,
#     device=DEVICE,
#     name=f"{dataset_name}-{corruption}"
# )

# # 결과 출력
# calc.print_results_table(results)
# calc.summarize_by_corruption(results)
# calc.summarize_by_dataset(results)
