In [None]:
!pip install -q segmentation_models_pytorch transformers accelerate datasets
!nvidia-smi

In [None]:
from os.path import dirname, abspath
import os
import copy
import cv2
import glob
import torch
from PIL import Image
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
import segmentation_models_pytorch as smp
import torch.optim as optim
from tqdm import tqdm
import gc
import torch.nn.functional as F
from collections import defaultdict
import time
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
import albumentations as A
from typing import Optional, List, Dict
from torch.cuda import amp
import transformers

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

### Class to store information

In [None]:
class CFG:
    seed          = 42
    debug         = False
    saved_model_path = "/kaggle/working"
    train_bs      = 32
    valid_bs      = 32
    img_size      = [512, 512]
    train_groups  = ["kidney_1_dense", "kidney_2"]
    valid_groups  = ["kidney_3_sparse"]
    scheduler     = "CosineAnnealingLR"
    n_fold        = 5
    num_classes   = 1
    n_accumulate  = max(1, 64//train_bs)
    data_dir = os.path.join(dirname(os.getcwd()), "input", "blood-vessel-segmentation")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pretrained = {
        "unet_parameters": os.path.join(dirname(os.getcwd()), "input", "saved-models/unet-epoch10-lr0.0001.pth"),
        "sam_parameters": os.path.join(dirname(os.getcwd()), "input", "saved-models/medsam_pretrained_parameters.pth"),
        "processor": os.path.join(dirname(os.getcwd()), "input", "saved-models/preprocessor_config.json"),
    }
    
    data_transforms = {
        "train": A.Compose([
            A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
            A.HorizontalFlip(p=0.5),
        ], p=1.0),
        "valid": A.Compose([
            A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
    }
    
    data_augmentation = {
        "default": A.Compose([
            A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
            A.RandomCrop(height=256, width=256, always_apply=True),
            A.RandomBrightness(p=1),
            A.OneOf(
                [
                    A.Blur(blur_limit=3, p=1),
                    A.MotionBlur(blur_limit=3, p=1),
                ],
                p=0.9,
            ),

        ])
    }
    
    sam_transformations = {
        "default": A.Compose([
            A.Resize(*[1024, 1024], interpolation=cv2.INTER_NEAREST),
#             A.HorizontalFlip(p=0.5),
#             A.VerticalFlip(p=0.5)
#             A.RandomBrightness(p=1),
#             A.OneOf(
#                 [
#                     A.Blur(blur_limit=3, p=1),
#                     A.MotionBlur(blur_limit=3, p=1),
#                 ],
#                 p=0.9,
#             )
        ])
    }

### Setting seed to ensure reproducibility

In [None]:
torch.manual_seed(CFG.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(CFG.seed)

### Visualize the mask and kidney slice images

In [None]:
def list_files_in_directory(directory_path: str) -> List[str]:
    direc = []
    for root, dirs, files in os.walk(directory_path):
        for dire in dirs:
            if dire in ["labels", "images"]:
                continue 
            file_path = os.path.join(root, dire)
            direc.append(file_path)
    return direc

def count_total_img(folders: List[str]) -> Dict:
    """
    Count numbers of samples per training folder
    """
    sub_f = ["images", "labels"]
    path = []
    total_files = []
    random_dir = random.choice(folders)
    for dire in folders:
        for subf in sub_f:
            if (dire == "/kaggle/input/blood-vessel-segmentation/train/kidney_3_dense") & (subf == "images"):
                continue 
            
            _dir = dire + "/" + subf
            total_sample = len(os.listdir(_dir))
            print(f"{_dir}: {total_sample}")
            path.append(_dir)
            total_files.append(total_sample)
    obj = {
        "path": path,
        "total_files":total_files
    }
    return obj

def display_random_img(folders: Dict):
    """
    Display random slice image from data and its corresponding mask
    """
    _paths = list(zip(folders['path'], folders['total_files']))
    _path_tup = random.choice(_paths)
    split_text = _path_tup[0].split("/")

    img_no = random.choice(range(_path_tup[1]))
    img_no = f"{img_no:04}"
    random_img_no = str(img_no)
    _IMG_PATH =  _path_tup[0] + '/' + random_img_no +".tif" 
    IMG_PATH = _IMG_PATH.replace("labels","images")
    LABEL_PATH = IMG_PATH.replace("images","labels")
    
    if "kidney_3_dense" in split_text:
        IMG_PATH = IMG_PATH
        LABEL_PATH = IMG_PATH.replace("kidney_3_dense","kidney_3_sparse").replace("labels","images")
        
    try:
        print(IMG_PATH)
        print(LABEL_PATH)
        _slice = plt.imread(IMG_PATH)
        _mask = plt.imread(LABEL_PATH)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(_slice)
        print(_slice.shape)
        print(_mask.shape)
        plt.title(f'3D image slice: {img_no}')
        plt.subplot(1, 2, 2)
        plt.imshow(_mask)
        plt.title(f'Mask: {img_no}')
        plt.show()
    except Exception as e:
        print(f"An error occurred:{e}")
        
def dice_coeff(prediction, target):
    mask = np.zeros_like(prediction)
    mask[prediction >= 0.5] = 1

    inter = np.sum(mask * target)
    union = np.sum(mask) + np.sum(target)
    epsilon = 1e-6
    result = np.mean(2 * inter / (union + epsilon))
    return result
        
def remove_small_objects(img, min_size):
    # Find all connected components (labels)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=8)

    # Create a mask where small objects are removed
    new_img = np.zeros_like(img)
    for label in range(1, num_labels):
        if stats[label, cv2.CC_STAT_AREA] >= min_size:
            new_img[labels == label] = 255

    return new_img

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    rle = ' '.join(str(x) for x in runs)
    if rle == '':
        rle = '1 0'
    return rle

def dice_coeff_batch(predictions, targets):
    # Initialize an array to store dice coefficients for each pair in the batch
    dice_coeffs = np.zeros(predictions.shape[0])

    for i in range(predictions.shape[0]):
        # Extract the single prediction and target pair from the batch
        prediction = predictions[i]
        target = targets[i]

        # Create a binary mask for the prediction
        mask = np.zeros_like(prediction)
        mask[prediction >= 0.5] = 1
        
        # Compute intersection and union
        inter = np.sum(mask * target)
        union = np.sum(mask) + np.sum(target)
        # Add a small epsilon to avoid division by zero
        epsilon = 1e-6
        # Compute Dice coefficient for the current pair
        dice_coeffs[i] = 2 * inter / (union + epsilon)
        
    return np.sum(dice_coeffs)

In [None]:
train_folders = list_files_in_directory(os.path.join(CFG.data_dir, "train"))
test_folders = list_files_in_directory(os.path.join(CFG.data_dir, "test"))
train_file_dir = count_total_img(train_folders)
train_folders.remove("/kaggle/input/blood-vessel-segmentation/train/kidney_3_dense")
print(train_folders)
# test_file_dir = count_total_img(test_folders)

In [None]:
display_random_img(train_file_dir)

### Construct datasets from image data

In [None]:
class VesselDataset(Dataset):
    @classmethod
    def load_image(cls, image_path: str):
        img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
        img = np.tile(img[...,None], [1, 1, 3])          # gray to rgb
        img = img.astype('float32')
        mx = np.max(img)
        if mx:
            img/=mx # scale image to [0, 1]
        return img
        
    @classmethod
    def load_mask(cls, mask_path: str):
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask = mask.astype('float32')
        mask /= 255.0
        return mask
    
    
    def __init__(self, data_dir: str, mode: str, transforms=None):
        assert(mode == "train" or mode == "test")
        self.mode = mode
        self.image_dir = os.path.join(data_dir, "images")
        self.image_files = os.listdir(self.image_dir)
        if mode == "train":
            self.mask_dir = os.path.join(data_dir, "labels")
            self.mask_files = os.listdir(self.mask_dir)
        self.transforms = transforms
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx: int):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        image = VesselDataset.load_image(image_path)
        
        if self.mode == "train":
            mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
            mask = VesselDataset.load_mask(mask_path)
            if self.transforms:
                data = self.transforms(image=image, mask=mask)
                image, mask = data["image"], data["mask"]
            image = np.transpose(image, (2, 0, 1))
            return torch.tensor(image), torch.tensor(mask)
        else:
            shape = image.shape
            if self.transforms:
                image = self.transforms(image=image)["image"]
            image = np.transpose(image, (2, 0, 1))
            return torch.tensor(image), torch.tensor(np.array([shape[0], shape[1]]))


class DiceCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super().__init__()

    def forward(self, inputs, targets, smooth=1):
#         print("inputs: ", inputs.shape)
#         print("targets: ", targets.shape)
        
        inputs = inputs.reshape(-1)
        targets = targets.reshape(-1)
        
        # Dice Loss
        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        
        # Cross Entropy Loss
#         CE = F.cross_entropy(inputs, targets, reduction='mean')
        
        # Combine Dice and BCE
#         Dice_CE = CE + dice_loss
        return dice_loss


### Unet model

In [None]:
train_groups = [os.path.join(CFG.data_dir, "train", folder) for folder in CFG.train_groups]
val_groups = [os.path.join(CFG.data_dir, "train", folder) for folder in CFG.valid_groups]

train_dataset = VesselDataset(train_groups[0], mode="train", transforms=CFG.data_augmentation['default'])
for idx in range(1, len(train_groups)):
    train_dataset += VesselDataset(train_groups[idx], mode="train", transforms=CFG.data_augmentation['default'])

valid_dataset = VesselDataset(val_groups[0], mode="train", transforms=CFG.data_transforms['valid'])
for idx in range(1, len(val_groups)):
    valid_dataset += VesselDataset(val_groups[idx], mode="train", transforms=CFG.data_transforms['valid'])

In [None]:
def train_one_epoch(model, criterion, optimizer, dataloader, epoch):
    model.train()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks) in pbar:         
        images = images.to(CFG.device, dtype=torch.float)
        masks  = masks.to(CFG.device, dtype=torch.float)
        batch_size = images.size(0)
        
        optimizer.zero_grad()
        masks_pred = model(images)
        loss = criterion(masks_pred.squeeze(), masks)
        loss.backward()
        optimizer.step()
        
        if step % 10 == 0:
            print(f"Step {step}, Loss: {loss}")

        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
    epoch_loss = running_loss / dataset_size

    mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
    current_lr = optimizer.param_groups[0]['lr']
    pbar.set_postfix( epoch=f'{epoch}',
                      train_loss=f'{epoch_loss:0.4f}',
                      lr=f'{current_lr:0.5f}',
                      gpu_mem=f'{mem:0.2f} GB')
    torch.cuda.empty_cache()
    gc.collect()
    return epoch_loss

@torch.no_grad()
def valid_one_epoch(model, criterion, dataloader):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    dice_coeff = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    for step, (images, masks) in pbar:
        images  = images.to(CFG.device, dtype=torch.float)
        masks   = masks.to(CFG.device, dtype=torch.float)
        
        with torch.no_grad():
            batch_size = images.size(0)
            masks_pred = model(images)
            loss = criterion(masks_pred.squeeze(), masks)

            running_loss += (loss.item() * batch_size)
            dataset_size += batch_size
            dice_coeff += dice_coeff_batch(masks_pred.to("cpu").numpy().squeeze(), masks.to("cpu").numpy().squeeze())
        
    epoch_loss = running_loss / dataset_size
    epoch_dice_coeff = dice_coeff / dataset_size
    mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
    current_lr = optimizer.param_groups[0]['lr']
    pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                    lr=f'{current_lr:0.5f}',
                    gpu_memory=f'{mem:0.2f} GB')
#     val_scores  = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()
    return epoch_loss, epoch_dice_coeff

def train(model, criterion, optimizer, num_epochs, model_name, train_loader, valid_loader):
    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))

    start_time = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss      = np.inf
    best_epoch     = -1
    history = defaultdict(list)

    for epoch in range(1, num_epochs + 1):
        gc.collect()
        start = time.time()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = train_one_epoch(
            model,
            criterion,
            optimizer, 
            dataloader=train_loader, 
            epoch=epoch
        )
        
        val_loss, dice_coeff = valid_one_epoch(
            model,
            criterion,
            valid_loader
        )
