In [None]:
#Because this was done in GoogleColab, mounting was a neccesity.
from google.colab import drive
drive.mount('/content/gdrive')


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


Installation of the different libraries.

In [None]:
!pip install datasets
!pip install albumentations
!pip install torchmetrics
!pip install transformers
!pip install torch torchvision albumentations
!pip install wandb
!pip install transformers[torch] accelerate
!pip install 'transformers[torch]' -U
!pip uninstall accelerate
!pip install accelerate
!pip install transform


Importing all the libraries needed to train the model and prepare the images and dataset for it.

In [None]:
import glob
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
from torchmetrics import JaccardIn dex, Precision, Recall, F1Score
from torch.nn.functional import interpolate
import wandb
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback, SegformerConfig, SegformerForSemanticSegmentation, TrainerCallback, TrainerControl, TrainerState, Trainer
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader


In [None]:
#function to load images in RGB.
def load_image(path):
    return Image.open(path).convert('RGB')

#function to assign a color channel to a label either 1 or 2. 0 is background and is assigned to black color channel.
def create_segmentation_map(mask):
    segmap = np.zeros(mask.shape[:2], dtype=np.uint8)  #Background
    segmap[mask[:, :, 1] == 255] = 1  #LOW_1_3
    segmap[mask[:, :, 2] == 255] = 2  #HIGH_4_5
    return segmap


#the class CustomDataset is from Pytorch library and is used to create the dataset for training the model.
#sourcecode: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
#CustomDataset composes a dataset of images and masks.
class CustomDataset(Dataset):
    def __init__(self, image_paths, label_paths):
        self.image_paths = image_paths
        self.label_paths = label_paths
        #Data Augmentation is defined for the images here. Albumentations library is used for this. As described in the report, resizing, horizontal-, and verticalflip, randombrightness and normalization is defined.
        self.transform = A.Compose([
            A.Resize(256, 256),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            A.VerticalFlip(p=0.5),
            #converting the image into tensor format because the model only takes images converted to tensors as input.
            ToTensorV2(),
        ])
        #Data Augmentation is defined for the masks as well.
        self.mask_transform = A.Compose([
            A.Resize(256, 256),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            #converting the mask into tensor format because the model only takes masks converted to tensors as input.
            ToTensorV2(),
        ])
#returns number of samples in dataset
    def __len__(self):
        return len(self.image_paths)

#Loads and return a single sample (image and the corresponding mask) from the dataset at the specified index, 'idx'. create_segmentation_map is also applied here such that each color channel is applied to a label.
    def __getitem__(self, idx):
        image = np.array(load_image(self.image_paths[idx]))
        mask = np.array(load_image(self.label_paths[idx]))
        mask = create_segmentation_map(mask)

        #augmentation is applied image
        transformed_image = self.transform(image=image)
        image = transformed_image['image']

        #augmentation is applied to mask
        transformed_mask = self.mask_transform(image=mask)
        mask = transformed_mask['image']

        #Kept getting a bug of an extra channel dimension, so this will be removed by using squeeze.
        mask = mask.squeeze(0)  # Remove channel dimension if present

        return {"pixel_values": image, "labels": mask.long()}


#this class is used during training to load the batch samples to the model.
#sourcecode: https://huggingface.co/docs/transformers/main_classes/data_collator
class CustomDataCollator:
    def __call__(self, batch):
      #transformed images (tensors) is used from the function right above 'getitem'. These images are tensors are stacked using the torch.stack function from Pytorch library. Please note it will be loaded as batches to the model.
        pixel_values = torch.stack([item['pixel_values'] for item in batch])
        #transformed masks (tensors) is used from the function right above 'getitem'. These masks are tensors are stacked using the torch.stack function from Pytorch library. Please note it will be loaded as batches to the model.
        labels = torch.stack([item['labels'] for item in batch])
        return {'pixel_values': pixel_values, 'labels': labels}


In [None]:
#sourcecode: https://lightning.ai/docs/torchmetrics/stable/classification/f1_score.html, https://lightning.ai/docs/torchmetrics/stable/classification/jaccard_index.html, https://lightning.ai/docs/torchmetrics/stable/classification/precision.html, https://lightning.ai/docs/torchmetrics/stable/classification/recall.html
#The metrics are calculated using torch metrics.
#For each metric, number of classes is defined, task is multiclass because there are 3 number of labels, background is ignored during training.
jaccard_index = JaccardIndex(num_classes=3, task="multiclass", ignore_index=0)
precision_metric = Precision(num_classes=3,task="multiclass", ignore_index=0)
recall_metric = Recall(num_classes=3,task="multiclass", ignore_index=0)
f1_metric = F1Score(num_classes=3, task="multiclass", ignore_index=0)

