# Channel Pruning v1 (CPv1) Example - IMPROVED

## Gradient-Pattern-Aware Channel Pruning

**핵심 차이점 (vs 기존 CPv1)**:
- ✅ **Pseudo-label 제거**: Label-free minimum deviation 사용
- ✅ **Gradient corruption sensitivity**: Noise가 gradient pattern을 corrupt하는 정도 측정
- ✅ **Batch aggregation**: Channel-level로 안정적인 gating
- ✅ **LGrad 최적화**: Two-stage architecture (img2grad + classifier)에 특화

**기존 CPv1의 문제**:
- ❌ Pseudo-label이 noisy에서 부정확
- ❌ Artifact를 noise로 오인 가능
- ❌ Batch-wise gating으로 불안정

**개선 방법**:
```python
# OLD (문제 있음):
pseudo_label = estimate_label(x)  # 부정확!
sensitivity = |curr - reference[pseudo_label]|

# NEW (IMPROVED):
dev_real = |curr - real_stats|
dev_fake = |curr - fake_stats|
sensitivity = min(dev_real, dev_fake)  # Label-free!
```

**기대 효과**:
- Noisy/corrupted image에서 성능 향상
- Gradient pattern을 robust하게 보존
- 안정적인 channel selection

## Import

In [1]:
import sys
# Clear cache
for mod in list(sys.modules.keys()):
    if any(x in mod for x in ['NPR', 'npr', 'LGrad', 'lgrad', 'networks', 'method', 'channel']):
        del sys.modules[mod]

In [2]:
import os
from pathlib import Path
from typing import Optional, Literal
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Subset
from PIL import Image
from torchvision import transforms

from utils.data.dataset import CorruptedDataset
from utils.visual.visualizer import DatasetVisualizer
from utils.eval.metrics import PredictionCollector, MetricsCalculator

# Channel Pruning v1 import
from model.method import (
    UnifiedChannelPruningV2,
    CPv2Config,
    compute_separated_statistics_v2,
)
from model.LGrad.lgrad_model import LGrad
from model.NPR.npr_model import NPR

## GPU and Model select

In [3]:
!nvidia-smi

Fri Jan  2 19:21:46 2026       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:04:00.0 Off |                    0 |
| N/A   36C    P0              27W / 250W |      4MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE-16GB           Off | 00000000:06:00.0 Off |  

In [4]:
DEVICE = "cuda:0"
MODEL_LIST = ["lgrad", "npr"]
MODEL = MODEL_LIST[0]  # "lgrad" or "npr"

## Dataloader

In [5]:
ROOT = "corrupted_dataset"
DATASETS = [
    "corrupted_test_data_progan",
    "corrupted_test_data_stylegan",
    "corrupted_test_data_stylegan2",
    "corrupted_test_data_biggan",
]

CORRUPTIONS = [
    "original",
    "gaussian_noise",
    "jpeg_compression",
]

if MODEL == "lgrad":
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
else:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [6]:
dataset = CorruptedDataset(
    root=ROOT,
    datasets=DATASETS,
    corruptions=CORRUPTIONS,
    transform=transform
)

print(f"Total samples: {len(dataset)}")

Total samples: 119874


## Model load

In [7]:
# LGrad
STYLEGAN_WEIGHTS_ROOT = "model/LGrad/weights/karras2019stylegan-bedrooms-256x256_discriminator.pth"
CLASSIFIER_WEIGHTS_ROOT = "model/LGrad/weights/LGrad-Pretrained-Model/LGrad-4class-Trainon-Progan_car_cat_chair_horse.pth"

# NPR
NPR_WEIGHTS_ROOT = "model/NPR/weights/NPR.pth"

if MODEL == "lgrad":
    base_model = LGrad(
        stylegan_weights=STYLEGAN_WEIGHTS_ROOT,
        classifier_weights=CLASSIFIER_WEIGHTS_ROOT,
        device=DEVICE
    )
    model_name = "LGrad"
