## U-Net

Module imports

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch import nn, manual_seed, optim, no_grad, unsqueeze
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

Dataset import class

In [None]:
# Dataset class definition.
class DatasetClass(Dataset):
    
    # Parameters: sample and ground truth image directories and transform condition.
    def __init__(self, X_dir, y_dir, transform=None):
        # Import sample and ground truch images.
        X = [X_dir + '/' + fname for fname in os.listdir(X_dir) if fname.endswith('.jpg')]
        y = [y_dir + '/' + fname for fname in os.listdir(y_dir) if fname.endswith('.jpg')]
        # Sort images.
        X.sort()
        y.sort()
        # X, y and transform instances.
        self.X, self.y, self.transform = X, y, transform

    # Built-in len method.
    def __len__(self):
        return len(self.X)

    # Built-in method used for list indexing. Get sample and ground truth images.
    def __getitem__(self, idx):
        X_image, y_image = Image.open(self.X[idx]), Image.open(self.y[idx])
        # When transform parameter is truem apply transformation.
        if self.transform:
            X_image, y_image = self.transform(X_image), self.transform(y_image)
        # Return images.
        return X_image, y_image
    
# Set the seed for generating random numbers.
manual_seed(0)

# Set number of epochs, batch size, and learning rate.
n_epochs, batch_size, learning_rate = 3, 16, 0.001

# Set image transformations
transform = transforms.Compose(
    [transforms.Grayscale(num_output_channels=1), transforms.ToTensor()]
)

# Create train and test dataloader class instances.
dataloader = {
    "train": DataLoader(
        DatasetClass(
            "./dataset/train/sat",
            "./dataset/train/gt",
            transform=transform
        ),
        batch_size=batch_size,
        shuffle=True,
    ),
    "test": DataLoader(
        DatasetClass(
            "./dataset/test/sat",
            "./dataset/test/sat",
            transform=transform,
        ),
        batch_size=batch_size,
        shuffle=True,
    ),
}

#### U-Net class
<img src="unet_diagram.png" alt="image" width="400px" height="auto">

In [None]:
# Convolution layer method.
def conv(in_channels, out_channels, kernel_size=3, padding=1):
    return nn.Sequential(
        # 2D Convolution.
        nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
        # 2D Batch normalizacion.
        nn.BatchNorm2d(out_channels),
        # ReLU activation function.
        nn.ReLU(inplace=True),
    )

# U-Net class definition.
class UNet(nn.Module):
    
    # Parameters: input and output channels.
    def __init__(self, in_channels=1, out_channels=1):
        # Inherit parent class functionality
        super(UNet, self).__init__()
        # Number of channels
        c = [16, 32, 64, 128]
        
        # Encoder sequence
        self.encoder = nn.Sequential(
            conv(in_channels, c[0]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[0], c[1]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[1], c[2]),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv(c[2], c[3]),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # Decoder sequence
        self.decoder = nn.Sequential(
            conv(c[3], c[2]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[2], c[1]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[1], c[0]),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            conv(c[0], out_channels),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
        )

    # Apply encoder and decoder sequence and sigmoid function.
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = nn.Sigmoid()(x)
        return x

# Create UNet instance
model = UNet()

In [None]:
# Create mean square error loss function instance.
loss_function = nn.MSELoss()

# Create adam optimizer instance.
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def train(network, optimizer, epoch, log_interval=3, device='cpu'):
    network.train()
    loss_logs = list()
    num_batches = len(dataloader["train"])
    size = len(dataloader["train"].dataset)
    
    for batch_idx, (data, target) in enumerate(dataloader["train"]):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        pred = network(data)
        loss = loss_function(pred, target)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            examples_shown = batch_idx * len(data)
            total_examples = num_batches
            fraction_shown = round(examples_shown * 100 / (total_examples * batch_size), 2)
            rounded_loss = round(loss.item(), 4)
            examples_so_far = batch_idx * batch_size + epoch * size
            loss_logs.append((examples_so_far, loss.item()))
            print(
                f"Train Epoch {epoch} Progress: {fraction_shown}%\tLoss: {rounded_loss}"
            )
    return loss_logs

In [None]:
def test(network, device='cpu'):
    network.eval()
    test_loss, correct, size = 0, 0, len(dataloader['test'].dataset)
    with no_grad():
        for data, target in tqdm(dataloader["test"]):
            data, target = data.to(device), target.to(device)
            output = network(data)
            test_loss += loss_function(output, target).item()
            pred = output >= 0.5
            correct += pred.eq(target.view_as(pred)).sum() / (256 * 256)

    test_loss /= size
    accuracy = round(100.0 * (correct.item() / size), 2)
    print(f"\nTest set: Avg loss {round(test_loss, 4)}, Accuracy {accuracy}%\n")
    return test_loss, accuracy

In [None]:
test_accuracies, test_losses, train_losses = list(), list(), list()
total_examples_seen = 0

for epoch in range(n_epochs):
  test_loss, test_accuracy = test(model)
  train_loss_logs = train(model, optimizer, epoch)
  test_accuracies.append((epoch, test_accuracy))
  test_losses.append((total_examples_seen, test_loss))
  train_losses.extend((train_loss_logs))
  total_examples_seen = train_loss_logs[-1][0]

test_loss, test_accuracy = test(model)
test_accuracies.append((n_epochs, test_accuracy))
test_losses.append((total_examples_seen, test_loss))

In [None]:
def preview_images(input, output, cmap="gray"):
        to_pil = transforms.ToPILImage()
        X, y = to_pil(input), to_pil(output)
        _, axes = plt.subplots(1, 2)
        axes[0].imshow(X, cmap=cmap)
        axes[0].axis("off")
        axes[1].imshow(y, cmap=cmap)
        axes[1].axis("off")
        plt.show()
        
def preview_prediction(model, input_tensor):
        output_tensor = model(unsqueeze(input_tensor, dim=0))[0]
        print("Input:", input_tensor.shape, "Output:", output_tensor.shape)
        preview_images(input_tensor, output_tensor)