In [1]:
import os
import zipfile
import platform
import warnings
import time
from glob import glob
from dataclasses import dataclass

# To filter UserWarning.
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
import cv2
import requests
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# For data augmentation and preprocessing.
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Imports required SegFormer classes
from transformers import SegformerForSemanticSegmentation

# Importing lighting along with a built-in callback it provides.
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

# Importing torchmetrics modular and functional implementations.
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassF1Score

# To print model summary.
#from torchinfo import summary

# Tensor and Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model

In [3]:
# Sets the internal precision of float32 matrix multiplications.
torch.set_float32_matmul_precision('high')

# To enable determinism.
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# To render the matplotlib figure in the notebook.
%matplotlib inline

In [4]:
ROOT_PATH = "C:/Users/Sergio/Documentos/2024-02-12_Data_Science_Spice/dats-data/2d_all"
MODEL_PATH = "C:/Users/Sergio/Documentos/2024-02-12_Data_Science_Spice/dats-data/models"
CLASS_MODEL = "resnet50v2_nn256_lr0001_relu_batch32_epoch30_v2.keras"
SEG_MODEL = "ckpt_epoch=049-vloss_val_loss=0.0000_vf1_valid_f1=0.0000.ckpt"

# Configuration Class Definition

In [5]:
# Define constants
#@dataclass(frozen=True)
#class DatasetClasConfig:
#    IMAGE_SIZE: tuple[int,int] = (288, 288) # W, H
#    MEAN: tuple = (0.136, 0.136, 0.136)
#    STD:  tuple = (0.178, 0.178, 0.178)
#    BATCH_SIZE = 32
#    IMG_DIR = os.path.join(ROOT_PATH, "valid")
#    MODEL_NAME = os.path.join(MODEL_PATH, CLASS_MODEL)
#    MODEL_NPY = MODEL_NAME + '.npy'
#    THR = 0.019938506186008453
#    
# Function for custom normalization
#def custom_normalization(image):  
#    image = image / 255.0    
#    image = (image - DatasetClasConfig.MEAN[0]) / DatasetClasConfig.STD[0]
#    return image
    
# Image preprocessing
#test_datagen = ImageDataGenerator(
#    preprocessing_function = custom_normalization)

# Define the test set
#test_generator = test_datagen.flow_from_directory(
#    DatasetClasConfig.IMG_DIR,
#    target_size = DatasetClasConfig.IMAGE_SIZE,
#    batch_size = DatasetClasConfig.BATCH_SIZE,
#    classes=['images'],
#    shuffle=False
#)

# Load the classification model
#class_model = load_model(DatasetClasConfig.MODEL_NAME)

# Predict!
#y_pred = class_model.predict(test_generator).reshape(-1)

#MODEL_NPY = DatasetClasConfig.MODEL_NAME + '.npy'
#np.save(DatasetClasConfig.MODEL_NPY, y_pred)
#y_pred = np.load(DatasetClasConfig.MODEL_NPY)
#THR = 0.1586085855960846
#THR = 0.019938506186008453 #99.5% recall
#CLASSIF_LABELS = y_pred > DatasetClasConfig.THR


# Define configuration classes

@dataclass(frozen=True)
class DatasetConfig:
    NUM_CLASSES:   int = 4 # including background.
    IMAGE_SIZE: tuple[int,int] = (288, 288) # W, H
    MEAN: tuple = (0.485, 0.456, 0.406)
    STD:  tuple = (0.229, 0.224, 0.225)
    MEAN_CLF: float = 0.136
    STD_CLF: float = 0.178
    THR: float = 0.019938506186008453
    BACKGROUND_CLS_ID: int = 0
    URL: str = r"https://www.dropbox.com/scl/fi/r0685arupp33sy31qhros/dataset_UWM_GI_Tract_train_valid.zip?rlkey=w4ga9ysfiuz8vqbbywk0rdnjw&dl=1"
    DATASET_PATH: str = os.path.join(os.getcwd(), ROOT_PATH)
    MODEL_NAME_CLF = os.path.join(MODEL_PATH, CLASS_MODEL)
    MODEL_NAME_SEG = os.path.join(MODEL_PATH, SEG_MODEL)

@dataclass(frozen=True)
class Paths:
    DATA_VALID_IMAGES: str = os.path.join(DatasetConfig.DATASET_PATH, "valid", "images", r"*.png")
    DATA_VALID_LABELS: str = os.path.join(DatasetConfig.DATASET_PATH, "valid", "masks",  r"*.png")

