In [None]:
import torch
from torch import nn
import torchvision.transforms as T
from sklearn.model_selection import train_test_split
from models import dataset, training_and_testing
from models.FastSCNN.models import fast_scnn
from metrics_and_losses import metrics
from utils import segmentation_labels

In [None]:
# configuration
model_name = 'fast_scnn_ccncsa'
weights_path = "models/weights/"
dataset_path = "headsegmentation_dataset_ccncsa/"

In [None]:
# defining transforms
tH, tW = 256, 256
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] # from ImageNet
image_transform = T.Compose([T.Resize((tH, tW)), T.Normalize(mean, std)])
target_transform = T.Compose([T.Resize((tH, tW))])

# fetching dataset
n_classes = len(segmentation_labels.labels)
img_paths, label_paths = dataset.get_paths(dataset_path + 'training.xml')
X_train, X_test, Y_train, Y_test = train_test_split(img_paths, label_paths, test_size=0.20, random_state=99, shuffle=True)
train_dataset = dataset.MyDataset(X_train, Y_train, image_transform, target_transform)
test_dataset = dataset.MyDataset(X_test, Y_test, image_transform, target_transform)

In [None]:
# training hyperparameters
device = 'cpu'
batch_size = 64
n_epochs = 15

# model, loss, score function
model = fast_scnn.FastSCNN(n_classes)
loss_fn = nn.CrossEntropyLoss()
score_fn = metrics.batch_mIoU

# optimizer
learning_rate = 0.01
momentum = 0.8
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

# learning rate scheduler
step_size = 5
gamma = 0.1
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
# training
results = training_and_testing.train_model(
    device, model, train_dataset, batch_size, n_epochs, score_fn, loss_fn, optimizer, lr_scheduler, evaluate=True, verbose=True)

In [None]:
# plotting training results
training_and_testing.plot_training_results(results, plotsize=(20, 6))

In [None]:
# resetting model, optimizer, learning rate scheduler
final_model = fast_scnn.FastSCNN(n_classes)
final_optimizer = torch.optim.SGD(final_model.parameters(), lr=learning_rate, momentum=momentum)
final_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# re-training model on entire training set
final_n_epochs = n_epochs
training_and_testing.train_model(
    device, final_model, train_dataset, batch_size, final_n_epochs, score_fn, loss_fn, final_optimizer, final_lr_scheduler)

# saving model parameters and testing model on test dataset
torch.save(final_model.state_dict(), weights_path + model_name + '.pth')
batch_mIoU = training_and_testing.test_model(device, final_model, test_dataset, batch_size, score_fn)
print(f'batch_mIoU[test]={batch_mIoU}.')