# 모든 모델 학습 스크립트
## 노이즈 종류별 데이터셋 구성 (5만장 * 3 = 15만장) 및 모델별 학습


In [None]:
# Part 1: 기본 설정 및 라이브러리 임포트
import os
import glob
import numpy as np
import tensorflow as tf
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from tensorflow.keras import layers, Model, regularizers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import sys
sys.path.append('.')

# Models 폴더에서 모든 모델 import
from models import (
    # MTL 모델들
    create_model as create_multitask_unet,
    create_bam_model,
    
    # 단일 태스크 모델들
    create_bam_restoration_model,
    create_bam_classification_model,
    
    # 기존 모델들
    build_cae_multitask, build_cae_restoration,
    build_dncnn_multitask, build_dncnn_restoration,
    build_unet_multitask, build_unet_restoration, build_unet_baseline
)

print("TensorFlow Version:", tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))


In [None]:
# Part 2: CIFAR-10 데이터 로드 및 전처리
print("--- Part 2: Loading CIFAR-10 Dataset ---")
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
y_train = y_train.flatten()
y_test = y_test.flatten()
cifar10_class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(cifar10_class_names)
print(f"CIFAR-10 data loaded. Train: {x_train.shape}, Test: {x_test.shape}")

# Z-score 정규화를 위한 통계값 계산
MEAN = tf.constant(np.mean(x_train, axis=(0, 1, 2)), dtype=tf.float32)
STD = tf.constant(np.std(x_train, axis=(0, 1, 2)) + 1e-6, dtype=tf.float32)

def to_zscore(x):
    return (x - MEAN) / STD

def from_zscore(z):
    return z * STD + MEAN


In [None]:
# Part 3: 노이즈 종류별 데이터셋 생성 (5만장 * 3 = 15만장)
print("--- Part 3: Creating Noise-Specific Datasets ---")

# 데이터 증강 파이프라인
data_augmentation_pipeline = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.05),
])

def augment_brightness_contrast(image):
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
    return image

# 노이즈 타입별 함수들
def add_gaussian_noise(image, snr_range=(-30, -10)):
    """Gaussian 노이즈 추가"""
    image_z = to_zscore(image)
    snr_db = tf.random.uniform([], minval=snr_range[0], maxval=snr_range[1], dtype=tf.float32)
    sigma = tf.pow(10.0, -snr_db / 20.0)
    noise = tf.random.normal(shape=tf.shape(image_z), mean=0.0, stddev=sigma, dtype=tf.float32)
    noisy_z = image_z + noise
    return noisy_z, 0  # noise_type = 0

def add_salt_pepper_noise(image, amount_range=(0.05, 0.30)):
    """Salt & Pepper 노이즈 추가"""
    image_z = to_zscore(image)
    amount = tf.random.uniform([], minval=amount_range[0], maxval=amount_range[1], dtype=tf.float32)
    u = tf.random.uniform(shape=tf.shape(image_z))
    salt = tf.cast(u < amount * 0.5, tf.float32)
    pepper = tf.cast(u > 1.0 - amount * 0.5, tf.float32)
    noisy_z = image_z * (1.0 - salt - pepper) + salt
    return noisy_z, 1  # noise_type = 1

def add_burst_noise(image, size_range=(0.2, 0.4), intensity_range=(0.7, 1.0)):
    """Burst 노이즈 추가"""
    image_z = to_zscore(image)
    h, w, cch = tf.shape(image_z)[0], tf.shape(image_z)[1], tf.shape(image_z)[2]
    size_factor = tf.random.uniform([], size_range[0], size_range[1])
    intensity = tf.random.uniform([], intensity_range[0], intensity_range[1])
    bh = tf.cast(tf.cast(h, tf.float32) * size_factor, tf.int32)
    bw = tf.cast(tf.cast(w, tf.float32) * size_factor, tf.int32)
    sy = tf.random.uniform([], maxval=tf.maximum(1, h - bh), dtype=tf.int32)
    sx = tf.random.uniform([], maxval=tf.maximum(1, w - bw), dtype=tf.int32)
    
    patch = tf.random.normal([bh, bw, cch], stddev=intensity)
    noise = tf.pad(patch, [[sy, h - sy - bh], [sx, w - sx - bw], [0, 0]])
    mask = tf.pad(tf.ones([bh, bw, cch]), [[sy, h - sy - bh], [sx, w - sx - bw], [0, 0]])
    noisy_z = image_z * (1.0 - mask) + (image_z + noise) * mask
    return noisy_z, 2  # noise_type = 2

