# Segformer

Research Paper: https://arxiv.org/abs/2105.15203

Datasets: https://data.mendeley.com/datasets/8gf9vpkhgy/2

Implementation adapted from:
1. https://github.com/NVlabs/SegFormer
2. https://debuggercafe.com/road-segmentation-using-segformer/
3. https://www.kaggle.com/code/andrewkettle/pytorch-segformer-and-sam-on-kindey-1
4. https://medium.com/geekculture/semantic-segmentation-with-segformer-2501543d2be4

In [1]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor, SegformerConfig
import torch.optim as optim
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, recall_score, jaccard_score
from tabulate import tabulate

  from .autonotebook import tqdm as notebook_tqdm


## Section 1: Datasets Processing

In [2]:
class Process_Datasets(Dataset):
    def __init__(self, root_dir, feature_extractor, transforms=None):
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.transforms = transforms

        self.image_path = os.path.join(self.root_dir, "img")
        self.mask_path = os.path.join(self.root_dir, "mask")

        img_files = []
        for root, dirs, files in os.walk(self.image_path):
            img_files.extend(files)
        self.images = sorted(img_files)

        mask_files = []
        for root, dirs, files in os.walk(self.mask_path):
            mask_files.extend(files)
        self.masks = sorted(mask_files)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.image_path, self.images[index])
        mask_path = os.path.join(self.mask_path, self.masks[index])

        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.transforms:
            transform = self.transforms(image=image, mask=mask)
            encoded = self.feature_extractor(transform['image'], transform['mask'], return_tensors="pt")
        else:
            encoded = self.feature_extractor(image, mask, return_tensors="pt")

        for k,v in encoded.items():
            encoded[k].squeeze_()

        return encoded

In [3]:
def load_datasets(root_dir):
    batch_size=4
    feature_extractor = SegformerFeatureExtractor(align=False, reduce_zero_label=False)    
    transform = A.Compose([
        A.Resize(128, 128),
        ToTensorV2()
    ], is_check_shapes=False)

    dataset = Process_Datasets(root_dir=root_dir, feature_extractor=feature_extractor, transforms=transform)
    train, val = train_test_split(dataset, test_size=0.1, random_state=5)
    val, test = train_test_split(val, test_size=0.1, random_state=5)

    train_dataset = DataLoader(train, batch_size=batch_size, shuffle=True)
    val_dataset = DataLoader(val, batch_size=batch_size, shuffle=True)
    test_dataset = DataLoader(test, shuffle=True)

    return train_dataset, val_dataset, test_dataset

### Part 1: Darwin Dataset

In [4]:
darwin_train, darwin_val, darwin_test = load_datasets(root_dir="./Datasets/Darwin")
len(darwin_train), len(darwin_val), len(darwin_test)

  return func(*args, **kwargs)


(1374, 138, 62)

### Part 2: Shenzhen Dataset

In [5]:
shenzhen_train, shenzhen_val, shenzhen_test = load_datasets(root_dir="./Datasets/Shenzhen")
len(shenzhen_train), len(shenzhen_val), len(shenzhen_test)

(128, 13, 6)

### Part 3: Covid-19 Dataset

In [6]:
covid_train, covid_val, covid_test = load_datasets(root_dir="./Datasets/COVID-19/COVID")
len(covid_train), len(covid_test)

(814, 37)

## Section 2: Model Implementation

In [7]:
def train_model(train_data, val_data):
    epochs = 10
    learning_rate = 0.0025

    config = SegformerConfig(num_labels=1)
    model = SegformerForSemanticSegmentation.from_pretrained('nvidia/mit-b0', config=config)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train network
    for ep in range(epochs):
        train_loss = []
        val_loss = []

        model.train()
        for index, batch in enumerate(tqdm(train_data)):
            pixel_values = batch["pixel_values"]
            labels = batch["labels"]
            optimizer.zero_grad()
            
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

            train_loss.append(loss.item())
            optimizer.step()

        model.eval()
        with torch.no_grad():
            for index, batch in enumerate(tqdm(val_data)):
                pixel_values = batch["pixel_values"]
                labels = batch["labels"]
                optimizer.zero_grad()
                
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = outputs.loss

                val_loss.append(loss.item())
                optimizer.step()

        print(f"Epoch [{ep+1}/{epochs}]. Training Loss [{np.mean(train_loss)}]. Validation Loss [{np.mean(val_loss)}]")

    return model

