# Channel Reweight v1 (CRv1) Example

## Import

In [1]:
import sys
# 캐시된 모듈 제거
for mod in list(sys.modules.keys()):
    if any(x in mod for x in ['NPR', 'npr', 'LGrad', 'lgrad', 'networks', 'method', 'channel_reweight']):
        del sys.modules[mod]

In [2]:
import os
from pathlib import Path
from typing import Optional, Literal
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Subset
from PIL import Image
from torchvision import transforms

from utils.data.dataset import CorruptedDataset
from utils.visual.visualizer import DatasetVisualizer
from utils.eval.metrics import PredictionCollector, MetricsCalculator

# Channel Reweight v1 import
from model.method import (
    UnifiedChannelReweightV1,
    ChannelReweightV1Config,
    compute_channel_statistics,
)
from model.LGrad.lgrad_model import LGrad
from model.NPR.npr_model import NPR

## GPU and Model select

In [3]:
!nvidia-smi

Tue Dec 30 08:14:02 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:04:00.0 Off |                    0 |
| N/A   36C    P0              27W / 250W |      4MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE-16GB           Off | 00000000:06:00.0 Off |  

In [4]:
DEVICE = "cuda:1"
MODEL_LIST = ["lgrad", "npr"]
MODEL = MODEL_LIST[0]  # "lgrad" or "npr"

## Dataloader

In [5]:
ROOT = "corrupted_dataset"
DATASETS = [
    "corrupted_test_data_progan",
    "corrupted_test_data_stylegan",
    "corrupted_test_data_stylegan2",
    "corrupted_test_data_biggan",
    # "corrupted_test_data_crn",
    # "corrupted_test_data_cyclegan",
    # "corrupted_test_data_deepfake",
    # "corrupted_test_data_gaugan",
    # "corrupted_test_data_imle",
    # "corrupted_test_data_san",
    # "corrupted_test_data_seeingdark",
    # "corrupted_test_data_stargan",
    # "corrupted_test_data_whichfaceisreal",
]

CORRUPTIONS = [
    "original",
    "gaussian_noise",
    "jpeg_compression",
    # "contrast",
    # "fog",
    # "motion_blur",
    # "pixelate",
]

if MODEL == "lgrad":
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
else:
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [6]:
dataset = CorruptedDataset(
    root=ROOT,
    datasets=DATASETS,
    corruptions=CORRUPTIONS,
    transform=transform
)

print(f"Total samples: {len(dataset)}")

Total samples: 119874


## Model load

In [7]:
# LGrad
STYLEGAN_WEIGHTS_ROOT = "model/LGrad/weights/karras2019stylegan-bedrooms-256x256_discriminator.pth"
CLASSIFIER_WEIGHTS_ROOT = "model/LGrad/weights/LGrad-Pretrained-Model/LGrad-4class-Trainon-Progan_car_cat_chair_horse.pth"

# NPR
NPR_WEIGHTS_ROOT = "model/NPR/weights/NPR.pth"

if MODEL == "lgrad":
    base_model = LGrad(
        stylegan_weights=STYLEGAN_WEIGHTS_ROOT,
        classifier_weights=CLASSIFIER_WEIGHTS_ROOT,
        device=DEVICE
    )
    model_name = "LGrad"
elif MODEL == "npr":
    base_model = NPR(
        weights=NPR_WEIGHTS_ROOT,
        device=DEVICE
    )
    model_name = "NPR"

print(f"Base model loaded: {model_name}")

  torch.load(stylegan_weights, map_location="cpu"),
  torch.load(classifier_weights, map_location="cpu")


Base model loaded: LGrad


## Step 1: Compute Channel Statistics from Clean Data

**중요!** Channel Reweight v1은 clean data의 통계가 필요합니다.

- ProGAN의 original (uncorrupted) 데이터로 statistics 수집
- 한 번 계산하면 저장해서 재사용 가능

In [8]:
# Clean data 준비 (ProGAN original)
progan_clean_indices = [
    i for i, s in enumerate(dataset.samples)
    if s['dataset'] == "corrupted_test_data_progan" and s['corruption'] == "original"
]

