In [1]:
import os
import zipfile
import platform
import warnings
from glob import glob
from dataclasses import dataclass

# To filter UserWarning.
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
import cv2
import requests
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# For data augmentation and preprocessing.
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Imports required SegFormer classes
from transformers import SegformerForSemanticSegmentation

# Importing lighting along with a built-in callback it provides.
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

# Importing torchmetrics modular and functional implementations.
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassF1Score

# To print model summary.
from torchinfo import summary

# Tensor and Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Sets the internal precision of float32 matrix multiplications.
torch.set_float32_matmul_precision('high')

# To enable determinism.
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# To render the matplotlib figure in the notebook.
%matplotlib inline

In [4]:
ROOT_DIR = "C:/Users/ssre_/Projects/dats-data/2d_all_v2"

# Classification Model

In [5]:
@dataclass(frozen=True)

# Define constants
class DatasetClasConfig:
    IMAGE_SIZE: tuple[int,int] = (288, 288) # W, H
    MEAN: tuple = (0.136, 0.136, 0.136)
    STD:  tuple = (0.178, 0.178, 0.178)
    BATCH_SIZE = 32
    IMG_DIR = os.path.join(ROOT_DIR, "valid")
    MODEL_DIR = "C:/Users/ssre_/Projects/dats-data/models"
    
# Function for custom normalization
def custom_normalization(image):  
    image = image / 255.0    
    image = (image - DatasetClasConfig.MEAN[0]) / DatasetClasConfig.STD[0]
    return image
    
# Image preprocessing
test_datagen = ImageDataGenerator(
    preprocessing_function = custom_normalization)

# Define the test set
test_generator = test_datagen.flow_from_directory(
    DatasetClasConfig.IMG_DIR,
    target_size = DatasetClasConfig.IMAGE_SIZE,
    batch_size = DatasetClasConfig.BATCH_SIZE,
    classes=['images'],
    shuffle=False
)

# Load the classification model
#class_model_name = os.path.join(DatasetClasConfig.MODEL_DIR, "resnet50v2_nn256_lr0001_relu_batch32_epoch30_v2_2p5d.keras")
#class_model = load_model(class_model_name)

# Predict!
#y_pred = class_model.predict(test_generator).reshape(-1)

Found 6560 images belonging to 1 classes.


In [6]:
#np.save('classification_labels.npy', y_pred)
prediction = np.load('classification_labels.npy')
THR = 0.1586085855960846
CLASSIF_LABELS = prediction > THR
print(CLASSIF_LABELS[0:12])

[False False False False False False False False False False False False]


# Segmentation Model

In [7]:
@dataclass(frozen=True)
class DatasetConfig:
    NUM_CLASSES:   int = 4 # including background.
    IMAGE_SIZE: tuple[int,int] = (288, 288) # W, H
    MEAN: tuple = (0.485, 0.456, 0.406)
    STD:  tuple = (0.229, 0.224, 0.225)
    BACKGROUND_CLS_ID: int = 0
    URL: str = r"https://www.dropbox.com/scl/fi/r0685arupp33sy31qhros/dataset_UWM_GI_Tract_train_valid.zip?rlkey=w4ga9ysfiuz8vqbbywk0rdnjw&dl=1"
    DATASET_PATH: str = os.path.join(os.getcwd(), ROOT_DIR)

@dataclass(frozen=True)
class Paths:
    DATA_TRAIN_IMAGES: str = os.path.join(DatasetConfig.DATASET_PATH, "train", "images", r"*.png")
    DATA_TRAIN_LABELS: str = os.path.join(DatasetConfig.DATASET_PATH, "train", "masks",  r"*.png")
    DATA_VALID_IMAGES: str = os.path.join(DatasetConfig.DATASET_PATH, "valid", "images", r"*.png")
    DATA_VALID_LABELS: str = os.path.join(DatasetConfig.DATASET_PATH, "valid", "masks",  r"*.png")

