# Training note
Necessary are all training, test and validation data\
Some code is sourced the following tutorial: https://github.com/qubvel/segmentation_models.pytorch/blob/master/examples/cars%20segmentation%20(camvid).ipynb

In [None]:
#import dependencies 
import os
import json
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch
import operator
from PIL import Image, ImageDraw
import albumentations as albu
import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils as smpu
import random
from sklearn.metrics import precision_recall_curve, roc_curve, auc
#defined use gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
### Cd to working dir

In [None]:
cd '/scratch/ndillenb/notebooks/Hand-Segmentation'

In [None]:
hands_file = '/scratch/ndillenb/notebooks/Hand-Segmentation/Hands_dataset/all_hands_2.json'
hands_img = "all_hands_2"

no_hands_file = '/scratch/ndillenb/notebooks/Hand-Segmentation/Hands_dataset/no_hands_2.json'
no_hands_img = "no_hands_2"
    
x_hands_data = os.path.join('Hands_dataset',hands_file)
y_hands_data = os.path.join('Hands_dataset',hands_img)
x_no_hands_data = os.path.join('Hands_dataset',no_hands_file)
y_no_hands_data = os.path.join('Hands_dataset',no_hands_img)

all_dataset = 'all_dataset'

## Helper functions

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
# helper function for data permutation
def apply_permutation(list1, list2):
    indices = list(range(len(list1)))  # Generate a list of indices
    random.shuffle(indices)  # Shuffle the indices randomly
    
    permuted_list1 = [list1[i] for i in indices]  # Apply permutation to list1
    permuted_list2 = [list2[i] for i in indices]  # Apply permutation to list2
    
    return permuted_list1, permuted_list2
# helper function for loss plotting
def plot_loss(train_loss, valid_loss,num_epochs):
    y1=[]
    y2=[]
    x = [i for i in range(1, num_epochs+1)]
    fig = plt.figure()
    for loss in train_loss:
        y1.append(loss['jaccard_loss']) #need to be change to dice_loss if using dice_loss
    for loss in valid_loss:
        y2.append(loss['jaccard_loss']) #need to be change to dice_loss if using dice_loss
    plt.plot(x,y1,color='r',label='training loss')
    plt.plot(x,y2,color='g',label='validation loss')
    plt.title('Training/Validation loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.show()
def plot_metric(train_loss, valid_loss,num_epochs,metric_type):
    fig = plt.figure()
    y1=[]
    y2=[]
    x = [i for i in range(1, num_epochs+1)]
    for loss in train_loss:
        y1.append(loss[metric_type])
    for loss in valid_loss:
        y2.append(loss[metric_type])
    plt.plot(x,y1,color='r',label='training')
    plt.plot(x,y2,color='g',label='validation')
    #plt.plot()
    plt.title('Training/Validation '+str(metric_type))
    plt.xlabel('epoch')
    plt.ylabel(metric_type)
    plt.legend()
    plt.show()

## Dataloader

In [None]:
#Define Dataset
class Dataset(BaseDataset):
    CLASSES = ['__BACKGROUND__', 'hand']
    def __init__(self, 
        data_img,
        data_file,  
        images_dir,
        classes=None, 
        augmentation=None, 
        preprocessing=None,):
        
        self.root = 'Hands_dataset'
        self.folder = images_dir
        
        self.ids = data_img
        self.images_fps = data_img
        self.masks_fps = data_file
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, self.folder, self.images_fps[idx])
        img_out = cv2.imread(img_path)
        img_out = cv2.cvtColor(img_out, cv2.COLOR_BGR2RGB)
        mask_json = self.masks_fps[idx]
        h,w,t = img_out.shape
        all_polygon = []
        im = Image.new("L", (w,h), 0)
        for shape in mask_json['annotation']:
            polygon = tuple()
            for point in shape['points']:
                new_x = point['x']*w
                new_y = point['y']*h
                point = tuple([new_x,new_y])
                polygon=polygon+point
            all_polygon.append(polygon)
            ImageDraw.Draw(im).polygon(polygon, outline=1,fill=1)
        mask = np.array(im)
        mask = mask[...,np.newaxis]
        mask = np.stack(mask, axis=0).astype('int')
        image = img_out
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        return image, mask

    def __len__(self):
        return len(self.ids)

