In [None]:
import torch
import os

print("PyTorch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())

!git clone https://github.com/gttae/med-hq-sam.git
!pip install timm
os.chdir('med-hq-sam')
!export PYTHONPATH=$(pwd)
from segment_anything import sam_model_registry, SamPredictor

In [None]:
!pip install -U monai

In [None]:
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry
from sklearn.metrics import f1_score
import torch.nn as nn
import glob
import random
import torch.nn.functional as F
import os

join = os.path.join


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
    )

def visualize_segmentation(image, gt_mask, pred_mask, epoch, step, file_path):
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(image.permute(1, 2, 0).cpu().numpy())
    axs[0].set_title('Input Image')
    axs[0].axis('off')

    axs[1].imshow(gt_mask.cpu().numpy(), cmap='gray')
    axs[1].set_title('Ground Truth Mask')
    axs[1].axis('off')

    axs[2].imshow(pred_mask.detach().cpu().numpy(), cmap='gray')
    axs[2].set_title('Predicted Mask')
    axs[2].axis('off')

    plt.suptitle(f'Epoch {epoch}, Step {step}')
    plt.savefig(file_path)
    plt.close()


class NpyDataset(Dataset):
    def __init__(self, data_root, bbox_shift=20):
        self.data_root = data_root
        self.gt_path = join(data_root, "gts")
        self.img_path = join(data_root, "imgs")
        self.gt_path_files = sorted(
            glob.glob(join(self.gt_path, "**/*.npy"), recursive=True)
        )
        self.gt_path_files = [
            file
            for file in self.gt_path_files
            if os.path.isfile(join(self.img_path, os.path.basename(file)))
        ]
        self.bbox_shift = bbox_shift
        print(f"number of images: {len(self.gt_path_files)}")

    def __len__(self):
        return len(self.gt_path_files)

    def __getitem__(self, index):
        # load npy image (1024, 1024, 3), [0,1]
        img_name = os.path.basename(self.gt_path_files[index])
        img_1024 = np.load(
            join(self.img_path, img_name), "r", allow_pickle=True
        )  # (1024, 1024, 3)
        # convert the shape to (3, H, W)
        img_1024 = np.transpose(img_1024, (2, 0, 1))
        assert (
            np.max(img_1024) <= 1.0 and np.min(img_1024) >= 0.0
        ), "image should be normalized to [0, 1]"
        gt = np.load(
            self.gt_path_files[index], "r", allow_pickle=True
        )  # multiple labels [0, 1,4,5...], (256,256)
        assert img_name == os.path.basename(self.gt_path_files[index]), (
            "img gt name error" + self.gt_path_files[index] + self.npy_files[index]
        )
        label_ids = np.unique(gt)[1:]
        gt2D = np.uint8(
            gt == random.choice(label_ids.tolist())
        )  # only one label, (256, 256)
        assert np.max(gt2D) == 1 and np.min(gt2D) == 0.0, "ground truth should be 0, 1"
        y_indices, x_indices = np.where(gt2D > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)
        # add perturbation to bounding box coordinates
        H, W = gt2D.shape
        x_min = max(0, x_min - random.randint(0, self.bbox_shift))
        x_max = min(W, x_max + random.randint(0, self.bbox_shift))
        y_min = max(0, y_min - random.randint(0, self.bbox_shift))
        y_max = min(H, y_max + random.randint(0, self.bbox_shift))
        bboxes = np.array([x_min, y_min, x_max, y_max])
        return (
            torch.tensor(img_1024).float(),
            torch.tensor(gt2D[None, :, :]).long(),
            torch.tensor(bboxes).float(),
            img_name,
        )

# Initialize your evaluation dataset
eval_dataset = NpyDataset("/content/drive/MyDrive/npy3/CT_Abd/")

# Initialize your data loader for evaluation
eval_dataloader = DataLoader(eval_dataset, batch_size=2, shuffle=False)


class MedSAM(nn.Module):
    def __init__(
        self,
        image_encoder,
        mask_decoder,
        prompt_encoder,
    ):
        super().__init__()
        self.image_encoder = image_encoder
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder
        # freeze prompt encoder
        for param in self.prompt_encoder.parameters():
            param.requires_grad = False

    def forward(self, image, box):
        image_embedding, interm_embeddings = self.image_encoder(image)  # (B, 256, 64, 64)
        interm_embeddings = interm_embeddings[0].unsqueeze(0)

        # do not compute gradients for prompt encoder
        with torch.no_grad():
            box_torch = torch.as_tensor(box, dtype=torch.float32, device=image.device)
            if len(box_torch.shape) == 2:
                box_torch = box_torch[:, None, :]  # (B, 1, 4)

            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=None,
                boxes=box_torch,
                masks=None,
            )
        low_res_masks, _ = self.mask_decoder(
            image_embeddings=image_embedding,  # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
            multimask_output=False,
            hq_token_only=False,
            interm_embeddings= interm_embeddings,
        )
        ori_res_masks = F.interpolate(
            low_res_masks,
            size=(image.shape[2], image.shape[3]),
            mode="bilinear",
            align_corners=False,
        )
        return ori_res_masks

# Load your model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_model = sam_model_registry['vit_l'](checkpoint="/content/drive/MyDrive/med_hq-sam_model_best.pth")
medsam_model = MedSAM(
    image_encoder=sam_model.image_encoder,
    mask_decoder=sam_model.mask_decoder,
    prompt_encoder=sam_model.prompt_encoder,
).to(device)

# Load the best model for evaluation
checkpoint = torch.load("/content/drive/MyDrive/med_hq-sam_model_best.pth", map_location=device)
medsam_model.load_state_dict(checkpoint["model"])

# Set the model to evaluation mode
medsam_model.eval()

# Define the Dice coefficient calculation function
def dice_coefficient(prediction, target):
    smooth = 1e-5
    intersection = (prediction * target).sum()
    dice = (2. * intersection + smooth) / (prediction.sum() + target.sum() + smooth)
    return dice.item()

# Initialize lists to store results
dice_scores = []

# Evaluate the model
with torch.no_grad():
    for image, gt2D, boxes, _ in tqdm(eval_dataloader):
        image, gt2D = image.to(device), gt2D.to(device)
        boxes_np = boxes.detach().cpu().numpy()

        # Forward pass
        medsam_pred = medsam_model(image, boxes_np)

        # Threshold the predicted masks
        pred_mask = (torch.sigmoid(medsam_pred) > 0.5).float()

        # Calculate Dice score for each sample in the batch
        for i in range(len(pred_mask)):
            dice = dice_coefficient(pred_mask[i, 0], gt2D[i, 0])
            dice_scores.append(dice)

        for i in range(len(pred_mask)):
            file_path = f"/content/drive/MyDrive/medhqsam_result/segmentation_result_{[i]}.png"
            visualize_segmentation(image[i], gt2D[i, 0], pred_mask[i, 0], 0, 0, file_path)

# Calculate the average Dice score
average_dice_score = np.mean(dice_scores)
print("Average Dice Score:", average_dice_score)