#         val_dice, val_jaccard = val_scores
        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history["Dice Coeff"].append(dice_coeff)
#         history['Valid Dice'].append(val_dice)
#         history['Valid Jaccard'].append(val_jaccard)
#         print(f'Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}')
        print(f'Valid Loss: {val_loss}')
    
        # deep copy the model
        if val_loss <= best_loss:
            print(f"Valid loss Improved ({best_loss} ---> {val_loss})")
#             best_dice = val_dice
#             best_jaccard = val_jaccard
            best_loss = val_loss
            best_epoch = epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = "/kaggle/working/best_epoch.bin"
            torch.save(model.state_dict(), PATH)
            print(f"Model Saved to: {PATH}")
            
        epoch_duration = time.time() - start
        print("Epoch finish in {:.0f}h {:.0f}m {:.0f}s".format(
            epoch_duration // 3600, (epoch_duration % 3600) // 60, (epoch_duration % 3600) % 60)
        )
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = "last_epoch.bin"
        torch.save(model.state_dict(), PATH)

    end_time = time.time()
    time_elapsed = end_time - start_time
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60)
    )
    print("Best Loss: {:.4f}".format(best_loss))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), os.path.join(CFG.saved_model_path, model_name))
    print(f"Model saved to {os.path.join(CFG.saved_model_path, model_name)}")
    return model, history