#sourcecode: https://www.kaggle.com/code/italyforever/drone-images-segmentation
#Compute metrics is used to calculate the evaluation metrics. The evaluation is done by using logits and labels.
def compute_metrics(eval_pred):
    logits, labels = eval_pred

    #'isinstance' is used to check if the logits and labels are in tensor form or not.
    logits = torch.tensor(logits) if not isinstance(logits, torch.Tensor) else logits
    labels = torch.tensor(labels, dtype=torch.long) if not isinstance(labels, torch.Tensor) else labels
    #argmax returns the probability of the image pixels belongs to the classes. a tensor is returned,  where each value represents the most likely class for the corresponding pixel.
    preds = torch.argmax(logits, dim=1)
    #Because the torch.argmax returns a tensor with dimensions (batch_size, num_classes, height, width), a dimension is added by unsqeezing the tensor; (batch_size, 1, height, width) which is neccesary in order to be able to
    #interpolate the tensor. Interpolation of the tensor is done such that it matches the label tensors. Using nearestneighbor the interpolation/reszing is done.
    preds = torch.nn.functional.interpolate(preds.unsqueeze(1).float(), size=labels.shape[-2:], mode='nearest').squeeze(1)

    #resetting all the metrics to ensure that every run starts from scratch.
    jaccard_index.reset()
    precision_metric.reset()
    recall_metric.reset()
    f1_metric.reset()

    #metrics gets updated during evaluation
    jaccard_index.update(preds, labels)
    precision_metric.update(preds, labels)
    recall_metric.update(preds, labels)
    f1_metric.update(preds, labels)

    #final metrics computed for evaluation.
    ious = jaccard_index.compute()
    precision = precision_metric.compute()
    recall = recall_metric.compute()
    f1 = f1_metric.compute()

    return {
        "jaccard_index/overall": ious.mean().item(),
        "jaccard_index/class_0": ious[0].item(),
        "jaccard_index/class_1": ious[1].item(),
        "jaccard_index/class_2": ious[2].item(),
        "precision/overall": precision.mean().item(),
        "precision/class_0": precision[0].item(),
        "precision/class_1": precision[1].item(),
        "precision/class_2": precision[2].item(),
        "recall/overall": recall.mean().item(),
        "recall/class_0": recall[0].item(),
        "recall/class_1": recall[1].item(),
        "recall/class_2": recall[2].item(),
        "f1/overall": f1.mean().item(),
        "f1/class_0": f1[0].item(),
        "f1/class_1": f1[1].item(),
        "f1/class_2": f1[2].item()
    }


The Trainer callback doesn't include metrics such as IOU for training, the callback can be extended using the class TrainingMetricsLoggingCallback. This class comes from the huggingface library transformers and allows customization of logged metrics during training.


sourcecode: https://huggingface.co/transformers/v4.6.0/_modules/transformers/trainer_callback.html#TrainerCallback.on_epoch_end, https://huggingface.co/docs/transformers/main_classes/callback

In [None]:
#TrainingMetricsLoggingCallback is used to customize the Trainer callback used to train and evaluate the model.
class TrainingMetricsLoggingCallback(TrainerCallback):
  #During training the definition or the customization is called at the end of every epoch
    def on_epoch_end(self, args, state: TrainerState, control: TrainerControl, model, **kwargs):
      #making sure the whole train dataset is used
        global train_dataset

        #Initializing metrics that should be included in the trainer callback
        jaccard_index = JaccardIndex(num_classes=3, task="multiclass", ignore_index=0)
        jaccard_index.reset()
        #DataLoader is from Pytorch library and is used to load the train_data.
        data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=CustomDataCollator())
        #iterating over the data_loader in batches
        for batch in data_loader:
          #batch is a dictionary consistent of keys, k and v (this will be pixel_values and labels),
            batch = {k:v for k, v in batch.items()}
            #the forward pass is wrapped because the weights does not need to be updated during inference. Torch.no_grad is used to do this.
            with torch.no_grad():
              #performing a forward pass of the model with the batch of input data. The **batch syntax unpacks the dictionary 'batch' (from line 13) and passes its items as keyword arguments to the model.
                outputs = model(**batch)
                #extracting raw prediction scores (logits) from the model output
            logits = outputs.logits
            #finding the most likely class for each pixel by taking the argmax along the class dimension
            preds = logits.argmax(dim=1)

            #resizing predictions using interpolate just like in the code chunk above.
            preds = preds.unsqueeze(1)  #adding a channel dimension
            preds = interpolate(preds.float(), size=batch['labels'].shape[-2:], mode='nearest') #using nearest neighbor method to interpolate the labels.
            preds = preds.squeeze(1)  #remocing the channel dimension
            #jaccard_index is used again from the Torchmetrics; https://lightning.ai/docs/torchmetrics/stable/classification/jaccard_index.html
            jaccard_index.update(preds, batch['labels']) #updating IOU metric for the prediction and labels in the batch.

        train_iou = jaccard_index.compute().mean().item() #IOU is computed overall and is returned as a tensor first but is then converted when using .item() as a pythonscalar such that the metric can be logged
        wandb.log({'train_iou': train_iou}, step=state.epoch) #train_iou is logged to wandb; https://docs.wandb.ai/guides/integrations/huggingface




