In [1]:
import os
import shutil
import glob

import pandas as pd
import numpy as np
import seaborn as sns
import cv2
import nibabel as nib
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
from torchmetrics.classification import BinaryJaccardIndex, Dice
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torchvision

# %% environment and functions
import matplotlib.pyplot as plt
join = os.path.join
from segment_anything import sam_model_registry
from skimage import io, transform


In [2]:
#!wget -O medsam_vit_b.pth https://zenodo.org/records/10689643/files/medsam_vit_b.pth

# Dataset Class

In [3]:
class BraTSDataset(Dataset):    
    def __init__(self, data_root_folder, folder = '', n_sample=None):
        main_folder = os.path.join(data_root_folder, folder)
        self.folder_path = os.path.join(main_folder, 'slice')

    def __getitem__(self, index):
        file_name = os.listdir(self.folder_path)[index]
        sample = torch.from_numpy(np.load(os.path.join(self.folder_path, file_name)))
        img_as_tensor = np.expand_dims(sample[0,:,:], axis=0)
        mask_as_tensor = np.expand_dims(sample[1,:,:], axis=0)
        return {
            'image': img_as_tensor,
            'mask': mask_as_tensor,
            'img_id': file_name
        }
 
    def __len__(self):
        return len(os.listdir(self.folder_path))

# Load Dataset

In [4]:
data_root_folder = '/kaggle/input/full_raw - Copy'
train_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'train')
val_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'val')
test_dataset = BraTSDataset(data_root_folder = data_root_folder, folder = 'test')

In [19]:
BATCH_SIZE = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [20]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Load MedSAM

## Self-defined functions

In [21]:
def convert_to_binary(arr):
    return np.where(arr > 0, 1, 0)

In [22]:
def visualize_masks(nii_img):
    nii_aff  = nii_img.affine
    nii_hdr  = nii_img.header

    # print(nii_aff ,'\n',nii_hdr)
    # print(nii_data.shape)
    nii_data = nii_img.get_fdata()
    number_of_slices = nii_data.shape[1]
    number_of_frames = nii_data.shape[2]
    number_of_frames = 1

    # Define the number of columns for subplot
    num_columns = 5

    if(len(nii_data.shape)==3):
        start_slice = 80
        end_slice = 100
        num_slices = end_slice - start_slice + 1
        num_rows = num_slices // num_columns + 1

        fig, ax = plt.subplots(num_rows, num_columns, figsize=(10,10))

        for i in range(start_slice, end_slice+1):
            row = (i-start_slice) // num_columns
            col = (i-start_slice) % num_columns
            ax[row, col].imshow(nii_data[:,:,i])
            ax[row, col].axis('off')

        # Remove empty subplots
        for j in range(i-start_slice+1, num_rows*num_columns):
            fig.delaxes(ax.flatten()[j])

        plt.show()

In [23]:
def box_coordinates_from_mask(mask, margin = 5):
    # Works on [240, 240]
    # Modify for this to work on [16, 1, 240, 240]
    nonzero_indices = np.transpose(np.nonzero(mask))

    min_y = np.min(nonzero_indices[:, 0])
    max_y = np.max(nonzero_indices[:, 0])
    min_x = np.min(nonzero_indices[:, 1])
    max_x = np.max(nonzero_indices[:, 1])

    margin = 5

    x0 = min_x - margin
    y0 = min_y - margin
    x1 = max_x + margin
    y1 = max_y + margin

    # print(f"Row and column for non-zero elements with greatest x range: [{min_x}, {max_x}]")
    # print(f"Row and column for non-zero elements with greatest y range: [{min_y}, {max_y}]")
    
    box_np = np.array([[x0,y0, x1, y1]])
    box_1024 = box_np / np.array([W, H, W, H]) * 1024
    return box_1024

In [24]:
def box_coordinates_from_mask_batch(batch_mask, margin = 5):
    W = 240
    H = 240
    margin = 5
    box_1024_batch = []
    useless_image_indices = []
    
    for ind, mask in enumerate(batch_mask):
        mask = mask[0]
        nonzero_indices = torch.nonzero(mask).T
        if nonzero_indices.shape[-1] == 0:
            useless_image_indices.append(ind)
#             #box_1024 = torch.asarray([[-1, -1, -1, -1]])
        else:
            min_y = torch.min(nonzero_indices[0, :]).values
            max_y = torch.max(nonzero_indices[0, :]).values
            min_x = torch.min(nonzero_indices[1, :]).values
            max_x = torch.max(nonzero_indices[1, :]).values

            x0 = min_x - margin
            y0 = min_y - margin
            x1 = max_x + margin
            y1 = max_y + margin

            box_np = torch.asarray([[x0,y0, x1, y1]])
            box_1024 = box_np / torch.asarray([W, H, W, H]) * 1024
            box_1024_batch.append(box_1024)
    
    final_tensor = torch.cat(box_1024_batch, dim = 0)
    final_tensor = final_tensor.unsqueeze(dim = 1)
    return final_tensor, useless_image_indices

