# Segformer

Research Paper: https://arxiv.org/abs/2105.15203

Datasets: https://data.mendeley.com/datasets/8gf9vpkhgy/2

Implementation adapted from:
1. https://github.com/NVlabs/SegFormer
2. https://debuggercafe.com/road-segmentation-using-segformer/
3. https://www.kaggle.com/code/andrewkettle/pytorch-segformer-and-sam-on-kindey-1

In [2]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import SegformerForSemanticSegmentation, SegformerConfig
import torch.optim as optim
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

## Section 1: Datasets Processing

In [3]:
class Load_Datasets(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])

        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.transform:
            transform = self.transform(image=image, mask=mask)
            image = transform['image']
            mask = transform['mask']

        image = image.float()/255.0
        mask = mask.long()
        return image, mask

In [4]:
batch_size=4

transform = A.Compose([
    A.Resize(256, 256),
    ToTensorV2()
], is_check_shapes=False)

### Part 1: Darwin Dataset

In [107]:
darwin_dataset = Load_Datasets(image_dir='./Datasets/Darwin/img', mask_dir='./Datasets/Darwin/mask', transform=transform)
train, test = train_test_split(darwin_dataset, test_size=0.1)

darwin_train = DataLoader(train, batch_size=batch_size, shuffle=True)
darwin_test = DataLoader(test, batch_size=batch_size, shuffle=True)

### Part 2: Shenzhen Dataset

In [5]:
shenzhen_dataset = Load_Datasets(image_dir='./Datasets/Shenzhen/img', mask_dir='./Datasets/Shenzhen/mask', transform=transform)
train, test = train_test_split(shenzhen_dataset, test_size=0.1)

shenzhen_train = DataLoader(train, batch_size=batch_size, shuffle=True)
shenzhen_test = DataLoader(test, batch_size=batch_size, shuffle=True)

### Part 3: Covid-19 Dataset

In [None]:
covid_dataset = Load_Datasets(image_dir='./Datasets/Covid-19/Covid/img', mask_dir='./Datasets/Covid-19/Covid/mask', transform=transform)
train, test = train_test_split(covid_dataset, test_size=0.1)

covid_train = DataLoader(train, batch_size=batch_size, shuffle=True)
covid_test = DataLoader(test, batch_size=batch_size, shuffle=True)

## Section 2: Model Implementation

In [6]:
def train_model(train_data, val_data):
    epochs = 10
    learning_rate = 0.0025

    config = SegformerConfig(num_labels=1)
    model = SegformerForSemanticSegmentation.from_pretrained('nvidia/mit-b0', config=config)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # Check for device
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40], gamma=0.1)

    # Train network
    for ep in range(epochs):
        model.train()
        train_loss = []

        # Training
        for idx, (images, masks) in enumerate(tqdm(train_data)):
            # Convert vars to GPU
            images = images.float().to(device)
            masks = masks.type(torch.LongTensor).to(device)
            output = model(pixel_values=images, labels=masks)

            loss = output.loss
            train_loss.append(loss)

            optimizer.step()
            optimizer.zero_grad()

        train_loss = loss.detach().numpy()
        print(f"Epoch [{ep+1}/{epochs}]. Training Loss [{np.mean(train_loss)}]")
        scheduler.step()
        
    return model

### Part 1: Darwin Dataset

In [None]:
darwin_model = train_model(darwin_train, darwin_test)

NameError: name 'darwin_train' is not defined

### Part 2: Shenzhen Dataset

In [7]:
shenzhen_model = train_model(shenzhen_train, shenzhen_test)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 128/128 [00:41<00:00,  3.11it/s]

Epoch [1/1]. Training Loss [0.49224919080734253]





### Part 3: Covid-19 Dataset

In [None]:
covid_model = train_model(covid_train, covid_test)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch [1/10]


100%|██████████| 814/814 [00:51<00:00, 15.87it/s]
100%|██████████| 91/91 [00:05<00:00, 16.18it/s]


Epoch [2/10]


100%|██████████| 814/814 [00:51<00:00, 15.77it/s]
100%|██████████| 91/91 [00:05<00:00, 15.52it/s]


Epoch [3/10]


100%|██████████| 814/814 [00:53<00:00, 15.20it/s]
100%|██████████| 91/91 [00:06<00:00, 14.81it/s]


Epoch [4/10]


100%|██████████| 814/814 [00:54<00:00, 14.83it/s]
100%|██████████| 91/91 [00:05<00:00, 15.24it/s]