### Part 1: Darwin Dataset

In [8]:
darwin_model = train_model(darwin_train, darwin_val)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 1374/1374 [06:39<00:00,  3.44it/s]
100%|██████████| 138/138 [00:34<00:00,  3.96it/s]


Epoch [1/10]. Training Loss [0.43524463944885405]. Validation Loss [0.4398825751698535]


100%|██████████| 1374/1374 [06:32<00:00,  3.50it/s]
100%|██████████| 138/138 [00:34<00:00,  3.96it/s]


Epoch [2/10]. Training Loss [0.43424851978753504]. Validation Loss [0.43674293311609735]


100%|██████████| 1374/1374 [06:33<00:00,  3.50it/s]
100%|██████████| 138/138 [00:34<00:00,  3.97it/s]


Epoch [3/10]. Training Loss [0.4343823699234528]. Validation Loss [0.4375596547472304]


100%|██████████| 1374/1374 [06:33<00:00,  3.49it/s]
100%|██████████| 138/138 [00:34<00:00,  3.95it/s]


Epoch [4/10]. Training Loss [0.4343797474600689]. Validation Loss [0.4378686169351357]


100%|██████████| 1374/1374 [06:33<00:00,  3.50it/s]
100%|██████████| 138/138 [00:34<00:00,  3.95it/s]


Epoch [5/10]. Training Loss [0.4346866974664047]. Validation Loss [0.44085417169591656]


100%|██████████| 1374/1374 [06:33<00:00,  3.49it/s]
100%|██████████| 138/138 [00:34<00:00,  3.95it/s]


Epoch [6/10]. Training Loss [0.4342090359399586]. Validation Loss [0.4405207497918088]


100%|██████████| 1374/1374 [06:33<00:00,  3.49it/s]
100%|██████████| 138/138 [00:34<00:00,  3.95it/s]


Epoch [7/10]. Training Loss [0.43555243776398106]. Validation Loss [0.43768960107927735]


100%|██████████| 1374/1374 [06:33<00:00,  3.50it/s]
100%|██████████| 138/138 [00:34<00:00,  3.96it/s]


Epoch [8/10]. Training Loss [0.43387104868563503]. Validation Loss [0.43755757074425183]


100%|██████████| 1374/1374 [06:32<00:00,  3.50it/s]
100%|██████████| 138/138 [00:34<00:00,  3.95it/s]


Epoch [9/10]. Training Loss [0.43431192622391246]. Validation Loss [0.43714567371036694]


100%|██████████| 1374/1374 [06:33<00:00,  3.49it/s]
100%|██████████| 138/138 [00:34<00:00,  3.95it/s]

Epoch [10/10]. Training Loss [0.4347247025743669]. Validation Loss [0.4380907317002614]





### Part 2: Shenzhen Dataset

In [9]:
shenzhen_model = train_model(shenzhen_train, shenzhen_val)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 128/128 [00:36<00:00,  3.49it/s]
100%|██████████| 13/13 [00:03<00:00,  4.03it/s]


Epoch [1/10]. Training Loss [0.5627518070396036]. Validation Loss [0.558212005175077]


100%|██████████| 128/128 [00:36<00:00,  3.52it/s]
100%|██████████| 13/13 [00:03<00:00,  4.03it/s]


Epoch [2/10]. Training Loss [0.5629258088301867]. Validation Loss [0.5568674848629878]


100%|██████████| 128/128 [00:36<00:00,  3.53it/s]
100%|██████████| 13/13 [00:03<00:00,  4.02it/s]


Epoch [3/10]. Training Loss [0.5634880259167403]. Validation Loss [0.5636127522358527]


100%|██████████| 128/128 [00:36<00:00,  3.52it/s]
100%|██████████| 13/13 [00:03<00:00,  4.02it/s]


Epoch [4/10]. Training Loss [0.5640577164012939]. Validation Loss [0.5646736117509695]


