In [1]:
import os
import cv2
from glob import glob
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, recall_score, jaccard_score, precision_score
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


## Data Processing

In [2]:
class Process_Datasets(Dataset):
    def __init__(self, root_dir, image_processor):
        self.root_dir = root_dir
        self.image_processor = image_processor

        self.image_path = os.path.join(self.root_dir, "img")
        self.mask_path = os.path.join(self.root_dir, "mask")

        image_files = [f for f in os.listdir(self.image_path) if '.png' in f]
        mask_files = [f for f in os.listdir(self.mask_path) if '.png' in f]
        self.images = sorted(image_files)
        self.masks = sorted(mask_files)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.image_path, self.images[index])
        mask_path = os.path.join(self.mask_path, self.masks[index])

        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        encoded = self.image_processor(image, mask, return_tensors="pt")

        return encoded

In [3]:
pre_trained_model = 'nvidia/segformer-b0-finetuned-ade-512-512'

def load_datasets(root_dir):
    batch_size=4
    image_processor = SegformerImageProcessor.from_pretrained(pre_trained_model)
    image_processor.do_reduce_labels = False
    image_processor.size = 128

    dataset = Process_Datasets(root_dir=root_dir, image_processor=image_processor)
    train, val = train_test_split(dataset, test_size=0.2)
    val, test = train_test_split(val, test_size=0.01)

    train_dataset = DataLoader(train, batch_size=batch_size, shuffle=True)
    val_dataset = DataLoader(val, batch_size=batch_size, shuffle=True)
    test_dataset = DataLoader(test, shuffle=True)

    return train_dataset, val_dataset, test_dataset

In [4]:
covid_train, covid_val, covid_test = load_datasets(root_dir="./Datasets/COVID-19/COVID")
len(covid_train), len(covid_val), len(covid_test)

  return func(*args, **kwargs)


(723, 179, 8)

## Model Implementation

In [6]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

def train_model(train_data, val_data):
    epochs = 15

    model = SegformerForSemanticSegmentation.from_pretrained(pre_trained_model, ignore_mismatched_sizes=True)
    optimizer = optim.Adam(model.parameters(), lr=0.0025)
    val_metrics = []

    # Train network
    for ep in range(epochs):
        train_loss = []
        val_loss = []

        model.train()
        for index, batch in enumerate(tqdm(train_data)):
            image = batch["pixel_values"]
            mask = batch["labels"]
            optimizer.zero_grad()
            
            outputs = model(pixel_values=image, labels=mask)
            loss = outputs.loss

            train_loss.append(loss.item())
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            for index, batch in enumerate(tqdm(val_data)):
                image = batch["pixel_values"]
                mask = batch["labels"]
                optimizer.zero_grad()

                outputs = model(pixel_values=image, labels=mask)
                logits = F.interpolate(outputs.logits, size=mask.shape[-2:], mode="bilinear", align_corners=False)
                prediction = logits.argmax(dim=1)

                for pred, true in zip(prediction, mask):
                    pred_mask = pred.cpu().numpy()
                    true_mask = true.cpu().numpy()

                    iou = jaccard_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                    accuracy = accuracy_score(true_mask.flatten(), pred_mask.flatten())
                    precision = precision_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                    recall = recall_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                    f1 = f1_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')

                    val_metrics.append([iou, accuracy, precision, recall, f1])

                loss = outputs.loss
                val_loss.append(loss.item())
                optimizer.step()

        print(f"Epoch [{ep+1}/{epochs}]. Training Loss [{np.mean(train_loss)}]. Validation Loss [{np.mean(val_loss)}]")

    metrics = pd.DataFrame(val_metrics, columns=["IoU", "Accuracy", "Precision", "Recall", "F1"])
    return model, metrics

def val_metrics(metric):
    avg_iou = metric['IoU'].mean()
    avg_accur = metric['Accuracy'].mean()
    avg_prec = metric['Precision'].mean()
    avg_recall = metric['Recall'].mean()
    avg_f1 = metric['F1'].mean()

    print(f"IoU: {avg_iou}, Accuracy: {avg_accur}, Precision: {avg_prec}, Recall: {avg_recall}, F1 Score: {avg_f1}")

In [None]:
covid_model, covid_metrics = train_model(covid_train, covid_val)

In [None]:
val_metrics(covid_metrics)
covid_metrics.to_csv("./results/covid_segformer.csv", index=False)

## Model Evaluation

In [None]:
def image_display(model, test_data):
    for index, batch in enumerate(tqdm(test_data)):
        image = batch["pixel_values"]
        mask = batch["labels"]

        outputs = model(image)
        prediction = torch.argmax(outputs.logits, 1)

        if (index < 4):
            image = image.squeeze()
            
            fig, ax = plt.subplots(1, 3, figsize=(12, 8))
            ax[0].imshow(image.permute(1, 2, 0))
            ax[1].imshow(mask.permute(1, 2, 0))
            ax[2].imshow(prediction.permute(1, 2, 0))

            ax[0].set_title(f'Test Image')
            ax[1].set_title(f'True Mask')
            ax[2].set_title(f'Predicted Mask')

In [None]:
image_display(covid_model, covid_test)