In [7]:
import torch
from torch import nn
import torchvision.transforms as T
from sklearn.model_selection import train_test_split
from models import dataset, training_and_testing
from models.FastFCN.models import fcn
from metrics_and_losses import metrics
from utils import segmentation_labels, utils

In [8]:
# configuration
model_name = 'fast_fcn_ccncsa'
weights_path = "models/weights/"
dataset_path = "headsegmentation_dataset_ccncsa/"

In [9]:
# defining transforms
tH, tW = 256, 256
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] # from ImageNet
image_transform = T.Compose([T.Resize((tH, tW)), T.Normalize(mean, std)])
target_transform = T.Compose([T.Resize((tH, tW))])

# fetching dataset
n_classes = len(segmentation_labels.labels)
img_paths, label_paths = dataset.get_paths(dataset_path + 'training.xml')
X_train, X_test, Y_train, Y_test = train_test_split(img_paths, label_paths, test_size=0.20, random_state=99, shuffle=True)
train_dataset = dataset.MyDataset(X_train, Y_train, image_transform, target_transform)
test_dataset = dataset.MyDataset(X_test, Y_test, image_transform, target_transform)

In [10]:
# training hyperparameters
device = 'cpu'
batch_size = 64
n_epochs = 50

# model, loss, score function
model = fcn.FCN(nclass=n_classes, backbone="resnet50")
loss_fn = nn.CrossEntropyLoss()
score_fn = metrics.batch_mIoU

# freezing pretrained backbone
for parameter in model.pretrained.parameters():
    parameter.requires_grad = False

# optimizer
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
n_learnable_parameters = utils.count_learnable_parameters(model) # we need weights!!!
print(f'n_learnable_parameters={n_learnable_parameters}')

n_learnable_parameters=29179403


In [None]:
# training
results = training_and_testing.train_model(
    device, model, test_dataset, batch_size, n_epochs, score_fn, loss_fn, optimizer, lr_scheduler=None, evaluate=True, verbose=True)

In [None]:
# plotting training results
training_and_testing.plot_training_results(results, plotsize=(20, 6))

In [None]:
# resetting model, optimizer, learning rate scheduler
final_model = fcn.FCN(nclass=n_classes, backbone="resnet50")
final_optimizer = torch.optim.Adam(final_model.parameters(), lr=learning_rate)

# re-training model on entire training set and saving its weights
final_n_epochs = 25
training_and_testing.train_model(
    device, final_model, train_dataset, batch_size, final_n_epochs, score_fn, loss_fn, final_optimizer, verbose=True)
torch.save(final_model.state_dict(), weights_path + model_name + '.pth')

In [None]:
# testing model on test dataset
batch_mIoU = training_and_testing.test_model(device, final_model, test_dataset, batch_size, score_fn)
print(f'batch_mIoU[test]={batch_mIoU}.')