elif MODEL == "npr":
    base_model = NPR(
        weights=NPR_WEIGHTS_ROOT,
        device=DEVICE
    )
    model_name = "NPR"

print(f"Base model loaded: {model_name}")

  torch.load(stylegan_weights, map_location="cpu"),
  torch.load(classifier_weights, map_location="cpu")


Base model loaded: LGrad


## Step 1: Compute Separated Statistics from Clean Data

**중요!** CPv1은 Real clean과 Fake clean의 **분리된** statistics가 필요합니다.

- ProGAN의 original (uncorrupted) 데이터로 statistics 수집
- **Labels 필수**: Real (0) vs Fake (1) 구분을 위해
- 한 번 계산하면 저장해서 재사용 가능

In [8]:
# Clean data 준비 (ProGAN original - LABELS 필수!)
progan_clean_indices = [
    i for i, s in enumerate(dataset.samples)
    if s['dataset'] == "corrupted_test_data_progan" and s['corruption'] == "original"
]

print(f"ProGAN clean samples: {len(progan_clean_indices)}")

# Subset & DataLoader (labels 포함!)
clean_subset = Subset(dataset, progan_clean_indices)
clean_loader = DataLoader(
    clean_subset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    drop_last=False
)

ProGAN clean samples: 8000


In [9]:
# Separated statistics 파일 경로
STATS_PATH_V2 = f"separated_stats_v2_{MODEL}_progan.pth"
if os.path.exists(STATS_PATH_V2):
    separated_stats_v2 = torch.load(STATS_PATH_V2)
else:
    separated_stats_v2 = compute_separated_statistics_v2(
        model=base_model,
        dataloader=clean_loader,
        device=DEVICE,
    )
    torch.save(separated_stats_v2, STATS_PATH_V2)

[CPv2] Computing separated statistics for 106 layers (ONLINE MODE)...


Computing separated statistics v2 (online): 100%|██████████| 250/250 [02:23<00:00,  1.75it/s]


  classifier.conv1: C=64
    Real: mean=[-6.3340, 7.3021], std_mean=1.4538
    Fake: mean=[-6.3223, 7.2592], std_mean=1.4562
    Artifact signature: mean=0.0062, max=0.0523
  classifier.bn1: C=64
    Real: mean=[0.0000, 2.0776], std_mean=0.1827
    Fake: mean=[0.0000, 2.0739], std_mean=0.1868
    Artifact signature: mean=0.0011, max=0.0045
  classifier.layer1.0.conv1: C=64
    Real: mean=[-7.1698, 6.7641], std_mean=0.5785
    Fake: mean=[-7.3122, 6.9164], std_mean=0.5762
    Artifact signature: mean=0.0402, max=0.1523
  classifier.layer1.0.bn1: C=64
    Real: mean=[0.0000, 1.1708], std_mean=0.1276
    Fake: mean=[0.0000, 1.1761], std_mean=0.1248
    Artifact signature: mean=0.0088, max=0.0430
  classifier.layer1.0.conv2: C=64
    Real: mean=[-6.2954, 7.5667], std_mean=2.5168
    Fake: mean=[-6.2174, 7.4683], std_mean=2.4162
    Artifact signature: mean=0.1382, max=0.6342
  classifier.layer1.0.bn2: C=64
    Real: mean=[0.0000, 1.6930], std_mean=0.1396
    Fake: mean=[0.0000, 1.6927], st

## Step 2: Create Channel Pruning v1 Model (IMPROVED)

In [10]:
# Config 설정 (IMPROVED!)
config_v2 = CPv2Config(
    model="LGrad",
    keep_ratio=0.7,  # Keep top 70% channels
    gating_type="hard",  # Hard gating (0 or 1)
    use_zscore=True,  # Z-score normalization
)

# Model 생성
model = UnifiedChannelPruningV2(
    base_model=base_model,
    separated_stats=separated_stats_v2,
    config=config_v2,
)

