# Channel Pruning v1 (CPv1) Example - IMPROVED

## Gradient-Pattern-Aware Channel Pruning

**핵심 차이점 (vs 기존 CPv1)**:
- ✅ **Pseudo-label 제거**: Label-free minimum deviation 사용
- ✅ **Gradient corruption sensitivity**: Noise가 gradient pattern을 corrupt하는 정도 측정
- ✅ **Batch aggregation**: Channel-level로 안정적인 gating
- ✅ **LGrad 최적화**: Two-stage architecture (img2grad + classifier)에 특화

**기존 CPv1의 문제**:
- ❌ Pseudo-label이 noisy에서 부정확
- ❌ Artifact를 noise로 오인 가능
- ❌ Batch-wise gating으로 불안정

**개선 방법**:
```python
# OLD (문제 있음):
pseudo_label = estimate_label(x)  # 부정확!
sensitivity = |curr - reference[pseudo_label]|

# NEW (IMPROVED):
dev_real = |curr - real_stats|
dev_fake = |curr - fake_stats|
sensitivity = min(dev_real, dev_fake)  # Label-free!
```

**기대 효과**:
- Noisy/corrupted image에서 성능 향상
- Gradient pattern을 robust하게 보존
- 안정적인 channel selection

## Import

In [1]:
import sys
# Clear cache
for mod in list(sys.modules.keys()):
    if any(x in mod for x in ['NPR', 'npr', 'LGrad', 'lgrad', 'networks', 'method', 'channel']):
        del sys.modules[mod]

In [2]:
import os
from pathlib import Path
from typing import Optional, Literal
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Subset
from PIL import Image
from torchvision import transforms

from utils.data.dataset import CorruptedDataset
from utils.visual.visualizer import DatasetVisualizer
from utils.eval.metrics import PredictionCollector, MetricsCalculator

# Channel Pruning v1 import
from model.method import (
    UnifiedChannelPruningV1,
    CPv1Config,
    compute_separated_statistics,
)
from model.LGrad.lgrad_model import LGrad
from model.NPR.npr_model import NPR

## GPU and Model select

In [3]:
!nvidia-smi

Tue Dec 30 18:39:17 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:04:00.0 Off |                    0 |
| N/A   47C    P0              31W / 250W |      4MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE-16GB           Off | 00000000:06:00.0 Off |  

In [4]:
DEVICE = "cuda:0"
MODEL_LIST = ["lgrad", "npr"]
MODEL = MODEL_LIST[0]  # "lgrad" or "npr"

## Dataloader

In [5]:
ROOT = "corrupted_dataset"
DATASETS = [
    "corrupted_test_data_progan",
    "corrupted_test_data_stylegan",
    "corrupted_test_data_stylegan2",
    "corrupted_test_data_biggan",
]

CORRUPTIONS = [
    "original",
    "gaussian_noise",
    "jpeg_compression",
]

if MODEL == "lgrad":
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
else:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [6]:
dataset = CorruptedDataset(
    root=ROOT,
    datasets=DATASETS,
    corruptions=CORRUPTIONS,
    transform=transform
)

print(f"Total samples: {len(dataset)}")

Total samples: 119874


## Model load

In [7]:
# LGrad
STYLEGAN_WEIGHTS_ROOT = "model/LGrad/weights/karras2019stylegan-bedrooms-256x256_discriminator.pth"
CLASSIFIER_WEIGHTS_ROOT = "model/LGrad/weights/LGrad-Pretrained-Model/LGrad-4class-Trainon-Progan_car_cat_chair_horse.pth"

# NPR
NPR_WEIGHTS_ROOT = "model/NPR/weights/NPR.pth"

if MODEL == "lgrad":
    base_model = LGrad(
        stylegan_weights=STYLEGAN_WEIGHTS_ROOT,
        classifier_weights=CLASSIFIER_WEIGHTS_ROOT,
        device=DEVICE
    )
    model_name = "LGrad"
