In [None]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchmetrics

from unet.unet_model import UNet
from unet.unet_dataset.image_and_masks.image_masks import ImageMasksDataset
from unet.train_unet_utils import collate_fn, train
from unet.loss_functions.focal_loss import FocalLoss

In [None]:
configs = {
    "device": "mps",
    "image_size": [256, 256],
    "in_channels": 3,
    "out_channels": 1,
    "conv_channels": [4, 8, 16],
    "data_dir": "../data/forest_segmentation/",
    "train_image_folder": "train_images",
    "train_mask_folder": "train_masks",
    "test_image_folder": "test_images",
    "test_mask_folder": "test_masks",
    "batch_size": 16,
    "epochs": 10,
    "verbose": True,
}

In [None]:
# Device.
device = configs.get(
    "device", 
    torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
if torch.cuda.is_available() is False and device == torch.device("cuda"):
    device = torch.device("cpu")

verbose = configs.get("verbose", False)
if verbose is True:
    print("Using device: {}.".format(device))

In [None]:
# U-Net.
in_channels = configs.get("in_channels", 3)
out_channels = configs.get("out_channels", 1)
conv_channels = configs.get("conv_channels", [8, 16, 32])

model = UNet(in_channels, out_channels, conv_channels, up_conv_by_resampling=False)
model.to(device)
model.train()

if verbose is True:
    print("U-Net in_channels={}, out_channels={}, conv_channels={}.".format(
        in_channels, out_channels, conv_channels)
    )

In [None]:
# Dataloaders.
data_dir = configs.get("data_dir", None)
train_image_folder = configs.get("train_image_folder", "train_images")
train_mask_folder = configs.get("train_mask_folder", "train_masks")
test_image_folder = configs.get("test_image_folder", "test_images")
test_mask_folder = configs.get("test_mask_folder", "test_masks")

image_size = configs.get("image_size", None)

data_dir = Path(data_dir)

# Train dataset. 
train_ds = ImageMasksDataset(
    data_dir=data_dir, 
    image_folder=train_image_folder, 
    mask_folder=train_mask_folder, 
    train=True,
    image_size=image_size,
)
# Test dataset.
test_ds = ImageMasksDataset(
    data_dir=data_dir, 
    image_folder=test_image_folder, 
    mask_folder=test_mask_folder, 
    train=False,
    image_size=image_size,
)

if verbose is True:
    print("len(train_ds): {}, len(test_ds): {}.".format(
        len(train_ds), len(test_ds)))

In [None]:
idx = 0
img_mask = train_ds[idx]
plt.subplot(1, 2, 1); plt.imshow(img_mask[0].permute(1, 2, 0).numpy())
plt.subplot(1, 2, 2); plt.imshow(img_mask[1].permute(1, 2, 0).numpy(), cmap="gray")
plt.show()

In [None]:
# Create the DataLoaders from the Datasets. 
batch_size = configs.get("batch_size", 4)
n_epochs = configs.get("epochs", 25)

train_dl = torch.utils.data.DataLoader(
    train_ds, 
    batch_size = batch_size, 
    shuffle = True, 
    collate_fn = collate_fn,
)

test_dl = torch.utils.data.DataLoader(
    test_ds, 
    batch_size = batch_size, 
    shuffle = False, 
    collate_fn = collate_fn,
)

if verbose is True:
    print("Batch size: {}, train epochs: {}.".format(batch_size, n_epochs))
    print("len(train_dl): {}, len(test_dl): {}.".format(
        len(train_dl), len(test_dl)))

In [None]:
# Set up optimizer and loss function.
loss_function = torch.nn.BCELoss()
#loss_function = FocalLoss(alpha=0.25, gamma=2)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(
    params, 
    lr=0.001, 
    betas=(0.9, 0.999), 
    eps=1e-08, 
    weight_decay=0,
)
accuracy_metrics = [
    torchmetrics.JaccardIndex(task="binary"),
]

if verbose is True:
    print(loss_function) 
    print(optimizer)
    print(accuracy_metrics)

In [None]:
# Train the model.
model, train_losses, test_losses, train_accuracies, test_accuracies = train(
    model=model, 
    optimizer=optimizer, 
    loss_function=loss_function, 
    n_epochs=n_epochs, 
    train_loader=train_dl, 
    test_loader=test_dl, 
    device=device, 
    verbose=verbose,
    metrics=accuracy_metrics,
)

In [None]:
try:
    epochs = range(1, 1+len(train_losses))
    plt.figure(figsize = [15, 5])
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, epochs, test_losses)
    plt.legend(["Train loss", "Test loss"])
    plt.grid(True); plt.xlabel("Epoch"); plt.xticks(epochs); plt.ylabel("Loss")
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies[:, 0], epochs, test_accuracies[:, 0])
    plt.legend(["Train accuracy", "Test accuracy"])
    plt.grid(True); plt.xlabel("Epoch"); plt.xticks(epochs); plt.ylabel("Accuracy")
    plt.show()
except:
    print("An error occured during model training!")

In [None]:
# Make some example predictions on the test set.
N = 5

model.eval()
model.cpu()

plt.figure(figsize = [5, 10])
for i in range(N):
    img_mask = test_ds[i]
    img = img_mask[0]
    mask = img_mask[1]
    pred = model(img.unsqueeze(0)).detach().cpu().squeeze().round()
    plt.subplot(N, 3, i*3+1); plt.imshow(img.permute(1, 2, 0).numpy()); plt.axis(False)
    plt.subplot(N, 3, i*3+2); plt.imshow(mask.permute(1, 2, 0).numpy()); plt.axis(False)
    plt.subplot(N, 3, i*3+3); plt.imshow(pred); plt.axis(False)
plt.show()

In [None]:
# TODO
# Save the model to disk.

In [None]:
""""
for batch in test_dl:
    break
    
model.eval()
y_pred = model(batch[0].to(device))

ys = batch[1]

for i in range(4):
    plt.figure()
    plt.subplot(2, 2, 1)
    plt.imshow(y_pred[i].detach().cpu().numpy().squeeze())
    plt.subplot(2, 2, 2)
    plt.imshow(ys[i].detach().cpu().numpy().squeeze())
    plt.show()
""";