@dataclass
class TrainingConfig:
    BATCH_SIZE:      int = 12 # 8
    NUM_EPOCHS:      int = 1
    INIT_LR:       float = 3e-4
    NUM_WORKERS:     int = 0 if platform.system() == "Windows" else os.cpu_count()

    OPTIMIZER_NAME:  str = "AdamW"
    WEIGHT_DECAY:  float = 1e-4
    USE_SCHEDULER:  bool = True # Use learning rate scheduler?
    SCHEDULER:       str = "MultiStepLR" # Name of the scheduler to use.
    MODEL_NAME:str = "nvidia/segformer-b4-finetuned-ade-512-512"


@dataclass
class InferenceConfig:
    BATCH_SIZE:  int = 10
    NUM_BATCHES: int = 2

In [8]:
# Create a mapping of class ID to RGB value.
id2color = {
    0: (0, 0, 0),    # background pixel
    1: (0, 0, 255),  # Stomach
    2: (0, 255, 0),  # Small Bowel
    3: (255, 0, 0),  # large Bowel
}

DatasetConfig.NUM_CLASSES = len(id2color)

print("Number of classes", DatasetConfig.NUM_CLASSES)

# Reverse id2color mapping.
# Used for converting RGB mask to a single channel (grayscale) representation.
rev_id2color = {value: key for key, value in id2color.items()}

Number of classes 4


In [9]:
# Custom Class for creating training and validation (segmentation) dataset objects.

