<a href="https://colab.research.google.com/github/jorge-jrzz/UEA-ML_SRCNN/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title # Instalación de dependencias
!pip install torch torchvision matplotlib kagglehub opencv-python numpy pillow scikit-image --quiet
!wget https://raw.githubusercontent.com/jorge-jrzz/UEA-ML_SRCNN/refs/heads/main/install_datasets.py

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m124.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m86.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m57.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [15]:
!python --version

Python 3.11.12


In [2]:
#@markdown Esto descarga los datasets desde kaggle necesarios para el modelo.
from install_datasets import download_datasets
download_datasets()

📥 Descargando dataset: Set5...
📥 Descargando dataset: DIV2K...


In [8]:
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

In [9]:
# Directorios de los datasets (ajusta las rutas según tu sistema)
DIV2K_TRAIN_DIR = './DIV2K/DIV2K_train_HR'  # Carpeta con imágenes HR de DIV2K (train)
DIV2K_VALID_DIR = './DIV2K/DIV2K_valid_HR'  # Carpeta con imágenes HR de DIV2K (valid)
SET5_DIR = './Set5'                         # Carpeta con imágenes HR de Set5

# Parámetros
SCALE_FACTOR = 2  # Factor de escala para superresolución (2x)
PATCH_SIZE_LR = 16  # Tamaño del parche LR (16x16)
PATCH_SIZE_HR = PATCH_SIZE_LR * SCALE_FACTOR  # Tamaño del parche HR (32x32)
BATCH_SIZE = 16  # Tamaño del batch para entrenamiento
NUM_PATCHES_PER_IMAGE = 10  # Número de parches extraídos por imagen
SUBSET_SIZE = 100  # Subconjunto de imágenes DIV2K para entrenamiento
VALID_SUBSET_SIZE = 50  # Subconjunto de imágenes DIV2K para entrenamiento

In [10]:
# Transformaciones para aumento de datos
data_augmentation = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=(90, 90)),  # Rotaciones de 90°
    transforms.RandomRotation(degrees=(180, 180)),  # Rotaciones de 180°
])

# Función para downsampling (crear imagen LR desde HR)
def create_lr_image(hr_image, scale_factor):
    """Convierte una imagen HR a LR usando interpolación bicúbica."""
    h, w = hr_image.shape[:2]
    new_h, new_w = h // scale_factor, w // scale_factor
    lr_image = cv2.resize(hr_image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
    # Upscale de vuelta para que LR tenga el mismo tamaño que HR (para SRCNN)
    lr_image = cv2.resize(lr_image, (w, h), interpolation=cv2.INTER_CUBIC)
    return lr_image

# Clase Dataset personalizada para DIV2K y Set5
class SuperResolutionDataset(Dataset):
    def __init__(self, hr_dir, is_train=True, subset_size=None):
        """Inicializa el dataset."""
        self.hr_dir = hr_dir
        self.is_train = is_train
        self.image_paths = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith('.png')])
        if subset_size is not None and is_train:
            self.image_paths = self.image_paths[:subset_size]  # Usar subconjunto para entrenamiento

    def __len__(self):
        """Número total de parches (imágenes * parches por imagen)."""
        return len(self.image_paths) * NUM_PATCHES_PER_IMAGE if self.is_train else len(self.image_paths)

    def __getitem__(self, idx):
        """Devuelve un par de parches LR-HR (entrenamiento) o imagen completa (evaluación)."""
        if self.is_train:
            # Seleccionar imagen y parche
            img_idx = idx // NUM_PATCHES_PER_IMAGE
            patch_idx = idx % NUM_PATCHES_PER_IMAGE
            img_path = self.image_paths[img_idx]
        else:
            img_path = self.image_paths[idx]

        # Cargar imagen HR
        hr_image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)  # Convertir a RGB
        hr_image = np.array(hr_image).astype(np.float32) / 255.0  # Normalizar a [0,1]

        # Crear imagen LR
        lr_image = create_lr_image(hr_image, SCALE_FACTOR)

        if self.is_train:
            # Extraer parche aleatorio
            h, w = hr_image.shape[:2]
            x = np.random.randint(0, w - PATCH_SIZE_HR)
            y = np.random.randint(0, h - PATCH_SIZE_HR)
            hr_patch = hr_image[y:y+PATCH_SIZE_HR, x:x+PATCH_SIZE_HR]
            lr_patch = lr_image[y:y+PATCH_SIZE_HR, x:x+PATCH_SIZE_HR]

            # Convertir a PIL para aumento de datos
            hr_patch_pil = Image.fromarray((hr_patch * 255).astype(np.uint8))
            lr_patch_pil = Image.fromarray((lr_patch * 255).astype(np.uint8))

            # Aplicar aumento de datos
            seed = np.random.randint(0, 10000)
            torch.manual_seed(seed)
            hr_patch_pil = data_augmentation(hr_patch_pil)
            torch.manual_seed(seed)
            lr_patch_pil = data_augmentation(lr_patch_pil)

            # Convertir de vuelta a numpy
            hr_patch = np.array(hr_patch_pil).astype(np.float32) / 255.0
            lr_patch = np.array(lr_patch_pil).astype(np.float32) / 255.0

            # Convertir a tensores
            hr_patch = torch.from_numpy(hr_patch).permute(2, 0, 1)  # [C, H, W]
            lr_patch = torch.from_numpy(lr_patch).permute(2, 0, 1)  # [C, H, W]

            return lr_patch, hr_patch
        else:
            # Para evaluación, devolver imagen completa
            hr_image = torch.from_numpy(hr_image).permute(2, 0, 1)
            lr_image = torch.from_numpy(lr_image).permute(2, 0, 1)
            return lr_image, hr_image


