In [None]:
import torch
from sklearn.utils import shuffle
from Utils import *
from data.dataloaders.dataloader import get_dataloader as gd
from models.model import get_model as gm
from models import train, validate, test

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = "cpu"

In [None]:
device

Read in images as BGR

In [None]:
gt_dir = "./data/images/gt"
spill_dir = "./data/images/spill"


In [None]:
gt_images = read_in_images_simple(directory=gt_dir)

In [None]:
spill_images = read_in_images_simple(directory=spill_dir)


In [None]:
print(f"Number of ground truth images: {len(gt_images)}")

In [None]:
print(f"Number of spill images: {len(spill_images)}")


Shuffle

In [None]:
gt_images, spill_images = shuffle(gt_images, spill_images)


Split into Train-Validate-Test

In [None]:
ds_len = len(gt_images)
train_end_index = int(ds_len * 0.6)
val_end_index = int(ds_len * 0.8)

train_gt = gt_images[:train_end_index]
val_gt = gt_images[train_end_index:val_end_index]
test_gt = gt_images[val_end_index:]

train_spill = spill_images[:train_end_index]
val_spill = spill_images[train_end_index:val_end_index]
test_spill = spill_images[val_end_index:]



In [None]:
print(f" Number of gt & spill training images: {len(train_gt)} & {len(train_spill)}")
print(f" Number of gt & spill validation images: {len(val_gt)} & {len(val_spill)}")
print(f" Number of gt & spill test images: {len(test_gt)} & {len(test_spill)}")


Create dataset

In [None]:
use_extra_channels=False

In [None]:
train_dataloader, num_train_batches = gd(spill_images=train_spill, gt_images=train_gt,
                                         batch_size=32, use_extra_channels=use_extra_channels)
val_dataloader, num_val_batches = gd(spill_images=val_spill, gt_images=val_gt,
                                     batch_size=32, use_extra_channels=use_extra_channels)
test_dataloader, num_test_batches = gd(spill_images=test_spill, gt_images=test_gt,
                                       batch_size=32, use_extra_channels=use_extra_channels)

In [None]:
print(f"Number of training images per batch: {len(train_dataloader.dataset) // num_train_batches}")
print(f"Number of validation images per batch: {len(val_dataloader.dataset) // num_val_batches}")
print(f"Number of test images per batch: {len(test_dataloader.dataset) // num_test_batches}")

Get model

In [None]:
model = gm(device=device, in_channels=5 if use_extra_channels else 3, out_channels=3).to(device)

Run training and validation

In [None]:
checkpoint_save_path = "./models/checkpoints/spill_model_5ch.pth" if use_extra_channels else "./models/checkpoints/spill_model_3ch.pth"


In [None]:
epochs = 3
for epoch in range(epochs):
    print(f"{epoch=}")
    train.train(model=model, train_loader=train_dataloader, device=device)
    validate.validate(model=model, valid_loader=val_dataloader, device=device)
torch.save(model.state_dict(), checkpoint_save_path)

Run testing

In [None]:
test.test(test_loader=test_dataloader, model=model, device=device, checkpoint_path=checkpoint_save_path)

Visuals

In [None]:
plot_results()