In [None]:
import os
import sys
import datetime

import pickle
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../../")))

from SRModels.deep_learning_models.ESRGAN_model import ESRGAN
from SRModels.loading_methods import load_dataset_as_patches
from SRModels.constants import ESRGAN_PATCH_SIZE, ESRGAN_STRIDE, RANDOM_SEED, ESRGAN_SCALE_FACTOR

# Enable memory growth (prevents full pre-allocation)
for gpu in tf.config.experimental.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print(f"Memory growth not set: {e}")

In [None]:
HR_ROOT = os.path.abspath(os.path.join(os.getcwd(), "../../data/images/HR"))
LR_ROOT = os.path.abspath(os.path.join(os.getcwd(), "../../data/images/LR"))

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
directory = f"models/ESRGAN/ESRGAN_{timestamp}"
grid_figures_directory = f"{directory}/grid_figures"

In [None]:
X, Y = load_dataset_as_patches(HR_ROOT, LR_ROOT, mode="scale", patch_size=ESRGAN_PATCH_SIZE, stride=ESRGAN_STRIDE, scale_factor=ESRGAN_SCALE_FACTOR)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, shuffle=True, random_state=RANDOM_SEED)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, shuffle=True, random_state=RANDOM_SEED)

print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_val shape: {X_val.shape}, Y_val shape: {Y_val.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

In [None]:
BATCH_SIZE = 8     # Ajustar según memoria GPU
EPOCHS = 1

# Se define también el dataset de test (solo para evaluación posterior)
# Se normaliza a [-1,1] para que coincida con lo usado en entrenamiento
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test)).batch(BATCH_SIZE)
test_dataset = test_dataset.map(lambda x,y: (x*2.0 - 1.0, y*2.0 - 1.0), num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

In [None]:
model = ESRGAN()

model.setup_model(
    scale_factor=2, 
    growth_channels=16, 
    num_rrdb_blocks=8, 
    input_shape=X_train.shape[1:],
    output_shape=Y_train.shape[1:],
    from_trained=False
)

In [None]:
# Train ESRGAN and capture callbacks for metrics
epoch_losses, time_cb, mem_cb = model.fit(
    X_train=X_train, 
    Y_train=Y_train, 
    X_val=X_val, 
    Y_val=Y_val,
    epochs=EPOCHS, 
    batch_size=BATCH_SIZE,
    save_dir=grid_figures_directory,
)

In [None]:
results = model.evaluate(test_dataset)

def _last_mean(name):
    vals = epoch_losses.get(name, None)
    if vals is None or len(vals) == 0:
        return None
    return float(np.mean(vals))

train_loss = _last_mean('g_loss')
train_psnr = _last_mean('psnr')
train_ssim = _last_mean('ssim')
val_loss = _last_mean('val_g_loss')
val_psnr = _last_mean('val_psnr')
val_ssim = _last_mean('val_ssim')

metrics_dict = {
    "eval_loss": float(results["avg_g_loss"]),
    "eval_psnr": float(results["avg_psnr"]),
    "eval_ssim": float(results["avg_ssim"]),
    "final_train_loss": train_loss,
    "final_val_loss": val_loss,
    "final_train_psnr": train_psnr,
    "final_val_psnr": val_psnr,
    "final_train_ssim": train_ssim,
    "final_val_ssim": val_ssim,
    "epoch_time_sec": time_cb.mean_time_value(),
    "memory": mem_cb.as_dict()
}

In [None]:
# Save the trained model
model.save(directory=directory, timestamp=timestamp)

In [None]:
# Save evaluation/time/memory metrics next to the model
metrics_path = os.path.abspath(os.path.join(os.getcwd(), f"models/ESRGAN/ESRGAN_{timestamp}/ESRGAN_{timestamp}_metrics.pkl"))

with open(metrics_path, "wb") as f:
    pickle.dump(metrics_dict, f)
    
print(f"Saved metrics to {metrics_path}")