# Part 3: MTL 모델 학습

이 노트북은 Part 2에서 이어집니다.

## 8. MTL 모델 학습

### 8.1 MTL BAM

In [None]:
print("\n" + "="*80)
print("Model 4/6: MTL BAM")
print("="*80)

# 모델 생성
mtl_bam = MTLBAM(
    input_dim=3072,
    latent_dim=256,
    num_classes=10,
    recon_weight=0.7,
    cls_weight=0.3,
    learning_rate=INITIAL_LR,
    recon_loss='mae'
)

print("\n[Model Architecture]")
mtl_bam.model.summary()

# 학습
callbacks = get_callbacks('mtl_bam', monitor='val_loss', patience=30)
history = mtl_bam.train(
    x_train_flat,
    x_train_clean_flat,
    y_train_augmented,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks
)

# 결과 저장
save_history(history, 'mtl_bam')
mtl_bam.save_model('weights/mtl_bam.keras')
print("✓ Model saved")

# 평가
print("\n[Evaluation]")
results = mtl_bam.evaluate_detailed(x_test_flat, x_test_clean_flat, y_test_augmented, batch_size=BATCH_SIZE)
save_results(results, 'mtl_bam')

print("\nMTL BAM Results:")
print(f"  Reconstruction MSE: {results['reconstruction']['mse']:.6f}")
print(f"  Reconstruction PSNR: {results['reconstruction']['psnr']:.2f} dB")
print(f"  Classification Accuracy: {results['classification']['accuracy']:.4f}")

# 메모리 정리
del mtl_bam
clear_memory()

### 8.2 MTL CAE

In [None]:
print("\n" + "="*80)
print("Model 5/6: MTL CAE")
print("="*80)

# 모델 생성
mtl_cae = CAEMTL(
    input_shape=(32, 32, 3),
    num_classes=10,
    recon_weight=0.7,
    cls_weight=0.3,
    learning_rate=INITIAL_LR,
    recon_loss='mae',
    dropout_rate=0.1
)

print("\n[Model Architecture]")
mtl_cae.model.summary()

# 학습
callbacks = get_callbacks('mtl_cae', monitor='val_loss', patience=30)
history = mtl_cae.train(
    x_train_augmented,
    np.repeat(x_train_clean, 3, axis=0),
    y_train_augmented,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks
)

# 결과 저장
save_history(history, 'mtl_cae')
mtl_cae.save_model('weights/mtl_cae.keras')
print("✓ Model saved")

# 평가
print("\n[Evaluation]")
results = mtl_cae.evaluate_detailed(
    x_test_augmented,
    np.repeat(x_test_clean, 3, axis=0),
    y_test_augmented,
    batch_size=BATCH_SIZE
)
save_results(results, 'mtl_cae')

print("\nMTL CAE Results:")
print(f"  Restoration MSE: {results['restoration']['mse']:.6f}")
print(f"  Restoration PSNR: {results['restoration']['psnr']:.2f} dB")
print(f"  Classification Accuracy: {results['classification']['accuracy']:.4f}")

# 메모리 정리
del mtl_cae
clear_memory()

### 8.3 MTL U-Net

In [None]:
print("\n" + "="*80)
print("Model 6/6: MTL U-Net")
print("="*80)

# 모델 생성
mtl_unet = UNetMTL(
    input_shape=(32, 32, 3),
    num_classes=10,
    recon_weight=0.6,
    cls_weight=0.4,
    learning_rate=INITIAL_LR,
    recon_loss='mae',
    dropout_rate=0.0
)

print("\n[Model Architecture]")
mtl_unet.model.summary()

# 학습
callbacks = get_callbacks('mtl_unet', monitor='val_loss', patience=30)
history = mtl_unet.train(
    x_train_augmented,
    np.repeat(x_train_clean, 3, axis=0),
    y_train_augmented,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    callbacks=callbacks
)

# 결과 저장
save_history(history, 'mtl_unet')
mtl_unet.save_model('weights/mtl_unet.keras')
print("✓ Model saved")

# 평가
print("\n[Evaluation]")
results = mtl_unet.evaluate_detailed(
    x_test_augmented,
    np.repeat(x_test_clean, 3, axis=0),
    y_test_augmented,
    batch_size=BATCH_SIZE
)
save_results(results, 'mtl_unet')

print("\nMTL U-Net Results:")
print(f"  Restoration MSE: {results['restoration']['mse']:.6f}")
print(f"  Restoration PSNR: {results['restoration']['psnr']:.2f} dB")
print(f"  Classification Accuracy: {results['classification']['accuracy']:.4f}")

# 메모리 정리
del mtl_unet
clear_memory()

## 9. 전체 학습 완료

모든 6개 모델의 학습이 완료되었습니다!

다음 Part 4에서 결과를 종합 분석하겠습니다.