Epoch [5/10]


100%|██████████| 814/814 [00:54<00:00, 14.83it/s]
100%|██████████| 91/91 [00:05<00:00, 15.47it/s]


Epoch [6/10]


100%|██████████| 814/814 [00:55<00:00, 14.73it/s]
100%|██████████| 91/91 [00:06<00:00, 15.05it/s]


Epoch [7/10]


100%|██████████| 814/814 [00:53<00:00, 15.11it/s]
100%|██████████| 91/91 [00:06<00:00, 15.12it/s]


Epoch [8/10]


100%|██████████| 814/814 [00:53<00:00, 15.15it/s]
100%|██████████| 91/91 [00:06<00:00, 15.13it/s]


Epoch [9/10]


100%|██████████| 814/814 [00:52<00:00, 15.38it/s]
100%|██████████| 91/91 [00:05<00:00, 15.77it/s]


Epoch [10/10]


100%|██████████| 814/814 [00:52<00:00, 15.60it/s]
100%|██████████| 91/91 [00:05<00:00, 15.78it/s]


## Section 3: Model Evaluation

In [51]:
def evaluate_model(model, val_data):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # Check for device
    model = model.to(device)

    model.eval()
    ious, accuracies, recalls, f1s = [], [], [], []
    
    with torch.no_grad():
        for images, masks in tqdm(val_data):
            images = images.float().to(device)
            masks = masks.type(torch.LongTensor).to(device)

            outputs = model(pixel_values=images)
            pred_masks = outputs.logits.argmax(dim=1)

            for pred_mask, true_mask in zip(pred_masks, masks):
                pred_mask_resized = F.interpolate(pred_mask.unsqueeze(0).unsqueeze(0).float(), size=true_mask.shape[-2:], mode='nearest').squeeze().cpu().numpy()
                true_mask_np = true_mask.cpu().numpy()

                iou = precision_score(true_mask_np.flatten(), pred_mask_resized.flatten(), average='micro')
                accuracy = accuracy_score(true_mask_np.flatten(), pred_mask_resized.flatten(), average='micro')
                recall = recall_score(true_mask_np.flatten(), pred_mask_resized.flatten(), average='micro')
                f1 = f1_score(true_mask_np.flatten(), pred_mask_resized.flatten(), average='micro')

                ious.append(iou)
                accuracies.append(accuracy)
                recalls.append(recall)
                f1s.append(f1)

    mean_iou = np.mean(ious)
    mean_accuracy = np.mean(accuracies)
    mean_recall = np.mean(recalls)
    mean_f1 = np.mean(f1s)

    return mean_iou, mean_accuracy, mean_recall, mean_f1

### Part 1: Darwin Dataset

In [None]:
mean_iou, mean_accuracy, mean_recall, mean_f1 = evaluate_model(darwin_model, darwin_test)

print(f"Validation Metrics:
        IoU: {mean_iou}, 
        Accuracy: {mean_accuracy}, 
        Recall: {mean_recall}, 
        F1 Score: {mean_f1}"
)

100%|██████████| 15/15 [00:03<00:00,  4.41it/s]

Validation Metrics - IoU: 0.7431587085389254, Precision: 0.7431587085389254, Recall: 0.7431587085389254, F1 Score: 0.7431587085389254





### Part 2: Shenzhen Dataset

In [52]:
mean_iou, mean_accuracy, mean_recall, mean_f1 = evaluate_model(shenzhen_model, shenzhen_test)

print(f"Validation Metrics:
        IoU: {mean_iou}, 
        Accuracy: {mean_accuracy}, 
        Recall: {mean_recall}, 
        F1 Score: {mean_f1}"
)

100%|██████████| 15/15 [00:03<00:00,  4.41it/s]

Validation Metrics - IoU: 0.7431587085389254, Precision: 0.7431587085389254, Recall: 0.7431587085389254, F1 Score: 0.7431587085389254





### Part 3: Covid-19 Dataset

In [None]:
mean_iou, mean_accuracy, mean_recall, mean_f1 = evaluate_model(covid_model, covid_test)

print(f"Validation Metrics:
        IoU: {mean_iou}, 
        Accuracy: {mean_accuracy}, 
        Recall: {mean_recall}, 
        F1 Score: {mean_f1}"
)

100%|██████████| 15/15 [00:03<00:00,  4.41it/s]

Validation Metrics - IoU: 0.7431587085389254, Precision: 0.7431587085389254, Recall: 0.7431587085389254, F1 Score: 0.7431587085389254