class MedicalDataset(Dataset):
    #def __init__(self, *, image_paths, mask_paths, img_size, ds_mean, ds_std, is_train=False):
    def __init__(self, *, image_paths, mask_paths, classif_labels, img_size, ds_mean, ds_std, is_train=False):
        self.image_paths = image_paths
        self.mask_paths  = mask_paths
        self.classif_labels = classif_labels
        self.is_train    = is_train
        self.img_size    = img_size
        self.ds_mean = ds_mean
        self.ds_std = ds_std
        self.transforms  = self.setup_transforms(mean=self.ds_mean, std=self.ds_std)

    def __len__(self):
        return len(self.image_paths)

    def setup_transforms(self, *, mean, std):
        transforms = []

        # Augmentation to be applied to the training set.
        if self.is_train:
            transforms.extend([
                A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5),
                A.ShiftScaleRotate(scale_limit=0.12, rotate_limit=0.15, shift_limit=0.12, p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.CoarseDropout(max_holes=8, max_height=self.img_size[1]//20, max_width=self.img_size[0]//20, min_holes=5, fill_value=0, mask_fill_value=0, p=0.5)
            ])

        # Preprocess transforms - Normalization and converting to PyTorch tensor format (HWC --> CHW).
        transforms.extend([
                A.Normalize(mean=mean, std=std, always_apply=True),
                ToTensorV2(always_apply=True),  # (H, W, C) --> (C, H, W)
        ])
        return A.Compose(transforms)

    def load_file(self, file_path, depth=0):
        file = cv2.imread(file_path, depth)
        if depth == cv2.IMREAD_COLOR:
            file = file[:, :, ::-1]
        return cv2.resize(file, (self.img_size), interpolation=cv2.INTER_NEAREST)

    def __getitem__(self, index):
        # Load image and mask file.
        image = self.load_file(self.image_paths[index], depth=cv2.IMREAD_COLOR)
        mask  = self.load_file(self.mask_paths[index],  depth=cv2.IMREAD_GRAYSCALE)
        labels = self.classif_labels[index]
        
        # Apply Preprocessing (+ Augmentations) transformations to image-mask pair
        transformed = self.transforms(image=image, mask=mask)
        image, mask = transformed["image"], transformed["mask"].to(torch.long)
        return image, mask, labels

In [10]:
class MedicalSegmentationDataModule(pl.LightningDataModule):
    def __init__(
        self,
        num_classes=10,
        img_size=(384, 384),
        ds_mean=(0.485, 0.456, 0.406),
        ds_std=(0.229, 0.224, 0.225),
        batch_size=12,
        num_workers=3,
        pin_memory=False,
        shuffle_validation=False,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.img_size    = img_size
        self.ds_mean     = ds_mean
        self.ds_std      = ds_std
        self.batch_size  = batch_size
        self.num_workers = num_workers
        self.pin_memory  = pin_memory

        self.shuffle_validation = shuffle_validation

    def prepare_data(self):
        # Download dataset.
        dataset_zip_path = f"{DatasetConfig.DATASET_PATH}.zip"

        # Download if dataset does not exists.
        if not os.path.exists(DatasetConfig.DATASET_PATH):

            print("Downloading and extracting assets...", end="")
            file = requests.get(DatasetConfig.URL)
            open(dataset_zip_path, "wb").write(file.content)

            try:
                with zipfile.ZipFile(dataset_zip_path) as z:
                    z.extractall(os.path.split(dataset_zip_path)[0]) # Unzip where downloaded.
                    print("Done")
            except:
                print("Invalid file")

            os.remove(dataset_zip_path) # Remove the ZIP file to free storage space.

    def setup(self, classif_labels=np.array([]), *args, **kwargs):
        # Create training dataset and dataloader.
        train_imgs = sorted(glob(f"{Paths.DATA_TRAIN_IMAGES}"))
        train_msks  = sorted(glob(f"{Paths.DATA_TRAIN_LABELS}"))

        # Create validation dataset and dataloader.
        valid_imgs = sorted(glob(f"{Paths.DATA_VALID_IMAGES}"))
        valid_imgs_2 = (glob(f"{Paths.DATA_VALID_IMAGES}"))
        valid_msks = sorted(glob(f"{Paths.DATA_VALID_LABELS}"))
        if valid_imgs == valid_imgs_2:
            print("OK!")
        if len(classif_labels) == 0:
            classif_labels = np.ones(len(valid_imgs))
        else:
            print(classif_labels[0:12])
       
        self.train_ds = MedicalDataset(image_paths=train_imgs, mask_paths=train_msks, classif_labels=classif_labels,
                                       img_size=self.img_size, is_train=True, ds_mean=self.ds_mean, ds_std=self.ds_std)

        self.valid_ds = MedicalDataset(image_paths=valid_imgs, mask_paths=valid_msks, classif_labels=classif_labels,
                                       img_size=self.img_size, is_train=False, ds_mean=self.ds_mean, ds_std=self.ds_std)

    def train_dataloader(self):
        # Create train dataloader object with drop_last flag set to True.
        return DataLoader(
            self.train_ds, batch_size=self.batch_size,  pin_memory=self.pin_memory,
            num_workers=self.num_workers, drop_last=True, shuffle=True
        )

    def val_dataloader(self):
        # Create validation dataloader object.
        return DataLoader(
            self.valid_ds, batch_size=self.batch_size,  pin_memory=self.pin_memory,
            num_workers=self.num_workers, shuffle=self.shuffle_validation
        )

In [11]:
#%%time

#dm = MedicalSegmentationDataModule(
#    num_classes=DatasetConfig.NUM_CLASSES,
#    img_size=DatasetConfig.IMAGE_SIZE,
#    ds_mean=DatasetConfig.MEAN,
#    ds_std=DatasetConfig.STD,
#    batch_size=InferenceConfig.BATCH_SIZE,
#    num_workers=0,
#    shuffle_validation=True,
#)

# Donwload dataset.
#dm.prepare_data()

# Create training & validation dataset.
#dm.setup()

#train_loader, valid_loader = dm.train_dataloader(), dm.val_dataloader()

In [12]:
def num_to_rgb(num_arr, color_map=id2color):
    single_layer = np.squeeze(num_arr)
    output = np.zeros(num_arr.shape[:2] + (3,))

    for k in color_map.keys():
        output[single_layer == k] = color_map[k]

    # return a floating point array in range [0.0, 1.0]
    return np.float32(output) / 255.0

In [13]:
# Function to overlay a segmentation map on top of an RGB image.
def image_overlay(image, segmented_image):
    alpha = 1.0  # Transparency for the original image.
    beta = 0.7  # Transparency for the segmentation map.
    gamma = 0.0  # Scalar added to each sum.

    segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)

    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    image = cv2.addWeighted(image, alpha, segmented_image, beta, gamma, image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return np.clip(image, 0.0, 1.0)

In [14]:
def denormalize(tensors, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    for c in range(3):
        tensors[:, c, :, :].mul_(std[c]).add_(mean[c])

    return torch.clamp(tensors, min=0.0, max=1.0)

In [15]:
def get_model(*, model_name, num_classes):
    model = SegformerForSemanticSegmentation.from_pretrained(
        model_name,
        num_labels=num_classes,
        ignore_mismatched_sizes=True,
    )
    return model

In [16]:
class MedicalSegmentationModel(pl.LightningModule):
    def __init__(
        self,
        model_name: str,
        num_classes: int = 10,
        init_lr: float = 0.001,
        optimizer_name: str = "Adam",
        weight_decay: float = 1e-4,
        use_scheduler: bool = False,
        scheduler_name: str = "multistep_lr",
        num_epochs: int = 1,
    ):
        super().__init__()

        # Save the arguments as hyperparameters.
        self.save_hyperparameters()

        # Loading model using the function defined above.
        self.model = get_model(model_name=self.hparams.model_name, num_classes=self.hparams.num_classes)

        # Initializing the required metric objects.
        self.mean_train_loss = MeanMetric()
        self.mean_train_f1 = MulticlassF1Score(num_classes=self.hparams.num_classes, average="macro")
        self.mean_valid_loss = MeanMetric()
        self.mean_valid_f1 = MulticlassF1Score(num_classes=self.hparams.num_classes, average="macro")

    def forward(self, data):
        outputs = self.model(pixel_values=data, return_dict=True)
        upsampled_logits = F.interpolate(outputs["logits"], size=data.shape[-2:], mode="bilinear", align_corners=False)
        return upsampled_logits

In [17]:
# Seed everything for reproducibility.
pl.seed_everything(42, workers=True)

model = MedicalSegmentationModel(
    model_name=TrainingConfig.MODEL_NAME,
    num_classes=DatasetConfig.NUM_CLASSES,
    init_lr=TrainingConfig.INIT_LR,
    optimizer_name=TrainingConfig.OPTIMIZER_NAME,
    weight_decay=TrainingConfig.WEIGHT_DECAY,
    use_scheduler=TrainingConfig.USE_SCHEDULER,
    scheduler_name=TrainingConfig.SCHEDULER,
    num_epochs=TrainingConfig.NUM_EPOCHS,
)

data_module = MedicalSegmentationDataModule(
    num_classes=DatasetConfig.NUM_CLASSES,
    img_size=DatasetConfig.IMAGE_SIZE,
    ds_mean=DatasetConfig.MEAN,
    ds_std=DatasetConfig.STD,
    batch_size=TrainingConfig.BATCH_SIZE,
    num_workers=TrainingConfig.NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
)

Seed set to 42
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
CKPT_PATH = 'models/ckpt_epoch=049-vloss_val_loss=0.0000_vf1_valid_f1=0.0000.ckpt'

In [19]:
model = MedicalSegmentationModel.load_from_checkpoint(CKPT_PATH)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
# Get the validation dataloader.

#data_module.setup()

data_module.setup(classif_labels=CLASSIF_LABELS)
valid_loader = data_module.val_dataloader()

OK!
[False False False False False False False False False False False False]


In [144]:
# Compute Dice coefficient
def dice_coef(predictions, ground_truths, num_classes=4, dims=(1, 2), smooth=1e-8):
    """Smooth Dice coefficient"""
    # (batch_size, num_classes, height, width)
    ground_truth_oh = F.one_hot(ground_truths, num_classes=num_classes)
    # (batch_size, num_classes, height, width)
    prediction_norm = F.one_hot(predictions, num_classes=num_classes)

    # (batch_size, num_classes)
    intersection = (prediction_norm * ground_truth_oh).sum(dim=dims)
    # (batch_size, num_classes)
    summation = prediction_norm.sum(dim=dims) + ground_truth_oh.sum(dim=dims)
    # (batch_size, num_classes)
    dice = (2.0 * intersection + smooth) / (summation + smooth)
    dice_mean = dice.mean()

    return dice_mean

def detect_nonzero_masks(masks):
    # Check if any element is nonzero along the last two axes
    nonzero_masks_mask = np.any(masks != 0, axis=(1, 2))
    # Get the indices of the zero images
    nonzero_masks_indices = np.where(nonzero_masks_mask)[0]
    return nonzero_masks_indices

In [151]:
@torch.inference_mode()
def inference(model, loader, img_size, device="cpu"):
    num_batches_to_process = InferenceConfig.NUM_BATCHES

    cont = 1
    cont2 = 1
    score_sum = 0
    score_sum2 = 0
    for idx, (batch_img, batch_mask, batch_labels) in enumerate(loader):

        #print(batch_mask.shape)
        print(batch_labels)
        # Make predictions
        predictions = model(batch_img.to(device))
        pred_all = predictions.argmax(dim=1).cpu().numpy()

        # Re-label to zero the masks that have been identified as Class 0 by the classifier.
        pred_all[np.where(batch_labels.numpy().copy() == False)] = 0
    
        # Plot results
        #batch_img = denormalize(batch_img.cpu(), mean=DatasetConfig.MEAN, std=DatasetConfig.STD)
        #batch_img = batch_img.permute(0, 2, 3, 1).numpy()

        if idx == num_batches_to_process * 20:
            break
            
        score = dice_coef(torch.tensor(pred_all),batch_mask).numpy()
        score_sum = score_sum + score
        score_ave = score_sum / cont
        cont = cont + 1
        print(np.round(score,3), np.round(score_ave,3))

        nonzero_mask_idxs = detect_nonzero_masks(batch_mask.numpy())

        if len(nonzero_mask_idxs) > 0:
            print(nonzero_mask_idxs)
            score2 = dice_coef(torch.tensor(pred_all[nonzero_mask_idxs]),batch_mask[nonzero_mask_idxs]).numpy()
            score_sum2 = score_sum2 + score2
            score_ave2 = score_sum2 / cont2
            cont2 = cont2 + 1
            print(np.round(score2,3), np.round(score_ave2,3))
           
#        for i in range(0, len(batch_img)):
#            fig = plt.figure(figsize=(20, 8))
#
#            # Display the original image.
#            ax1 = fig.add_subplot(1, 4, 1)
#            ax1.imshow(batch_img[i])
#            ax1.title.set_text("Actual frame")
#            plt.axis("off")

            # Display the ground truth mask.
#            true_mask_rgb = num_to_rgb(batch_mask[i], color_map=id2color)
#            ax2 = fig.add_subplot(1, 4, 2)
#            ax2.set_title("Ground truth labels")
#            ax2.imshow(true_mask_rgb)
#            plt.axis("off")

            # Display the predicted segmentation mask.
#            pred_mask_rgb = num_to_rgb(pred_all[i], color_map=id2color)
#            ax3 = fig.add_subplot(1, 4, 3)
#            ax3.set_title("Predicted labels")
#            ax3.imshow(pred_mask_rgb)
#            plt.axis("off")

            # Display the predicted segmentation mask overlayed on the original image.
#            overlayed_image = image_overlay(batch_img[i], pred_mask_rgb)
#            ax4 = fig.add_subplot(1, 4, 4)
#            ax4.set_title("Overlayed image")
#            ax4.imshow(overlayed_image)
#            plt.axis("off")
#            plt.show()   

In [152]:
# Use GPU if available.
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model.to(DEVICE)
model.eval()

inference(model, valid_loader, device=DEVICE, img_size=DatasetConfig.IMAGE_SIZE)

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False])
1.0 1.0
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False])
1.0 1.0
tensor([False, False, False, False, False, False, False, False, False, False,
         True, False])
0.958 0.986
[11]
0.75 0.75
tensor([True, True, True, True, True, True, True, True, True, True, True, True])
0.977 0.984
[ 0  1  2  3  4  5  6  7  8  9 10 11]
0.977 0.863
tensor([True, True, True, True, True, True, True, True, True, True, True, True])
0.962 0.98
[ 0  1  2  3  4  5  6  7  8  9 10 11]
0.962 0.896
tensor([True, True, True, True, True, True, True, True, True, True, True, True])
0.933 0.972
[ 0  1  2  3  4  5  6  7  8  9 10 11]
0.933 0.906
tensor([True, True, True, True, True, True, True, True, True, True, True, True])
0.925 0.965
[ 0  1  2  3  4  5  6  7  8  9 10 11]
0.925 0.909
tensor([True, True, True, True, True, True, True, True, True, True, True, True])
0.

KeyboardInterrupt: 