In [1]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim


from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

from torchvision.datasets import ImageFolder
from tqdm import tqdm   # show loops progress

#import wandb    # track and visualize aspects of training proccess in real time
#import evaluate
from unet import UNet
#from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss

In [2]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA version from PyTorch: {torch.version.cuda}")

PyTorch version: 2.5.1
CUDA version from PyTorch: 12.1


**Carga de datos de entrenamiento**

In [5]:
# Dataset personalizado para DRIVE
class DriveDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))  # Ordenar para emparejar correctamente
        self.mask_filenames = sorted(os.listdir(mask_dir))    # Ordenar para emparejar correctamente
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])
        
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Convertir a escala de grises para las máscaras
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


IMG_SIZE = 512

# Transformaciones para las imágenes y las máscaras
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Ajusta según tus necesidades
    transforms.ToTensor()
])

# Routes
image_dir = './data/DRIVE/training/images'
gt_dir = './data/DRIVE/training/1st_manual'
#dir_checkpoint = Path('./checkpoints/')        # donde se guardarán

dataset = DriveDataset(image_dir, gt_dir, transform=transform)

In [None]:
example_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)

# Obtener algunas imágenes y máscaras del example_loader
data_iter = iter(example_loader)
images, masks = next(data_iter)

# Convertir los tensores a formato numpy para visualizarlos
images = images.numpy().transpose(0, 2, 3, 1)  # [N, C, H, W] a [N, H, W, C]
masks = masks.numpy()  # [N, H, W] para las máscaras

# Eliminar la dimensión extra (1, H, W) de las máscaras
masks = np.squeeze(masks)  # Esto convierte la forma (1, H, W) a (H, W)

# Visualizar las imágenes en la fila superior y las máscaras en la fila inferior
fig, axes = plt.subplots(2, 4, figsize=(15, 7))  # 2 filas, 4 columnas (imágenes arriba y máscaras abajo)

for i in range(4):
    # Mostrar la imagen en la fila superior
    axes[0, i].imshow(images[i])
    axes[0, i].axis('off')
    axes[0, i].set_title(f'Imagen {i+1}')
    
    # Mostrar la máscara en la fila inferior
    axes[1, i].imshow(masks[i], cmap='gray')  # Usar escala de grises para las máscaras
    axes[1, i].axis('off')
    axes[1, i].set_title(f'Máscara {i+1}')

plt.tight_layout()
plt.show()

In [None]:
VAL_PERCENT:float = 0.2            # Percentage of dataset intended for validation (rest is for training) 
BATCH_SIZE:int = 4

# Split into train / validation partitions
val_size = int(len(dataset) * VAL_PERCENT)     # number of samples for validation
train_size = len(dataset) - val_size           # number of samples for training
train_set, val_set = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(0))

# Create data loaders for our datasets; shuffle for training, not for validation
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)

print('Training set has {} instances'.format(train_size))
print('Validation set has {} instances'.format(val_size))

In [None]:
sample_image, sample_label = train_set[0]

NUM_CHANNELS_IN:int = sample_image.size(0)  
NUM_CHANNELS_OUT:int = sample_label.size(0)

print(f"Number of channels in input: {NUM_CHANNELS_IN}")
print(f"Number of channels in output: {NUM_CHANNELS_OUT}")

In [None]:
EPOCHS:int = 100
LEARNING_RATE:float = 0.001

SAVE_PATH = "./trained_models"        # where to save the trained best model state

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device.type=}")

In [None]:
from unet import UNet
model = UNet(n_channels=NUM_CHANNELS_IN, n_classes=NUM_CHANNELS_OUT)
model.to(device)      # move to cuda if possible

In [None]:
from torchsummary import summary
summary(model, (NUM_CHANNELS_IN, IMG_SIZE, IMG_SIZE))

**TRAIN**

In [10]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

In [11]:
def training_loop(epochs, model, train_dataloader, val_dataloader, 
                  loss_fn, optimizer, save_path):
    
    history = {'train_loss': [], 'val_loss': []}
    best_val_loss = float('inf')  # Initialize to infinity
    best_model_state = None

    for epoch in range(1, epochs+1):
        model.train()
        running_train_loss = 0
        for data in tqdm(train_dataloader):
            img, mask = data
            img, mask = img.to(device), mask.to(device)
            pred = model(img)
            loss = loss_fn(pred, mask)
            running_train_loss += loss.item()
            loss.backward()    # calculate gradients
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        with torch.no_grad():
            running_val_loss = 0
            for data in tqdm(val_dataloader):
                img, mask = data
                img, mask = img.to(device), mask.to(device)
                pred = model(img)
                loss = loss_fn(pred, mask)
                running_val_loss += loss.item()

        train_loss = running_train_loss / len(train_dataloader.dataset)
        val_loss = running_val_loss / len(val_dataloader.dataset)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        print(f'Epoch: {epoch}/{epochs} | Training loss: {train_loss} | Validation loss: {val_loss}')
        
        # Save the model if the validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()  # Save the model state

    # Save the best model state
    if best_model_state is not None:
        model_path = save_path+'/'+'model_{}_{}'.format(type(model).__name__, datetime.now().strftime('%Y%m%d_%H%M%S'))
        torch.save(best_model_state, model_path)
        print(f'\nBest model saved at {model_path}')
    
    model.eval()
    return history

In [None]:
history = training_loop(EPOCHS, model, train_loader, val_loader, 
                        loss_fn, optimizer, SAVE_PATH)

In [None]:
plt.figure(figsize=(7, 7))
plt.plot(np.arange(1,EPOCHS+1), history['train_loss'], label='Training loss')
plt.plot(np.arange(1,EPOCHS+1), history['val_loss'], label='Validation loss')
interval = 10 
plt.xticks(np.arange(1, EPOCHS + 1, interval))
plt.xlabel("Epochs")
plt.grid()
plt.legend()
plt.show()

**TEST**

In [None]:
model = UNet(n_channels=NUM_CHANNELS_IN, n_classes=NUM_CHANNELS_OUT)
model.load_state_dict(torch.load('./trained_models/model_UNet_20250118_220040'))
model.to(device)
model.eval()

In [12]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Ajusta según tus necesidades
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
])

class TestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_names = os.listdir(image_dir)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(img_name).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.image_names[idx]  # return image and its name

# load test images
test_dir = './data/DRIVE/test/images'
test_dataset = TestDataset(test_dir, transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
for images, labels in test_loader:  # test_loader es tu DataLoader
    print(f"Shape of image: {images.shape}")  # Forma de la imagen
    break  # Solo mostrar el tamaño de la primera imagen

In [None]:
num_images = 4  # number of images to show
fig, axes = plt.subplots(2, num_images, figsize=(15, 7))

i = 0  # counter

with torch.no_grad():  # no gradient needed
    for image, image_name in test_loader:
        if i >= num_images: 
            break

        # move to cuda
        image = image.to(device)

        output = model(image)  # predict
        output = torch.sigmoid(output)  # apply sigmoid if mask binary
        output = (output > 0.5).float()  # convert to 0 or 1

        # convert output to image
        output_img = output.squeeze().cpu().numpy()  # delete batch dimension

        # original test images
        axes[0, i].imshow(image.squeeze().cpu().numpy().transpose(1, 2, 0))  # from tensor to image
        axes[0, i].axis('off')
        axes[0, i].set_title(f"Original: {image_name[0]}")

        # test images predicted segmentations
        axes[1, i].imshow(output_img, cmap='gray')  # grayscale
        axes[1, i].axis('off')
        axes[1, i].set_title(f"Segmentation: {image_name[0]}")

        i += 1

plt.tight_layout()
plt.show()