print(f"ProGAN clean samples: {len(progan_clean_indices)}")

# Subset & DataLoader
clean_subset = Subset(dataset, progan_clean_indices)
clean_loader = DataLoader(
    clean_subset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
    drop_last=False
)

ProGAN clean samples: 8000


In [9]:
# Statistics 파일 경로 (block-level statistics)
STATS_PATH = f"clean_stats_{MODEL}_progan_blocks.pth"

# 기존 statistics가 있으면 로드, 없으면 계산
if os.path.exists(STATS_PATH):
    print(f"Loading pre-computed statistics from {STATS_PATH}")
    clean_stats = torch.load(STATS_PATH)
    print(f"Statistics loaded for {len(clean_stats)} layers")
else:
    print("Computing channel statistics from clean data (block-level)...")
    print("This may take a few minutes...\n")
    
    # Block 단위로 statistics 수집
    target_layers_for_stats = [
        'classifier.layer1',
        'classifier.layer2',
        'classifier.layer3',
        'classifier.layer4',
    ] if MODEL == "lgrad" else None
    
    clean_stats = compute_channel_statistics(
        model=base_model,
        dataloader=clean_loader,
        target_layers=target_layers_for_stats,  # Block 단위로 지정
        device=DEVICE,
        max_batches=50,  # 속도를 위해 50 batches만 (필요하면 늘리기)
    )
    
    # 저장
    torch.save(clean_stats, STATS_PATH)
    print(f"\nStatistics saved to {STATS_PATH}")

Computing channel statistics from clean data (block-level)...
This may take a few minutes...

[CRv1] Computing statistics for 4 layers...


Computing statistics: 100%|██████████| 50/50 [00:17<00:00,  2.86it/s, batch=50/50]


  classifier.layer1: C=256, mean_range=[0.0000, 2.1960], var_range=[0.0000, 0.3713]
  classifier.layer2: C=512, mean_range=[0.0000, 1.9317], var_range=[0.0000, 1.1237]
  classifier.layer3: C=1024, mean_range=[0.0000, 3.1548], var_range=[0.0000, 2.0396]
  classifier.layer4: C=2048, mean_range=[0.0000, 2.2700], var_range=[0.0000, 2.3144]
[CRv1] Statistics computed for 4 layers

Statistics saved to clean_stats_lgrad_progan_blocks.pth


## Step 2: Create Channel Reweight v1 Model

In [10]:
# Config 설정
config = ChannelReweightV1Config(
    model=model_name,
    target_layers=[
        'classifier.layer1',  # Block 전체 output [B, 256, 64, 64]
        'classifier.layer2',  # Block 전체 output [B, 512, 32, 32]
        'classifier.layer3',  # Block 전체 output [B, 1024, 16, 16]
        'classifier.layer4',  # Block 전체 output [B, 2048, 8, 8]
    ] if MODEL == "lgrad" else None,  # NPR은 None (auto-detect)
    temperature_init=2.0,
    use_learnable_temperature=True,
    use_channel_bias=True,
    deviation_metric="mean+var",
    normalize_deviation=True,
    enable_adaptation=False,  # Test-time adaptation은 선택적
    adaptation_lr=1e-4,
    adaptation_loss="entropy",
    device=DEVICE,
)

# Model 생성
model = UnifiedChannelReweightV1(
    base_model=base_model,
    clean_stats=clean_stats,
    config=config,
)

print("\nChannel Reweight v1 model created!")

  Installed gate at classifier.layer1 (C=256)
  Installed gate at classifier.layer2 (C=512)
  Installed gate at classifier.layer3 (C=1024)
  Installed gate at classifier.layer4 (C=2048)
[CRv1] Initialized for LGrad
[CRv1] Target layers: 4
[CRv1] Temperature init: 2.0 (learnable=True)
[CRv1] Channel bias: True
[CRv1] Deviation metric: mean+var

Channel Reweight v1 model created!


## (Optional) Test-Time Adaptation

Noisy validation data로 temperature와 channel bias를 fine-tuning할 수 있습니다.

**Skip 가능!** Adaptation 없이도 사용 가능합니다.

