## U-Net

Libraries imports

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch import nn, manual_seed, optim, no_grad, unsqueeze
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

Dataset import class

In [None]:
# Dataset class definition.
class DatasetClass(Dataset):
    
    # Parameters: sample and ground truth image directories and transform condition.
    def __init__(self, X_dir, y_dir, transform=None):
        # Import sample and ground truch images.
        X = [X_dir + '/' + fname for fname in os.listdir(X_dir) if fname.endswith('.jpg')]
        y = [y_dir + '/' + fname for fname in os.listdir(y_dir) if fname.endswith('.jpg')]
        # Sort images.
        X.sort()
        y.sort()
        # X, y and transform instances.
        self.X, self.y, self.transform = X, y, transform

    # Built-in len method.
    def __len__(self):
        return len(self.X)

    # Built-in method used for list indexing. Get sample and ground truth images.
    def __getitem__(self, idx):
        X_image, y_image = Image.open(self.X[idx]), Image.open(self.y[idx])
        # When transform parameter is truem apply transformation.
        if self.transform:
            X_image, y_image = self.transform(X_image), self.transform(y_image)
        # Return images.
        return X_image, y_image
    
# Set the seed for generating random numbers.
manual_seed(0)

# Set number of epochs, batch size, and learning rate.
n_epochs, batch_size, learning_rate = 3, 16, 0.001

# Set image transformations
transform = transforms.Compose(
    [transforms.Grayscale(num_output_channels=1), transforms.ToTensor()]
)

# Create train and test dataloader class instances.
dataloader = {
    "train": DataLoader(
        DatasetClass(
            "./dataset/train/sat",
            "./dataset/train/gt",
            transform=transform
        ),
        batch_size=batch_size,
        shuffle=True,
    ),
    "test": DataLoader(
        DatasetClass(
            "./dataset/test/sat",
            "./dataset/test/sat",
            transform=transform,
        ),
        batch_size=batch_size,
        shuffle=True,
    ),
}

#### U-Net class

In [None]:
# Convolution layer method.
def conv(in_channels, out_channels, kernel_size=3, padding=1):
    return nn.Sequential(
        # 2D Convolution.
        nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
        # 2D Batch normalization.
        nn.BatchNorm2d(out_channels),
        # ReLU activation function.
        nn.ReLU(inplace=True),
    )