print("\nChannel Pruning v1 (IMPROVED) model created!")

  classifier.conv1 (C=64):
    Artifact disc: mean=0.0062, max=0.0523
  classifier.bn1 (C=64):
    Artifact disc: mean=0.0011, max=0.0045
  classifier.layer1.0.conv1 (C=64):
    Artifact disc: mean=0.0402, max=0.1523
  classifier.layer1.0.bn1 (C=64):
    Artifact disc: mean=0.0088, max=0.0430
  classifier.layer1.0.conv2 (C=64):
    Artifact disc: mean=0.1382, max=0.6342
  classifier.layer1.0.bn2 (C=64):
    Artifact disc: mean=0.0044, max=0.0229
  classifier.layer1.0.conv3 (C=256):
    Artifact disc: mean=0.0143, max=0.0695
  classifier.layer1.0.bn3 (C=256):
    Artifact disc: mean=0.0071, max=0.0693
  classifier.layer1.0.downsample.0 (C=256):
    Artifact disc: mean=0.0329, max=0.1696
  classifier.layer1.0.downsample.1 (C=256):
    Artifact disc: mean=0.0126, max=0.0645
  classifier.layer1.1.conv1 (C=64):
    Artifact disc: mean=0.1097, max=0.3635
  classifier.layer1.1.bn1 (C=64):
    Artifact disc: mean=0.0096, max=0.0702
  classifier.layer1.1.conv2 (C=64):
    Artifact disc: mean=0.

## (Optional) Test-Time Adaptation

Noisy validation data로 temperature와 channel bias를 fine-tuning할 수 있습니다.

**Skip 가능!** Adaptation 없이도 사용 가능합니다.

In [11]:
# # Adaptation을 원하면 주석 해제
# ENABLE_ADAPTATION = True

# if ENABLE_ADAPTATION:
#     # Noisy validation data 준비
#     progan_noisy_indices = [
#         i for i, s in enumerate(dataset.samples)
#         if s['dataset'] == "corrupted_test_data_progan" and s['corruption'] == "gaussian_noise"
#     ]
    
#     print(f"ProGAN noisy samples for adaptation: {len(progan_noisy_indices)}")
    
#     noisy_subset = Subset(dataset, progan_noisy_indices[:500])
#     noisy_loader = DataLoader(
#         noisy_subset,
#         batch_size=32,
#         shuffle=True,
#         num_workers=4,
#         drop_last=False
#     )
    
#     # Adaptation 실행
#     print("\nStarting test-time adaptation...\n")
#     model.adapt(
#         dataloader=noisy_loader,
#         epochs=5,
#         lr=1e-4,
#         verbose=True,
#     )
#     print("\nAdaptation complete!")

## Evaluation

Dataset별, Corruption별로 평가합니다.

In [12]:
# Evaluation
calc = MetricsCalculator()
all_results = {}

for dataset_name in DATASETS:
    for corruption in CORRUPTIONS:
        combination_indices = [
            i for i, s in enumerate(dataset.samples)
            if s['dataset'] == dataset_name and s['corruption'] == corruption
        ]
        
        if len(combination_indices) == 0:
            print(f"{dataset_name}-{corruption}: 샘플 없음, 스킵")
            continue
        
        print(f"\n{'='*60}")
        print(f"평가 중: {dataset_name}-{corruption}")
        print(f"샘플 수: {len(combination_indices)}")
        print(f"{'='*60}")
        
        # Subset과 DataLoader 생성
        subset = Subset(dataset, combination_indices)
        dataloader = DataLoader(
            subset,
            batch_size=16,
            shuffle=False,
            num_workers=4,
            drop_last=True
        )
        
        # 평가
        metrics = calc.evaluate(
            model=model,
            dataloader=dataloader,
            device=DEVICE,
            name=f"{dataset_name}-{corruption}"
        )
        
        # 즉시 결과 출력
        print(f"\n결과:")
        print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
        print(f"  AUC:      {metrics['auc']*100:.2f}%")
        print(f"  AP:       {metrics['ap']*100:.2f}%")
        print(f"  F1:       {metrics['f1']*100:.2f}%")
        
        # 결과 저장
        all_results[(dataset_name, corruption)] = metrics