In [11]:
# Adaptation을 원하면 주석 해제
# ENABLE_ADAPTATION = True

# if ENABLE_ADAPTATION:
#     # Noisy validation data 준비 (ProGAN gaussian_noise)
#     progan_noisy_indices = [
#         i for i, s in enumerate(dataset.samples)
#         if s['dataset'] == "corrupted_test_data_progan" and s['corruption'] == "gaussian_noise"
#     ]
    
#     print(f"ProGAN noisy samples for adaptation: {len(progan_noisy_indices)}")
    
#     noisy_subset = Subset(dataset, progan_noisy_indices[:500])  # 500개만 사용
#     noisy_loader = DataLoader(
#         noisy_subset,
#         batch_size=32,
#         shuffle=True,
#         num_workers=4,
#         drop_last=False
#     )
    
#     # Adaptation 실행
#     print("\nStarting test-time adaptation...\n")
#     model.adapt(
#         dataloader=noisy_loader,
#         epochs=5,
#         lr=1e-4,
#         verbose=True,
#     )
#     print("\nAdaptation complete!")

## Evaluation

Dataset별, Corruption별로 평가합니다.

In [12]:
# Evaluation
calc = MetricsCalculator()
all_results = {}

for dataset_name in DATASETS:
    for corruption in CORRUPTIONS:
        combination_indices = [
            i for i, s in enumerate(dataset.samples)
            if s['dataset'] == dataset_name and s['corruption'] == corruption
        ]
        
        if len(combination_indices) == 0:
            print(f"{dataset_name}-{corruption}: 샘플 없음, 스킵")
            continue
        
        print(f"\n{'='*60}")
        print(f"평가 중: {dataset_name}-{corruption}")
        print(f"샘플 수: {len(combination_indices)}")
        print(f"{'='*60}")
        
        # Subset과 DataLoader 생성
        subset = Subset(dataset, combination_indices)
        dataloader = DataLoader(
            subset,
            batch_size=32,
            shuffle=False,
            num_workers=4,
            drop_last=True
        )
        
        # 평가
        metrics = calc.evaluate(
            model=model,
            dataloader=dataloader,
            device=DEVICE,
            name=f"{dataset_name}-{corruption}"
        )
        
        # 즉시 결과 출력
        print(f"\n결과:")
        print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
        print(f"  AUC:      {metrics['auc']*100:.2f}%")
        print(f"  AP:       {metrics['ap']*100:.2f}%")
        print(f"  F1:       {metrics['f1']*100:.2f}%")
        
        # 결과 저장
        all_results[(dataset_name, corruption)] = metrics

# 전체 결과 테이블 출력
print(f"\n\n{'='*60}")
print("전체 결과 요약")
print(f"{'='*60}\n")
calc.print_results_table()
calc.summarize_by_corruption(all_results)
calc.summarize_by_dataset(all_results)


평가 중: corrupted_test_data_progan-original
샘플 수: 8000


