# `train.ipynb` - Model Training and Validation

In [None]:
import os
import time
from glob import glob
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from data import DriveDataset
from model import build_unet
from loss import DiceBCELoss
from utils import seeding, create_dir, epoch_time
import gc

In [None]:
def load_data(path):
    """
    Loads paths for images and masks from the augmented dataset directory.
    """
    train_x = sorted(glob(os.path.join(path, "train", "image", "*")))
    train_y = sorted(glob(os.path.join(path, "train", "mask", "*")))
    
    test_x = sorted(glob(os.path.join(path, "test", "image", "*")))
    test_y = sorted(glob(os.path.join(path, "test", "mask", "*")))
    return (train_x, train_y), (test_x, test_y)

def train(model, loader, optimizer, loss_fn, device):
    """
    Handles the training phase for one epoch.
    """
    model.train()
    epoch_loss = 0.0

    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(loader)

def evaluate(model, loader, loss_fn, device):
    """
    Handles the validation/evaluation phase.
    """
    model.eval()
    epoch_loss = 0.0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

    return epoch_loss / len(loader)


if __name__ == "__main__":
    
    # Force clean GPU memory and garbage collector
    torch.cuda.empty_cache()
    gc.collect()
    
    
    seeding(42)
    # Create folder for saving model checkpoints
    create_dir("files")

    # Load Augmented Data Paths
    data_path = "D:/nada/______Projects/Retina-Blood-Vessel-Segmentation/new_data"
    (train_x, train_y), (valid_x, valid_y) = load_data(data_path)

    print(f"Dataset Size:\nTrain: {len(train_x)}\nValid: {len(valid_x)}")

    # Hyperparameters
    H = 512
    W = 512
    size = (H, W)
    batch_size = 1      # To prevent CUDA out of Memory
    num_epochs = 50     
    lr = 1e-4
    checkpoint_path = "files/checkpoint.pth"

    # Dataset & Loader Initialization
    train_dataset = DriveDataset(train_x, train_y)
    valid_dataset = DriveDataset(valid_x, valid_y)

    # num_workers=0 is required for stability on Windows with small GPU memory
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on: {device}")
    
    model = build_unet()
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
    loss_fn = DiceBCELoss()

    # Training Loop
    best_valid_loss = float("inf")
    
    print("Starting Training Loop...")

    for epoch in range(num_epochs):
        start_time = time.time()

        # Training 
        train_loss = train(model, train_loader, optimizer, loss_fn, device)
        
        # Validation 
        valid_loss = evaluate(model, valid_loader, loss_fn, device)

        # Save Checkpoint if validation loss improves
        if valid_loss < best_valid_loss:
            print(f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving Checkpoint...")
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), checkpoint_path)

        # Calculate time taken for the epoch
        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        # Log epoch results
        print(f"Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s")
        print(f"Train Loss: {train_loss:.3f} | Valid Loss: {valid_loss:.3f}\n")

Dataset Size:
Train: 80
Valid: 20
Training on: cuda
Starting Training Loop...
Valid loss improved from inf to 0.9976. Saving Checkpoint...
Epoch: 01 | Time: 0m 26s
Train Loss: 1.062 | Valid Loss: 0.998

Valid loss improved from 0.9976 to 0.8784. Saving Checkpoint...
Epoch: 02 | Time: 0m 26s
Train Loss: 0.892 | Valid Loss: 0.878

Valid loss improved from 0.8784 to 0.8100. Saving Checkpoint...
Epoch: 03 | Time: 0m 26s
Train Loss: 0.805 | Valid Loss: 0.810

Valid loss improved from 0.8100 to 0.7429. Saving Checkpoint...
Epoch: 04 | Time: 0m 26s
Train Loss: 0.734 | Valid Loss: 0.743

Valid loss improved from 0.7429 to 0.6836. Saving Checkpoint...
Epoch: 05 | Time: 0m 26s
Train Loss: 0.676 | Valid Loss: 0.684

Valid loss improved from 0.6836 to 0.6306. Saving Checkpoint...
Epoch: 06 | Time: 0m 26s
Train Loss: 0.622 | Valid Loss: 0.631

Valid loss improved from 0.6306 to 0.6043. Saving Checkpoint...
Epoch: 07 | Time: 0m 27s
Train Loss: 0.577 | Valid Loss: 0.604

Valid loss improved from 0.60