# MedSAM Training & Testing

In [None]:
%pip install -q git+https://github.com/bowang-lab/MedSAM.git
%pip install segment-anything
%pip install boxsdk

In [3]:
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from sklearn.model_selection import train_test_split
import torch
import os
import pandas as pd
from PIL import Image
import numpy as np
import torch.nn as nn
from segment_anything import sam_model_registry
from skimage import transform
import torch.nn.functional as F
import matplotlib.pyplot as plt
from boxsdk import OAuth2
from boxsdk import Client
import monai
# from io import StringIO

2024-08-14 13:51:38.243274: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX_VNNI
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-14 13:51:39.813848: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


### Dataloader & Box Integration

In [111]:
# Setup connection

oauth = OAuth2(
  client_id='4ux6qwjvcbp58iau0kgar3t2ceewq47x',
  client_secret='pyYMweGpw6Y7W5WBiWE7bmxymJlDjqoB',
  access_token='YgviOaU1NZ03Wj7CNvin704DQiWQ9kBk',
)

client = Client(oauth)

In [79]:
# Access MRI data folder

SHARED_LINK_URL = 'https://app.box.com/file/1433849452284?s=swtqk78ia9uyl2lookmxh26pas75hgrk'
shared_item = client.get_shared_item(SHARED_LINK_URL)

with open('FileFromBox.xlsx', 'wb') as open_file:
    client.with_shared_link(SHARED_LINK_URL, None).file(shared_item.id).download_to(open_file)
    open_file.close()

mri_data = pd.read_excel('FileFromBox.xlsx')

In [80]:
# Additional Fields for MRI Data

mri_data['Total Images'] = mri_data['Number of Slices'] + mri_data['Number of Brightness Levels']
mri_data['Start_Index'] = 0
mri_data['Has MRI'] = ~mri_data['PNG filtered MRI'].isna()
mri_data['Has Seg'] = ~mri_data['PNG segmentation'].isna()
folders = []
for i in range(len(mri_data)):

    brightness_folders = []
    if mri_data['Has MRI'][i]:
        try:
            png = client.get_shared_item(mri_data['PNG filtered MRI'][i])

            items = client.folder(png.id).get_items()
            for item in items:
                brightness_folders.append(item.id)
                # brightness_items = client.folder(item.id).get_items()
                # imgs = []
                # for img in brightness_items:
                #     imgs.append(img.id)
                
                # brightness_folders.append((item.id, imgs))
        except:
            print(f"{mri_data['MRI/Patient ID'][i]} has error; verify folder addresses")
        
    folders.append(brightness_folders)

mri_data['Brightness Folders'] = folders

In [81]:
# filter for only patients with both MRI and Seg PNG files ready
mri_data = mri_data[(mri_data['Has MRI']) & (mri_data['Has Seg'])]

In [97]:
# Train/Test split
train_data, test_data = train_test_split(mri_data, test_size=0.25)
train_data = train_data.reset_index().drop(columns = 'index')
test_data = test_data.reset_index().drop(columns = 'index')

for i in range(len(train_data)):
    train_data.loc[i, 'Start_Index'] = sum(train_data['Total Images'][:i])
    
for i in range(len(test_data)):
    test_data.loc[i, 'Start_Index'] = sum(test_data['Total Images'][:i])

In [125]:
# Dataset using Box access

class CancerDataset(Dataset):
    def __init__(self, labels, train = True, transform=None, target_transform=None):
        self.img_labels = labels
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        len = sum(self.img_labels['Total Images'])
        return int(len)

    def __getitem__(self, idx):
        if train:
            path = 'train/'
        else:
            path = 'test/'
        row = self.img_labels.loc[self.img_labels.Start_Index.where(self.img_labels.Start_Index >= idx).first_valid_index()]
        patient = row['MRI/Patient ID']
        bright_level = int((idx - row['Start_Index']) % len(row['Brightness Folders']))
        bright_id = row['Brightness Folders'][bright_level]

        brightness_items = client.folder(bright_id).get_items()
        imgs = []
        for img in brightness_items:
            imgs.append(img.id)

        img_idx = int((idx - row['Start_Index']) % row['Number of Slices'])
        with open(f'{path}png_{idx}.png', 'wb') as open_pic:
            client.file(imgs[img_idx]).download_to(open_pic)
            open_pic.close()

        image = torch.Tensor(np.array(Image.open(f'{path}png_{idx}.png'), dtype='int16'))
        
        seg_folder = client.get_shared_item(row['PNG segmentation']).get_items()
        seg_ids = []
        for item in seg_folder:
            seg_ids.append(item.id)

        with open(f'{path}seg_{idx}.png', 'wb') as open_pic:
            client.file(seg_ids[img_idx]).download_to(open_pic)
            open_pic.close()

        label = torch.Tensor(np.array(Image.open(f'{path}seg_{idx}.png'), dtype='int16'))

        value_to_class = {
            9362: 0, # Background
            18724: 1, # Water
            28086: 2, # Skin
            37449: 3, # Fat
            46811: 4, # FGT?
            56173: 5, # Tumor
            65535: 6 # Clip
        }

        label_classes = torch.zeros_like(label, dtype=torch.long)  # Initialize with zeros
        for value, class_idx in value_to_class.items():
            label_classes[label == value] = class_idx
        
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label_classes, patient, bright_level