@dataclass
class TrainingConfig:
    BATCH_SIZE:      int = 12 # 8
    NUM_EPOCHS:      int = 1
    INIT_LR:       float = 3e-4
    NUM_WORKERS:     int = 0 if platform.system() == "Windows" else os.cpu_count()

    OPTIMIZER_NAME:  str = "AdamW"
    WEIGHT_DECAY:  float = 1e-4
    USE_SCHEDULER:  bool = True # Use learning rate scheduler?
    SCHEDULER:       str = "MultiStepLR" # Name of the scheduler to use.
    MODEL_NAME:str = "nvidia/segformer-b4-finetuned-ade-512-512"
    
#@dataclass
class InferenceConfig:
    BATCH_SIZE:  int = 10
    NUM_BATCHES: int = 2

DatasetConfig.NUM_CLASSES = 4

# Class Definition

In [6]:
# Custom Class for creating training and validation (segmentation) dataset objects.

class MedicalDataset(Dataset):
    #def __init__(self, *, image_paths, mask_paths, img_size, ds_mean, ds_std, is_train=False):
    def __init__(self, *, image_paths, mask_paths, img_size, ds_mean, ds_std, ds_mean_clf, ds_std_clf, is_train=False):
        self.image_paths = image_paths
        self.mask_paths  = mask_paths
        self.is_train    = is_train
        self.img_size    = img_size
        self.ds_mean = ds_mean
        self.ds_std = ds_std
        self.ds_mean_clf = ds_mean_clf
        self.ds_std_clf = ds_std_clf
        self.transforms  = self.setup_transforms(mean=self.ds_mean, std=self.ds_std)

    def __len__(self):
        return len(self.image_paths)

    # Normalization Function for Classification
    def normalize_classif(self, image):
        image = image / 255.0
        image = (image - self.ds_mean_clf) / self.ds_std_clf
        return image
    
    def setup_transforms(self, *, mean, std):
        transforms = []

        # Augmentation to be applied to the training set.
        if self.is_train:
            transforms.extend([
                A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5),
                A.ShiftScaleRotate(scale_limit=0.12, rotate_limit=0.15, shift_limit=0.12, p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.CoarseDropout(max_holes=8, max_height=self.img_size[1]//20, max_width=self.img_size[0]//20, min_holes=5, fill_value=0, mask_fill_value=0, p=0.5)
            ])

        # Preprocess transforms - Normalization and converting to PyTorch tensor format (HWC --> CHW).
        transforms.extend([
                A.Normalize(mean=mean, std=std, always_apply=True),
                ToTensorV2(always_apply=True),  # (H, W, C) --> (C, H, W)
        ])
        return A.Compose(transforms)

    def load_file(self, file_path, depth=0):
        file = cv2.imread(file_path, depth)
        if depth == cv2.IMREAD_COLOR:
            file = file[:, :, ::-1]
        return cv2.resize(file, (self.img_size), interpolation=cv2.INTER_NEAREST)

    def load_file_2(self, file_path, depth=0):
        file = cv2.imread(file_path, depth)
        if depth == cv2.IMREAD_COLOR:
            file = file[:, :, ::-1]
        return cv2.resize(file, (self.img_size), interpolation=cv2.INTER_LINEAR) #INTER_CUBIC #INTER_LANCZOS4  

    def load_file_3(self, file_path, depth=0):
        file = cv2.imread(file_path, depth)
        if depth == cv2.IMREAD_COLOR:
            file = file[:, :, ::-1]
        return cv2.resize(file, (self.img_size), interpolation=cv2.INTER_CUBIC) #INTER_LANCZOS4  

    def __getitem__(self, index):
        # Load image and mask file.
        image = self.load_file_3(self.image_paths[index], depth=cv2.IMREAD_COLOR)       
        mask  = self.load_file_3(self.mask_paths[index],  depth=cv2.IMREAD_GRAYSCALE)
        
        # Apply Preprocessing (+ Augmentations) transformations to image-mask pair
        image_clf = self.normalize_classif(self.load_file_3(self.image_paths[index], depth=cv2.IMREAD_COLOR))
        transformed = self.transforms(image=image, mask=mask)
        image, mask = transformed["image"], transformed["mask"].to(torch.long)        
        return image_clf, image, mask

In [7]:
class MedicalSegmentationDataModule(pl.LightningDataModule):
    def __init__(
        self,
        num_classes=10,
        img_size=(384, 384),
        ds_mean=(0.485, 0.456, 0.406),
        ds_std=(0.229, 0.224, 0.225),
        ds_mean_clf=0.136,
        ds_std_clf=0.178,
        batch_size=12,
        num_workers=3,
        pin_memory=False,
        shuffle_validation=False,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.img_size    = img_size
        self.ds_mean     = ds_mean
        self.ds_std      = ds_std
        self.ds_mean_clf = ds_mean_clf
        self.ds_std_clf  = ds_std_clf
        self.batch_size  = batch_size
        self.num_workers = num_workers
        self.pin_memory  = pin_memory
        self.shuffle_validation = shuffle_validation

    def setup(self, *args, **kwargs):
        # Create validation dataset and dataloader.
        valid_imgs = sorted(glob(f"{Paths.DATA_VALID_IMAGES}"))        
        valid_msks = sorted(glob(f"{Paths.DATA_VALID_LABELS}"))
        self.valid_ds = MedicalDataset(image_paths=valid_imgs, mask_paths=valid_msks, img_size=self.img_size, 
                                       is_train=False, ds_mean=self.ds_mean, ds_std=self.ds_std,
                                       ds_mean_clf=self.ds_mean_clf, ds_std_clf=self.ds_std_clf)

    def val_dataloader(self):
        # Create validation dataloader object.
        return DataLoader(
            self.valid_ds, batch_size=self.batch_size,  pin_memory=self.pin_memory,
            num_workers=self.num_workers, shuffle=self.shuffle_validation
        )

In [8]:
def get_model(*, model_name, num_classes):
    model = SegformerForSemanticSegmentation.from_pretrained(
        model_name,
        num_labels=num_classes,
        ignore_mismatched_sizes=True,
    )
    return model

In [9]:
class MedicalSegmentationModel(pl.LightningModule):
    def __init__(
        self,
        model_name: str,
        num_classes: int = 10,
        init_lr: float = 0.001,
        optimizer_name: str = "Adam",
        weight_decay: float = 1e-4,
        use_scheduler: bool = False,
        scheduler_name: str = "multistep_lr",
        num_epochs: int = 1,
    ):
        super().__init__()

        # Save the arguments as hyperparameters.
        self.save_hyperparameters()

        # Loading model using the function defined above.
        self.model = get_model(model_name=self.hparams.model_name, num_classes=self.hparams.num_classes)

        # Initializing the required metric objects.
        self.mean_train_loss = MeanMetric()
        self.mean_train_f1 = MulticlassF1Score(num_classes=self.hparams.num_classes, average="macro")
        self.mean_valid_loss = MeanMetric()
        self.mean_valid_f1 = MulticlassF1Score(num_classes=self.hparams.num_classes, average="macro")

    def forward(self, data):
        outputs = self.model(pixel_values=data, return_dict=True)
        upsampled_logits = F.interpolate(outputs["logits"], size=data.shape[-2:], mode="bilinear", align_corners=False)
        return upsampled_logits

# Evaluation Functions

In [17]:
# Compute Dice coefficient
def dice_coef(predictions, ground_truths, num_classes=4, dims=(1, 2), smooth=1e-8):
    """Smooth Dice coefficient"""
    print(np.amax(predictions.numpy()))
    print(np.amax(ground_truths.numpy()))
    # (batch_size, num_classes, height, width)
    ground_truth_oh = F.one_hot(ground_truths, num_classes=num_classes)
    # (batch_size, num_classes, height, width)
    prediction_norm = F.one_hot(predictions, num_classes=num_classes)

    # (batch_size, num_classes)
    intersection = (prediction_norm * ground_truth_oh).sum(dim=dims)
    # (batch_size, num_classes)
    summation = prediction_norm.sum(dim=dims) + ground_truth_oh.sum(dim=dims)
    # (batch_size, num_classes)
    dice = (2.0 * intersection + smooth) / (summation + smooth)
    dice_mean = dice.mean()

    return dice_mean

def detect_nonzero_masks(masks):
    # Check if any element is nonzero along the last two axes
    nonzero_masks_mask = np.any(masks != 0, axis=(1, 2))
    # Get the indices of the zero images
    nonzero_masks_indices = np.where(nonzero_masks_mask)[0]
    return nonzero_masks_indices

In [11]:
@torch.inference_mode()
def inference(class_model, seg_model, loader, img_size, device="cpu"):
    num_batches_to_process = InferenceConfig.NUM_BATCHES

    cont = 0
    cont_2 = 0
    score_sum = 0
    score_sum_2 = 0
    score = 1.0
    score_2 = 1.0
    score_ave = 1.0
    score_ave_2 = 1.0
    time_acc = 0

    for idx, (batch_img_clf, batch_img_seg, batch_mask) in enumerate(loader):

        start = time.perf_counter()

        # Classification predictions
        #y_pred_clf = class_model.predict(batch_img_clf, verbose=0).reshape(-1)
        #clf_labels = y_pred_clf > DatasetConfig.THR
                
        # Segmentation predictions
        predictions = seg_model(batch_img_seg.to(device))
        pred_all = predictions.argmax(dim=1).cpu().numpy()

        # Re-label to zero the masks that have been identified as Class 0 by the classifier.
        #pred_all[np.where(clf_labels == False)] = 0

        end = time.perf_counter()
        time_acc = time_acc + (end - start)

        cont = cont + 1
        score = dice_coef(torch.tensor(pred_all), batch_mask).numpy()        
        score_sum = score_sum + score
        score_ave = score_sum / cont
        
        nonzero_mask_idxs = detect_nonzero_masks(batch_mask.numpy())
        
        if len(nonzero_mask_idxs) > 0:
            cont_2 = cont_2 + 1
            score_2 = dice_coef(torch.tensor(pred_all[nonzero_mask_idxs]),batch_mask[nonzero_mask_idxs]).numpy()
            score_sum_2 = score_sum_2 + score_2
            score_ave_2 = score_sum_2 / cont_2      

        print("-----" * 5)
        print(f"ave dice_score (all masks): {np.round(score_ave,3)}")
        print(f"ave dice_score (only segmented): {np.round(score_ave_2,3)}") 

        if idx == num_batches_to_process * 20:
            break
                          
    time_per_image = time_acc / cont

    return score_ave, score_ave_2, time_per_image

# Model Loading

In [12]:
# Seed everything for reproducibility.
pl.seed_everything(42, workers=True)

model = MedicalSegmentationModel(
    model_name=TrainingConfig.MODEL_NAME,
    num_classes=DatasetConfig.NUM_CLASSES,
    init_lr=TrainingConfig.INIT_LR,
    optimizer_name=TrainingConfig.OPTIMIZER_NAME,
    weight_decay=TrainingConfig.WEIGHT_DECAY,
    use_scheduler=TrainingConfig.USE_SCHEDULER,
    scheduler_name=TrainingConfig.SCHEDULER,
    num_epochs=TrainingConfig.NUM_EPOCHS,
)

data_module = MedicalSegmentationDataModule(
    num_classes=DatasetConfig.NUM_CLASSES,
    img_size=DatasetConfig.IMAGE_SIZE,
    ds_mean=DatasetConfig.MEAN,
    ds_std=DatasetConfig.STD,
    ds_mean_clf=DatasetConfig.MEAN_CLF,
    ds_std_clf=DatasetConfig.STD_CLF,
    batch_size=TrainingConfig.BATCH_SIZE,
    num_workers=TrainingConfig.NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
)

Seed set to 42
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# Load classification model
class_model = load_model(DatasetConfig.MODEL_NAME_CLF)

In [14]:
# Load segmentation model
seg_model = MedicalSegmentationModel.load_from_checkpoint(DatasetConfig.MODEL_NAME_SEG)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Model Evaluation

In [18]:
# Get the validation dataloader.
data_module.setup()
valid_loader = data_module.val_dataloader()

# Use GPU if available.
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model.to(DEVICE)
model.eval()

score_ave, score_ave_2, time_per_image = inference(class_model, seg_model, valid_loader, device=DEVICE, img_size=DatasetConfig.IMAGE_SIZE)

print("-----" * 5)
print(f"ave dice_score (all masks): {np.round(score_ave,3)}")
print(f"ave dice_score (only segmented): {np.round(score_ave_2,3)}") 
print(f"Segmentation time per image: {np.round(time_per_image,3)}")


1
0
-------------------------
ave dice_score (all masks): 0.75
ave dice_score (only segmented): 1.0
1
0
-------------------------
ave dice_score (all masks): 0.864
ave dice_score (only segmented): 1.0
1
1
1
1
-------------------------
ave dice_score (all masks): 0.86
ave dice_score (only segmented): 0.954
3
4


RuntimeError: Class values must be smaller than num_classes.