# U-Net class definition.
class UNet(nn.Module):
    
    # Parameters: input and output channels.
    def __init__(self, in_channels=1, out_channels=1):
        # Inherit parent class functionality
        super(UNet, self).__init__()
        # Number of channels
        c = [16, 32, 64, 128]
        
        # Encoder sequence
        self.encoder = nn.Sequential(
            conv(in_channels, c[0]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[0], c[1]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[1], c[2]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[2], c[3]),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # Decoder sequence
        self.decoder = nn.Sequential(
            conv(c[3], c[2]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[2], c[1]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[1], c[0]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[0], out_channels),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
        )

    # Apply encoder and decoder sequence and sigmoid function.
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = nn.Sigmoid()(x)
        return x

# Create UNet instance
model = UNet()

#### Model training

In [None]:
# Create mean square error loss function instance.
loss_function = nn.MSELoss()

# Create adam optimizer instance.
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Model training function
# Parameters: model, optimizer instance, number of epochs, 
# log interval and device type
def train(network, optimizer, epoch, log_interval=3, device='cpu'):
    # Set the network to training mode
    network.train()
    # List to store loss logs
    loss_logs = list()
    # Number of batches in the training dataset
    num_batches = len(dataloader["train"])
    # Total size of the training dataset
    size = len(dataloader["train"].dataset)
    
    # Iterate through each batch in the training dataset
    for batch_idx, (data, target) in enumerate(dataloader["train"]):
        # Move data and target to the specified device (CPU or GPU)
        data, target = data.to(device), target.to(device)
        # Reset optimizer gradients
        optimizer.zero_grad()
        # Forward pass: compute predicted outputs by passing inputs to the network
        pred = network(data)
        # Compute the loss
        loss = loss_function(pred, target)
        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # Perform a single optimization step (parameter update)
        optimizer.step()

        # Log training progress
        if batch_idx % log_interval == 0:
            examples_shown = batch_idx * len(data)
            total_examples = num_batches
             # Calculate the fraction of examples shown in the current epoch
            fraction_shown = round(examples_shown * 100 / (total_examples * batch_size), 2)
            # Round the loss to 4 decimal places
            rounded_loss = round(loss.item(), 4)
            # Calculate the total number of examples seen so far
            examples_so_far = batch_idx * batch_size + epoch * size
            # Append the current loss to the loss logs
            loss_logs.append((examples_so_far, loss.item()))
            # Print training progress
            print(
                f"Train Epoch {epoch} Progress: {fraction_shown}%\tLoss: {rounded_loss}"
            )
    # Return loss logs
    return loss_logs

#### Model testing

In [None]:
# Model testing function
# Parameters: model, device type
def test(network, device='cpu'):
    # Set the network to evaluation mode
    network.eval()
    # Initialize variables for test loss, correct predictions,
    # and total size of the test dataset
    test_loss, correct, size = 0, 0, len(dataloader['test'].dataset)
    
    # Disable gradient calculation
    with no_grad():
        # Iterate through each batch in the test dataset
        for data, target in tqdm(dataloader["test"]):
            # Move data and target to the specified device (CPU or GPU)
            data, target = data.to(device), target.to(device)
            # Forward pass: compute predicted outputs by passing inputs to the network
            output = network(data)
            # Compute the test loss
            test_loss += loss_function(output, target).item()
            # Compute predictions based on output probability (using a threshold of 0.5)
            pred = output >= 0.5
            # Compute the number of correct predictions and update the correct variable
            correct += pred.eq(target.view_as(pred)).sum() / (256 * 256)

    # Calculate average test loss
    test_loss /= size
    # Calculate accuracy
    accuracy = round(100.0 * (correct.item() / size), 2)
    # Print test results
    print(f"\nTest set: Avg loss {round(test_loss, 4)}, Accuracy {accuracy}%\n")
    # Return test loss and accuracy
    return test_loss, accuracy

In [None]:
# Lists to store test accuracies, test losses, and training losses
test_accuracies, test_losses, train_losses = list(), list(), list()
# Variable to keep track of the total number of examples seen during training
total_examples_seen = 0

# Loop through each epoch
for epoch in range(n_epochs):
    # Perform testing on the model and retrieve test loss and accuracy
    test_loss, test_accuracy = test(model)
    # Perform training on the model for the current epoch and get the training loss logs
    train_loss_logs = train(model, optimizer, epoch)
    # Append the epoch and test accuracy to the test accuracy list
    test_accuracies.append((epoch, test_accuracy))
    # Append the total number of examples seen and test loss to the test loss list
    test_losses.append((total_examples_seen, test_loss))
    # Extend the training loss logs to the training loss list
    train_losses.extend((train_loss_logs))
    # Update the total number of examples seen with the examples seen in the current epoch
    total_examples_seen = train_loss_logs[-1][0]

# Perform testing on the model after all epochs and retrieve test loss and accuracy
test_loss, test_accuracy = test(model)
# Append the final epoch and test accuracy to the test accuracy list
test_accuracies.append((n_epochs, test_accuracy))
# Append the total number of examples seen and test loss to the test loss list
test_losses.append((total_examples_seen, test_loss))

#### Prediction preview

In [None]:
# Function to preview input and output images side by side
# Parameters: input tensor, output tensor, colormap for visualization (default is grayscale)
def preview_images(input, output, cmap="gray"):
    # Convert tensors to PIL images
    to_pil = transforms.ToPILImage()
    X, y = to_pil(input), to_pil(output)
    
    # Create subplots for displaying input and output images
    _, axes = plt.subplots(1, 2)
    # Display input image
    axes[0].imshow(X, cmap=cmap)
    axes[0].axis("off")
    # Display output image
    axes[1].imshow(y, cmap=cmap)
    axes[1].axis("off")
    
    # Show the plot
    plt.show()

# Function to preview model predictions on a single input tensor
# Parameters: model (network), input_tensor
def preview_prediction(model, input_tensor):
    # Generate output tensor using the model
    output_tensor = model(torch.unsqueeze(input_tensor, dim=0))[0]
    # Print shapes of input and output tensors
    print("Input:", input_tensor.shape, "Output:", output_tensor.shape)
    # Preview input and output images
    preview_images(input_tensor, output_tensor)
