In [None]:
import torch  # PyTorch library for deep learning
import torch.nn as nn  # For neural network functionality
import numpy as np  # For numerical operations (arrays)
from unet import UNet  # Import the UNet model (assumed to be defined elsewhere)
import matplotlib.pyplot as plt  # For plotting loss curves
import torchvision.ops as ops  # For image processing operations (though not used here)
import cv2 as cv  # OpenCV library for image processing
import os  # For interacting with the file system (directories, files)

In [None]:
# Set the device to GPU (cuda) for faster processing if available
torch.cuda.set_device(0)  # Set GPU device (if available)
device = torch.device('cuda')  # Move computation to the GPU
loss_function = nn.BCEWithLogitsLoss()  # Binary Cross Entropy loss with logits for segmentation

In [None]:
# Load and preprocess images from the 'train' directory
images = []
directory = "train\\"  # Directory where the training images are stored

In [None]:
# Normalize pixel values to [0, 1] range
images = np.stack(images) / 255  # Stack images into a numpy array

In [None]:
# Normalize image pixels (RGB channels are normalized with mean and std)
images = np.stack(images) / 255  # Stack images into a numpy array and scale pixel values

In [None]:
# Separate the RGB channels and normalize them based on ImageNet statistics
r = images[:, :, :, 0]
g = images[:, :, :, 1]
b = images[:, :, :, 2]

In [None]:
# Normalize each channel (using ImageNet's statistics)
r = (r - 0.485) / 0.229
g = (g - 0.456) / 0.224
b = (b - 0.406) / 0.225

In [None]:
# Stack the normalized channels back together
images = np.stack([r, g, b], axis=3)

In [None]:
# Save the preprocessed images to disk for later use
np.save("images.npy", images)

In [None]:
# Load and preprocess the masks (binary segmentation labels)
masks = []
directory = "masks\\"  # Directory where mask files are stored

In [None]:
for folder in os.listdir(directory):
    inner_directory = os.path.join(directory, folder)
    for file in os.listdir(inner_directory):
        # Load mask from file (assuming numpy format for masks)
        img = np.load(os.path.join(inner_directory, file))

        # Augment the masks by rotating them in various orientations (same as images)
        masks.append(np.rot90(img, 0))  # No rotation
        masks.append(np.rot90(img, 1, axes=(1, 0)))  # 90° rotation
        masks.append(np.rot90(img, 1, axes=(0, 1)))  # 90° rotation in another direction
        masks.append(np.rot90(img, 2))  # 180° rotation

In [None]:
# Stack the masks into a numpy array
masks = np.stack(masks)

In [None]:
# Print the shape of the masks to ensure they are loaded correctly
print(masks.shape)

In [None]:
# Save the processed masks to disk
np.save("masks.npy", masks)

In [None]:
# Convert the images and masks to PyTorch tensors
images = torch.tensor(np.load("images.npy"), dtype=torch.float32)
masks = torch.tensor(np.load("masks.npy"), dtype=torch.float32)

In [None]:
# Split the dataset into training and validation sets
split = int(images.shape[0] * 1)  # Use all data for training (no validation in this case)
train_images = images[:split]
train_masks = masks[:split]

In [None]:
val_images = images[split:]
val_masks = masks[split:]

In [None]:
# Print the shapes of the training and validation sets
print(train_masks.shape)
print(val_masks.shape)

In [None]:
# Function to calculate the validation loss
def val_loss():
    val_losses = []
    for i in range(val_images.shape[0]):
        # Calculate the loss for each image in the validation set
        val_losses.append(
            loss_function(
                model(val_images[i:i+1].permute(0, 3, 1, 2).to(device)).squeeze(),
                val_masks[i:i+1].to(device).squeeze()
            ).item()
        )
    return sum(val_losses) / len(val_losses)

In [None]:
# Initialize variables for tracking losses
losses = []  # List to store training losses
val_losses = []  # List to store validation losses (currently not used)

In [None]:
# Define hyperparameters
n_epochs = 100  # Number of training epochs
batch_size = 16  # Batch size for training

In [None]:
# Training loop
for epoch in range(n_epochs):
    permutation = torch.randperm(train_images.shape[0])  # Shuffle the training data
    for i in range(0, permutation.shape[0], batch_size):
        optimizer.zero_grad()

        # Get the batch of images and masks
        indices = permutation[i:i+batch_size]
        x, y = train_images[indices], train_masks[indices]

        x = x.to(device)  # Move the images to the GPU
        y = y.to(device)  # Move the masks to the GPU

        # Forward pass: pass the images through the model
        logits = model(x.permute(0, 3, 1, 2))  # Permute the input dimensions for PyTorch CNN
        loss = loss_function(logits.squeeze(), y.squeeze())  # Calculate the loss
        losses.append(loss.item())  # Append the loss to the list

        # Backward pass: compute gradients and update weights
        loss.backward()
        optimizer.step()

In [None]:
# Plot the training loss curve
plt.plot(losses)

In [None]:
# Save the trained model parameters to disk
torch.save(model.state_dict(), "models/unet.pt")