# Import

In [2]:
import os
from pathlib import Path 
from typing import Optional, Literal
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms

from utils.data.dataset import CorruptedDataset
from utils.visual.visualizer import DatasetVisualizer
from utils.eval.metrics import PredictionCollector, MetricsCalculator

# SGS method import
from model.method.sgs import LGradSGS, SGSConfig
from model.LGrad.lgrad_model import LGrad

# GPU and Model select

In [3]:
!nvidia-smi

Sun Dec 21 19:49:54 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:04:00.0 Off |                    0 |
| N/A   39C    P0              34W / 250W |    260MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE-16GB           Off | 00000000:06:00.0 Off |  

In [4]:
DEVICE="cuda:1"
MODEL_LIST = ["lgrad", "npr"]
MODEL = MODEL_LIST[0]

# Dataloader

In [5]:
ROOT = "/workspace/robust_deepfake_ai/corrupted_dataset"
DATASETS = ["corrupted_test_data_biggan", "corrupted_test_data_crn", "corrupted_test_data_cyclegan", "corrupted_test_data_deepfake", "corrupted_test_data_gaugan", "corrupted_test_data_imle", "corrupted_test_data_progan", "corrupted_test_data_san", "corrupted_test_data_seeingdark", "corrupted_test_data_stargan", "corrupted_test_data_stylegan", "corrupted_test_data_stylegan2", "corrupted_test_data_whichfaceisreal"]
CORRUPTIONS = ["original", "contrast", "fog", "gaussian_noise", "jpeg_compression", "motion_blur", "pixelate"]

transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

In [6]:
dataset = CorruptedDataset(
    root= ROOT,
    datasets=DATASETS,
    corruptions=CORRUPTIONS,
    transform=transform
)

print(f"Total samples: {len(dataset)}")

Total samples: 632303


In [7]:
# # dataset sampling
# from torch.utils.data import Subset
# import random

# samples_per_combination = 500
# selected_indices = []

# random.seed(42)

# for dataset_name in DATASETS:
#     for corruption in CORRUPTIONS:
#         combination_indices = [
#             i for i, s in enumerate(dataset.samples)
#             if s['dataset'] == dataset_name and s['corruption'] == corruption]

#         # 1000개 샘플링 (부족하면 전부 사용)
#         n_samples = min(samples_per_combination, len(combination_indices))
#         if n_samples > 0:
#             sampled = random.sample(combination_indices, n_samples)
#             selected_indices.extend(sampled)

#         print(f"{dataset_name}-{corruption}: {len(combination_indices)} -> {n_samples} samples")

# # Subset 생성
# subset_dataset = Subset(dataset, selected_indices)
# print(f"\nTotal: {len(dataset)} -> {len(subset_dataset)} samples")

# dataloader = DataLoader(
#     subset_dataset,
#     batch_size=32,
#     shuffle=False,
#     num_workers=4,
# )

In [8]:
# dataloader = DataLoader(
#     dataset,
#     batch_size=32,
#     shuffle=False,
#     num_workers=4,
#     pin_memory=True
# )

In [9]:
# # Visualization
# viz = DatasetVisualizer(seed=1)

# viz(dataset, corruption="all", n_samples=3, label="real")

# viz.stats(dataset)

# Model load

In [10]:
from model.LGrad.lgrad_model import LGrad
from model.NPR.npr_model import NPR

#LGrad
STYLEGAN_WEIGHTS_ROOT="/workspace/robust_deepfake_ai/model/LGrad/weights/karras2019stylegan-bedrooms-256x256_discriminator.pth"
CLASSIFIER_WEIGHTS_ROOT="/workspace/robust_deepfake_ai/model/LGrad/weights/LGrad-Pretrained-Model/LGrad-4class-Trainon-Progan_car_cat_chair_horse.pth"

#NPR
NPR_WEIGHTS_ROOT="/workspace/robust_deepfake_ai/model/NPR/weights/NPR.pth"

if MODEL == "lgrad":
    model = LGrad(
        stylegan_weights=STYLEGAN_WEIGHTS_ROOT,
        classifier_weights=CLASSIFIER_WEIGHTS_ROOT,
        device=DEVICE
    )
elif MODEL == "NPR":
    model = NPR(
        weights=NPR_WEIGHTS_ROOT,
        device=DEVICE
    )


  torch.load(stylegan_weights, map_location="cpu"),
  torch.load(classifier_weights, map_location="cpu")


In [11]:
sgs_config = SGSConfig(
    K=4,
    huber_tv_lambda=0.03,
    huber_tv_delta=0.01,
    huber_tv_iterations=5,
    huber_tv_step_size=0.1,
    denoise_target="input",
    device=DEVICE,
)

model = LGradSGS(model, sgs_config)

model.model.eval()

[SGS] Initialized with K=4, denoise_target=input
[SGS] Base params: λ=0.03, δ=0.01, iter=5, step=0.1
[SGS] Parameter sets for 4 views:
  View 0: λ=0.0000, δ=0.0000, iter=0, step=0.00
  View 1: λ=0.0147, δ=0.0070, iter=5, step=0.10
  View 2: λ=0.0329, δ=0.0105, iter=5, step=0.10
  View 3: λ=0.0478, δ=0.0126, iter=5, step=0.10


