In [1]:
import os
import cv2
from glob import glob
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, TrainingArguments, Trainer
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, recall_score, jaccard_score, precision_score
import matplotlib.pyplot as plt
import evaluate

  from .autonotebook import tqdm as notebook_tqdm

INFO:datasets:PyTorch version 2.3.1+cu118 available.
INFO:datasets:TensorFlow version 2.17.0 available.


## Data Processing

In [2]:
class Process_Datasets(Dataset):
    def __init__(self, root_dir, image_processor):
        self.root_dir = root_dir
        self.image_processor = image_processor

        self.image_path = os.path.join(self.root_dir, "img")
        self.mask_path = os.path.join(self.root_dir, "mask")

        image_files = [f for f in os.listdir(self.image_path) if '.png' in f]
        mask_files = [f for f in os.listdir(self.mask_path) if '.png' in f]
        self.images = sorted(image_files)
        self.masks = sorted(mask_files)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.image_path, self.images[index])
        mask_path = os.path.join(self.mask_path, self.masks[index])

        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        encoded = self.image_processor(image, mask, return_tensors="pt")

        for k,v in encoded.items():
            encoded[k].squeeze_()

        return encoded

In [3]:
pre_trained_model = 'nvidia/segformer-b0-finetuned-ade-512-512'
image_processor = SegformerImageProcessor.from_pretrained(pre_trained_model)

def load_datasets(root_dir):
    batch_size=4
    image_processor.do_reduce_labels = False
    image_processor.size = 256

    dataset = Process_Datasets(root_dir=root_dir, image_processor=image_processor)
    train, val = train_test_split(dataset, test_size=0.2)
    val, test = train_test_split(val, test_size=0.01)

    train_dataset = DataLoader(train, batch_size=batch_size, shuffle=True)
    val_dataset = DataLoader(val, batch_size=batch_size, shuffle=True)
    test_dataset = DataLoader(test, shuffle=True)

    return train_dataset, val_dataset, test_dataset

  return func(*args, **kwargs)


In [4]:
covid_train, covid_val, covid_test = load_datasets(root_dir="./Datasets/COVID-19/COVID")
len(covid_train), len(covid_val), len(covid_test)

KeyboardInterrupt: 

## Model Implementation

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

def val_metrics(metric):
    avg_iou = metric['IoU'].mean()
    avg_accur = metric['Accuracy'].mean()
    avg_prec = metric['Precision'].mean()
    avg_recall = metric['Recall'].mean()
    avg_f1 = metric['F1'].mean()

    print(f"IoU: {avg_iou}, Accuracy: {avg_accur}, Precision: {avg_prec}, Recall: {avg_recall}, F1 Score: {avg_f1}")

def train_model(train_data, val_data):
    epochs = 10

    model = SegformerForSemanticSegmentation.from_pretrained(pre_trained_model, ignore_mismatched_sizes=True)
    optimizer = optim.Adam(model.parameters(), lr=0.0025)
    val_metrics = []

    # Train network
    for ep in range(epochs):
        train_loss = []
        val_loss = []

        model.train()
        for index, batch in enumerate(tqdm(train_data)):
            image = batch["pixel_values"]
            mask = batch["labels"]
            optimizer.zero_grad()
            
            outputs = model(pixel_values=image, labels=mask)
            loss = outputs.loss

            train_loss.append(loss.item())
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            for index, batch in enumerate(tqdm(val_data)):
                image = batch["pixel_values"]
                mask = batch["labels"]
                optimizer.zero_grad()

                outputs = model(pixel_values=image, labels=mask)
                logits = F.interpolate(outputs.logits, size=mask.shape[-2:], mode="bilinear", align_corners=False)
                prediction = logits.argmax(dim=1)

                for pred, true in zip(prediction, mask):
                    pred_mask = pred.cpu().numpy()
                    true_mask = true.cpu().numpy()

                    iou = jaccard_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                    accuracy = accuracy_score(true_mask.flatten(), pred_mask.flatten())
                    precision = precision_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                    recall = recall_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                    f1 = f1_score(true_mask.flatten(), pred_mask.flatten(), average='weighted')
                    val_metrics.append([iou, accuracy, precision, recall, f1])

                loss = outputs.loss
                val_loss.append(loss.item())
                optimizer.step()

        print(f"Epoch [{ep+1}/{epochs}]. Training Loss [{np.mean(train_loss)}]. Validation Loss [{np.mean(val_loss)}]")

    metrics = pd.DataFrame(val_metrics, columns=["IoU", "Accuracy", "Precision", "Recall", "F1"])
    return model, metrics

In [None]:
covid_model, covid_metrics = train_model(covid_train, covid_val)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 723/723 [04:17<00:00,  2.81it/s]
100%|██████████| 179/179 [00:50<00:00,  3.53it/s]