In [None]:
# import albumentations as A
# from albumentations.pytorch import ToTensorV2


# DATASET_FOLDER = '/kaggle/input/blood-vessel-segmentation'
# IMG_PATH = DATASET_FOLDER + '/train'
# TO_VISUALIZE = 'kidney_1_dense'

# df_train = pd.read_csv(os.path.join(DATASET_FOLDER, "train_rles.csv"))
# df_train[["dataset", "slice"]] = df_train['id'].str.rsplit(pat='_', n=1, expand=True)

# valid_transforms = A.Compose([
#     A.HorizontalFlip(p=0.5),
#     A.RandomRotate90(p=0.5),
#     A.ToRGB(),
#     ToTensorV2()
# ])

# train_dataset = SenNetHOATiledDataset(
#                             df_train.loc[df_train.dataset == TO_VISUALIZE],
#                             path_img_dir=IMG_PATH,
#                             tile_size=[800, 800],
#                             empty_tile_pct=0.0,
#                             transforms=valid_transforms)

# train_loader = DataLoader(
#     train_dataset,
#     batch_size=8,
#     num_workers=4,
#     shuffle=False,
#     pin_memory=True
# )

In [None]:
unet_model = smp.Unet(
    encoder_name="resnet50",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None,
)
unet_model.to(CFG.device)