# 전체 결과 테이블 출력
print(f"\n\n{'='*60}")
print("전체 결과 요약")
print(f"{'='*60}\n")
calc.print_results_table()
calc.summarize_by_corruption(all_results)
calc.summarize_by_dataset(all_results)


평가 중: corrupted_test_data_progan-original
샘플 수: 8000


corrupted_test_data_progan-original: 100%|██████████| 500/500 [02:11<00:00,  3.79it/s]



결과:
  Accuracy: 51.31%
  AUC:      55.00%
  AP:       55.01%
  F1:       11.01%

평가 중: corrupted_test_data_progan-gaussian_noise
샘플 수: 8000


corrupted_test_data_progan-gaussian_noise: 100%|██████████| 500/500 [02:14<00:00,  3.72it/s]



결과:
  Accuracy: 49.00%
  AUC:      50.10%
  AP:       50.13%
  F1:       27.32%

평가 중: corrupted_test_data_progan-jpeg_compression
샘플 수: 8000


corrupted_test_data_progan-jpeg_compression: 100%|██████████| 500/500 [02:20<00:00,  3.56it/s]



결과:
  Accuracy: 50.35%
  AUC:      50.41%
  AP:       50.27%
  F1:       8.61%

평가 중: corrupted_test_data_stylegan-original
샘플 수: 11982


corrupted_test_data_stylegan-original: 100%|██████████| 748/748 [03:19<00:00,  3.75it/s]



결과:
  Accuracy: 52.65%
  AUC:      57.24%
  AP:       58.33%
  F1:       15.10%

평가 중: corrupted_test_data_stylegan-gaussian_noise
샘플 수: 11982


corrupted_test_data_stylegan-gaussian_noise: 100%|██████████| 748/748 [03:17<00:00,  3.78it/s]



결과:
  Accuracy: 50.08%
  AUC:      48.56%
  AP:       49.66%
  F1:       30.66%

평가 중: corrupted_test_data_stylegan-jpeg_compression
샘플 수: 11982


corrupted_test_data_stylegan-jpeg_compression: 100%|██████████| 748/748 [03:17<00:00,  3.78it/s]



결과:
  Accuracy: 50.15%
  AUC:      48.51%
  AP:       49.65%
  F1:       5.90%

평가 중: corrupted_test_data_stylegan2-original
샘플 수: 15976


corrupted_test_data_stylegan2-original: 100%|██████████| 998/998 [04:25<00:00,  3.75it/s]



결과:
  Accuracy: 50.58%
  AUC:      53.34%
  AP:       52.93%
  F1:       9.85%

평가 중: corrupted_test_data_stylegan2-gaussian_noise
샘플 수: 15976


corrupted_test_data_stylegan2-gaussian_noise: 100%|██████████| 998/998 [04:22<00:00,  3.81it/s]



결과:
  Accuracy: 48.75%
  AUC:      48.98%
  AP:       48.55%
  F1:       25.78%

평가 중: corrupted_test_data_stylegan2-jpeg_compression
샘플 수: 15976


corrupted_test_data_stylegan2-jpeg_compression: 100%|██████████| 998/998 [04:28<00:00,  3.72it/s]



결과:
  Accuracy: 49.69%
  AUC:      48.19%
  AP:       48.35%
  F1:       5.82%

평가 중: corrupted_test_data_biggan-original
샘플 수: 4000


corrupted_test_data_biggan-original:  36%|███▌      | 90/250 [00:24<00:43,  3.67it/s]


KeyboardInterrupt: 

In [None]:
# Gating 강제로 끄기 (모든 채널 사용)
print("Disabling all gates (forcing gate=1.0)...")

for gate_module in model.gates.values():
    # Temperature를 0으로 설정 + bias를 매우 크게
    gate_module.temperature.data.fill_(0.0)
    gate_module.channel_bias.data.fill_(100.0)  # sigmoid(100) ≈ 1.0