## Augmentation functions

In [None]:
def get_training_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        #albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        #albu.Affine(scale=(1,0.5)), (USED BEFORE 03.05.2023)
        albu.LongestMaxSize(max_size=160, interpolation=1, always_apply=True, p=1),
        #albu.Resize(480,480,interpolation=1,always_apply=True,p=1)
        #albu.PadIfNeeded(min_height=None,min_width=None,pad_height_divisor=32,pad_width_divisor=32,always_apply=True, border_mode=1)
        albu.PadIfNeeded(min_height=160,min_width=160,always_apply=True, border_mode=0)
        #albu.RandomCrop(height=320, width=320, always_apply=True),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        #albu.geometric.resize.SmallestMaxSize(max_size=1600, interpolation=1, always_apply=True, p=1),
        albu.LongestMaxSize(max_size=160, interpolation=1, always_apply=True, p=1),
        #albu.Resize(480,480,interpolation=1,always_apply=True,p=1)
        #albu.PadIfNeeded(min_height=None,min_width=None,pad_height_divisor=32,pad_width_divisor=32,always_apply=True, border_mode=1)
        albu.PadIfNeeded(min_height=160,min_width=160,always_apply=True, border_mode=0)        #albu.RandomCrop(height=320, width=320, always_apply=True),
        #albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

## Create model and train

In [None]:
ENCODER = 'resnext101_32x8d'
ENCODER_WEIGHTS = 'instagram'
CLASSES = ['hand']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
# create segmentation model with pretrained encoder
model = smp.UnetPlusPlus( #Choice of architecture
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

## Choice in datadistribution

* 0.5 for half illustrations containing our feature, half not containing features
* 0.33 for about 3 times mores illustrations that do not contains features as illustrations containing features
* 0.1 for about 10 times mores illustrations that do not contains features as illustrations containing features

In [None]:
DATA_SPLIT_TRAIN=0.5 #Here we pick as many illustrations containing features as not

In [None]:
hands_data_mask = json.load(open(x_hands_data))
hands_data_mask = list(sorted(hands_data_mask['samples'],key=operator.itemgetter('imageUrl')))
hands_data_img = list(sorted(os.listdir(y_hands_data)))
hands_data_mask, hands_data_img = apply_permutation(hands_data_mask,hands_data_img)

no_hands_data_mask = json.load(open(x_no_hands_data))
no_hands_data_mask = list(sorted(no_hands_data_mask['samples'],key=operator.itemgetter('imageUrl')))
no_hands_data_img = list(sorted(os.listdir(y_no_hands_data)))
no_hands_data_mask, no_hands_data_img = apply_permutation(no_hands_data_mask,no_hands_data_img)

if DATA_SPLIT_TRAIN == 0.5:
    train_mask = hands_data_mask[0:173] + no_hands_data_mask[0:173]
    train_img = hands_data_img[0:173] + no_hands_data_img[0:173]
    train_mask, train_img = apply_permutation(train_mask,train_img)

    valid_mask = hands_data_mask[173:194] + no_hands_data_mask[173:852]
    valid_img = hands_data_img[173:194] + no_hands_data_img[173:852]
    valid_mask, valid_img = apply_permutation(valid_mask,valid_img)

    test_mask = hands_data_mask[194:215] + no_hands_data_mask[852:1531]
    test_img = hands_data_img[194:215] + no_hands_data_img[852:1531]

elif DATA_SPLIT_TRAIN == 0.33:
    train_mask = hands_data_mask[0:173] + no_hands_data_mask[0:346]
    train_img = hands_data_img[0:173] + no_hands_data_img[0:346]
    train_mask, train_img = apply_permutation(train_mask,train_img)

    valid_mask = hands_data_mask[173:194] + no_hands_data_mask[346:1025]
    valid_img = hands_data_img[173:194] + no_hands_data_img[346:1025]
    valid_mask, valid_img = apply_permutation(valid_mask,valid_img)

    test_mask = hands_data_mask[194:215] + no_hands_data_mask[1025:1704]
    test_img = hands_data_img[194:215] + no_hands_data_img[1025:1704]

elif DATA_SPLIT_TRAIN == 0.1:
    train_mask = hands_data_mask[0:173] + no_hands_data_mask[0:1730]
    train_img = hands_data_img[0:173] + no_hands_data_img[0:1730]
    train_mask, train_img = apply_permutation(train_mask,train_img)

    valid_mask = hands_data_mask[173:194] + no_hands_data_mask[1730:2409]
    valid_img = hands_data_img[173:194] + no_hands_data_img[1730:2409]
    valid_mask, valid_img = apply_permutation(valid_mask,valid_img)

    test_mask = hands_data_mask[194:215] + no_hands_data_mask[2409:3088]
    test_img = hands_data_img[194:215] + no_hands_data_img[2409:3088]

test_mask, test_img = apply_permutation(test_mask,test_img)

In [None]:
train_dataset = Dataset(
    train_img, 
    train_mask, 
    all_dataset,
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = Dataset(
    valid_img, 
    valid_mask, 
    all_dataset,
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_dataset = Dataset(
    test_img, 
    test_mask, 
    all_dataset,
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=1)

In [None]:
# Summarize dataset sizes
print('SPLIT = ' + 'TRAIN' +' | ' + 'VALID' + ' | ' +'TEST')
print('SPLIT = ' + str(len(train_dataset)) +' | ' + str(len(valid_dataset)) + ' | ' +str(len(test_dataset)))

## Visualize the data to make sure all went smoothly

In [None]:
num = random.randint(0,len(train_dataset)-1)
image, mask = train_dataset[num] # get some sample
print(image.shape)
print(mask.shape)
visualize(
    image=image.transpose(1, 2, 0).astype('float'), 
    hands_mask=mask.squeeze(),
)

In [None]:
current_tresh = 0.5 #Treshold used for separating positive from negative values

#loss = smp.utils.losses.DiceLoss() #Choose this for a DiceLoss
loss = smp.utils.losses.JaccardLoss() #Choose this for a JaccardLoss

metrics = [ #Metrics we want to keep track of, some are usefull for debugging
    smpu.metrics.IoU(threshold=current_tresh),
    smpu.metrics.Fscore(threshold=current_tresh),
    smpu.metrics.Accuracy(threshold=current_tresh),
    smpu.metrics.Recall(threshold=current_tresh),
    smpu.metrics.Precision(threshold=current_tresh)
]

In [None]:
#Choice of learning rate
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=5.09E-05),])

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader's samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model for X epochs
num_epochs = 6000 #Need to be changed to the max number of epochs
max_score = 0
train_logs = []
valid_logs = []