def generate_noise_specific_sample(clean_image, label, noise_type):
    """특정 노이즈 타입으로 샘플 생성"""
    # 1) 증강
    aug = data_augmentation_pipeline(clean_image[tf.newaxis], training=True)
    aug = tf.squeeze(aug, axis=0)
    aug = augment_brightness_contrast(aug)
    clean_aug = tf.clip_by_value(aug, 0.0, 1.0)
    
    # 2) 노이즈 추가
    if noise_type == 0:
        noisy_z, _ = add_gaussian_noise(clean_aug)
    elif noise_type == 1:
        noisy_z, _ = add_salt_pepper_noise(clean_aug)
    else:  # noise_type == 2
        noisy_z, _ = add_burst_noise(clean_aug)
    
    # 3) clean_z 준비
    clean_z = to_zscore(clean_aug)
    
    return (noisy_z, noise_type), (clean_z, label)

print("노이즈 함수 정의 완료")


In [None]:
# Part 4: 노이즈 종류별 데이터셋 생성 (5만장 * 3)
print("--- Part 4: Creating Noise-Specific Datasets (50k * 3) ---")

# 각 노이즈 타입별로 5만장씩 생성
BATCH_SIZE = 64
noise_types = ['gaussian', 'salt_pepper', 'burst']
datasets = {}