print("All gates disabled. Re-evaluating...")

# ProGAN original 다시 평가
progan_orig_indices = [i for i, s in enumerate(dataset.samples)
                        if s['dataset'] == "corrupted_test_data_progan"
                        and s['corruption'] == "original"]

test_loader = DataLoader(
    Subset(dataset, progan_orig_indices),
    batch_size=16,
    shuffle=False,
    num_workers=4,
    drop_last=True
)

calc = MetricsCalculator()
metrics = calc.evaluate(model, test_loader, DEVICE, "no-gating-test")

print(f"\nWith gating disabled:")
print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
print(f"  AUC:      {metrics['auc']*100:.2f}%")

Disabling all gates (forcing gate=1.0)...
All gates disabled. Re-evaluating...


no-gating-test: 100%|██████████| 500/500 [02:00<00:00,  4.14it/s]


With gating disabled:
  Accuracy: 99.76%
  AUC:      99.98%





In [None]:
# Stats 파일 로드
import torch

STATS_PATH = f"separated_stats_{MODEL}_progan.pth"
separated_stats = torch.load(STATS_PATH)

print("="*60)
print("Artifact Discriminability (|fake_mean - real_mean|) 확인")
print("="*60)
print()

# 각 레이어별로 확인
for layer_name, stats in separated_stats.items():
    real_mean = stats['real']['mean']
    fake_mean = stats['fake']['mean']

    artifact_sig = (fake_mean - real_mean).abs()

    print(f"{layer_name}:")
    print(f"  Artifact signature: mean={artifact_sig.mean():.6f}, max={artifact_sig.max():.6f}, min={artifact_sig.min():.6f}")

    # Real/Fake mean 범위 확인
    print(f"  Real mean: min={real_mean.min():.4f}, max={real_mean.max():.4f}, avg={real_mean.mean():.4f}")
    print(f"  Fake mean: min={fake_mean.min():.4f}, max={fake_mean.max():.4f}, avg={fake_mean.mean():.4f}")
    print()

# 전체 평균
all_artifact_sigs = []
for stats in separated_stats.values():
    artifact_sig = (stats['fake']['mean'] - stats['real']['mean']).abs()
    all_artifact_sigs.append(artifact_sig.mean().item())

print("="*60)
print(f"Overall average artifact signature: {sum(all_artifact_sigs)/len(all_artifact_sigs):.6f}")
print("="*60)

# 문제 진단
avg_sig = sum(all_artifact_sigs)/len(all_artifact_sigs)
if avg_sig < 0.001:
    print("🚨 문제! Artifact signature가 너무 작습니다!")
    print("   Real과 Fake의 mean이 거의 같음 → Stats 계산 오류 가능성")
elif avg_sig < 0.01:
    print("⚠️  주의! Artifact signature가 작습니다.")
    print("   Real/Fake 구분이 약함")
else:
    print("✅ 정상! Artifact signature가 충분히 큼")

Artifact Discriminability (|fake_mean - real_mean|) 확인

classifier.conv1:
  Artifact signature: mean=0.006160, max=0.052292, min=0.000000
  Real mean: min=-6.3340, max=7.3021, avg=-0.0532
  Fake mean: min=-6.3223, max=7.2592, avg=-0.0534

classifier.bn1:
  Artifact signature: mean=0.001062, max=0.004508, min=0.000000
  Real mean: min=0.0000, max=2.0776, avg=0.4093
  Fake mean: min=0.0000, max=2.0739, avg=0.4099

classifier.layer1.0.conv1:
  Artifact signature: mean=0.040236, max=0.152285, min=0.000000
  Real mean: min=-7.1698, max=6.7641, avg=-0.6328
  Fake mean: min=-7.3122, max=6.9164, avg=-0.6532

classifier.layer1.0.bn1:
  Artifact signature: mean=0.008797, max=0.043005, min=0.000000
  Real mean: min=0.0000, max=1.1708, avg=0.1708
  Fake mean: min=0.0000, max=1.1761, avg=0.1664

