### Setting ###

In [None]:
# Google Drive Mount
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
# Install foundation model - Segment Anything
#!pip install git+https://github.com/facebookresearch/segment-anything.git

In [None]:
# libraries
import os
import glob
import copy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import tifffile as tiff
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as tf
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm

from transformers import SamProcessor, SamModel
import monai

In [None]:
torch.cuda.empty_cache()

In [None]:
# configuration
batch_size = 3
epochs = 500
lr = 0.00005
weight_decay = 0

device = "cuda:1" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

print(device)

In [None]:
# directory setting
class ROOTDIR:
    image = "/home/kmk/COSE474Project/data/images/"
    mask = "/home/kmk/COSE474Project/data/masks/"

### Data example ###

In [None]:
images = sorted(glob.glob(ROOTDIR.image + "*.tif"))
masks = sorted(glob.glob(ROOTDIR.mask + "*.tif"))
fns = sorted([i.split("/")[-1].split(".")[0] for i in images])

In [None]:
len(images), len(masks), len(fns)

In [None]:
# train/val/test split (8:1:1)
train_images, val_images, train_masks, val_masks, train_fns, val_fns = train_test_split(images, masks, fns, test_size=0.2)
val_images, test_images, val_masks, test_masks, val_fns, test_fns = train_test_split(val_images, val_masks, val_fns, test_size=0.5)

In [None]:
len(train_images), len(val_images), len(test_images)