for i in range(len(train_logs), num_epochs):
    print('\nEpoch: {}'.format(i))
    train_logs.append(train_epoch.run(train_loader))
    valid_logs.append(valid_epoch.run(valid_loader))
    if max_score < valid_logs[i]['iou_score']: #Save model locally if reached IoU highscore
        max_score = valid_logs[i]['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')

## Plotting metrics

In [None]:
num_epochs = len(train_logs)
plot_loss(train_logs,valid_logs,len(train_logs))
plot_metric(train_logs,valid_logs,num_epochs,'iou_score')
plot_metric(train_logs,valid_logs,num_epochs,'fscore')
plot_metric(train_logs,valid_logs,num_epochs,'accuracy')
print('max IoU score : ' + str(max_score))

## Test best saved model

In [None]:
# load best saved checkpoint
best_model = torch.load('./best_model.pth')

In [None]:
# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)
test_epoch.run(test_loader)

## Precision-Recall curve and ROC Plot (Dataset-wise)

In [None]:
#Computing y_pred
labels = []
predictions = []
predictions_arrays = []
labels_array = []
for i in range(len(test_dataset)):
    label = test_dataset[i][1]
    image = test_dataset[i][0]
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    prediction = best_model.predict(x_tensor)
    prediction = (prediction.squeeze().cpu().numpy())
    labels.append(label.squeeze().flatten())
    predictions.append(prediction.squeeze().flatten())
    predictions_arrays.append(prediction)
    labels_array.append(label)
    
labels = np.concatenate(labels,axis=0)
predictions = np.concatenate(predictions,axis=0)