lr = 1e-2
num_epoch = 4
# model.load_state_dict(torch.load(checkpoint_path))
criterion = smp.losses.DiceLoss(mode='binary')
optimizer = optim.AdamW(unet_model.parameters(), lr=lr)

train_loader = DataLoader(train_dataset, batch_size=CFG.train_bs, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_bs, shuffle=False)

model, history = train(
    unet_model,
    criterion,
    optimizer,
    num_epochs=num_epoch, 
    model_name=f"unet-epoch{num_epoch}-lr{lr}.pth",
    train_loader=train_loader,
    valid_loader=valid_loader
)

### Sanity Check

In [None]:
images, masks = next(iter(valid_loader))

model.load_state_dict(torch.load("/kaggle/working/best_epoch.bin"))

with torch.no_grad():
    images  = images.to(CFG.device, dtype=torch.float)
    masks   = masks.to(CFG.device, dtype=torch.float)
    masks_pred = model(images)
    
    loss = criterion(masks_pred, masks)
    masks_pred = (nn.Sigmoid()(masks_pred) > 0.9).double()
    masks = masks.to("cpu").numpy().squeeze()
    masks_pred = masks_pred.to("cpu").numpy().squeeze()

    print("loss: ", loss)
for i in range(4):
    fig, axes = plt.subplots(1, 2, figsize=(10, 15))
    axes[0].set_title("Predicted Mask")
    axes[0].imshow(masks_pred[i])
    
    axes[1].set_title("True Mask")
    axes[1].imshow(masks[i])

In [None]:
model_name = f"unet-epoch{num_epoch}-lr{lr}.pth"
model_path = os.path.join(dirname(CFG.data_dir), "saved-models", model_name)
unet_model.load_state_dict(torch.load(model_path))

