In [None]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt

from unet.unet_model import UNet
from unet.unet_dataset.image_and_masks.image_masks import ImageMasksDataset
from unet.train_unet_utils import collate_fn, train

In [None]:
configs = {
    "device": "mps",
    "in_channels": 3,
    "out_channels": 1,
    "conv_channels": [8, 16, 32],
    "data_dir": "../data/water_bodies_segmentation/",
    "train_image_folder": "train_images",
    "train_mask_folder": "train_masks",
    "test_image_folder": "test_images",
    "test_mask_folder": "test_masks",
    "batch_size": 4,
    "epochs": 25,
    "verbose": True,
}

In [None]:
# Device.
device = configs.get(
    "device", 
    torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
if torch.cuda.is_available() is False and device == torch.device("cuda"):
    device = torch.device("cpu")

verbose = configs.get("verbose", False)
if verbose is True:
    print("Using device: {}.".format(device))

In [None]:
# U-Net.
in_channels = configs.get("in_channels", 3)
out_channels = configs.get("out_channels", 1)
conv_channels = configs.get("conv_channels", [64, 128, 256, 512, 1024])

model = UNet(in_channels, out_channels, conv_channels)
model.to(device)
model.train()

if verbose is True:
    print("U-Net in_channels={}, out_channels={}, conv_channels={}.".format(
        in_channels, out_channels, conv_channels)
    )

In [None]:
# Dataloaders.
data_dir = configs.get("data_dir", None)
train_image_folder = configs.get("train_image_folder", "train_images")
train_mask_folder = configs.get("train_mask_folder", "train_masks")
test_image_folder = configs.get("test_image_folder", "test_images")
test_mask_folder = configs.get("test_mask_folder", "test_masks")

data_dir = Path(data_dir)

# Train dataset. 
train_ds = ImageMasksDataset(
    data_dir=data_dir, image_folder=train_image_folder, mask_folder=train_mask_folder, train=True
)
# Test dataset.
test_ds = ImageMasksDataset(
    data_dir=data_dir, image_folder=test_image_folder, mask_folder=test_mask_folder, train=False
)

if verbose is True:
    print("len(train_ds): {}, len(test_ds): {}.".format(len(train_ds), len(test_ds)))

In [None]:
# Create the DataLoaders from the Datasets. 
batch_size = configs.get("batch_size", 4)
n_epochs = configs.get("epochs", 25)

train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size = batch_size, shuffle = False, collate_fn = collate_fn,
)

test_dl = torch.utils.data.DataLoader(
    test_ds, batch_size = batch_size, shuffle = False, collate_fn = collate_fn,
)

if verbose is True:
    print("len(train_dl): {}, len(test_dl): {}.".format(len(train_dl), len(test_dl)))

In [None]:
# Set up optimizer and loss function.
loss_function = torch.nn.BCELoss()

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

if verbose is True:
    print(loss_function) 
    print(optimizer)

In [None]:
# Train the model.
model, train_losses, test_losses = train(model, optimizer, loss_function, n_epochs, train_dl, test_dl, device, verbose)

In [None]:
epochs = range(1, 1+len(train_losses))
plt.plot(epochs, train_losses)
plt.plot(epochs, test_losses)
plt.legend(["Train loss", "Test loss"])
plt.grid(True)
plt.xlabel("Epoch")
plt.xticks(epochs)
plt.ylabel("Loss")
plt.show()

In [None]:
# TODO
# Save the model to disk.

In [None]:
""""
for batch in test_dl:
    break
    
model.eval()
y_pred = model(batch[0].to(device))

ys = batch[1]

for i in range(4):
    plt.figure()
    plt.subplot(2, 2, 1)
    plt.imshow(y_pred[i].detach().cpu().numpy().squeeze())
    plt.subplot(2, 2, 2)
    plt.imshow(ys[i].detach().cpu().numpy().squeeze())
    plt.show()
""";