In [126]:
batch_size = 16
train_dataset = CancerDataset(labels=train_data)
test_dataset = CancerDataset(labels=test_data)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### MedSAM Evaluation: No Training

In [55]:
# # # %% environment and functions
# import numpy as np
# import matplotlib.pyplot as plt
# import os
# join = os.path.join
# import torch
# from segment_anything import sam_model_registry
# from skimage import io, transform
# import torch.nn.functional as F

# visualization functions
# source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
# change color to avoid red and green
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251/255, 252/255, 30/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :] # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed, # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
        multimask_output=True,
        )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg


In [127]:
#%% load model and image
MedSAM_CKPT_PATH = "medsam_vit_b.pth"
medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model.image_encoder.patch_embed.proj = nn.Conv2d(3, 768, kernel_size = (8, 8), stride = (8, 8), padding = (0, 24))
device = "cuda:0"
medsam_model = medsam_model.to(device)
medsam_model.eval()

test_acc = []

for x, y, patient, b_level in test_loader:
    img = x
    B, H, W = img.size()
    img_3c = img.repeat(3, 1, 1, 1).view(B, 3, H, W).to(device)

    box_np = np.array([[0,0, W, H]]).to(device)

    with torch.no_grad():
        image_embedding = medsam_model.image_encoder(img_3c) # (1, 256, 64, 64)
    
    medsam_seg = medsam_inference(medsam_model, image_embedding, box_np, H, W)

    acc = (torch.tensor(medsam_seg) == y).float().mean()
    test_acc.append((acc, patient, b_level))

    if i == 0:
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].imshow(img, cmap="gray")
        show_box(box_np[0], ax[0])
        ax[0].set_title("Input Image and Bounding Box")
        ax[1].imshow(img, cmap="gray")
        show_mask(medsam_seg, ax[1])
        show_box(box_np[0], ax[1])
        ax[1].set_title(f"MedSAM Segmentation w/ Accuracy: {acc}")
        plt.show()

RuntimeError: stack expects each tensor to be equal size, but got [512, 464] at entry 0 and [512, 512] at entry 1

In [None]:
torch.Tensor(np.array(Image.open(f'{path}png_{idx}.png'), dtype='int16'))

In [None]:
# Export Results
base_results = pd.DataFrame(test_acc, columns = ['Accuracy', 'Patient', 'Brightness Level'])
base_results.to_csv("MRI_MedSAM_Base.csv", index=False)

### MedSAM Fine-tuning and Testing

In [None]:
# import monai

In [None]:
medsam_model.train()

losses = []
seg_loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, reduction="mean")
ce_loss = nn.BCEWithLogitsLoss(reduction="mean")
img_mask_encdec_params = list(medsam_model.image_encoder.parameters()) + list(medsam_model.mask_decoder.parameters())
optimizer = torch.optim.AdamW(img_mask_encdec_params, lr=.0001, weight_decay=.01)

train_acc = []

epochs = 10
for i in range(epochs):
    for x, y, patient, b_level in train_loader:
        img = x
        B, H, W = img.size()
        img_3c = img.repeat(3, 1, 1, 1).view(B, 3, H, W)#.to(device)

        box_np = np.array([[0,0, W, H]])

        with torch.no_grad():
            image_embedding = medsam_model.image_encoder(img_3c) # (1, 256, 64, 64)
        
        medsam_seg = medsam_inference(medsam_model, image_embedding, box_np, H, W)

        acc = (torch.tensor(medsam_seg) == y).float().mean()
        test_acc.append((acc, patient, b_level))

        if i == 0:
            fig, ax = plt.subplots(1, 2, figsize=(10, 5))
            ax[0].imshow(img, cmap="gray")
            show_box(box_np[0], ax[0])
            ax[0].set_title("Input Image and Bounding Box")
            ax[1].imshow(img, cmap="gray")
            show_mask(medsam_seg, ax[1])
            show_box(box_np[0], ax[1])
            ax[1].set_title(f"MedSAM Segmentation w/ Accuracy: {acc}")
            plt.show()

In [None]:
medsam_model.eval()

test_acc = []

for x, y, patient, b_level in test_loader:
    img = x
    B, H, W = img.size()
    img_3c = img.repeat(3, 1, 1, 1).view(B, 3, H, W)#.to(device)

    box_np = np.array([[0,0, W, H]])

    with torch.no_grad():
        image_embedding = medsam_model.image_encoder(img_3c) # (1, 256, 64, 64)
    
    medsam_seg = medsam_inference(medsam_model, image_embedding, box_np, H, W)

    acc = (torch.tensor(medsam_seg) == y).float().mean()
    test_acc.append((acc, patient, b_level))

    if i == 0:
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].imshow(img, cmap="gray")
        show_box(box_np[0], ax[0])
        ax[0].set_title("Input Image and Bounding Box")
        ax[1].imshow(img, cmap="gray")
        show_mask(medsam_seg, ax[1])
        show_box(box_np[0], ax[1])
        ax[1].set_title(f"MedSAM Segmentation w/ Accuracy: {acc}")
        plt.show()