test_loader = DataLoader(test_dataset, batch_size=CFG.valid_bs, shuffle=False)
it = iter(test_loader)
images, masks = next(it)
images = images.to(CFG.device, dtype=torch.float)
masks = masks.to(CFG.device, dtype=torch.float)
with amp.autocast(enabled=True):
    masks_pred = unet_model(images)
    
    
for i in range(len(masks_pred)):
    pred = masks_pred[i]
    pred_np = pred.cpu().detach().numpy()
#     real_np = real.cpu().detach().numpy()
    image = images[i]
    image_np = image.cpu().detach().numpy()
    fig, axes = plt.subplots(1,2)
    axes[0].imshow(np.transpose(np.array(image_np, dtype="uint8"), (1, 2, 0)))
    axes[0].set_title(f"Input Image {i + 1}")
    axes[1].imshow(np.transpose(pred_np, (1, 2, 0)))
    axes[1].set_title(f"Predicted Mask for image {i + 1}")
#     axes[2].imshow(real_np)
#     axes[2].set_title(f"Real Mask for image {i + 1}")
    fig.show()
    

### SAM Model

#### Predicted mask size for SAM model will always be (256, 256), and SAM processor will transform input image to size (1024, 1024) since it uses ViT as image encoder (future -> we could try to replace image encoder with Swin Transformer instead to take abitrary inputs)

In [None]:
def resize_mask(image):
    longest_edge = 256
    
    # get new size
    w, h = image.size
    scale = longest_edge * 1.0 / max(h, w)
    new_h, new_w = h * scale, w * scale
    new_h = int(new_h + 0.5)
    new_w = int(new_w + 0.5)

    resized_image = image.resize((new_w, new_h), resample=Image.Resampling.BILINEAR)
    return resized_image

def pad_mask(image):
    pad_height = 256 - image.height
    pad_width = 256 - image.width

    padding = ((0, pad_height), (0, pad_width))
    padded_image = np.pad(image, padding, mode="constant")
    return padded_image

def process_mask(image):
    resized_mask = resize_mask(image)
    padded_mask = pad_mask(resized_mask)
    return padded_mask

In [None]:

import torch.nn.functional as F
from torch.nn.functional import threshold, normalize
from typing import Tuple

def postprocess_masks(masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...],
                      image_size=1024) -> torch.Tensor:
    """
    Remove padding and upscale masks to the original image size.

    Args:
      masks (torch.Tensor):
        Batched masks from the mask_decoder, in BxCxHxW format.
      input_size (tuple(int, int)):
        The size of the image input to the model, in (H, W) format. Used to remove padding.
      original_size (tuple(int, int)):
        The original size of the image before resizing for input to the model, in (H, W) format.

    Returns:
      (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
        is given by original_size.
    """
    masks = F.interpolate(
        masks,
        (image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )
    masks = masks[..., : input_size[0], : input_size[1]]
    masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
    return masks

def get_model_response(model, inputs, input_resized_shape, input_original_shape):
    with torch.no_grad():
        outputs = model(**inputs, multimask_output=False)
    processed_output = postprocess_masks(outputs["pred_masks"][0], input_resized_shape, input_original_shape)
    predicted_masks = normalize(threshold(processed_output, 0.0, 0)).squeeze(1)
    return predicted_masks[0]

In [None]:
from transformers import SamModel, SamProcessor, AutoProcessor, AutoModelForMaskGeneration
from PIL import Image

class SAMDataset(Dataset):
    @classmethod
    def get_bounding_box(cls, ground_truth_map):
        '''
        This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model
        The padding is random int values between 5 and 20 pixels
        '''
        # get bounding box from mask
        y_indices, x_indices = np.where(ground_truth_map > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)

        # add perturbation to bounding box coordinates
        H, W = ground_truth_map.shape
        x_min = max(0, x_min - np.random.randint(5, 20))
        x_max = min(W, x_max + np.random.randint(5, 20))
        y_min = max(0, y_min - np.random.randint(5, 20))
        y_max = min(H, y_max + np.random.randint(5, 20))

        bbox = [x_min, y_min, x_max, y_max]

        return bbox
    
    @classmethod
    def load_image(cls, image_path: str):
        img = Image.open(image_path)
        if img.mode != 'RGB':
            img = img.convert('RGB')
        img = np.array(img)
        img = img.astype('float32')
        img /= 255.0                                     # image normalization                 
        return img

    @classmethod
    def load_mask(cls, mask_path: str):
        mask = Image.open(mask_path)
        # resize mask to (256, 256)
        mask = process_mask(mask)
        mask = np.array(mask)
        mask = mask.astype('float32')
        mask /= 255.0
        return mask
    
    def __init__(self, data_dir: str, mode: str, processor: SamProcessor, transforms=None):
        assert(mode == "train" or mode == "test")
        self.mode = mode
        self.image_dir = os.path.join(data_dir, "images")
        self.image_files = os.listdir(self.image_dir)
        self.mask_dir = os.path.join(data_dir, "labels")
        self.mask_files = os.listdir(self.mask_dir)
        self.transforms = transforms
        self.processor = processor
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx: int):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        image = SAMDataset.load_image(image_path)
        if self.transforms:
            image = self.transforms(image=image)["image"]
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        mask = SAMDataset.load_mask(mask_path)
        
        if self.mode == "train":
            image = np.transpose(image, (2, 0, 1))
            input_boxes = SAMDataset.get_bounding_box(mask)
            inputs = self.processor(image, input_boxes=[[input_boxes]], return_tensors="pt")
            inputs = {k:v.squeeze(0) for k,v in inputs.items()}
            inputs["ground_truth_mask"] = mask
            return inputs
        else:
            # put input_boxes as the same size of inputs
            image = np.transpose(image, (2, 0, 1))
            input_boxes = [0, 0, image.shape[-2], image.shape[-1]]
            inputs = self.processor(image, input_boxes=[[input_boxes]], return_tensors="pt")
            inputs = {k:v.squeeze(0) for k,v in inputs.items()}
            inputs["ground_truth_mask"] = mask
            return inputs