All functions and callbacks prior to training and inference is now done. Now the dataset need to be loaded and defined.

In [None]:
#Using pathlib library to load the datasets from the respective image and mask directories.
image_dir = Path('/content/gdrive/MyDrive/speciale/export1/assets')
label_dir = Path('/content/gdrive/MyDrive/speciale/export1/labels')
#Glob library is used to sort the images and masks such that when the directories are used, the image and mask is aligned in pairs.
image_paths = sorted(glob.glob(str(image_dir / '*.jpg')))
label_paths = sorted(glob.glob(str(label_dir / '*.png')))

#class CustomDataset is used here to create the dataset (note that the directories from line 2 and 3 are used here)
dataset = CustomDataset(image_paths, label_paths)
#the size of the dataset is defined as 80% of the overall dataset size.
train_size = int(0.8 * len(dataset))

# the rest will be for the test dataset, which is 20%.
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])


id2label and label2id is used to map the labels to the corresponding colorrange. Please note that id2label is purely for the sake of being able to visualize the predictions of the masks.

sourcecode: https://github.com/huggingface/notebooks/blob/main/examples/semantic_segmentation.ipynb

In [None]:
#assigning the labels of each label.
id2label = {
    0: 'BACKGROUND',
    1: 'LOW_1_3',
    2: 'HIGH_4_5'
}

#defining label to id mapping by inverting id2label
label2id = {label: id for id, label in id2label.items()}



All results are logged to wandb which is a platform that can generate plots for the specified metrics. source: https://docs.wandb.ai/guides/integrations/huggingface

In [None]:
wandb.login()
wandb.init(project="Pytorch", entity="mitth") #login to wandb and the dashboard shows up.

Using the pretrained models from huggingfaces modelhub:
SegFormer B0:https://huggingface.co/nvidia/segformer-b0-finetuned-cityscapes-768-768
SegFormer B3: https://huggingface.co/nvidia/segformer-b3-finetuned-cityscapes-1024-1024
SegFormer B5: https://huggingface.co/nvidia/segformer-b5-finetuned-cityscapes-1024-1024

In [None]:
#using the following github repository as pipeline for training the model: https://github.com/huggingface/notebooks/blob/main/examples/semantic_segmentation.ipynb
#definting the model name from the huggingface model hub.
pretrained_model_name = "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
#loading the model configurations for SegFormer and the number of labels is set as 3 as there are 3 labels.
config = SegformerConfig.from_pretrained(pretrained_model_name, num_labels=3)
#loading the pretrained model with the corresponding defined configurations.
model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model_name, config=config)


#TrainingArguments is an instance of transformers library and is used to define the training arguments.
#sourcecode: https://huggingface.co/docs/transformers/v4.41.2/en/main_classes/trainer#transformers.TrainingArguments
training_args = TrainingArguments(
    #specification for saving strategy for evaluation; should be saved at the end of each epoch
    evaluation_strategy="epoch",
    #directory for logging metrics.
    logging_dir="logs",
    #All logging are done by the end of each epoch
    logging_strategy="epoch",
    #batch size for training set
    per_device_train_batch_size=23,
    #batch size for inference/test set
    per_device_eval_batch_size=23,
    #setting the number of epochs for training
    num_train_epochs=150,
    #setting (maximum due to linear scheduling) learning rate
    learning_rate=0.0005,
    #L2-regularization constant
    weight_decay=0.01,
    #ensuring the results are logged to wandb.
    report_to="wandb",
    run_name="segformer-training-run",
    #proportion of warmup phase; set because earlystopping kicked in too early and it was impossible to tell if the model was learning or not.
    warmup_ratio=0.1,
    #saving stategy for saving the best model; will happen at the end of each epoch
    save_strategy="epoch",
    #best model will be saved.
    load_best_model_at_end=True
)


#Here the model will start training using Trainer callback from transformers library. The model, trainin_args, train_dataset, test_dataset,
#metrics to be logged during evaluation (compute_metrics) and training (Trainer () and TraningMetricsLoggingCallback() are all predefined and is simply just called in this section. )
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    callbacks=[TrainingMetricsLoggingCallback()]
)

#Earlystopping callback is used from transformers library. Sourcecode: https://towardsdatascience.com/fine-tuning-pretrained-nlp-models-with-huggingfaces-trainer-6326a4456e7b
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=15,
    early_stopping_threshold=0.001
)



#Starting the training process.
trainer.train()