In [11]:
# Crear datasets
train_dataset = SuperResolutionDataset(DIV2K_TRAIN_DIR, is_train=True, subset_size=SUBSET_SIZE)
valid_dataset = SuperResolutionDataset(DIV2K_VALID_DIR, is_train=False, subset_size=VALID_SUBSET_SIZE)
test_dataset = SuperResolutionDataset(SET5_DIR, is_train=False)

# Crear dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

In [12]:
# Definición del modelo SRCNN
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        # Capa 1: Extracción de parches y representación
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU(inplace=True)

        # Capa 2: Mapeo no lineal
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, padding=0)
        self.relu2 = nn.ReLU(inplace=True)

        # Capa 3: Reconstrucción
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=2)

    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = self.conv3(out)
        return out

In [13]:
# Inicializar modelo, función de pérdida y optimizador
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Usando dispositivo: {device}")

model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)  # Reducir LR cada 30 épocas

# Función para entrenar el modelo
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    for lr_batch, hr_batch in dataloader:
        lr_batch, hr_batch = lr_batch.to(device), hr_batch.to(device)

        # Forward pass
        outputs = model(lr_batch)
        loss = criterion(outputs, hr_batch)

        # Backward pass y optimización
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

# Función para evaluar el modelo
def evaluate(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0
    psnr_val = 0
    ssim_val = 0
    with torch.no_grad():
        for lr_batch, hr_batch in dataloader:
            lr_batch, hr_batch = lr_batch.to(device), hr_batch.to(device)

            # Forward pass
            outputs = model(lr_batch)
            loss = criterion(outputs, hr_batch)

            epoch_loss += loss.item()

            # Calcular PSNR y SSIM
            output_np = outputs.squeeze(0).permute(1, 2, 0).cpu().numpy().clip(0, 1)
            target_np = hr_batch.squeeze(0).permute(1, 2, 0).cpu().numpy()

            psnr_val += psnr(target_np, output_np, data_range=1.0)
            ssim_val += ssim(target_np, output_np,
                             data_range=1.0,
                             win_size=3,  # Tamaño de ventana más pequeño (3x3)
                             channel_axis=2)  # El eje 2 corresponde a los canales RGB

    return epoch_loss / len(dataloader), psnr_val / len(dataloader), ssim_val / len(dataloader)

# Función para guardar imágenes de ejemplo
def save_example(model, dataloader, epoch, device, save_dir='./results'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model.eval()
    with torch.no_grad():
        for lr_batch, hr_batch in dataloader:
            lr_batch, hr_batch = lr_batch.to(device), hr_batch.to(device)

            # Forward pass
            sr_batch = model(lr_batch)

            # Convertir a numpy para visualización
            lr_img = lr_batch[0].permute(1, 2, 0).cpu().numpy().clip(0, 1)
            hr_img = hr_batch[0].permute(1, 2, 0).cpu().numpy()
            sr_img = sr_batch[0].permute(1, 2, 0).cpu().numpy().clip(0, 1)

            # Calcular PSNR y SSIM
            psnr_val = psnr(hr_img, sr_img, data_range=1.0)
            ssim_val = ssim(hr_img, sr_img,
                           data_range=1.0,
                           win_size=3,  # Tamaño de ventana más pequeño
                           channel_axis=2)  # El eje 2 corresponde a los canales RGB

            # Crear figura
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            axes[0].imshow(lr_img)
            axes[0].set_title('LR Input')
            axes[0].axis('off')

            axes[1].imshow(sr_img)
            axes[1].set_title(f'SR Output (PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f})')
            axes[1].axis('off')

            axes[2].imshow(hr_img)
            axes[2].set_title('HR Ground Truth')
            axes[2].axis('off')

            plt.tight_layout()
            plt.savefig(f'{save_dir}/epoch_{epoch}.png')
            plt.close()

            break  # Solo guardar una imagen

# Entrenamiento principal
def train_model(model, train_loader, valid_loader, test_loader, criterion, optimizer, scheduler,
                num_epochs=100, device='cpu', save_dir='./checkpoints'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    best_psnr = 0
    train_losses = []
    val_losses = []
    val_psnrs = []
    val_ssims = []

    for epoch in range(1, num_epochs + 1):
        start_time = time.time()

        # Entrenamiento
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)

        # Validación
        val_loss, val_psnr, val_ssim = evaluate(model, valid_loader, criterion, device)
        val_losses.append(val_loss)
        val_psnrs.append(val_psnr)
        val_ssims.append(val_ssim)

        # Actualizar learning rate
        scheduler.step()

        # Guardar el mejor modelo
        if val_psnr > best_psnr:
            best_psnr = val_psnr
            torch.save(model.state_dict(), f'{save_dir}/best_model.pth')
            print(f"Epoch {epoch}: Mejor modelo guardado con PSNR {best_psnr:.2f}")

        # Guardar checkpoint cada 10 épocas
        if epoch % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_psnr': best_psnr
            }, f'{save_dir}/checkpoint_epoch_{epoch}.pth')

        # Guardar imagen de ejemplo cada 10 épocas
        if epoch % 10 == 0:
            save_example(model, test_loader, epoch, device)

        # Imprimir estadísticas
        time_taken = time.time() - start_time
        print(f"Época {epoch}/{num_epochs} - Tiempo: {time_taken:.2f}s - Train Loss: {train_loss:.6f} - "
              f"Val Loss: {val_loss:.6f} - Val PSNR: {val_psnr:.2f} - Val SSIM: {val_ssim:.4f}")

    # Evaluar en el conjunto de prueba
    test_loss, test_psnr, test_ssim = evaluate(model, test_loader, criterion, device)
    print(f"Resultados finales en test - Loss: {test_loss:.6f} - PSNR: {test_psnr:.2f} - SSIM: {test_ssim:.4f}")

    # Guardar gráficas de entrenamiento
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss vs. Epoch')

    plt.subplot(1, 3, 2)
    plt.plot(val_psnrs, label='Val PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.legend()
    plt.title('PSNR vs. Epoch')

    plt.subplot(1, 3, 3)
    plt.plot(val_ssims, label='Val SSIM')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.legend()
    plt.title('SSIM vs. Epoch')

    plt.tight_layout()
    plt.savefig(f'{save_dir}/training_curves.png')
    plt.close()

    return model

Usando dispositivo: cuda


In [14]:
NUM_EPOCHS = 20
SAVE_DIR = './srcnn_checkpoints'

# Entrenar modelo
trained_model = train_model(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=NUM_EPOCHS,
    device=device,
    save_dir=SAVE_DIR
)

# Guardar modelo final
torch.save(trained_model.state_dict(), f'{SAVE_DIR}/final_model.pth')
print(f"Modelo final guardado en {SAVE_DIR}/final_model.pth")

Epoch 1: Mejor modelo guardado con PSNR 22.81
Época 1/20 - Tiempo: 200.74s - Train Loss: 0.032641 - Val Loss: 0.006277 - Val PSNR: 22.81 - Val SSIM: 0.6972


KeyboardInterrupt: 