In [None]:
%load_ext autoreload
%autoreload 2
from scripts.load_data import load_raw_data
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

## 2D WGAN-SN 

In [None]:
import torch
import os
# scripts/wgan_sn.py에 정의된 함수와 클래스를 임포트합니다.
from scripts.wgan_sn import load_npz_patches_to_0_1, WGAN_SN_Torch

CHANNELS = 2
IMG_SIZE = 512
EPOCHS = 10001
BATCH_SIZE = 32
DATA_PATH = ["dataset/NorthSea_Augmented_Total_down4_p512_s64.npz"]

# 1. 데이터 로드
# target_channels 인자를 통해 1채널 유지 또는 2채널 원-핫 변환을 결정합니다.
X = load_npz_patches_to_0_1(
    DATA_PATH,
    img_rows=IMG_SIZE, 
    img_cols=IMG_SIZE,
    target_channels=CHANNELS
)

print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Data Loaded. Shape: {X.shape}")

# 2. WGAN-SN 모델 초기화
# channels 변수에 따라 Generator의 마지막 레이어(Softmax/Sigmoid)가 결정됩니다.
wgan = WGAN_SN_Torch(
    img_rows=IMG_SIZE,
    img_cols=IMG_SIZE,
    channels=CHANNELS, 
    latent_dim=100,
    lr_g=2e-5,
    lr_c=2e-5,
    sample_dir=f"outcomes/snapshots_wgan_sn_{CHANNELS}ch",
    model_dir=f"outcomes/models_wgan_sn_{CHANNELS}ch"
)

# 3. 학습 시작
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Start {CHANNELS}-Channel Training...")
wgan.train(
    X,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    n_critic=1,
    save_interval=10,
    print_every=1
)

print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {CHANNELS}-Channel Training completed.")