Epoch [1/15]. Training Loss [0.06904020317090878]. Validation Loss [1.7864640649133953e-05]


100%|██████████| 723/723 [04:56<00:00,  2.44it/s]
100%|██████████| 179/179 [00:50<00:00,  3.52it/s]


Epoch [2/15]. Training Loss [0.0003206816734067343]. Validation Loss [2.1851783514174193e-06]


100%|██████████| 723/723 [04:26<00:00,  2.72it/s]
100%|██████████| 179/179 [00:50<00:00,  3.52it/s]


Epoch [3/15]. Training Loss [3.0376521069167593e-05]. Validation Loss [2.0332332444380365e-06]


100%|██████████| 723/723 [04:58<00:00,  2.43it/s]
100%|██████████| 179/179 [00:50<00:00,  3.52it/s]


Epoch [4/15]. Training Loss [1.2418163137407727e-05]. Validation Loss [2.3621249559947594e-07]


100%|██████████| 723/723 [04:26<00:00,  2.71it/s]
100%|██████████| 179/179 [00:51<00:00,  3.51it/s]


Epoch [5/15]. Training Loss [6.44311979492364e-06]. Validation Loss [1.434376257361253e-07]


100%|██████████| 723/723 [04:58<00:00,  2.42it/s]
100%|██████████| 179/179 [00:51<00:00,  3.51it/s]


Epoch [6/15]. Training Loss [2.0482913737084714e-06]. Validation Loss [4.873307506766425e-08]


100%|██████████| 723/723 [04:26<00:00,  2.71it/s]
100%|██████████| 179/179 [00:50<00:00,  3.52it/s]


Epoch [7/15]. Training Loss [3.316084513018147e-06]. Validation Loss [5.753516812242243e-08]


100%|██████████| 723/723 [04:58<00:00,  2.42it/s]
100%|██████████| 179/179 [00:50<00:00,  3.53it/s]


Epoch [8/15]. Training Loss [1.384770340389392e-06]. Validation Loss [1.985296380232285e-07]


100%|██████████| 723/723 [04:26<00:00,  2.72it/s]
100%|██████████| 179/179 [00:50<00:00,  3.53it/s]


Epoch [9/15]. Training Loss [3.496622730382353e-06]. Validation Loss [2.9347814209611954e-07]


100%|██████████| 723/723 [04:57<00:00,  2.43it/s]
100%|██████████| 179/179 [00:50<00:00,  3.53it/s]


Epoch [10/15]. Training Loss [1.232149088780011e-06]. Validation Loss [1.1329539378330982e-07]


100%|██████████| 723/723 [04:25<00:00,  2.72it/s]
100%|██████████| 179/179 [00:50<00:00,  3.51it/s]


Epoch [11/15]. Training Loss [1.6131192684249328e-06]. Validation Loss [4.321913653507133e-07]


100%|██████████| 723/723 [04:57<00:00,  2.43it/s]
100%|██████████| 179/179 [00:50<00:00,  3.51it/s]


Epoch [12/15]. Training Loss [2.940865325969765e-07]. Validation Loss [4.5553341575587115e-08]


100%|██████████| 723/723 [04:25<00:00,  2.72it/s]
100%|██████████| 179/179 [00:50<00:00,  3.52it/s]


Epoch [13/15]. Training Loss [1.4038858476484465e-07]. Validation Loss [4.4219313916061154e-08]


100%|██████████| 723/723 [04:57<00:00,  2.43it/s]
100%|██████████| 179/179 [00:51<00:00,  3.51it/s]


Epoch [14/15]. Training Loss [5.225623113030788e-07]. Validation Loss [4.612059563204707e-08]


100%|██████████| 723/723 [04:26<00:00,  2.71it/s]
100%|██████████| 179/179 [00:50<00:00,  3.52it/s]

Epoch [15/15]. Training Loss [5.286889762823003e-08]. Validation Loss [8.78920960970058e-09]





In [None]:
covid_metrics.to_csv("./results/covid_segformer.csv", index=False)
val_metrics(covid_metrics)

## Model Evaluation

In [None]:
def image_display(model, test_data):
    for index, batch in enumerate(tqdm(test_data)):
        image = batch["pixel_values"]
        mask = batch["labels"]

        outputs = model(image)
        prediction = torch.argmax(outputs.logits, 1)
        
        image = image.squeeze()
        fig, ax = plt.subplots(1, 3, figsize=(12, 8))
        ax[0].imshow(image.permute(1, 2, 0))
        ax[1].imshow(mask.permute(1, 2, 0))
        ax[2].imshow(prediction.permute(1, 2, 0))

        ax[0].set_title(f'Test Image')
        ax[1].set_title(f'True Mask')
        ax[2].set_title(f'Predicted Mask')

In [None]:
image_display(covid_model, covid_test)