In [None]:
# from torch.profiler import Profiler

def pt_to_pil_convertor(image):
    image = np.transpose(image.squeeze().numpy(), (1, 2, 0))
    converted_img = Image.fromarray((image * 255).astype(np.uint8))
    return converted_img

def dice_coeff_batch(predictions, targets):
    # Initialize an array to store dice coefficients for each pair in the batch
    dice_coeffs = np.zeros(predictions.shape[0])

    for i in range(predictions.shape[0]):
        # Extract the single prediction and target pair from the batch
        prediction = predictions[i]
        target = targets[i]

        # Create a binary mask for the prediction
        mask = np.zeros_like(prediction)
        mask[prediction >= 0.5] = 1

        # Compute intersection and union
        inter = np.sum(mask * target)
        union = np.sum(mask) + np.sum(target)

        # Add a small epsilon to avoid division by zero
        epsilon = 1e-6
        # Compute Dice coefficient for the current pair
        dice_coeffs[i] = 2 * inter / (union + epsilon)
        
    return np.sum(dice_coeffs)

def build_model(model_name: str):
    processor = AutoProcessor.from_pretrained(model_name)
    model = AutoModelForMaskGeneration.from_pretrained(model_name)
    
    # Freeze all parameters in the model except mask decoder
    for param in model.parameters():
        param.requires_grad = False
    for param in model.mask_decoder.parameters():
        param.requires_grad = True
    
#     model = nn.DataParallel(model)
    model.to(CFG.device)
    return model, processor

def finetune_sam_one_epoch(
    model: SamModel,
    criterion,
    optimizer,
    dataloader: DataLoader,
    epoch: int
):
    model.train()
    scaler = amp.GradScaler()
    
    dataset_size = 0
    running_loss = 0.0
    dice_coeff = 0.0

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Finetune ")
    for step, batch in pbar:
