### Setting ###

In [None]:
# Google Drive Mount
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
# Install foundation model - Segment Anything
#!pip install git+https://github.com/facebookresearch/segment-anything.git

In [None]:
# libraries
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import albumentations as A
from albumentations.pytorch import transforms

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

In [None]:
# device setting
device = "cuda:1" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

print(device)

In [None]:
# environment setting
batch_size = 4
epochs = 10
lr = 0.001
weight_decay = 0.0005
alpha = 0.8
gamma = 2

In [None]:
# directory setting
class ROOTDIR:
    image = "/home/kmk/COSE474Project/data/images/"
    mask = "/home/kmk/COSE474Project/data/masks/"

### Data example ###

In [None]:
images = sorted(glob.glob(ROOTDIR.image + "*.jpg"))
masks = sorted(glob.glob(ROOTDIR.mask + "*.jpg"))

In [None]:
train_images, val_images, train_masks, val_masks = train_test_split(images, masks, test_size=0.2)

In [None]:
val_images, test_images, val_masks, test_masks = train_test_split(val_images, val_masks, test_size=0.5)

In [None]:
ex_img = Image.open(train_images[0]).convert("RGB")
ex_mask = Image.open(train_masks[0]).convert("L")

plt.subplot(1, 2, 1)
plt.imshow(ex_img)
plt.title("image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(ex_mask)
plt.title("ground truth mask")
plt.axis("off")

plt.show()

### Zero shot prediction ###

In [None]:
# get bounding box
def get_bbox(true_mask):
    y_indices, x_indices = np.where(true_mask > 0)
    
    h, w = true_mask.shape
    
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    
    x_min = max(0, x_min - np.random.randint(0, 5))
    x_max = min(w, x_max + np.random.randint(0, 5))
    y_min = max(0, y_min - np.random.randint(0, 5))
    y_max = min(h, y_max + np.random.randint(0, 5))
    
    bbox = np.array([x_min, y_min, x_max, y_max])
    
    return bbox

In [None]:
# show bounding box
def show_bbox(bbox):
    color = "blue"
    
    w = bbox[2] - bbox[0]
    h = bbox[3] - bbox[1]
    
    rect = patches.Rectangle((bbox[0], bbox[1]), w, h, edgecolor=color, facecolor=(0,0,0,0))
    
    ax = plt.gca()
    ax.add_patch(rect)

In [None]:
fig, axes = plt.subplots()

axes.imshow(np.array(ex_img))
show_bbox(get_bbox(np.array(ex_mask)))
axes.set_title("Ground truth mask & bounding box")
axes.axis("off")
plt.show()

In [None]:
# SAM model
model_type = "vit_h"
sam_checkpoint = "/home/kmk/COSE474Project/sam_vit_h_4b8939.pth"
# automatic mask generator
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)
predictor.set_image(np.array(ex_img))
masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=get_bbox(np.array(ex_mask)),
                                multimask_output=False)

In [None]:
def show_mask(mask):
    ax = plt.gca()
    color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(ex_img)
show_mask(masks)
plt.title("zero-shot predicted mask")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(ex_mask)
plt.title("ground truth mask")
plt.axis("off")

### Prepare Dataset ###

In [None]:
# Dataset
class MedDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform

    # img list 길이
    def __len__(self):
        return len(self.dataset)

    # get image, label
    def __getitem__(self, idx):
        img = self.img_dir[idx]
        mask = self.mask_dir[idx]
        
        img = Image.open(img).convert("RGB")
        mask = Image.open(img).convert("L")
        
        # get bounding box
        prompt = get_bbox(np.array(mask))
        
        if self.transform:
            img, mask, bbox = self.transform(img, mask, prompt)
        
        return img, mask, bbox

In [None]:
# Image preprocess for SAM
class ResizeAndPad:
    def __init__(self, target_size):
        self.target_size = target_size
        self.transform = ResizeLongestSide(target_size)
        self.to_tensor = transforms.ToTensor()
    
    def __call__(self, image, mask, bboxes):
        h1, w1, _ = image.shape
        image = self.transform.apply_image(image)
        image = self.to_tensor(image)
        mask = [torch.tensor(self.transform.apply_image(mask))]
        
        _, h2, w2 = image.shape
        max_dim = max(h2, w2)
        pad_w = (max_dim - w2) // 2
        pad_h = (max_dim - h2) // 2
        
        padding = (pad_w, pad_h, max_dim - w2 - pad_w, max_dim - h2 - pad_h)
        image = transforms.Pad(padding)(image)
        mask = transforms.Pad(padding)(mask)
        
        bbox = self.transform.apply_boxes(bboxes, (h1, w1))
        bbox = [bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h]
        
        return image, mask, bbox

In [None]:
transform = ResizeAndPad(1024)
train_data = MedDataset(train_images, train_masks, transform)
val_data = MedDataset(val_images, val_masks, transform)
test_data = MedDataset(test_images, test_masks, transform)

In [None]:
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

### Training ###

In [None]:
# Calculate Focal Loss
class FocalLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, inputs, targets):
        inputs = inputs.flatten()
        BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
        return focal_loss

# Calculate Dice Loss
class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.flatten(0, 2)
        intersection = (inputs * targets).sum()
        dice = (2, * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice

In [None]:
model = sam_model_registry[model_type](checkpoint=sam_checkpoint)

for param in model.image_encoder.parameters():
    param.requires_grad = False
    
for param in model.mask_decoder.parameters():
    param.requires_grad = False
    
transform = ResizeLongestSide(1024)