classifier.layer1.0.conv2:
  Artifact signature: mean=0.138221, max=0.634240, min=0.000000
  Real mean: min=-6.2954, max=7.5667, avg=0.3800
  Fake mean: min=-6.2174, max=7.4683, avg=0.2925

classifier.layer

  separated_stats = torch.load(STATS_PATH)


In [None]:
# Gate values 실시간 확인
import torch

# ProGAN original 샘플 하나로 테스트
progan_orig_indices = [i for i, s in enumerate(dataset.samples)
                        if s['dataset'] == "corrupted_test_data_progan"
                        and s['corruption'] == "original"]

test_loader = DataLoader(
    Subset(dataset, progan_orig_indices[:32]),  # 1 batch만
    batch_size=32,
    shuffle=False
)

# Forward pass하면서 gate values 캡처
batch = next(iter(test_loader))
images = batch[0].to(DEVICE)

model.eval()
with torch.no_grad():
    # Forward하면서 gate module에서 직접 확인
    # 첫 번째 gate만 체크
    gate_name = list(model.gates.keys())[0]
    gate_module = model.gates[gate_name]

    # Dummy forward to get intermediate activations
    _ = model(images)

print("="*60)
print("Gate Values 확인 (첫 번째 레이어)")
print("="*60)

# 수동으로 gate 계산 확인
for sanitized_name, gate in list(model.gates.items())[:3]:  # 처음 3개만
    original_name = model.gate_name_mapping[sanitized_name]

    # Temperature, bias
    temp = gate.temperature.item()
    bias_mean = gate.channel_bias.mean().item()
    disc_mean = gate.artifact_discriminability.mean().item()

    print(f"\n{original_name}:")
    print(f"  Temperature: {temp:.4f}")
    print(f"  Channel bias (mean): {bias_mean:+.4f}")
    print(f"  Artifact disc (mean): {disc_mean:.4f}")

    # 예상 score 계산 (rough estimate)
    # sensitivity는 original이므로 매우 작아야 함 (예: 0.01)
    expected_sens = 0.01  # 가정
    expected_score = disc_mean / expected_sens
    expected_gate = 1 / (1 + torch.exp(-temp * expected_score - bias_mean))

    print(f"  Expected score (if sens=0.01): {expected_score:.2f}")
    print(f"  Expected gate: {expected_gate:.4f}")

    # 실제로는 sensitivity가 얼마인지 확인 필요
    print(f"  ⚠️  Actual sensitivity를 확인해야 정확함!")

Gate Values 확인 (첫 번째 레이어)

classifier.conv1:
  Temperature: 1.0000
  Channel bias (mean): +0.0000
  Artifact disc (mean): 0.0062


TypeError: exp(): argument 'input' (position 1) must be Tensor, not float

In [None]:
# 실제 Gate 동작 확인
import torch

# Forward 시 gate values를 캡처하는 hook 추가
gate_debug_info = {}

def debug_hook(name):
    def hook(module, input, output):
        x = input[0]  # Input activation

        # Sensitivity 계산
        sensitivity = module.compute_gradient_corruption_sensitivity(x)  # [B, C]

        # Discriminability
        disc = module.artifact_discriminability  # [C]

        # Score
        score = disc.unsqueeze(0) / (sensitivity + 1e-6)  # [B, C]

        # Gate (batch aggregation)
        score_agg = score.mean(dim=0)  # [C]
        gate_logits = module.temperature * score_agg + module.channel_bias
        gate = torch.sigmoid(gate_logits)  # [C]

        # 저장
        gate_debug_info[name] = {
            'sensitivity_min': sensitivity.min().item(),
            'sensitivity_mean': sensitivity.mean().item(),
            'sensitivity_max': sensitivity.max().item(),
            'score_min': score.min().item(),
            'score_mean': score.mean().item(),
            'score_max': score.max().item(),
            'gate_min': gate.min().item(),
            'gate_mean': gate.mean().item(),
            'gate_max': gate.max().item(),
            'gate_below_05': (gate < 0.5).sum().item(),  # 얼마나 많은 채널이 pruned?
        }
    return hook