corrupted_test_data_progan-original: 100%|██████████| 250/250 [01:31<00:00,  2.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.00%
  AUC:      3.03%
  AP:       31.00%
  F1:       0.00%

평가 중: corrupted_test_data_progan-gaussian_noise
샘플 수: 8000


corrupted_test_data_progan-gaussian_noise: 100%|██████████| 250/250 [01:31<00:00,  2.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.00%
  AUC:      53.39%
  AP:       54.28%
  F1:       0.00%

평가 중: corrupted_test_data_progan-jpeg_compression
샘플 수: 8000


corrupted_test_data_progan-jpeg_compression: 100%|██████████| 250/250 [01:31<00:00,  2.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.00%
  AUC:      61.91%
  AP:       61.50%
  F1:       0.00%

평가 중: corrupted_test_data_stylegan-original
샘플 수: 11982


corrupted_test_data_stylegan-original: 100%|██████████| 374/374 [02:17<00:00,  2.71it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.06%
  AUC:      3.11%
  AP:       30.95%
  F1:       0.00%

평가 중: corrupted_test_data_stylegan-gaussian_noise
샘플 수: 11982


corrupted_test_data_stylegan-gaussian_noise: 100%|██████████| 374/374 [02:17<00:00,  2.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.06%
  AUC:      49.80%
  AP:       49.54%
  F1:       0.00%

평가 중: corrupted_test_data_stylegan-jpeg_compression
샘플 수: 11982


corrupted_test_data_stylegan-jpeg_compression: 100%|██████████| 374/374 [02:17<00:00,  2.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.06%
  AUC:      33.17%
  AP:       39.76%
  F1:       0.00%

평가 중: corrupted_test_data_stylegan2-original
샘플 수: 15976


corrupted_test_data_stylegan2-original: 100%|██████████| 499/499 [03:03<00:00,  2.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.03%
  AUC:      5.27%
  AP:       31.30%
  F1:       0.00%

평가 중: corrupted_test_data_stylegan2-gaussian_noise
샘플 수: 15976


corrupted_test_data_stylegan2-gaussian_noise: 100%|██████████| 499/499 [03:03<00:00,  2.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.03%
  AUC:      39.86%
  AP:       43.29%
  F1:       0.00%

평가 중: corrupted_test_data_stylegan2-jpeg_compression
샘플 수: 15976


corrupted_test_data_stylegan2-jpeg_compression: 100%|██████████| 499/499 [03:03<00:00,  2.73it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.03%
  AUC:      4.04%
  AP:       31.05%
  F1:       0.00%

평가 중: corrupted_test_data_biggan-original
샘플 수: 4000


corrupted_test_data_biggan-original: 100%|██████████| 125/125 [00:46<00:00,  2.71it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.00%
  AUC:      12.10%
  AP:       32.35%
  F1:       0.00%

평가 중: corrupted_test_data_biggan-gaussian_noise
샘플 수: 4000


corrupted_test_data_biggan-gaussian_noise: 100%|██████████| 125/125 [00:46<00:00,  2.70it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.00%
  AUC:      50.30%
  AP:       51.14%
  F1:       0.00%

평가 중: corrupted_test_data_biggan-jpeg_compression
샘플 수: 4000


corrupted_test_data_biggan-jpeg_compression: 100%|██████████| 125/125 [00:46<00:00,  2.71it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



결과:
  Accuracy: 50.00%
  AUC:      57.66%
  AP:       58.88%
  F1:       0.00%


전체 결과 요약



ModuleNotFoundError: No module named 'rich'

## Learned Temperature 확인

각 layer의 temperature가 어떻게 설정/학습되었는지 확인

In [None]:
print("\n" + "="*60)
print("Learned Temperatures")
print("="*60 + "\n")

for sanitized_name, gate in model.gates.items():
    # Get original layer name
    original_name = model.gate_name_mapping[sanitized_name]
    temp_value = gate.temperature.item()
    print(f"{original_name:50s}: temperature = {temp_value:.4f}")
    
    # Channel bias statistics
    if config.use_channel_bias:
        bias_mean = gate.channel_bias.mean().item()
        bias_std = gate.channel_bias.std().item()
        bias_min = gate.channel_bias.min().item()
        bias_max = gate.channel_bias.max().item()
        print(f"{'':50s}  channel_bias: mean={bias_mean:+.4f}, std={bias_std:.4f}, range=[{bias_min:+.4f}, {bias_max:+.4f}]")
    print()

## Summary

### Channel Reweight v1 핵심:

1. **Clean data statistics 기반**: 노이즈 없는 데이터의 channel 통계를 기준점으로 사용
2. **Noise sensitivity 측정**: Test time에 통계적 편차가 큰 channel = noise-sensitive
3. **자동 가중치 조정**: Temperature & channel bias가 learnable → 수동 튜닝 불필요
4. **선택적 adaptation**: Zero-shot으로도 사용 가능, adaptation으로 성능 향상 가능

### 사용 Flow:
```
Clean data → compute_channel_statistics() → UnifiedChannelReweightV1
                                                      ↓
                                             (Optional) adapt()
                                                      ↓
                                              Inference & Eval
```

### 다음 단계:
- NORM, SGS, SAS 등 다른 방법들과 성능 비교
- Target layers 조합 실험 (layer3만? layer4만? 둘 다?)
- NPR 모델에도 적용
- Adaptation 효과 분석