elif MODEL == "npr":
    base_model = NPR(
        weights=NPR_WEIGHTS_ROOT,
        device=DEVICE
    )
    model_name = "NPR"

print(f"Base model loaded: {model_name}")

  torch.load(stylegan_weights, map_location="cpu"),
  torch.load(classifier_weights, map_location="cpu")


Base model loaded: LGrad


## Step 1: Compute Separated Statistics from Clean Data

**중요!** CPv1은 Real clean과 Fake clean의 **분리된** statistics가 필요합니다.

- ProGAN의 original (uncorrupted) 데이터로 statistics 수집
- **Labels 필수**: Real (0) vs Fake (1) 구분을 위해
- 한 번 계산하면 저장해서 재사용 가능

In [None]:
# Clean data 준비 (ProGAN original - LABELS 필수!)
progan_clean_indices = [
    i for i, s in enumerate(dataset.samples)
    if s['dataset'] == "corrupted_test_data_progan" and s['corruption'] == "original"
]

print(f"ProGAN clean samples: {len(progan_clean_indices)}")

# Subset & DataLoader (labels 포함!)
clean_subset = Subset(dataset, progan_clean_indices)
clean_loader = DataLoader(
    clean_subset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    drop_last=False
)

ProGAN clean samples: 8000


In [None]:
# Separated statistics 파일 경로
STATS_PATH = f"separated_stats_{MODEL}_progan.pth"

# 기존 statistics가 있으면 로드, 없으면 계산
if os.path.exists(STATS_PATH):
    print(f"Loading pre-computed separated statistics from {STATS_PATH}")
    separated_stats = torch.load(STATS_PATH)
    print(f"Statistics loaded for {len(separated_stats)} layers")
else:
    print("Computing separated statistics from clean data...")
    print("This may take a few minutes...\n")
    
    # Target layers: None for auto-detection (detects individual Conv2d/BN layers)
    # Note: classifier.layer4 is a Sequential module, not a single layer!
    # Auto-detection will find layers like 'classifier.layer4.0.conv1', 'classifier.layer4.2.bn3', etc.
    target_layers_for_stats = None
    
    # Compute (Real/Fake separated!)
    separated_stats = compute_separated_statistics(
        model=base_model,
        dataloader=clean_loader,  # MUST have labels!
        target_layers=target_layers_for_stats,
        device=DEVICE,
        max_batches=None,  # 속도를 위해 50 batches만
    )
    
    # 저장
    torch.save(separated_stats, STATS_PATH)
    print(f"\nStatistics saved to {STATS_PATH}")

Computing separated statistics from clean data...
This may take a few minutes...

[CPv1] Computing separated statistics for 106 layers...


Computing separated statistics:   5%|▌         | 27/500 [00:44<12:38,  1.60s/it]

## Step 2: Create Channel Pruning v1 Model (IMPROVED)

In [None]:
# Config 설정 (IMPROVED!)
config = CPv1Config(
    model=model_name,
    target_layers=None,  # Use all auto-detected layers from separated_stats
    
    # NEW: Sensitivity method (RECOMMENDED!)
    sensitivity_method="min",  # Label-free minimum deviation
    deviation_metric="mean",   # Simple and effective
    normalize_deviation=False,
    
    # NEW: Batch aggregation (RECOMMENDED!)
    enable_batch_aggregation=True,  # Stable channel-level gating
    aggregation_method="mean",
    
    # Gating parameters
    temperature_init=1.0,
    use_learnable_temperature=True,
    use_channel_bias=True,
    gating_type="soft",  # or "hard"
    
    # Optional
    enable_adaptation=False,  # 선택적
    device=DEVICE,
)

print("Configuration (IMPROVED):")
print(f"  Sensitivity method: {config.sensitivity_method} (LABEL-FREE!)")
print(f"  Batch aggregation: {config.enable_batch_aggregation}")
print(f"  Deviation metric: {config.deviation_metric}")
print()