for i, noise_type in enumerate(noise_types):
    print(f"{noise_type} 노이즈 데이터셋 생성 중...")
    
    # 각 노이즈 타입별로 5만장 생성
    dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    dataset = dataset.shuffle(50000).take(50000)  # 5만장 선택
    dataset = dataset.map(
        lambda img, lbl: generate_noise_specific_sample(img, lbl, i),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    
    datasets[noise_type] = dataset
    print(f"{noise_type} 데이터셋 크기: {len(list(dataset))} 배치")

# 테스트 데이터셋도 생성
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.map(
    lambda img, lbl: generate_noise_specific_sample(img, lbl, 0),  # gaussian으로 테스트
    num_parallel_calls=tf.data.AUTOTUNE
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"테스트 데이터셋 크기: {len(list(test_dataset))} 배치")
print("데이터셋 생성 완료!")


In [None]:
# Part 5: 모델 정의 및 학습 함수
print("--- Part 5: Model Training Functions ---")

def train_model(model, train_datasets, test_dataset, model_name, epochs=20, save_path=None):
    """모델 학습 함수"""
    print(f"\n=== {model_name} 모델 학습 시작 ===")
    
    # 콜백 설정
    callbacks = [
        EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
        ModelCheckpoint(
            filepath=f'best_{model_name.lower()}.keras',
            monitor='val_loss',
            save_best_only=True,
            save_weights_only=False
        )
    ]
    
    # 학습 히스토리 저장
    histories = {}
    
    # 각 노이즈 타입별로 학습
    for noise_type, dataset in train_datasets.items():
        print(f"\n{noise_type} 노이즈로 학습 중...")
        
        history = model.fit(
            dataset,
            validation_data=test_dataset,
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )
        
        histories[noise_type] = history
        
        # 모델 저장
        if save_path:
            model.save(f"{save_path}_{noise_type}.keras")
    
    return histories

def evaluate_model(model, test_dataset, model_name):
    """모델 평가 함수"""
    print(f"\n=== {model_name} 모델 평가 ===")
    
    # 테스트 데이터로 예측
    predictions = model.predict(test_dataset, verbose=1)
    
    # 결과 출력
    if isinstance(predictions, dict):
        print(f"복원 출력 형태: {predictions['restoration_output'].shape}")
        print(f"분류 출력 형태: {predictions['classification_output'].shape}")
    else:
        print(f"출력 형태: {predictions.shape}")
    
    return predictions

print("학습 함수 정의 완료")


In [None]:
# Part 6: MTL 모델들 학습
print("--- Part 6: Training MTL Models ---")

# 6-1. Multitask U-Net 학습
print("\n1. Multitask U-Net 학습")
mtl_unet = create_multitask_unet(input_shape=(32, 32, 3), num_classes=num_classes)
mtl_unet.compile(
    optimizer='adam',
    loss={'restoration_output': 'mse', 'classification_output': 'categorical_crossentropy'},
    loss_weights={'restoration_output': 1.0, 'classification_output': 0.1},
    metrics={'restoration_output': ['mae'], 'classification_output': ['accuracy']}
)

mtl_unet_histories = train_model(
    mtl_unet, datasets, test_dataset, 
    "MTL_UNet", epochs=20, save_path="models/mtl_unet"
)

# 6-2. BAM MTL 모델 학습
print("\n2. BAM MTL 모델 학습")
# BAM 모델은 평탄화된 입력을 사용하므로 데이터 변환 필요
def prepare_bam_data(dataset):
    """BAM 모델용 데이터 변환"""
    def transform_batch(batch):
        (noisy_z, noise_type), (clean_z, label) = batch
        # 평탄화
        noisy_flat = tf.reshape(noisy_z, [tf.shape(noisy_z)[0], -1])
        clean_flat = tf.reshape(clean_z, [tf.shape(clean_z)[0], -1])
        # 원-핫 인코딩
        label_1h = tf.keras.utils.to_categorical(label, num_classes)
        return (noisy_flat, noise_type), (clean_flat, label_1h)
    
    return dataset.map(transform_batch)

# BAM용 데이터셋 준비
bam_datasets = {}
for noise_type, dataset in datasets.items():
    bam_datasets[noise_type] = prepare_bam_data(dataset)

bam_test_dataset = prepare_bam_data(test_dataset)

bam_mtl = create_bam_model(input_dim=3072, latent_dim=128, num_classes=num_classes)
bam_mtl_histories = train_model(
    bam_mtl, bam_datasets, bam_test_dataset,
    "BAM_MTL", epochs=20, save_path="models/bam_mtl"
)

print("MTL 모델들 학습 완료!")


In [None]:
# Part 7: 단일 태스크 모델들 학습
print("--- Part 7: Training Single-Task Models ---")

# 7-1. BAM 복원 모델 학습
print("\n1. BAM 복원 모델 학습")
bam_restoration = create_bam_restoration_model(input_dim=3072, latent_dim=128)

# 복원용 데이터셋 준비 (복원만)
def prepare_restoration_data(dataset):
    def transform_batch(batch):
        (noisy_z, noise_type), (clean_z, label) = batch
        noisy_flat = tf.reshape(noisy_z, [tf.shape(noisy_z)[0], -1])
        clean_flat = tf.reshape(clean_z, [tf.shape(clean_z)[0], -1])
        return noisy_flat, clean_flat
    return dataset.map(transform_batch)

restoration_datasets = {}
for noise_type, dataset in datasets.items():
    restoration_datasets[noise_type] = prepare_restoration_data(dataset)

restoration_test_dataset = prepare_restoration_data(test_dataset)

bam_restoration_histories = train_model(
    bam_restoration, restoration_datasets, restoration_test_dataset,
    "BAM_Restoration", epochs=20, save_path="models/bam_restoration"
)

# 7-2. BAM 분류 모델 학습
print("\n2. BAM 분류 모델 학습")
bam_classification = create_bam_classification_model(input_dim=3072, latent_dim=128, num_classes=num_classes)

# 분류용 데이터셋 준비 (분류만)
def prepare_classification_data(dataset):
    def transform_batch(batch):
        (noisy_z, noise_type), (clean_z, label) = batch
        noisy_flat = tf.reshape(noisy_z, [tf.shape(noisy_z)[0], -1])
        label_1h = tf.keras.utils.to_categorical(label, num_classes)
        return noisy_flat, label_1h
    return dataset.map(transform_batch)

classification_datasets = {}
for noise_type, dataset in datasets.items():
    classification_datasets[noise_type] = prepare_classification_data(dataset)

classification_test_dataset = prepare_classification_data(test_dataset)

bam_classification_histories = train_model(
    bam_classification, classification_datasets, classification_test_dataset,
    "BAM_Classification", epochs=20, save_path="models/bam_classification"
)

print("단일 태스크 모델들 학습 완료!")


In [None]:
# Part 8: 기존 모델들 학습
print("--- Part 8: Training Existing Models ---")

# 8-1. CAE 모델들
print("\n1. CAE 모델들 학습")
cae_multitask = build_cae_multitask(input_shape=(32, 32, 3), num_classes=num_classes)
cae_multitask.compile(
    optimizer='adam',
    loss={'restoration_output': 'mse', 'classification_output': 'categorical_crossentropy'},
    loss_weights={'restoration_output': 1.0, 'classification_output': 0.1},
    metrics={'restoration_output': ['mae'], 'classification_output': ['accuracy']}
)

cae_multitask_histories = train_model(
    cae_multitask, datasets, test_dataset,
    "CAE_Multitask", epochs=20, save_path="models/cae_multitask"
)

# 8-2. DnCNN 모델들
print("\n2. DnCNN 모델들 학습")
dncnn_multitask = build_dncnn_multitask(input_shape=(32, 32, 3), num_classes=num_classes)
dncnn_multitask.compile(
    optimizer='adam',
    loss={'restoration_output': 'mse', 'classification_output': 'categorical_crossentropy'},
    loss_weights={'restoration_output': 1.0, 'classification_output': 0.1},
    metrics={'restoration_output': ['mae'], 'classification_output': ['accuracy']}
)

dncnn_multitask_histories = train_model(
    dncnn_multitask, datasets, test_dataset,
    "DnCNN_Multitask", epochs=20, save_path="models/dncnn_multitask"
)

# 8-3. U-Net 모델들
print("\n3. U-Net 모델들 학습")
unet_multitask = build_unet_multitask(input_shape=(32, 32, 3), num_classes=num_classes)
unet_multitask.compile(
    optimizer='adam',
    loss={'restoration_output': 'mse', 'classification_output': 'categorical_crossentropy'},
    loss_weights={'restoration_output': 1.0, 'classification_output': 0.1},
    metrics={'restoration_output': ['mae'], 'classification_output': ['accuracy']}
)

unet_multitask_histories = train_model(
    unet_multitask, datasets, test_dataset,
    "UNet_Multitask", epochs=20, save_path="models/unet_multitask"
)

print("기존 모델들 학습 완료!")


In [None]:
# Part 9: 모델 평가 및 결과 시각화
print("--- Part 9: Model Evaluation and Visualization ---")

# 모든 모델 평가
models_to_evaluate = {
    'MTL_UNet': mtl_unet,
    'BAM_MTL': bam_mtl,
    'BAM_Restoration': bam_restoration,
    'BAM_Classification': bam_classification,
    'CAE_Multitask': cae_multitask,
    'DnCNN_Multitask': dncnn_multitask,
    'UNet_Multitask': unet_multitask
}

evaluation_results = {}

for model_name, model in models_to_evaluate.items():
    print(f"\n{model_name} 평가 중...")
    
    if 'BAM' in model_name and 'MTL' not in model_name:
        # BAM 단일 태스크 모델들
        if 'Restoration' in model_name:
            predictions = model.predict(restoration_test_dataset, verbose=1)
        else:  # Classification
            predictions = model.predict(classification_test_dataset, verbose=1)
    else:
        # MTL 모델들
        if 'BAM' in model_name:
            predictions = model.predict(bam_test_dataset, verbose=1)
        else:
            predictions = model.predict(test_dataset, verbose=1)
    
    evaluation_results[model_name] = predictions
    print(f"{model_name} 평가 완료")

print("\n모든 모델 평가 완료!")


In [None]:
# Part 10: 학습 결과 시각화
print("--- Part 10: Training Results Visualization ---")

# 학습 히스토리 시각화
def plot_training_history(histories, model_name, noise_types):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'{model_name} Training History', fontsize=16)
    
    for i, noise_type in enumerate(noise_types):
        history = histories[noise_type]
        
        # Loss
        axes[0, 0].plot(history.history['loss'], label=f'{noise_type} train')
        axes[0, 0].plot(history.history['val_loss'], label=f'{noise_type} val')
        axes[0, 0].set_title('Loss')
        axes[0, 0].legend()
        
        # Accuracy (if available)
        if 'accuracy' in history.history:
            axes[0, 1].plot(history.history['accuracy'], label=f'{noise_type} train')
            axes[0, 1].plot(history.history['val_accuracy'], label=f'{noise_type} val')
            axes[0, 1].set_title('Accuracy')
            axes[0, 1].legend()
        
        # MAE (if available)
        if 'mae' in history.history:
            axes[1, 0].plot(history.history['mae'], label=f'{noise_type} train')
            axes[1, 0].plot(history.history['val_mae'], label=f'{noise_type} val')
            axes[1, 0].set_title('MAE')
            axes[1, 0].legend()
    
    plt.tight_layout()
    plt.show()