LGrad(
  (grad_model): StyleGANDiscriminator(
    (input0): ConvBlock(
      (mbstd): Identity()
      (blur): Identity()
      (downsample): Identity()
      (activate): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (layer0): ConvBlock(
      (mbstd): Identity()
      (blur): Identity()
      (downsample): Identity()
      (activate): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (layer1): ConvBlock(
      (mbstd): Identity()
      (blur): BlurLayer()
      (downsample): Identity()
      (activate): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (input1): ConvBlock(
      (mbstd): Identity()
      (blur): Identity()
      (downsample): Identity()
      (activate): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (layer2): ConvBlock(
      (mbstd): Identity()
      (blur): Identity()
      (downsample): Identity()
      (activate): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (layer3): ConvBlock(
      (mbstd): Identity()
      (blur): BlurLayer()
 

In [12]:
from torchsummary import summary
summary(model, input_size=(3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Identity-1          [-1, 3, 256, 256]               0
          Identity-2          [-1, 3, 256, 256]               0
          Identity-3         [-1, 64, 256, 256]               0
         LeakyReLU-4         [-1, 64, 256, 256]               0
         ConvBlock-5         [-1, 64, 256, 256]             256
          Identity-6         [-1, 64, 256, 256]               0
          Identity-7         [-1, 64, 256, 256]               0
          Identity-8         [-1, 64, 256, 256]               0
         LeakyReLU-9         [-1, 64, 256, 256]               0
        ConvBlock-10         [-1, 64, 256, 256]          36,928
         Identity-11         [-1, 64, 256, 256]               0
        BlurLayer-12         [-1, 64, 256, 256]               0
         Identity-13        [-1, 128, 128, 128]               0
        LeakyReLU-14        [-1, 128, 1

In [13]:
# def collate_fn(batch):
#     # 모델이 사용하는 resize 크기 (LGrad의 경우 256)
#     import torch
#     import torch.nn.functional as F
#     target_size = (256, 256)

#     images = []
#     for item in batch:
#         img = item[0]
#         if isinstance(img, torch.Tensor):
#             # 크기가 다르면 미리 resize (모델 내부 transform과 동일하게)
#             if img.shape[-2:] != target_size:
#                 img = F.interpolate(
#                     img.unsqueeze(0),
#                     size=target_size,
#                     mode='bilinear',
#                     align_corners=False
#                 ).squeeze(0)
#             images.append(img)
#         else:
#             images.append(img)

#     images = torch.stack(images)
#     labels = torch.tensor([item[1] for item in batch])

#     if len(batch[0]) == 3:
#         metadata = [item[2] for item in batch]
#         return images, labels, metadata
#     return images, labels

In [None]:
from torch.utils.data import Subset, DataLoader
import random

# # 설정
# samples_per_combination = 500
# random.seed(42)

# Evaluation
calc = MetricsCalculator()
all_results = {}

for dataset_name in DATASETS:
    for corruption in CORRUPTIONS:
        # 해당 조합의 인덱스 찾기
        combination_indices = [
            i for i, s in enumerate(dataset.samples)
            if s['dataset'] == dataset_name and s['corruption'] ==corruption
        ]

        if len(combination_indices) == 0:
            print(f"{dataset_name}-{corruption}: 샘플 없음, 스킵")
            continue

        # # 샘플링
        # n_samples = min(samples_per_combination, len(combination_indices))
        # sampled_indices = random.sample(combination_indices, n_samples)

        print(f"\n{'='*60}")
        print(f"평가 중: {dataset_name}-{corruption}")
        print(f"샘플 수: {len(combination_indices)}")
        print(f"{'='*60}")

        # Subset과 DataLoader 생성
        subset = Subset(dataset, combination_indices)
        dataloader = DataLoader(
            subset,
            batch_size=16,
            shuffle=False,
            num_workers=4,
            # collate_fn=collate_fn,
            drop_last=True
        )

        # 평가
        metrics = calc.evaluate(
            model=model,
            dataloader=dataloader,
            device=DEVICE,
            name=f"{dataset_name}-{corruption}"
        )

        # 즉시 결과 출력
        print(f"\n결과:")
        print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
        print(f"  AUC:      {metrics['auc']*100:.2f}%")
        print(f"  AP:       {metrics['ap']*100:.2f}%")
        print(f"  F1:       {metrics['f1']*100:.2f}%")

        # 결과 저장
        all_results[(dataset_name, corruption)] = metrics

# 전체 결과 테이블 출력
print(f"\n\n{'='*60}")
print("전체 결과 요약")
print(f"{'='*60}\n")
calc.print_results_table()
calc.summarize_by_corruption(all_results)
calc.summarize_by_dataset(all_results)


평가 중: corrupted_test_data_biggan-original
샘플 수: 4000


corrupted_test_data_biggan-original:  18%|█▊        | 44/250 [00:28<02:12,  1.55it/s]

In [None]:
# # Evaluation
# calc = MetricsCalculator()

# # 조합별 평가
# results = calc.evaluate(
#     model=model,
#     dataloader=dataloader,
#     device=DEVICE,
#     name=f"{dataset_name}-{corruption}"
# )

# # 결과 출력
# calc.print_results_table(results)
# calc.summarize_by_corruption(results)
# calc.summarize_by_dataset(results)
