# 전체 모델 학습 - CIFAR-10 Denoising & Classification

## 모델 라인업
### Sequential Models (2-stage)
1. **Sequential BAM**: Dense, BAM 양방향 연상
2. **Sequential CAE**: Conv, No skip connection
3. **Sequential U-Net**: Conv, Skip connection

### MTL Models (1-stage)
4. **MTL BAM**: Dense, BAM 양방향 연상
5. **MTL CAE**: Conv, No skip connection
6. **MTL U-Net**: Conv, Skip connection

## 학습 설정
- **데이터**: 150,000장 (3 noise types × 5 SNR levels)
- **Epochs**: 200
- **Batch size**: 128
- **Validation split**: 20% (120K train / 30K val)
- **Early stopping**: patience=30
- **LR schedule**: Exponential decay (initial → 10% at epoch 200)

## 0. 환경 설정 및 임포트

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import pandas as pd
import os
import gc
import pickle
from datetime import datetime
import json

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

# GPU 메모리 설정 (3070 Ti 8GB)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        # GPU 메모리 제한 (8GB)
        tf.config.set_logical_device_configuration(
            gpus[0],
            [tf.config.LogicalDeviceConfiguration(memory_limit=8192)]
        )
        print("✓ GPU memory growth enabled and limited to 8GB")
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")
else:
    print("⚠ No GPU found, using CPU")

## 1. 데이터 로드

In [None]:
print("Loading augmented data...")

# Train data
x_train_augmented = np.load('data/x_train_augmented.npy')
y_train_augmented = np.load('data/y_train_augmented.npy')
x_train_clean = np.load('data/x_train_clean.npy')
train_noise_info = pd.read_csv('data/train_noise_info.csv')

# Test data
x_test_augmented = np.load('data/x_test_augmented.npy')
y_test_augmented = np.load('data/y_test_augmented.npy')
x_test_clean = np.load('data/x_test_clean.npy')
test_noise_info = pd.read_csv('data/test_noise_info.csv')

print(f"✓ Train shape: {x_train_augmented.shape}")
print(f"✓ Test shape: {x_test_augmented.shape}")
print(f"✓ Train clean reference shape: {x_train_clean.shape}")
print(f"✓ Test clean reference shape: {x_test_clean.shape}")

## 2. BAM 모델용 데이터 준비 (Flattened)

In [None]:
# BAM 모델은 flattened input 필요
print("Preparing flattened data for BAM models...")

# Train data flatten
x_train_flat = x_train_augmented.reshape(x_train_augmented.shape[0], -1)
x_train_clean_flat = np.repeat(x_train_clean, 3, axis=0).reshape(-1, 32*32*3)  # 원본 데이터를 3배 복제

# Test data flatten
x_test_flat = x_test_augmented.reshape(x_test_augmented.shape[0], -1)
x_test_clean_flat = np.repeat(x_test_clean, 3, axis=0).reshape(-1, 32*32*3)

print(f"✓ Train flat shape: {x_train_flat.shape}")
print(f"✓ Train clean flat shape: {x_train_clean_flat.shape}")
print(f"✓ Test flat shape: {x_test_flat.shape}")
print(f"✓ Test clean flat shape: {x_test_clean_flat.shape}")

## 3. 학습 설정 및 콜백 정의

In [None]:
# 학습 파라미터
EPOCHS = 200
BATCH_SIZE = 128
VALIDATION_SPLIT = 0.2
INITIAL_LR = 1e-3
FINAL_LR = INITIAL_LR * 0.1  # 10%

# Exponential decay rate 계산
DECAY_RATE = (FINAL_LR / INITIAL_LR) ** (1 / EPOCHS)

print(f"Training configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Validation split: {VALIDATION_SPLIT} ({int(len(x_train_augmented)*VALIDATION_SPLIT):,} samples)")
print(f"  Initial LR: {INITIAL_LR}")
print(f"  Final LR: {FINAL_LR}")
print(f"  Decay rate: {DECAY_RATE:.6f}")

# 디렉토리 생성
os.makedirs('weights', exist_ok=True)
os.makedirs('history', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('logs', exist_ok=True)

def get_callbacks(model_name, monitor='val_loss', patience=30):
    """
    학습 콜백 생성
    
    Args:
        model_name: 모델 이름
        monitor: 모니터링할 메트릭
        patience: Early stopping patience
    """
    callbacks = []
    
    # 1. Early Stopping
    early_stopping = keras.callbacks.EarlyStopping(
        monitor=monitor,
        patience=patience,
        restore_best_weights=True,
        verbose=1
    )
    callbacks.append(early_stopping)
    
    # 2. Model Checkpoint
    checkpoint = keras.callbacks.ModelCheckpoint(
        filepath=f'weights/{model_name}_best.keras',
        monitor=monitor,
        save_best_only=True,
        verbose=1
    )
    callbacks.append(checkpoint)
    
    # 3. Learning Rate Scheduler (Exponential Decay)
    def lr_schedule(epoch, lr):
        return INITIAL_LR * (DECAY_RATE ** epoch)
    
    lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule, verbose=0)
    callbacks.append(lr_scheduler)
    
    # 4. CSV Logger
    csv_logger = keras.callbacks.CSVLogger(
        f'logs/{model_name}_training.csv',
        append=False
    )
    callbacks.append(csv_logger)
    
    # 5. TensorBoard
    tensorboard = keras.callbacks.TensorBoard(
        log_dir=f'logs/{model_name}',
        histogram_freq=0,
        write_graph=True
    )
    callbacks.append(tensorboard)
    
    return callbacks

print("\n✓ Callbacks configured")
print(f"  - Early stopping (patience={30})")
print(f"  - Model checkpoint")
print(f"  - LR scheduler (exponential decay)")
print(f"  - CSV logger")
print(f"  - TensorBoard")

## 4. 메모리 정리 함수

In [None]:
def clear_memory():
    """
    메모리 정리
    """
    print("\n" + "="*60)
    print("Clearing memory...")
    print("="*60)
    
    # Keras session 정리
    keras.backend.clear_session()
    
    # Garbage collection
    gc.collect()
    
    # GPU 메모리 정리 시도
    if tf.config.list_physical_devices('GPU'):
        try:
            tf.keras.backend.clear_session()
            print("✓ GPU memory cleared")
        except:
            print("⚠ Could not clear GPU memory explicitly")
    
    print("✓ Memory cleared\n")

def save_history(history, model_name):
    """
    학습 히스토리 저장
    """
    with open(f'history/{model_name}_history.pkl', 'wb') as f:
        pickle.dump(history.history, f)
    print(f"✓ History saved: history/{model_name}_history.pkl")

def save_results(results, model_name):
    """
    평가 결과 저장
    """
    with open(f'results/{model_name}_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    print(f"✓ Results saved: results/{model_name}_results.json")

print("✓ Utility functions defined")