In [1]:
import os
import sys
import datetime

import pickle
from sklearn.model_selection import train_test_split

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../../")))
from SRModels.loading_methods import load_dataset_as_patches
from SRModels.deep_learning_models.SRCNN_model import SRCNNModel
from SRModels.constants import SRCNN_PATCH_SIZE, SRCNN_STRIDE, RANDOM_SEED

In [2]:
HR_ROOT = os.path.abspath(os.path.join(os.getcwd(), "../../data/images/HR"))
LR_ROOT = os.path.abspath(os.path.join(os.getcwd(), "../../data/images/LR"))
INTERP_MAP_PATH = os.path.abspath(os.path.join(os.getcwd(), "../../data/images/interpolation_map.pkl"))

In [3]:
# X -> Low-resolution patches (model input) (Low-resolution images with same size as Y but noised)
# Y -> High-resolution patches (target)
X, Y, hr_h, hr_w = load_dataset_as_patches(HR_ROOT, LR_ROOT, mode="srcnn", patch_size=SRCNN_PATCH_SIZE, stride=SRCNN_STRIDE, interpolation_map_path=INTERP_MAP_PATH)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, shuffle=True, random_state=RANDOM_SEED)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.1, shuffle=True, random_state=RANDOM_SEED)

In [4]:
print(f"X shape: {X.shape}, Y shape: {Y.shape}")
print(f"X_train shape: {X_train.shape}, Y_train shape: {Y_train.shape}")
print(f"X_val shape: {X_val.shape}, Y_val shape: {Y_val.shape}")
print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")

X shape: (2560, 96, 96, 3), Y shape: (2560, 96, 96, 3)
X_train shape: (1843, 96, 96, 3), Y_train shape: (1843, 96, 96, 3)
X_val shape: (205, 96, 96, 3), Y_val shape: (205, 96, 96, 3)
X_test shape: (512, 96, 96, 3), Y_test shape: (512, 96, 96, 3)


In [5]:
model = SRCNNModel()

model.setup_model(input_shape=X_train.shape[1:], learning_rate=1e-4, loss="mean_squared_error")

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 96, 96, 64)        15616     
                                                                 
 conv2d_1 (Conv2D)           (None, 96, 96, 32)        2080      
                                                                 
 conv2d_2 (Conv2D)           (None, 96, 96, 3)         2403      
                                                                 
Total params: 20,099
Trainable params: 20,099
Non-trainable params: 0
_________________________________________________________________


In [6]:
# Train SRCNN and capture callbacks for metrics
time_cb, mem_cb = model.fit(
    X_train, Y_train, X_val, Y_val,
    batch_size=16,
    epochs=1,
    use_augmentation=True,
    use_mix=True
)

Training on GPU: /physical_device:GPU:0


In [7]:
# Evaluate and prepare metrics dictionary
results = model.evaluate(X_test, Y_test)
metrics_dict = {
    "eval_loss": float(results[0]),
    "eval_psnr": float(results[1]),
    "eval_ssim": float(results[2]),
    "epoch_time_sec": time_cb.mean_time_value(),
    "memory": mem_cb.as_dict()
}
print(metrics_dict)

Loss: 0.0125, PSNR: 19.24 dB, SSIM: 0.5687
{'eval_loss': 0.012475560419261456, 'eval_psnr': 19.235689163208008, 'eval_ssim': 0.5686850547790527, 'epoch_time_sec': 7.27239000001282, 'memory': {'gpu_mean_current_mb': 21.7918701171875, 'gpu_peak_mb': 575.06201171875}}


In [8]:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

# Save the trained model
model.save(directory=f"models/SRCNN/SRCNN_{timestamp}", timestamp=timestamp)

# Save hr_h and hr_w for later use in the reconstruction of the images
with open(os.path.abspath(os.path.join(os.getcwd(), f"models/SRCNN/SRCNN_{timestamp}/SRCNN_{timestamp}_hrh_hrw.pkl")), "wb") as f:
    pickle.dump((hr_h, hr_w), f)

Model saved to models/SRCNN/SRCNN_20250908_041140\SRCNN_20250908_041140.h5


In [9]:
# Save evaluation/time/memory metrics next to the model
metrics_path = os.path.abspath(os.path.join(os.getcwd(), f"models/SRCNN/SRCNN_{timestamp}/SRCNN_{timestamp}_metrics.pkl"))

with open(metrics_path, "wb") as f:
    pickle.dump(metrics_dict, f)
    
print(f"Saved metrics to {metrics_path}")

Saved metrics to c:\Users\bgmanuel\InteligenciaArtificial\MasterInteligenciaArtificial\Periodo2\TFM\Super-Resolution-Images-for-3D-Printing-Defect-Detection\SRModels\deep_learning_models\models\SRCNN\SRCNN_20250908_041140\SRCNN_20250908_041140_metrics.pkl