## MedSAM functions

In [25]:
# visualization functions
# source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
# change color to avoid red and green
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251/255, 252/255, 30/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))

In [26]:
#%% load model and image
MedSAM_CKPT_PATH = "/kaggle/working/medsam_vit_b.pth"
device = "cuda:0"
medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = medsam_model.to(device)
medsam_model.eval()

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )


In [35]:
def box_coordinates_from_mask_batch(batch_mask, margin = 5, W = 240, H = 240):
    box_1024_batch = []
    
    for ind, mask in enumerate(batch_mask):
        mask = mask[0]
        nonzero_indices = torch.nonzero(mask).T
        
        #if(nonzero_indices.shape[-1] == 0):
        min_y = torch.min(nonzero_indices[0, :])
        max_y = torch.max(nonzero_indices[0, :])
        min_x = torch.min(nonzero_indices[1, :])
        max_x = torch.max(nonzero_indices[1, :])

        x0 = min_x - margin
        y0 = min_y - margin
        x1 = max_x + margin
        y1 = max_y + margin

        box = torch.asarray([[x0,y0, x1, y1]])
        box = box / torch.asarray([W, H, W, H]) * 1024
        box_1024_batch.append(box)
    
    box_batch = torch.cat(box_1024_batch, dim = 0)
    box_batch = box_batch.unsqueeze(dim = 1)
    return box_batch

def generate_box_and_embedding(img, mask, W = 240, H = 240):

    # resize image, make it 3D
    resizer = torchvision.transforms.Resize(size = (1024,1024))
    img = resizer(img)
    img = img.repeat(1, 3, 1, 1)
    
    #mask = resizer(mask)
    # generate bounding boxes
    box_batch = box_coordinates_from_mask_batch(mask)
  
    with torch.no_grad():
        image_embedding = medsam_model.image_encoder(img) # (1, 256, 64, 64)
        
    return image_embedding, mask, box_batch

In [36]:
# Medsam inference code
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :] # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    
#     print(f'img_embed: {img_embed.shape} should be (B, 256, 64, 64)')
#     print(f'sparse_embeddings : {sparse_embeddings.shape} (B, 2, 256)')
#     print(f'dense_prompt_embeddings : {dense_embeddings.shape} (B, 256, 64, 64)')
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed, # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
        multimask_output=False,
        )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg

In [None]:
valid_batch_dice = []
valid_batch_jaccard = []
dice_metric = Dice().to(device)
jaccard_index_metric = BinaryJaccardIndex().to(device)
with torch.no_grad():
    H = 240
    W = 240
    for i, batch in enumerate(tqdm(test_dataloader)):
        # get image and masks from dataloader
        imgs = batch['image'].to(device).float()
        img_masks = batch['mask'].to(device).float()
        
        indices_to_remove = []
        
        for idx, img_mask in enumerate(img_masks):
            if (torch.sum(img_mask) == 0):
                indices_to_remove.append(idx)
        
        if(len(indices_to_remove) != BATCH_SIZE):
            indices_to_select = set(range(0, BATCH_SIZE)).difference(set(indices_to_remove))
            indices_to_select = list(indices_to_select)

            imgs = imgs[indices_to_select,:,:,:]
            img_masks = img_masks[indices_to_select,:,:,:]

            image_embedding, true_mask, box = generate_box_and_embedding(imgs, img_masks)
             
            true_mask = true_mask.int()
            true_mask = true_mask[0][0]
            y_pred = medsam_inference(medsam_model, image_embedding, box, H, W)
            y_pred = torch.from_numpy(y_pred)
            #y_pred = y_pred.unsqueeze(dim = 1)
            
            y_pred = y_pred.to(device)
            true_mask = true_mask.to(device)

            batch_dice_score = dice_metric(y_pred, true_mask)
            valid_batch_dice.append(batch_dice_score)

            batch_jaccard_score = jaccard_index_metric(y_pred, true_mask)
            valid_batch_jaccard.append(batch_jaccard_score)

            print(f'DICE score: {batch_dice_score}, Jaccard score: {batch_jaccard_score}', end='\r')
    print(f'Validation DICE score: {np.mean(valid_batch_dice)}, Validation Jaccard score: {np.mean(jaccard_index_metric)}', end='\r')

  0%|          | 0/28985 [00:00<?, ?it/s]

DICE score: 0.9967882037162781, Jaccard score: 0.319852948188781743

In [None]:
train_loss, train_dice, train_jaccard, valid_loss, valid_dice, valid_jaccard = train_net(ddp_model, EPOCHS, train_dataloader, validation_dataloader, optimizer, loss_function)