#         convertor = lambda x: pt_to_pil_convertor(x)
#         converted_imgs = np.vectorize(convertor)(images)
#         masks = masks.to(CFG.device, dtype=torch.float)
#         input_boxes = [input_boxes]
#         inputs = processor(converted_imgs, input_boxes=input_boxes, return_tensors="pt").to(CFG.device)
        batch_size = batch["pixel_values"].size(0)
        with amp.autocast(enabled=True):
            outputs = model(
                pixel_values=batch["pixel_values"].to(CFG.device),
                input_boxes=batch["input_boxes"].to(CFG.device),
                multimask_output=False
            )
            
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].to(CFG.device)
            loss = criterion(predicted_masks, ground_truth_masks.unsqueeze(1))
            loss = loss / CFG.n_accumulate
        
            scaler.scale(loss).backward()

            if (step + 1) % CFG.n_accumulate == 0:
                scaler.step(optimizer)
                scaler.update()

                # zero the parameter gradients
                optimizer.zero_grad()

            running_loss += (loss.item() * batch_size)
            dataset_size += batch_size

            epoch_loss = running_loss / dataset_size
            mem = torch.cuda.memory_reserved() / 1E9 if CFG.device else 0
            current_lr = optimizer.param_groups[0]['lr']
            pbar.set_postfix( epoch=f'{epoch}',
                              train_loss=f'{epoch_loss:0.4f}',
                              lr=f'{current_lr:0.5f}',
                              gpu_mem=f'{mem:0.2f} GB')

    torch.cuda.empty_cache()
    gc.collect()
    return epoch_loss

@torch.no_grad()
def valid_sam_one_epoch(
    model: SamModel,
    criterion,
    dataloader: DataLoader
):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    dice_coeff = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Valid ")
    for step, batch in pbar:
        batch_size = batch["pixel_values"].size(0)
        with torch.no_grad():
            outputs = model(
                pixel_values=batch["pixel_values"].to(CFG.device),
                input_boxes=batch["input_boxes"].to(CFG.device),
                multimask_output=False
            )

            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].to(CFG.device)
            loss = criterion(predicted_masks, ground_truth_masks.unsqueeze(1))

            mask_pred = predicted_masks.cpu().numpy()
            mask_true = ground_truth_masks.unsqueeze(1).cpu().numpy()
            dice_coeff += dice_coeff_batch(mask_pred, mask_true)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
    
    epoch_loss = running_loss / dataset_size
    epoch_dice_coeff = dice_coeff / dataset_size
    mem = torch.cuda.memory_reserved() / 1E9 if CFG.device else 0
    current_lr = optimizer.param_groups[0]['lr']
    pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                    lr=f'{current_lr:0.5f}',
                    gpu_memory=f'{mem:0.2f} GB')
    torch.cuda.empty_cache()
    gc.collect()
    return epoch_loss, epoch_dice_coeff


def finetune_medsam(
    model: SamModel,
    criterion,
    optimizer,
    num_epochs: int,
    model_name: str,
    train_loader: DataLoader,
    valid_loader: DataLoader
):
    torch.cuda.empty_cache()
    gc.collect()
    
    start_time = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss      = np.inf
    best_epoch     = -1
    history = defaultdict(list)

    for epoch in range(1, num_epochs + 1):
        gc.collect()
        start = time.time()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = finetune_sam_one_epoch(
            model,
            criterion,
            optimizer,
            dataloader=train_loader, 
            epoch=epoch
        )
        
        val_loss = valid_sam_one_epoch(
            model,
            criterion,
            valid_loader
        )
        
        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        print(f'Valid Loss: {val_loss}')
    
        # deep copy the model
        if val_loss <= best_loss:
            print(f"Valid loss Improved ({best_loss} ---> {val_loss})")
            best_loss = val_loss
            best_epoch = epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = "/kaggle/working/best_epoch.bin"
            torch.save(model.state_dict(), PATH)
            print(f"Model Saved to: {PATH}")
            
        epoch_duration = time.time() - start
        print("Epoch finish in {:.0f}h {:.0f}m {:.0f}s".format(
            epoch_duration // 3600, (epoch_duration % 3600) // 60, (epoch_duration % 3600) % 60)
        )
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = "last_epoch.bin"
        torch.save(model.state_dict(), PATH)

    end_time = time.time()
    time_elapsed = end_time - start_time
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60)
    )
    print("Best Loss: {:.4f}".format(best_loss))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), os.path.join(CFG.saved_model_path, model_name))
    print(f"Model saved to {os.path.join(CFG.saved_model_path, model_name)}")
    return model, history

In [None]:
# sam_model, processor = build_model("wanglab/medsam-vit-base")