# Model 생성
model = UnifiedChannelPruningV1(
    base_model=base_model,
    separated_stats=separated_stats,
    config=config,
)

print("\nChannel Pruning v1 (IMPROVED) model created!")

## (Optional) Test-Time Adaptation

Noisy validation data로 temperature와 channel bias를 fine-tuning할 수 있습니다.

**Skip 가능!** Adaptation 없이도 사용 가능합니다.

In [None]:
# # Adaptation을 원하면 주석 해제
# ENABLE_ADAPTATION = True

# if ENABLE_ADAPTATION:
#     # Noisy validation data 준비
#     progan_noisy_indices = [
#         i for i, s in enumerate(dataset.samples)
#         if s['dataset'] == "corrupted_test_data_progan" and s['corruption'] == "gaussian_noise"
#     ]
    
#     print(f"ProGAN noisy samples for adaptation: {len(progan_noisy_indices)}")
    
#     noisy_subset = Subset(dataset, progan_noisy_indices[:500])
#     noisy_loader = DataLoader(
#         noisy_subset,
#         batch_size=32,
#         shuffle=True,
#         num_workers=4,
#         drop_last=False
#     )
    
#     # Adaptation 실행
#     print("\nStarting test-time adaptation...\n")
#     model.adapt(
#         dataloader=noisy_loader,
#         epochs=5,
#         lr=1e-4,
#         verbose=True,
#     )
#     print("\nAdaptation complete!")

## Evaluation

Dataset별, Corruption별로 평가합니다.

In [None]:
# Evaluation
calc = MetricsCalculator()
all_results = {}

for dataset_name in DATASETS:
    for corruption in CORRUPTIONS:
        combination_indices = [
            i for i, s in enumerate(dataset.samples)
            if s['dataset'] == dataset_name and s['corruption'] == corruption
        ]
        
        if len(combination_indices) == 0:
            print(f"{dataset_name}-{corruption}: 샘플 없음, 스킵")
            continue
        
        print(f"\n{'='*60}")
        print(f"평가 중: {dataset_name}-{corruption}")
        print(f"샘플 수: {len(combination_indices)}")
        print(f"{'='*60}")
        
        # Subset과 DataLoader 생성
        subset = Subset(dataset, combination_indices)
        dataloader = DataLoader(
            subset,
            batch_size=32,
            shuffle=False,
            num_workers=4,
            drop_last=True
        )
        
        # 평가
        metrics = calc.evaluate(
            model=model,
            dataloader=dataloader,
            device=DEVICE,
            name=f"{dataset_name}-{corruption}"
        )
        
        # 즉시 결과 출력
        print(f"\n결과:")
        print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
        print(f"  AUC:      {metrics['auc']*100:.2f}%")
        print(f"  AP:       {metrics['ap']*100:.2f}%")
        print(f"  F1:       {metrics['f1']*100:.2f}%")
        
        # 결과 저장
        all_results[(dataset_name, corruption)] = metrics

# 전체 결과 테이블 출력
print(f"\n\n{'='*60}")
print("전체 결과 요약")
print(f"{'='*60}\n")
calc.print_results_table()
calc.summarize_by_corruption(all_results)
calc.summarize_by_dataset(all_results)

## Learned Parameters 확인

각 layer의 temperature가 어떻게 설정/학습되었는지 확인

In [None]:
print("\n" + "="*60)
print("Learned Parameters (Temperature & Channel Bias)")
print("="*60 + "\n")

