In [None]:
# === colab configuration ===
# p.s. when training on colab, weights are saved on Drive (directory DSCAS/weights).
# p.p.s. skip this cell if running demo file locally!

! pip install torch-summary
from google.colab import drive
import sys

# setting paths
repository_path = '/content/deep-seasonal-color-analysis-system/'
dataset_path = repository_path + 'headsegmentation_dataset_ccncsa/'
dataset_path_drive = '/content/drive/MyDrive/DSCAS/headsegmentation_dataset_ccncsa/'
weights_path = repository_path + 'models/weights/'
weights_path_drive = '/content/drive/MyDrive/DSCAS/weights/'
sys.path.insert(0, repository_path)

# cloning project repository and downloading dataset
drive.mount('/content/drive')
! test ! -d $repository_path && git clone https://github.com/mrcmich/deep-seasonal-color-analysis-system.git
! test ! -d $dataset_path && cp -R $dataset_path_drive $dataset_path
%cd $repository_path

# setting branch and pulling updates
branch = 'main__fastscnn_hyperparameters_tuning'
! git checkout $branch
! git pull origin $branch

executing_on_colab = True

In [None]:
import torch
from torch import nn, optim
import torchvision.transforms as T
from sklearn.model_selection import train_test_split
from models import dataset, training_and_testing
from models.cloud.UNet import unet
from metrics_and_losses import metrics
from utils import segmentation_labels, utils, custom_transforms
import matplotlib.pyplot as plt
from palette_classification import color_processing
import torchsummary
from models.config import *

try:
  executing_on_colab
except NameError:
  executing_on_colab = False

In [None]:
# local configuration
if executing_on_colab is False:
  weights_path = 'models/weights/'
  dataset_path = ROOT_DIR + 'headsegmentation_dataset_ccncsa/'

In [None]:
# === defining transforms ===

tH, tW = 256, 256
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] # from ImageNet
bilateral_filter = custom_transforms.BilateralFilter(sigma_color=50, sigma_space=100, diameter=7)

image_transform = T.Compose([
    T.Resize((tH, tW)), 
    bilateral_filter,
    T.Normalize(mean, std)])

target_transform = T.Compose([T.Resize((tH, tW))])

# fetching dataset
n_classes = len(segmentation_labels.labels)
img_paths, label_paths = dataset.get_paths(dataset_path, file_name='training.xml')
X_train, X_test, Y_train, Y_test = train_test_split(img_paths, label_paths, test_size=0.20, random_state=99, shuffle=True)
train_dataset = dataset.MyDataset(X_train, Y_train, image_transform, target_transform)
test_dataset = dataset.MyDataset(X_test, Y_test, image_transform, target_transform)

In [None]:
# training hyperparameters
device = 'cuda' if executing_on_colab is True else 'cpu'
batch_size = 32 # <-- UPDATE
n_epochs = 25

# model, loss, score function
model_name = 'unet_ccncsa_tuned'
model = unet.UNet(out_channels=n_classes)
class_weights = torch.tensor(
    [0.3762, 0.9946, 0.9974, 0.9855, 0.7569, 0.9140, 0.9968, 0.9936, 0.9989, 0.9893, 0.9968], device=device)
loss_fn = nn.CrossEntropyLoss(weight=class_weights)
score_fn = metrics.batch_mIoU

# optimizer <-- UPDATE
learning_rate = 0.01
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# scheduler <-- UPDATE
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.05)

# printing model summary
model_summary = torchsummary.summary(model, input_data=(batch_size, 3, tH, tW), batch_dim=None, verbose=0)
print(model_summary)

In [None]:
# training
results = training_and_testing.train_model(
    device, model, train_dataset, batch_size, n_epochs, score_fn, loss_fn, optimizer, scheduler, verbose=True)

# saving model's weights
torch.save(model.state_dict(), weights_path + model_name + '.pth')
if executing_on_colab is True:
  torch.save(model.state_dict(), weights_path_drive + model_name + '.pth')

In [None]:
# plotting training results
training_and_testing.plot_training_results(results, plotsize=(20, 6))

In [None]:
# loading tuned model
try:
    model
except:
    model = unet.UNet(out_channels=n_classes)
    model.load_state_dict(torch.load(weights_path + model_name + '.pth'))

# testing model on test dataset
test_score_fn = metrics.batch_IoU
label_names = list(segmentation_labels.labels.keys())
batch_IoU = training_and_testing.test_model(device, model, test_dataset, batch_size, test_score_fn)
batch_IoU_with_labels = { label: score for label, score in list(zip(label_names, batch_IoU.tolist())) }
batch_mIoU = batch_IoU.mean().item()
for label in batch_IoU_with_labels:
    print(f'batch_IoU_{label}: {batch_IoU_with_labels[label]}')
print(f'batch_mIoU={batch_mIoU}')

In [None]:
# === plotting some examples from test dataset ===

n_examples = 5
random_images = torch.zeros((n_examples, 3, tH, tW))
random_targets = torch.zeros((n_examples, n_classes, tH, tW))

for i in range(n_examples):
    random_idx = torch.randint(high=len(test_dataset), size=(1,))
    random_image, random_target = test_dataset[random_idx]
    random_images[i] = random_image
    random_targets[i] = random_target

with torch.no_grad():
    model.eval()
    random_images = random_images.to(device)
    random_output = model(random_images)[0]

channels_max, _ = torch.max(random_output, axis=1)
random_predictions = (random_output == channels_max.unsqueeze(axis=1)).to('cpu')

for i in range(n_examples):
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.title('Ground Truth')
    plt.imshow(utils.from_DHW_to_HWD(
        color_processing.colorize_segmentation_masks(random_targets[i], segmentation_labels.labels)))
    plt.subplot(1, 2, 2)
    plt.title('Prediction')
    plt.imshow(utils.from_DHW_to_HWD(
        color_processing.colorize_segmentation_masks(random_predictions[i], segmentation_labels.labels)))