100%|██████████| 128/128 [00:36<00:00,  3.52it/s]
100%|██████████| 13/13 [00:03<00:00,  4.02it/s]


Epoch [5/10]. Training Loss [0.5626006373204291]. Validation Loss [0.5647830917285039]


100%|██████████| 128/128 [00:36<00:00,  3.52it/s]
100%|██████████| 13/13 [00:03<00:00,  4.01it/s]


Epoch [6/10]. Training Loss [0.563137001125142]. Validation Loss [0.560822090277305]


100%|██████████| 128/128 [00:36<00:00,  3.53it/s]
100%|██████████| 13/13 [00:03<00:00,  4.04it/s]


Epoch [7/10]. Training Loss [0.5628131642006338]. Validation Loss [0.565063999249385]


100%|██████████| 128/128 [00:36<00:00,  3.52it/s]
100%|██████████| 13/13 [00:03<00:00,  4.01it/s]


Epoch [8/10]. Training Loss [0.562933347420767]. Validation Loss [0.5645095751835749]


100%|██████████| 128/128 [00:36<00:00,  3.52it/s]
100%|██████████| 13/13 [00:03<00:00,  3.99it/s]


Epoch [9/10]. Training Loss [0.5627028229646385]. Validation Loss [0.5661531274135296]


100%|██████████| 128/128 [00:36<00:00,  3.52it/s]
100%|██████████| 13/13 [00:03<00:00,  4.00it/s]

Epoch [10/10]. Training Loss [0.5627696823794395]. Validation Loss [0.5640419813302847]





### Part 3: Covid-19 Dataset

In [10]:
covid_model = train_model(covid_train, covid_val)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 814/814 [03:52<00:00,  3.50it/s]
100%|██████████| 82/82 [00:20<00:00,  3.98it/s]


Epoch [1/10]. Training Loss [0.5540234648258738]. Validation Loss [0.5446027161144629]


100%|██████████| 814/814 [03:52<00:00,  3.51it/s]
100%|██████████| 82/82 [00:20<00:00,  3.96it/s]


Epoch [2/10]. Training Loss [0.5542196601409584]. Validation Loss [0.5470831466884147]


100%|██████████| 814/814 [03:53<00:00,  3.49it/s]
100%|██████████| 82/82 [00:20<00:00,  3.95it/s]


Epoch [3/10]. Training Loss [0.5541503212302558]. Validation Loss [0.5486820943471862]


100%|██████████| 814/814 [03:52<00:00,  3.50it/s]
100%|██████████| 82/82 [00:20<00:00,  3.95it/s]


Epoch [4/10]. Training Loss [0.553922365170146]. Validation Loss [0.5494555641965169]


100%|██████████| 814/814 [03:52<00:00,  3.50it/s]
100%|██████████| 82/82 [00:20<00:00,  3.96it/s]


Epoch [5/10]. Training Loss [0.5542508487472956]. Validation Loss [0.5428105360124169]


100%|██████████| 814/814 [03:52<00:00,  3.50it/s]
100%|██████████| 82/82 [00:20<00:00,  3.95it/s]


Epoch [6/10]. Training Loss [0.5541196159780465]. Validation Loss [0.5514216495723259]


100%|██████████| 814/814 [03:53<00:00,  3.49it/s]
100%|██████████| 82/82 [00:20<00:00,  3.95it/s]


Epoch [7/10]. Training Loss [0.5540672229461061]. Validation Loss [0.5518807092817818]


100%|██████████| 814/814 [03:53<00:00,  3.49it/s]
100%|██████████| 82/82 [00:20<00:00,  3.95it/s]


Epoch [8/10]. Training Loss [0.55410531503709]. Validation Loss [0.549522464231747]


100%|██████████| 814/814 [03:52<00:00,  3.49it/s]
100%|██████████| 82/82 [00:20<00:00,  3.96it/s]


Epoch [9/10]. Training Loss [0.5542291139941251]. Validation Loss [0.5509657812554661]


100%|██████████| 814/814 [03:52<00:00,  3.50it/s]
100%|██████████| 82/82 [00:20<00:00,  3.95it/s]

Epoch [10/10]. Training Loss [0.5541447640490473]. Validation Loss [0.5512824116683588]





## Section 3: Model Evaluation