# 각 모델별 히스토리 시각화
all_histories = {
    'MTL_UNet': mtl_unet_histories,
    'BAM_MTL': bam_mtl_histories,
    'BAM_Restoration': bam_restoration_histories,
    'BAM_Classification': bam_classification_histories,
    'CAE_Multitask': cae_multitask_histories,
    'DnCNN_Multitask': dncnn_multitask_histories,
    'UNet_Multitask': unet_multitask_histories
}

for model_name, histories in all_histories.items():
    if histories:  # 히스토리가 있는 경우만
        plot_training_history(histories, model_name, noise_types)

print("학습 결과 시각화 완료!")


In [None]:
# Part 11: 최종 결과 요약
print("--- Part 11: Final Results Summary ---")

print("\n=== 학습 완료된 모델들 ===")
print("1. MTL 모델들:")
print("   - MTL U-Net")
print("   - BAM MTL")
print("   - CAE Multitask")
print("   - DnCNN Multitask")
print("   - U-Net Multitask")

print("\n2. 단일 태스크 모델들:")
print("   - BAM Restoration")
print("   - BAM Classification")

print("\n3. 데이터셋 정보:")
print(f"   - 노이즈 타입: {noise_types}")
print(f"   - 각 타입별 데이터: 50,000장")
print(f"   - 총 학습 데이터: 150,000장")
print(f"   - 테스트 데이터: 10,000장")

print("\n4. 저장된 모델 파일들:")
import glob
model_files = glob.glob("*.keras")
for file in sorted(model_files):
    print(f"   - {file}")

print("\n=== 모든 모델 학습 완료! ===")