# train_groups = [os.path.join(CFG.data_dir, "train", folder) for folder in CFG.train_groups]
# val_groups = [os.path.join(CFG.data_dir, "train", folder) for folder in CFG.valid_groups]


# train_dataset = SAMDataset(train_groups[0], mode="train", processor=processor, transforms=CFG.sam_transformations["default"])
# for idx in range(1, len(train_groups)):
#     train_dataset += SAMDataset(train_groups[idx], mode="train", processor=processor, transforms=CFG.sam_transformations["default"])

# valid_dataset = SAMDataset(val_groups[0], mode="test", processor=processor, transforms=CFG.sam_transformations["default"])
# for idx in range(1, len(val_groups)):
#     valid_dataset += SAMDataset(val_groups[idx], mode="test", processor=processor, transforms=CFG.sam_transformations["default"])

In [None]:
# processor.from_pretrained("/kaggle/working/preprocessor_config.json")

In [None]:
# train_loader = DataLoader(train_dataset, batch_size=CFG.train_bs, shuffle=True)
# valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_bs, shuffle=False)

# batch = next(iter(valid_loader))
# for k, v in batch.items():
#     print(k, v.shape)

In [None]:
# sam_lr = 1e-3
# criterion = DiceCELoss()
# # criterion = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
# optimizer = optim.AdamW(sam_model.mask_decoder.parameters(), lr=sam_lr)

# model, history = finetune_medsam(
#     sam_model,
#     criterion, 
#     optimizer, 
#     num_epochs=5,
#     model_name="medsam-finetuned",
#     train_loader=train_loader,
#     valid_loader=valid_loader
# )

# # masks = processor.image_processor.post_process_masks(
# #     outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
# # )

# # data agumentation: multiple bbox prompts, how about langauge promots? how to maintain its zero-shot capacity?
# # distillation
# # scores = outputs.iou_scores

### Med SAM uses DiceCELoss loss so we need to postprocess generated masks

In [None]:
# def show_mask(pred_masks, ax):
#     # apply sigmoid
#     seg_prob = torch.sigmoid(pred_masks.squeeze(1))
#     # convert soft mask to hard mask
#     seg_prob = seg_prob.to("cpu").numpy().squeeze()
#     seg = (seg_prob > 0.5).astype(np.uint8)
#     h, w = mask.shape[-2:]
#     mask_image = mask.reshape(h, w, 1)
#     ax.imshow(mask_image)
#     return mask_image

# def display_history(history):
#     fig, axes = plt.subplots(3, 1, figsize=(10, 15))
#     train_loss = history["Train Loss"]
#     valid_loss = history["Valid Loss"]
#     dice_coeff = history["Dice Coeff"]

#     epochs = range(1, len(train_loss) + 1)

#     # Plotting Train Loss
#     axes[0].plot(epochs, train_loss, 'r-', label='Train Loss')
#     axes[0].set_title('Training Loss')
#     axes[0].set_xlabel('Epochs')
#     axes[0].set_ylabel('Loss')
#     axes[0].legend()

#     # Plotting Validation Loss
#     axes[1].plot(epochs, valid_loss, 'b-', label='Validation Loss')
#     axes[1].set_title('Validation Loss')
#     axes[1].set_xlabel('Epochs')
#     axes[1].set_ylabel('Loss')
#     axes[1].legend()

#     # Plotting Dice Coefficient
#     axes[2].plot(epochs, dice_coeff, 'g-', label='Dice Coefficient')
#     axes[2].set_title('Dice Coefficient')
#     axes[2].set_xlabel('Epochs')
#     axes[2].set_ylabel('Dice Coeff')
#     axes[2].legend()

#     plt.tight_layout()
#     plt.show()

# # fig, axes = plt.subplots(1, 2)
# # mask_np = np.transpose(mask.numpy(), (1, 2, 0))
# # axes[0].imshow(mask_np)
# # axes[0].set_title(f"True mask")
# # pred_mask = show_mask(outputs.pred_masks, axes[1])
# # axes[1].set_title(f"Predicted mask")


# # need to resize to 1024x1024 (input image)
# # ground truth masks needed to 

# display_history(history)

#### ground truth masks vs. predicted masks