# Part 2: 모델 임포트 및 Sequential 모델 학습

이 노트북은 Part 1에서 이어집니다.

## 5. 필수 라이브러리 임포트 및 설정

In [None]:
# 필수 라이브러리 임포트
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import pandas as pd
import os
import gc
import pickle
import json
from datetime import datetime

# 학습 파라미터 정의 (Part 1에서 정의했지만 여기서도 명시)
EPOCHS = 200
BATCH_SIZE = 128
VALIDATION_SPLIT = 0.2
INITIAL_LR = 1e-3
FINAL_LR = INITIAL_LR * 0.1
DECAY_RATE = (FINAL_LR / INITIAL_LR) ** (1 / EPOCHS)

print(f"✓ Libraries imported")
print(f"✓ Training configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Initial LR: {INITIAL_LR}")
print(f"  Final LR: {FINAL_LR}")

## 6. 유틸리티 함수 로드

In [None]:
# 유틸리티 함수 임포트
from utils import (
    get_callbacks,
    clear_memory,
    save_history,
    save_results
)

# 디렉토리 생성
os.makedirs('weights', exist_ok=True)
os.makedirs('history', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('logs', exist_ok=True)

print("✓ Utility functions loaded")
print("✓ Directories created")

## 7. 데이터 로드

In [None]:
# Train data
print("Loading augmented data...")
x_train_augmented = np.load('data/x_train_augmented.npy')
y_train_augmented = np.load('data/y_train_augmented.npy')
x_train_clean = np.load('data/x_train_clean.npy')

# Test data
x_test_augmented = np.load('data/x_test_augmented.npy')
y_test_augmented = np.load('data/y_test_augmented.npy')
x_test_clean = np.load('data/x_test_clean.npy')

print(f"✓ Train shape: {x_train_augmented.shape}")
print(f"✓ Test shape: {x_test_augmented.shape}")

# BAM 모델용 flattened 데이터
x_train_flat = x_train_augmented.reshape(x_train_augmented.shape[0], -1)
x_train_clean_flat = np.repeat(x_train_clean, 3, axis=0).reshape(-1, 32*32*3)

x_test_flat = x_test_augmented.reshape(x_test_augmented.shape[0], -1)
x_test_clean_flat = np.repeat(x_test_clean, 3, axis=0).reshape(-1, 32*32*3)

print(f"✓ Train flat shape: {x_train_flat.shape}")
print(f"✓ Test flat shape: {x_test_flat.shape}")

## 8. 모델 임포트

In [None]:
# 모델 임포트 (프로젝트 루트에서)
from bam_sequential import SequentialBAM

# 다른 모델들 (있다면)
try:
    from cae_sequential import SequentialCAE
    print("✓ Sequential CAE imported")
except:
    print("⚠ Sequential CAE not found (optional)")

try:
    from unet_sequential import SequentialUNet
    print("✓ Sequential U-Net imported")
except:
    print("⚠ Sequential U-Net not found (optional)")

print("\n✓ BAM model imported successfully")

## 9. 학습 시작 - Sequential BAM

### 9.1 Sequential BAM

In [None]:
print("\n" + "="*80)
print("Model 1/3: Sequential BAM")
print("="*80)

# 모델 생성
seq_bam = SequentialBAM(
    input_dim=3072,  # 32*32*3
    denoise_latent=256,
    cls_latent=128,
    num_classes=10
)

# 모델 컴파일
seq_bam.compile_models(
    denoise_lr=INITIAL_LR,
    cls_lr=INITIAL_LR
)

print("\n[Stage 1: Denoising BAM]")
seq_bam.denoise_model.summary()

# Stage 1 학습
callbacks_stage1 = get_callbacks(
    'sequential_bam_stage1', 
    monitor='val_mse', 
    patience=30,
    initial_lr=INITIAL_LR,
    epochs=EPOCHS
)

history_stage1 = seq_bam.train_stage1(
    x_train_flat,
    x_train_clean_flat,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks_stage1
)

# Stage 1 결과 저장
save_history(history_stage1, 'sequential_bam_stage1')
seq_bam.denoise_model.save('weights/sequential_bam_denoise.keras')
print("✓ Stage 1 model saved")

print("\n[Stage 2: Classification BAM]")
seq_bam.cls_model.summary()

# Stage 2 학습
callbacks_stage2 = get_callbacks(
    'sequential_bam_stage2', 
    monitor='val_accuracy', 
    patience=30,
    initial_lr=INITIAL_LR,
    epochs=EPOCHS
)

history_stage2 = seq_bam.train_stage2(
    x_train_flat,
    y_train_augmented,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks_stage2
)

# Stage 2 결과 저장
save_history(history_stage2, 'sequential_bam_stage2')
seq_bam.cls_model.save('weights/sequential_bam_classification.keras')
print("✓ Stage 2 model saved")

# 평가
print("\n[Evaluation]")
results = seq_bam.evaluate(
    x_test_flat, 
    x_test_clean_flat, 
    y_test_augmented, 
    batch_size=BATCH_SIZE
)
save_results(results, 'sequential_bam')

print("\nSequential BAM Results:")
print(f"  Restoration MSE: {results['restoration']['mse']:.6f}")
print(f"  Restoration PSNR: {results['restoration']['psnr']:.2f} dB")
print(f"  Classification Accuracy: {results['classification']['accuracy']:.4f}")

# 메모리 정리
del seq_bam
clear_memory()

### 6.2 Sequential CAE

In [None]:
print("\n" + "="*80)
print("Model 2/6: Sequential CAE")
print("="*80)

# 모델 생성
seq_cae = SequentialCAE(
    input_shape=(32, 32, 3),
    num_classes=10
)

# 모델 컴파일
seq_cae.compile_models(
    restore_lr=INITIAL_LR,
    cls_lr=INITIAL_LR,
    restore_loss='mae'
)

print("\n[Stage 1: Restoration CAE]")
seq_cae.restore_model.summary()

# Stage 1 학습
callbacks_stage1 = get_callbacks('sequential_cae_stage1', monitor='val_mae', patience=30)
history_stage1 = seq_cae.train_stage1(
    x_train_augmented,
    np.repeat(x_train_clean, 3, axis=0),  # 원본 데이터 3배 복제
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks_stage1
)

# Stage 1 결과 저장
save_history(history_stage1, 'sequential_cae_stage1')
seq_cae.restore_model.save('weights/sequential_cae_restore.keras')
print("✓ Stage 1 model saved")

print("\n[Stage 2: Classification CAE]")
seq_cae.cls_model.summary()

# Stage 2 학습
callbacks_stage2 = get_callbacks('sequential_cae_stage2', monitor='val_accuracy', patience=30)
history_stage2 = seq_cae.train_stage2(
    x_train_augmented,
    y_train_augmented,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks_stage2
)

# Stage 2 결과 저장
save_history(history_stage2, 'sequential_cae_stage2')
seq_cae.cls_model.save('weights/sequential_cae_classification.keras')
print("✓ Stage 2 model saved")

# 평가
print("\n[Evaluation]")
results = seq_cae.evaluate(
    x_test_augmented,
    np.repeat(x_test_clean, 3, axis=0),
    y_test_augmented,
    batch_size=BATCH_SIZE
)
save_results(results, 'sequential_cae')

print("\nSequential CAE Results:")
print(f"  Restoration MSE: {results['restoration']['mse']:.6f}")
print(f"  Restoration PSNR: {results['restoration']['psnr']:.2f} dB")
print(f"  Classification Accuracy: {results['classification']['accuracy']:.4f}")

# 메모리 정리
del seq_cae
clear_memory()

### 6.3 Sequential U-Net

In [None]:
print("\n" + "="*80)
print("Model 3/6: Sequential U-Net")
print("="*80)

# 모델 생성
seq_unet = SequentialUNet(
    input_shape=(32, 32, 3),
    num_classes=10,
    restore_dropout=0.0,
    cls_dropout=0.1
)

# 모델 컴파일
seq_unet.compile_models(
    restore_lr=INITIAL_LR,
    cls_lr=INITIAL_LR,
    restore_loss='mae'
)

print("\n[Stage 1: Restoration U-Net]")
seq_unet.restore_model.summary()

# Stage 1 학습
callbacks_stage1 = get_callbacks('sequential_unet_stage1', monitor='val_mae', patience=30)
history_stage1 = seq_unet.train_stage1(
    x_train_augmented,
    np.repeat(x_train_clean, 3, axis=0),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks_stage1
)

# Stage 1 결과 저장
save_history(history_stage1, 'sequential_unet_stage1')
seq_unet.restore_model.save('weights/sequential_unet_restore.keras')
print("✓ Stage 1 model saved")

print("\n[Stage 2: Classification U-Net]")
seq_unet.cls_model.summary()

# Stage 2 학습
callbacks_stage2 = get_callbacks('sequential_unet_stage2', monitor='val_accuracy', patience=30)
history_stage2 = seq_unet.train_stage2(
    x_train_augmented,
    y_train_augmented,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks_stage2
)

# Stage 2 결과 저장
save_history(history_stage2, 'sequential_unet_stage2')
seq_unet.cls_model.save('weights/sequential_unet_classification.keras')
print("✓ Stage 2 model saved")

# 평가
print("\n[Evaluation]")
results = seq_unet.evaluate(
    x_test_augmented,
    np.repeat(x_test_clean, 3, axis=0),
    y_test_augmented,
    batch_size=BATCH_SIZE
)
save_results(results, 'sequential_unet')

print("\nSequential U-Net Results:")
print(f"  Restoration MSE: {results['restoration']['mse']:.6f}")
print(f"  Restoration PSNR: {results['restoration']['psnr']:.2f} dB")
print(f"  Classification Accuracy: {results['classification']['accuracy']:.4f}")

# 메모리 정리
del seq_unet
clear_memory()

## 7. Sequential 모델 학습 완료

3개의 Sequential 모델 학습이 완료되었습니다.

다음 Part 3에서 MTL 모델들을 학습하겠습니다.