# 처음 5개 레이어에 debug hook 추가
debug_handles = []
for i, (sanitized_name, gate_module) in enumerate(list(model.gates.items())[:5]):
    original_name = model.gate_name_mapping[sanitized_name]
    handle = gate_module.register_forward_hook(debug_hook(original_name))
    debug_handles.append(handle)

# ProGAN original 데이터로 forward
progan_orig_indices = [i for i, s in enumerate(dataset.samples)
                        if s['dataset'] == "corrupted_test_data_progan"
                        and s['corruption'] == "original"]

test_loader = DataLoader(
    Subset(dataset, progan_orig_indices[:32]),
    batch_size=32,
    shuffle=False
)

batch = next(iter(test_loader))
images = batch[0].to(DEVICE)

model.eval()
with torch.no_grad():
    _ = model(images)

# Hook 제거
for handle in debug_handles:
    handle.remove()

# 결과 출력
print("="*70)
print("Gate Debug Info (Original Clean Data)")
print("="*70)
print()

for layer_name, info in gate_debug_info.items():
    print(f"{layer_name}:")
    print(f"  Sensitivity: min={info['sensitivity_min']:.6f}, mean={info['sensitivity_mean']:.6f}, max={info['sensitivity_max']:.6f}")
    print(f"  Score:       min={info['score_min']:.2f}, mean={info['score_mean']:.2f}, max={info['score_max']:.2f}")
    print(f"  Gate:        min={info['gate_min']:.4f}, mean={info['gate_mean']:.4f}, max={info['gate_max']:.4f}")

    num_channels = len(model.gates[layer_name.replace('.', '_')].artifact_discriminability)
    pruned_ratio = info['gate_below_05'] / num_channels
    print(f"  Pruned:      {info['gate_below_05']}/{num_channels} ({pruned_ratio*100:.1f}%)")

    if pruned_ratio > 0.5:
        print(f"  🚨 문제! 50% 이상 채널이 pruned됨!")
    print()

Gate Debug Info (Original Clean Data)

classifier.conv1:
  Sensitivity: min=0.000000, mean=0.497076, max=14.903045
  Score:       min=0.00, mean=1.84, max=3005.12
  Gate:        min=0.5005, mean=0.5567, max=1.0000
  Pruned:      0/64 (0.0%)

classifier.bn1:
  Sensitivity: min=0.000000, mean=0.078641, max=1.081109
  Score:       min=0.00, mean=0.23, max=97.60
  Gate:        min=0.5000, mean=0.5479, max=0.9579
  Pruned:      0/64 (0.0%)

classifier.layer1.0.conv1:
  Sensitivity: min=0.000000, mean=0.805715, max=4.691023
  Score:       min=0.00, mean=0.42, max=332.76
  Gate:        min=0.5000, mean=0.5484, max=1.0000
  Pruned:      0/64 (0.0%)

classifier.layer1.0.bn1:
  Sensitivity: min=0.000000, mean=0.408792, max=1.631517
  Score:       min=0.00, mean=0.09, max=37.88
  Gate:        min=0.5000, mean=0.5195, max=0.8477
  Pruned:      0/64 (0.0%)

classifier.layer1.0.conv2:
  Sensitivity: min=0.000000, mean=2.104465, max=10.741886
  Score:       min=0.00, mean=0.37, max=218.89
  Gate:    

## Learned Parameters 확인

각 layer의 temperature가 어떻게 설정/학습되었는지 확인

In [None]:
print("\n" + "="*60)
print("Learned Parameters (Temperature & Channel Bias)")
print("="*60 + "\n")

