In [1]:
from pathlib import Path
import torch


from unet.unet_model import UNet
from unet.unet_dataset.image_and_masks.image_masks import ImageMasksDataset
from unet.train_unet_utils import collate_fn, train

from unet.train_unet_utils import unbatch, train_batch, validate_batch

In [2]:
configs = {
    "device": "mps",
    "in_channels": 3,
    "out_channels": 1,
    "conv_channels": [8, 16, 32],
    "data_dir": "../data/water_bodies_segmentation/",
    "train_image_folder": "train_images",
    "train_mask_folder": "train_masks",
    "test_image_folder": "test_images",
    "test_mask_folder": "test_masks",
    "batch_size": 4,
    "epochs": 20,
}

In [3]:
# Device.
device = configs.get(
    "device", 
    torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
if torch.cuda.is_available() is False and device == torch.device("cuda"):
    device = torch.device("cpu")
print("Using device: {}.".format(device))

Using device: mps.


In [4]:
# U-Net.
in_channels = configs.get("in_channels", 3)
out_channels = configs.get("out_channels", 1)
conv_channels = configs.get("conv_channels", [64, 128, 256, 512, 1024])

model = UNet(in_channels, out_channels, conv_channels)
model.to(device)
model.train()

print("U-Net in_channels={}, out_channels={}, conv_channels={}.".format(
    in_channels, out_channels, conv_channels)
)

U-Net in_channels=3, out_channels=1, conv_channels=[8, 16, 32].


In [5]:
# Dataloaders.
data_dir = configs.get("data_dir", None)
train_image_folder = configs.get("train_image_folder", "train_images")
train_mask_folder = configs.get("train_mask_folder", "train_masks")
test_image_folder = configs.get("test_image_folder", "test_images")
test_mask_folder = configs.get("test_mask_folder", "test_masks")

data_dir = Path(data_dir)

# Train dataset. 
train_ds = ImageMasksDataset(
    data_dir=data_dir, image_folder=train_image_folder, mask_folder=train_mask_folder, train=True
)
# Test dataset.
test_ds = ImageMasksDataset(
    data_dir=data_dir, image_folder=test_image_folder, mask_folder=test_mask_folder, train=False
)

print(len(train_ds), len(test_ds))

300 68


In [6]:
# Create the DataLoaders from the Datasets. 
batch_size = configs.get("batch_size", 4)
n_epochs = configs.get("epochs", 25)

train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size = batch_size, shuffle = False, collate_fn = collate_fn,
)

test_dl = torch.utils.data.DataLoader(
    test_ds, batch_size = batch_size, shuffle = False, collate_fn = collate_fn,
)

print(len(train_dl), len(test_dl))

75 17


In [None]:
# Set up optimizer and loss function.
loss_function = torch.nn.BCELoss()

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

print(loss_function) 
print(optimizer)

BCELoss()
SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    fused: None
    lr: 0.005
    maximize: False
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0005
)


In [None]:
# Train the model.
model, train_losses, test_losses = train(model, optimizer, loss_function, n_epochs, train_dl, test_dl, device, True)

Training ###########################################################################
Testing  #################
Epoch 0. Train loss: 0.42278608679771423. Test loss: 0.2297869622707367.
Training ###########################################################################
Testing  #################
Epoch 1. Train loss: 0.20809845626354218. Test loss: 0.1546270102262497.
Training ###########################################################################
Testing  #################
Epoch 2. Train loss: 0.16863712668418884. Test loss: 0.1314755529165268.
Training ###########################################################################
Testing  #################
Epoch 3. Train loss: 0.1520172506570816. Test loss: 0.11981411278247833.
Training #######################################################################

In [None]:
# TODO
# Save the model to disk.

In [None]:
import matplotlib.pyplot as plt
""""
for batch in test_dl:
    break
    
model.eval()
y_pred = model(batch[0].to(device))

ys = batch[1]

for i in range(4):
    plt.figure()
    plt.subplot(2, 2, 1)
    plt.imshow(y_pred[i].detach().cpu().numpy().squeeze())
    plt.subplot(2, 2, 2)
    plt.imshow(ys[i].detach().cpu().numpy().squeeze())
    plt.show()
""";