#comparing y_pred and y
precision, recall, thresholds = precision_recall_curve(labels, predictions)
fpr, tpr, tresholds = roc_curve(labels, predictions)
roc_auc = auc(fpr, tpr)

#plotting P-R
plt.figure()
plt.plot(recall, precision)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.grid(True)
num_thresholds = 5  # Number of thresholds to display
lower_threshold = 0.1
upper_threshold = 1
step = max(1, len(thresholds) // num_thresholds)
selected_thresholds = thresholds[(thresholds >= lower_threshold) & (thresholds <= upper_threshold)] #selecting between thresholds
step = max(1, len(selected_thresholds) // num_thresholds)
selected_thresholds = selected_thresholds[::step]
for threshold in selected_thresholds:
    index = np.where(thresholds == threshold)[0][0]
    plt.annotate(f' {threshold:.2f}', (recall[index], precision[index])) #anotating thresholds
plt.show()

#plotting ROC
plt.figure()
plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}')
plt.plot([0, 1], [0, 1], 'k--')  # Diagonal line representing random guessing
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

### Imagewise and dataset-wise metrics

In [None]:
#stacking y and y_pred
labels_array_stacked_tensor = torch.from_numpy(np.stack(labels_array, axis=0)).to(torch.int32)
predictions_arrays_stacked_tensor = torch.from_numpy(np.stack(np.expand_dims(predictions_arrays, axis=1), axis=0))

In [None]:
THRESHOLD = 0.90 #treshold for separating positive from negative values

tp, fp, fn, tn = smp.metrics.get_stats(predictions_arrays_stacked_tensor, labels_array_stacked_tensor, mode='binary', threshold=THRESHOLD)
#List of metrics we wish to compute
#All metrics are available at : https://smp.readthedocs.io/en/latest/metrics.html
iou_score_micro = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
iou_score_micro_iw = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro")
f1_score_micro = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
f1_score_micro_iw = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise")
f2_score_micro = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
accuracy_micro = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro")
accuracy_micro_iw = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro-imagewise")
recall_micro = smp.metrics.recall(tp, fp, fn, tn, reduction="micro")
recall_micro_iw = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
precision_micro = smp.metrics.recall(tp, fp, fn, tn, reduction="micro")
precision_micro_iw = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
fpr_micro = smp.metrics.false_positive_rate(tp, fp, fn, tn, reduction="micro")
fpr_micro_iw = smp.metrics.false_positive_rate(tp, fp, fn, tn, reduction="micro-imagewise")

stats = {'predict_array': predictions_arrays_stacked_tensor, 'labels_array': labels_array_stacked_tensor, 'test_tresh':THRESHOLD, 'tp':tp, 'fp':fp,'fn':fn,'tn':tn,
        'iou_micro': iou_score_micro, 'iou_micro_iw': iou_score_micro_iw, 'f1_score_micro': f1_score_micro, 'f1_score_micro_iw':f1_score_micro_iw,
        'f2_score_micro':f2_score_micro,'accuracy_micro':accuracy_micro,'accuracy_micro_iw':accuracy_micro_iw,'recall_micro':recall_micro,'recall_micro_iw':recall_micro_iw,
        'precision_micro':precision_micro,'precision_micro_iw':precision_micro_iw,'fpr_micro':fpr_micro,'fpr_micro_iw':fpr_micro_iw}


## Saving the results and the models

In [None]:
parent_dir = '/trained_model_saved/'
ARCHITECTURE = 'UnetPlusPlus' #Architecture used (too keep track)
LOSS = 'JaccardLoss' #Loss used (too keep track)
DATASET_NAME = 'half_split_test_val_distribution(50-3-3)-(80-10-10)' #How we named our datasplit (too keep track of)
NEW_FOLDER = 'batch_16_size160_epoch1398_lr_5.09E-05' #Comment (should be unique if trained with same settings)
path = os.path.join(parent_dir, ARCHITECTURE, ENCODER, ENCODER_WEIGHTS, LOSS, DATASET_NAME, NEW_FOLDER)
os.makedirs(path, exist_ok=True) #create folders
torch.save(model, os.path.join(path,'model.pth')) #save last models, if training needs to be resumed
torch.save(best_model, os.path.join(path,'best_model.pth')) #save best model