for sanitized_name, gate in model.gates.items():
    # Get original layer name
    original_name = model.gate_name_mapping[sanitized_name]
    temp_value = gate.temperature.item()
    print(f"{original_name:50s}: temperature = {temp_value:.4f}")
    
    # Artifact discriminability
    artifact_disc = gate.artifact_discriminability
    disc_mean = artifact_disc.mean().item()
    disc_std = artifact_disc.std().item()
    disc_max = artifact_disc.max().item()
    print(f"{'':50s}  artifact_disc: mean={disc_mean:.4f}, std={disc_std:.4f}, max={disc_max:.4f}")
    
    # Channel bias statistics
    if config.use_channel_bias:
        bias_mean = gate.channel_bias.mean().item()
        bias_std = gate.channel_bias.std().item()
        bias_min = gate.channel_bias.min().item()
        bias_max = gate.channel_bias.max().item()
        print(f"{'':50s}  channel_bias: mean={bias_mean:+.4f}, std={bias_std:.4f}, range=[{bias_min:+.4f}, {bias_max:+.4f}]")
    print()

## Summary

### Channel Pruning v1 (IMPROVED) 핵심:

#### 1. **LGrad의 작동 원리**
```
Image → Gradient (StyleGAN) → Classifier → Prediction
        ↓                      ↓
    GAN pattern          Pruned channels
```
- Real: Natural gradient pattern
- Fake: GAN-specific gradient pattern  
- Noise: Corrupts gradient pattern

#### 2. **핵심 아이디어**
- Gradient pattern을 robust하게 보존하는 채널 유지
- Noise corruption에 민감한 채널 제거

#### 3. **방법론 (IMPROVED)**

**OLD (기존 CPv1)**:
```python
# ❌ Pseudo-label 의존
pseudo_label = estimate_label(x)  # Noisy에서 부정확!
if pseudo_label == Fake:
    sensitivity = |curr - fake_stats|  # Artifact를 noise로 오인
else:
    sensitivity = |curr - real_stats|
```

**NEW (Gradient-Pattern-Aware)**:
```python
# ✅ Label-free minimum deviation
dev_real = |curr - real_stats|
dev_fake = |curr - fake_stats|
sensitivity = min(dev_real, dev_fake)

# Clean: close to at least one → small
# Noisy: far from both → large
```

**Score 계산**:
```python
discriminability = |fake_stats - real_stats|  # Gradient pattern 차이
score = discriminability / sensitivity

# High score: Strong pattern detection + Low corruption sensitivity → KEEP
# Low score: Weak detection + High sensitivity → PRUNE
```

#### 4. **주요 개선사항**

| 측면 | 기존 CPv1 | IMPROVED CPv1 |
|------|-----------|---------------|
| Pseudo-label | ❌ 필요 (부정확) | ✅ 불필요 (label-free) |
| Sensitivity | ❌ Label-aware (오류 가능) | ✅ Minimum deviation |
| Gating | ❌ Batch-wise (불안정) | ✅ Channel-level (안정) |
| LGrad 이해 | ❌ 일반적 | ✅ Gradient-specific |

#### 5. **vs Channel Reweight v1**

| 특징 | CRv1 | CPv1 (IMPROVED) |
|------|------|-----------------|
| Statistics | Mixed (Real+Fake) | Separated (Real, Fake) |
| Artifact info | ❌ 구분 불가 | ✅ 보존 |
| Method | Reweighting | Gradient-pattern-aware pruning |
| Label-free | ✅ Yes | ✅ Yes (improved!) |

### 작동 원리:

1. **Pre-compute**: Real/Fake clean gradient features의 separated statistics
2. **Test time**: Noisy gradient → classifier features
3. **Sensitivity**: `min(|curr - real|, |curr - fake|)` (gradient corruption)
4. **Discriminability**: `|fake - real|` (pattern difference)
5. **Score**: `disc / sens`
6. **Gating**: High score channels만 유지

### 다음 단계:
- ✅ Baseline (no gating)과 성능 비교
- ✅ 기존 CPv1과 성능 비교
- ✅ CRv1과 성능 비교
- □ SGS, SAS 등 다른 방법들과 비교
- □ NPR 모델에도 적용