for sanitized_name, gate in model.gates.items():
    # Get original layer name
    original_name = model.gate_name_mapping[sanitized_name]
    temp_value = gate.temperature.item()
    print(f"{original_name:50s}: temperature = {temp_value:.4f}")
    
    # Artifact discriminability
    artifact_disc = gate.artifact_discriminability
    disc_mean = artifact_disc.mean().item()
    disc_std = artifact_disc.std().item()
    disc_max = artifact_disc.max().item()
    print(f"{'':50s}  artifact_disc: mean={disc_mean:.4f}, std={disc_std:.4f}, max={disc_max:.4f}")
    
    # Channel bias statistics
    if config.use_channel_bias:
        bias_mean = gate.channel_bias.mean().item()
        bias_std = gate.channel_bias.std().item()
        bias_min = gate.channel_bias.min().item()
        bias_max = gate.channel_bias.max().item()
        print(f"{'':50s}  channel_bias: mean={bias_mean:+.4f}, std={bias_std:.4f}, range=[{bias_min:+.4f}, {bias_max:+.4f}]")
    print()

## Summary

### Channel Pruning v1 (IMPROVED) 핵심:

#### 1. **LGrad의 작동 원리**
```
Image → Gradient (StyleGAN) → Classifier → Prediction
        ↓                      ↓
    GAN pattern          Pruned channels
```
- Real: Natural gradient pattern
- Fake: GAN-specific gradient pattern  
- Noise: Corrupts gradient pattern

#### 2. **핵심 아이디어**
- Gradient pattern을 robust하게 보존하는 채널 유지
- Noise corruption에 민감한 채널 제거

#### 3. **방법론 (IMPROVED)**

**OLD (기존 CPv1)**:
```python
# ❌ Pseudo-label 의존
pseudo_label = estimate_label(x)  # Noisy에서 부정확!
if pseudo_label == Fake:
    sensitivity = |curr - fake_stats|  # Artifact를 noise로 오인
else:
    sensitivity = |curr - real_stats|
```

**NEW (Gradient-Pattern-Aware)**:
```python
# ✅ Label-free minimum deviation
dev_real = |curr - real_stats|
dev_fake = |curr - fake_stats|
sensitivity = min(dev_real, dev_fake)

# Clean: close to at least one → small
# Noisy: far from both → large
```

**Score 계산**:
```python
discriminability = |fake_stats - real_stats|  # Gradient pattern 차이
score = discriminability / sensitivity

# High score: Strong pattern detection + Low corruption sensitivity → KEEP
# Low score: Weak detection + High sensitivity → PRUNE
```

#### 4. **주요 개선사항**

| 측면 | 기존 CPv1 | IMPROVED CPv1 |
|------|-----------|---------------|
| Pseudo-label | ❌ 필요 (부정확) | ✅ 불필요 (label-free) |
| Sensitivity | ❌ Label-aware (오류 가능) | ✅ Minimum deviation |
| Gating | ❌ Batch-wise (불안정) | ✅ Channel-level (안정) |
| LGrad 이해 | ❌ 일반적 | ✅ Gradient-specific |

#### 5. **vs Channel Reweight v1**

| 특징 | CRv1 | CPv1 (IMPROVED) |
|------|------|-----------------|
| Statistics | Mixed (Real+Fake) | Separated (Real, Fake) |
| Artifact info | ❌ 구분 불가 | ✅ 보존 |
| Method | Reweighting | Gradient-pattern-aware pruning |
| Label-free | ✅ Yes | ✅ Yes (improved!) |

### 작동 원리:

1. **Pre-compute**: Real/Fake clean gradient features의 separated statistics
2. **Test time**: Noisy gradient → classifier features
3. **Sensitivity**: `min(|curr - real|, |curr - fake|)` (gradient corruption)
4. **Discriminability**: `|fake - real|` (pattern difference)
5. **Score**: `disc / sens`
6. **Gating**: High score channels만 유지

### 다음 단계:
- ✅ Baseline (no gating)과 성능 비교
- ✅ 기존 CPv1과 성능 비교
- ✅ CRv1과 성능 비교
- □ SGS, SAS 등 다른 방법들과 비교
- □ NPR 모델에도 적용