In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import torch
from monai.utils import set_determinism
from torch.utils.data import DataLoader

sys.path.insert(0, "..")

In [None]:
from src.utils import setup_dirs

root_dir = Path(os.getcwd()).parent
data_dir, log_dir, out_dir = setup_dirs(root_dir)
data_dir = data_dir / "ACDC" / "database"

set_determinism(seed=42)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
from src.transforms.transforms import get_transforms
from src.datasets.acdc_dataset import ACDCDataset

augment = True
train_transforms, val_transforms = get_transforms(spatial_dims=2, augment=augment)
train_data = ACDCDataset(data_dir=data_dir / "training", transform=train_transforms)
test_data = ACDCDataset(
    data_dir=data_dir / "testing", transform=val_transforms
)  # TODO: do we need separate test transforms?

train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=0)
test_loader = DataLoader(test_data, batch_size=4, shuffle=True, num_workers=0)

In [None]:
from segment_anything import SamPredictor, sam_model_registry

model_type = "vit_h"
checkpoint = root_dir / "models" / "sam_vit_h_4b8939.pth"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam = sam.to(device)
predictor = SamPredictor(sam)

In [None]:
def get_bounding_box(ground_truth_map):
    # get bounding box from mask
    y_indices, x_indices = np.where(ground_truth_map > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    # add perturbation to bounding box coordinates
    H, W = ground_truth_map.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]

    return bbox


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
    )
    ax.scatter(
        neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
    )


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))  # %%


In [None]:
# TODO: base model inference on ACDC
# TODO: base model inference on MNMs
# TODO: fine-tune on ACDC, inference on ACDC, MNMs
# TODO: fine-tune on MNMs, inference on ACDC, MNMs

In [None]:

import cv2
import matplotlib.pyplot as plt

batch = next(iter(test_loader))
inputs, labels = batch["image"][0].to(device), batch["label"][0].to(device, dtype=torch.uint8)

# The input to the SAM predictor needs to be HWC, where C = 3 in either RGB or BGR format
inputs = cv2.cvtColor(inputs.permute(2, 1, 0).numpy(), cv2.COLOR_GRAY2RGB)
# Scale to 0-255, convert to uint8
inputs = ((inputs - inputs.min()) * (1 / (inputs.max() - inputs.min()) * 255)).astype("uint8")
predictor.set_image(inputs)

bboxes = []
labels = labels[0].permute(1, 0)  # Swap W, H
num_classes = np.max(labels) + 1

# Get bounding box for each class of one-hot encoded mask
for class_index in range(num_classes):
    onehot = np.array((labels == class_index)).astype(int)
    bboxes.append(np.array(get_bounding_box(onehot)))

masks = []
for bbox in bboxes:
    mask, _, _ = predictor.predict(box=bbox, multimask_output=False)
    masks.append(mask)


In [None]:
plt.figure(figsize=(10, 10))
for i in range(num_classes):
    print(i)
    plt.subplot(num_classes, 3, i * 3 + 1)
    plt.imshow(inputs)
    show_box(bboxes[i], plt.gca())
    plt.axis("off")
    plt.subplot(num_classes, 3, i * 3 + 2)
    plt.imshow((labels == i))
    plt.axis("off")
    plt.subplot(num_classes, 3, i * 3 + 3)
    show_mask(masks[i], plt.gca())
    plt.axis("off")

plt.show()



In [None]:
eps = 1e-6
dice_scores = []
for class_index in range(num_classes):
    # Ignore background class
    if class_index == 0: continue

    ground_truth = (labels == class_index).astype(int)

    tp = masks[class_index] * ground_truth
    fp = masks[class_index] * (1 - ground_truth)
    fn = (1 - masks[class_index]) * ground_truth
    tn = (1 - masks[class_index]) * (1 - ground_truth)

    dice = (2 * tp.sum() + eps) / (2 * tp.sum() + fp.sum() + fn.sum() + eps)
    print(f"Class {class_index} dice: {dice:.3f}")
    dice_scores.append(dice)

print(f"Average dice: {np.mean(dice_scores):.3f}")



In [None]:
from segment_anything.utils.transforms import ResizeLongestSide

resize_transform = ResizeLongestSide(sam.image_encoder.img_size)


def prepare_image(image, transform, device):
    image = cv2.cvtColor(image.permute(2, 1, 0).numpy(), cv2.COLOR_GRAY2RGB)
    image = ((image - image.min()) * (1 / (image.max() - image.min()) * 255)).astype("uint8")
    image = transform.apply_image(image)
    image = torch.as_tensor(image, device=device)
    return image.permute(2, 0, 1).contiguous()

In [None]:
eps = 1e-6


def calculate_dice_per_class(masks, labels, ignore_background=True):
    dice_scores = []
    for class_index, mask in enumerate(masks):
        if ignore_background and class_index == 0: continue

        ground_truth = (labels == class_index).astype(int)

        tp = mask * ground_truth
        fp = mask * (1 - ground_truth)
        fn = (1 - mask) * ground_truth
        # tn = (1 - mask) * (1 - ground_truth)

        dice = (2 * tp.sum() + eps) / (2 * tp.sum() + fp.sum() + fn.sum() + eps)
        # print(f"Class {class_index} dice: {dice:.3f}")
        dice_scores.append(dice)

    return dice_scores


In [None]:
dice_for_batch = []

for batch in test_loader:
    inputs, labels = batch["image"].to(device), batch["label"].to(device, dtype=torch.uint8)
    # batched_input = []
    for index, image in enumerate(inputs):
        bboxes = []
        ground_truth = labels[index][0].permute(1, 0)  # Swap W, H
        num_classes = np.max(ground_truth) + 1
        # Get bounding box for each class of one-hot encoded mask
        for class_index in range(num_classes):
            onehot = np.array((ground_truth == class_index)).astype(int)
            if np.count_nonzero(onehot) == 0:
                bboxes.append(None)
            else:
                bboxes.append(np.array(get_bounding_box(onehot)))

        masks = []
        for bbox in bboxes:
            # prepared_image = prepare_image(image, resize_transform, device)
            # batched_input.append({"image": prepared_image, "box": bbox, "original_size": image.shape[1:]})
            mask, _, _ = predictor.predict(box=bbox, multimask_output=False)
            masks.append(mask)

        dice_scores = calculate_dice_per_class(masks, labels)
        dice_for_batch.append(dice_scores)

    # batched_output = sam(batched_input, multimask_output=False)
    # print(batched_output)



In [None]:
dice_for_batch