In [None]:
def get_bbox(gt_mask):
    y_indices, x_indices = np.where(gt_mask > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    
    h, w = gt_mask.shape
    x_min = max(0, x_min-np.random.randint(0, 10))
    x_max = min(w, x_max+np.random.randint(0, 10))
    y_min = max(0, y_min-np.random.randint(0, 10))
    y_max = min(h, y_max+np.random.randint(0, 10))
    
    bbox = [x_min, y_min, x_max, y_max]
    
    return bbox

In [None]:
def show_bbox(bbox):
    ax = plt.gca()
    
    w = bbox[2] - bbox[0]
    h = bbox[3] - bbox[1]
        
    rect = patches.Rectangle((bbox[0], bbox[1]), w, h, color="blue", fill=False)
        
    ax.add_patch(rect)

In [None]:
def show_mask_on_image(mask):
    color = np.array([0, 255, 0, 0.6])
    
    if len(mask.shape) == 4:
        mask = mask.squeeze()
    
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax = plt.gca()
    ax.imshow(mask_image)

In [None]:
ex_img = tiff.imread(train_images[0])
ex_img = np.array(ex_img)
ex_mask = cv2.imread(train_masks[0], cv2.IMREAD_UNCHANGED)
ex_mask = np.array(ex_mask)
ex_bbox = get_bbox(ex_mask)

plt.subplot(1, 3, 1)
plt.imshow(ex_img)
plt.title("image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(ex_img)
show_bbox(ex_bbox)
plt.title("bounding box")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(ex_mask)
plt.title("ground truth mask")
plt.axis("off")

plt.show()

### Zero shot prediction ###

In [None]:
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [None]:
inputs = processor(ex_img, input_boxes=[[[ex_bbox]]], return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(**inputs, multimask_output=False)
    
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), 
                                                     inputs["original_sizes"].cpu(),
                                                     inputs["reshaped_input_sizes"].cpu())

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(ex_img)
show_mask_on_image(masks[0])
plt.title("predicted mask")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(ex_mask)
plt.title("ground truth mask")
plt.axis("off")
plt.show()

### Prepare Dataset ###

In [None]:
class MedDataset(Dataset):
    def __init__(self, img_dir, mask_dir, processor, mode):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.processor = processor
        self.mode = mode
        
    def __len__(self):
        return len(self.img_dir)
    
    def __getitem__(self, idx):
        image_dir = self.img_dir[idx]
        mask_dir = self.mask_dir[idx]
        
        image = tiff.imread(image_dir)
        image = np.array(image)
        
        mask = tiff.imread(mask_dir)
        mask = np.array(mask)
        gt_mask = (cv2.imread(mask_dir, cv2.IMREAD_GRAYSCALE) / 255.).astype(np.uint8)
        
        bbox = get_bbox(np.array(gt_mask))
        
        inputs = self.processor(image, input_boxes=[[bbox]], return_tensors="pt")
        
        inputs = {k:v.squeeze(0) for k, v in inputs.items()}
        
        gt_mask = cv2.resize(gt_mask, (256, 256), interpolation=cv2.INTER_NEAREST)
        
        inputs["ground_truth_mask"] = gt_mask
        
        if self.mode == "test":
            return image, mask, bbox, inputs
        else:
            return inputs

In [None]:
train_data = MedDataset(img_dir=train_images[:30], mask_dir=train_masks[:30], processor=processor, mode="train")
val_data = MedDataset(img_dir=val_images[:30], mask_dir=val_masks[:30], processor=processor, mode="val")

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

In [None]:
batch = next(iter(train_dataloader))
for k, v in batch.items():
    print(k, v.shape)

### Training ###

In [None]:
# tensorboard
writer = SummaryWriter()

In [None]:
def train_model(model, dataloader, optimizer, criterion, epoch):
    model.train()
    train_running_loss = 0.0
    
    for j, batch in enumerate(tqdm(dataloader)):
        pixel_values = batch["pixel_values"].to(device)
        input_boxes = batch["input_boxes"].to(device)
        gt_masks = batch["ground_truth_mask"].float().to(device)
        
        outputs = model(pixel_values=pixel_values, input_boxes=input_boxes, multimask_output=False)
        
        predicted_masks = torch.sigmoid(outputs.pred_masks.squeeze(1)).to(device)
        
        loss = criterion(predicted_masks, gt_masks.unsqueeze(1))
        
        writer.add_scalar("Loss/train", loss, j+epoch*len(dataloader))
        
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()       
        
        train_running_loss += loss.item()
    
    train_loss = train_running_loss / (j+1)
    
    return train_loss

In [None]:
def val_model(model, dataloader, criterion, epoch):
    model.eval()
    val_running_loss = 0.0
    
    with torch.no_grad():
        for j, batch in enumerate(tqdm(dataloader)):
            pixel_values = batch["pixel_values"].to(device)
            input_boxes = batch["input_boxes"].to(device)
            gt_masks = batch["ground_truth_mask"].float().to(device)
            
            outputs = model(pixel_values=pixel_values, input_boxes=input_boxes, multimask_output=False)
            
            predicted_masks = torch.sigmoid(outputs.pred_masks.squeeze(1)).to(device)
            
            loss = criterion(predicted_masks, gt_masks.unsqueeze(1))
            
            writer.add_scalar("Loss/validation", loss, j+epoch*len(dataloader))
            
            val_running_loss += loss.item()
            
        val_loss = val_running_loss / (j+1)
        
        return val_loss, model

In [None]:
class EarlyStopping:
    def __init__(self, patience=20, verbose=False, delta=0, trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.trace_func = trace_func
        self.counter = 0
        self.val_loss = None
        self.val_loss_min = np.Inf
        self.early_stop = False
        
    def __call__(self, val_loss, model, file_name):
        if self.val_loss is None:
            self.val_loss = val_loss
            self.save_checkpoint(val_loss, model, file_name)
        elif val_loss > self.val_loss + self.delta:
            self.counter += 1
            self.trace_func(f"Early Stopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.val_loss = val_loss
            self.save_checkpoint(val_loss, model, file_name)
            self.counter = 0
            
    def save_checkpoint(self, val_loss, model, file_name):
        if self.verbose:
            self.trace_func(f"Validation loss decreased: {self.val_loss_min:.6f} --> {val_loss:.6f}. Saving model...")
        torch.save(model.state_dict(), file_name)
        self.val_loss_min = val_loss

In [None]:
for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)
        
model.to(device)

In [None]:
optimizer = optim.Adam(model.mask_decoder.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch:0.95**epoch)
criterion = monai.losses.DiceFocalLoss(sigmoid=True, squared_pred=True, reduction='mean')
es = EarlyStopping(patience=10, verbose=False, delta=0.0001)

In [None]:
train_loss_list = []
val_loss_list = []

In [None]:
for epoch in tqdm(range(epochs)):
    train_loss = train_model(model, train_dataloader, optimizer, criterion, epoch)
    val_loss, model = val_model(model, val_dataloader, criterion, epoch)
    
    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    
    TUNED_FILE = f"/home/kmk/COSE474Project/checkpoint/fine_tuned_sam.pth"
    
    '''es(val_loss, model, TUNED_FILE)
    
    if es.early_stop:
        writer.close()
        break'''
        
torch.save(model.state_dict(), TUNED_FILE)

In [None]:
plt.plot(train_loss_list, label='train loss')
plt.plot(val_loss_list, label='val loss')
plt.legend(loc="upper right")
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

### Test ###

In [None]:
test_data = MedDataset(test_images, test_masks, processor, "test")
test_dataloader = DataLoader(test_data, 1, shuffle=True)

In [None]:
tuned_model = SamModel.from_pretrained("facebook/sam-vit-base")
tuned_model.load_state_dict(torch.load("/home/kmk/COSE474Project/checkpoint/fine_tuned_sam.pth"))
tuned_model.to(device)

In [None]:
test_losses = 0.0

for j, (img, mask, bbox, batch) in enumerate(tqdm(test_dataloader)):
    pixel_values = batch["pixel_values"].to(device)
    input_boxes = batch["input_boxes"].to(device)
    gt_masks = batch["ground_truth_mask"].float().to(device)
    
    outputs = tuned_model(pixel_values=pixel_values, input_boxes=input_boxes, multimask_output=False)
    
    predicted_masks = torch.sigmoid(outputs.pred_masks.squeeze(1))
    
    loss = criterion(predicted_masks, gt_masks.unsqueeze(1))
    
    test_losses += loss.item()
    
    if (j+1)%10 == 0:
        print(f"{j+1}th data")

        img = img.squeeze().cpu().detach().numpy()
        mask = mask.squeeze().cpu().detach().numpy()
        seg = processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            batch["original_sizes"].cpu(),
            batch["reshaped_input_sizes"].cpu(),
        )
        seg = seg[0].squeeze()
        
        bbox = list(map(int, bbox))
        
        plt.figure(figsize=(15, 15))
        
        plt.subplot(1, 4, 1)
        plt.imshow(img, cmap='gray')
        plt.title('input image')
        plt.axis('off')
        
        plt.subplot(1, 4, 2)
        plt.imshow(img, cmap='gray')
        show_bbox(bbox)
        plt.title('prompt')
        plt.axis('off')
        
        plt.subplot(1, 4, 3)
        plt.imshow(img, cmap='gray')
        show_mask_on_image(seg)
        plt.title('predicted mask')
        plt.axis('off')
        
        plt.subplot(1, 4, 4)
        plt.imshow(mask, cmap='gray')
        plt.title('ground truth mask')
        plt.axis('off')
        
        plt.show()
        
test_loss = test_losses / (j+1)

print(f"Test Loss: {test_loss:.4f}")