In [11]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

def evaluate_model(model, test_data):
    batch = next(iter(test_data))
    ious, accuracies, recalls, f1s = [], [], [], []
    
    with torch.no_grad():
        for index, batch in enumerate(tqdm(test_data)):
            image = batch["pixel_values"]
            mask = batch["labels"]

            outputs = model(pixel_values=image, labels=mask)
            logits = F.interpolate(outputs.logits, size=mask.shape[-2:], mode="bilinear", align_corners=False)
            prediction = logits.argmax(dim=1)

            for pred, true in zip(prediction, mask):
                pred_mask = pred.cpu().numpy()
                true_mask = true.cpu().numpy()

                iou = jaccard_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                accuracy = accuracy_score(true_mask.flatten(), pred_mask.flatten())
                recall = recall_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                f1 = f1_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')

                ious.append(iou)
                accuracies.append(accuracy)
                recalls.append(recall)
                f1s.append(f1)

    mean_iou = np.mean(ious)
    mean_accuracy = np.mean(accuracies)
    mean_recall = np.mean(recalls)
    mean_f1 = np.mean(f1s)

    return mean_iou, mean_accuracy, mean_recall, mean_f1

### Part 1: Darwin Dataset

In [12]:
darwin_iou, darwin_accuracy, darwin_recall, darwin_f1 = evaluate_model(darwin_model, darwin_test)

print(f"Validation Metrics: IoU: {darwin_iou}, Accuracy: {darwin_accuracy}, Recall: {darwin_recall}, F1 Score: {darwin_f1}")

100%|██████████| 62/62 [00:15<00:00,  3.93it/s]

Validation Metrics: IoU: 0.4488634690642357, Accuracy: 0.6655332503780242, Recall: 0.6655332503780242, F1 Score: 0.5344151998205954





### Part 2: Shenzhen Dataset

In [13]:
shenzhen_iou, shenzhen_accuracy, shenzhen_recall, shenzhen_f1 = evaluate_model(shenzhen_model, shenzhen_test)

print(f"Validation Metrics: IoU: {shenzhen_iou}, Accuracy: {shenzhen_accuracy}, Recall: {shenzhen_recall}, F1 Score: {shenzhen_f1}")

100%|██████████| 6/6 [00:01<00:00,  3.89it/s]

Validation Metrics: IoU: 0.5149720447758833, Accuracy: 0.7153523763020834, Recall: 0.7153523763020834, F1 Score: 0.5979566650943469





### Part 3: Covid-19 Dataset

In [14]:
covid_iou, covid_accuracy, covid_recall, covid_f1 = evaluate_model(covid_model, covid_test)

print(f"Validation Metrics: IoU: {covid_iou}, Accuracy: {covid_accuracy}, Recall: {covid_recall}, F1 Score: {covid_f1}")

100%|██████████| 37/37 [00:10<00:00,  3.58it/s]

Validation Metrics: IoU: 0.561079980855858, Accuracy: 0.7448812948690878, Recall: 0.7448812948690878, F1 Score: 0.6384047276710346





### Part 4: Conclusion

In [15]:
results_table = [
    ["Darwin", darwin_iou, darwin_accuracy, darwin_recall, darwin_f1],
    ["Zhenshen", shenzhen_iou, shenzhen_accuracy, shenzhen_recall, shenzhen_f1],
    ["Covid-19", covid_iou, covid_accuracy, covid_recall, covid_f1]
]

head = ["Datasets", "IoU Score", " Accuracy Score", "Recall Score", "F-1 score"]

print(tabulate(results_table, headers=head, tablefmt="grid"))

+------------+-------------+-------------------+----------------+-------------+
| Datasets   |   IoU Score |    Accuracy Score |   Recall Score |   F-1 score |
| Darwin     |    0.448863 |          0.665533 |       0.665533 |    0.534415 |
+------------+-------------+-------------------+----------------+-------------+
| Zhenshen   |    0.514972 |          0.715352 |       0.715352 |    0.597957 |
+------------+-------------+-------------------+----------------+-------------+
| Covid-19   |    0.56108  |          0.744881 |       0.744881 |    0.638405 |
+------------+-------------+-------------------+----------------+-------------+
