In [1]:
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
from monai.metrics import DiceMetric, MeanIoU, SurfaceDiceMetric, SSIMMetric, GeneralizedDiceScore
from segment_anything.utils.transforms import ResizeLongestSide
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from monai.losses import GeneralizedDiceLoss, DiceLoss, GeneralizedDiceFocalLoss
from monai.metrics import DiceMetric, GeneralizedDiceScore
#from LinearWarmupCosine import LinearWarmupCosineAnnealingLR

PyTorch version: 2.3.1+cu121
Torchvision version: 0.18.1+cu121
CUDA is available: True


In [2]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

In [3]:
image_folder = "C:/Users/39327/Desktop/SAM/DataSet/CVC-300/images"
mask_folder =  "C:/Users/39327/Desktop/SAM/DataSet/CVC-300/masks"
save_folder = "C:/Users/39327/Desktop/SAM/DataSet/CVC-300/dataset"
os.makedirs(save_folder, exist_ok = True)

image_path = []
mask_path = []

for root, dirs, files in os.walk(image_folder, topdown=False): #finds MRI files
    for name in files:
        if name.endswith(".png"):
            apath=os.path.join(root, name)
            image_path.append(apath)
            
for root, dirs, files in os.walk(mask_folder, topdown=False): #finds MRI files
    for name in files:
        if name.endswith(".png"):
            apath=os.path.join(root, name)
            mask_path.append(apath)
            
print(image_path[-1], mask_path[-1])

# with open('D:\Yuheng Li\Segment Anything\kvasir-seg\\kavsir_bboxes.json') as f:
#     labels = json.load(f)

X_train, X_test, y_train, y_test = train_test_split(image_path, mask_path, test_size=0.2, random_state=42)

C:/Users/39327/Desktop/SAM/DataSet/CVC-300/images\208.png C:/Users/39327/Desktop/SAM/DataSet/CVC-300/masks\208.png


In [4]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# sam_checkpoint = "sam_vit_h_4b8939.pth"
# model_type = "vit_h"

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

# sam_checkpoint = "sam_vit_l_0b3195.pth"
# model_type = "vit_l"


device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

transform = ResizeLongestSide(sam.image_encoder.img_size)

In [5]:
# def extract_bboxes(mask, num_instances):

#     """Compute bounding boxes from masks.

#     mask: [height, width, num_instances]. Mask pixels are either 1 or 0.

 

#     Returns: bbox array [num_instances, (y1, x1, y2, x2)].

#     """

#     boxes = np.zeros([num_instances, 4], dtype=np.int32)

#     for i in range(num_instances):

#         m = mask

#         # Bounding box.

#         horizontal_indicies = np.where(np.any(m, axis=0))[0]

# #         print("np.any(m, axis=0)",np.any(m, axis=0))

# #         print("p.where(np.any(m, axis=0))",np.where(np.any(m, axis=0)))

#         vertical_indicies = np.where(np.any(m, axis=1))[0]

#         if horizontal_indicies.shape[0]:

#             x1, x2 = horizontal_indicies[[0, -1]]

#             y1, y2 = vertical_indicies[[0, -1]]

#             # x2 and y2 should not be part of the box. Increment by 1.

#             x2 += 1

#             y2 += 1

#         else:

#             # No mask for this instance. Might happen due to

#             # resizing or cropping. Set bbox to zeros

#             x1, x2, y1, y2 = 0, 0, 0, 0

#         boxes[i] = np.array([y1, x1, y2, x2])
        

#     return boxes.astype(np.int32)

In [6]:
#computing bonding boxes with yolo
import numpy as np
import torch
from ultralytics import YOLO

# Define the function to convert YOLO bounding boxes to the required format
def extract_bboxes_from_yolo(yolo_data):
    """
    Compute bounding boxes from YOLO data.

    yolo_data: Tensor with shape (num_instances, 4) containing detected bounding boxes.
               Each bounding box is represented as [x_min, y_min, x_max, y_max].

    Returns: bbox array [num_instances, (y1, x1, y2, x2)].
    """
    num_instances = yolo_data.shape[0]
    boxes = np.zeros([num_instances, 4], dtype=np.int32)

    for i in range(num_instances):
        x_min, y_min, x_max, y_max = yolo_data[i, :4]
        boxes[i] = np.array([y_min, x_min, y_max, x_max])
    
    return boxes

# Load the model
model = YOLO("C:/Users/39327/runs/detect/train10/weights/best.pt")

# Dictionary to store bounding boxes for each image
bbox_cache = {}

def extract_bboxes(image, num_instances):
    """
    Perform object detection on the image and extract bounding boxes.

    image: The input image on which to perform object detection.

    Returns: bbox array [num_instances, (y1, x1, y2, x2)].
             If no boxes are detected, returns a bounding box that covers the entire image.
    """
    # Generate a unique identifier for the image (e.g., using image data hash)
    image_id = hash(image.data.tobytes())

    # Check if bounding boxes for this image are already cached
    if image_id in bbox_cache:
        print("Using cached bounding boxes.")
        return bbox_cache[image_id]

    # Perform inference
    results = model(image, device='cuda', conf=0.2)

    if not results:  # If results list is empty
        height, width = image.shape[:2]
        bbox_cache[image_id] = np.array([[0, 0, height, width]], dtype=np.int32)
        return bbox_cache[image_id]
    
    result = results[0]  # Assuming only one result is returned for one image
    boxes = result.boxes  # Boxes object for bounding box outputs

    # Extract the 'xyxy' attribute
    xyxy = boxes.xyxy.cpu().numpy()  # Convert tensor to numpy array

    if xyxy.size == 0:  # No boxes detected
        height, width = image.shape[:2]
        bbox_cache[image_id] = np.array([[0, 0, height, width]], dtype=np.int32)
        return bbox_cache[image_id]

    # Convert to the required format
    extracted_boxes = extract_bboxes_from_yolo(xyxy)
    
    #Calculate shift amounts
    # shift_x = image.shape[1] * 0.13  # 20% of the image width
    # shift_y = image.shape[0] * 0.1   # 10% of the image height
    
    # #Shift boxes 20% to the right and 10% down
    # extracted_boxes[:, 1] = np.clip(extracted_boxes[:, 1] + shift_x, 0, image.shape[1])  # Shift x1
    # extracted_boxes[:, 3] = np.clip(extracted_boxes[:, 3] + shift_x, 0, image.shape[1])  # Shift x2
    # extracted_boxes[:, 0] = np.clip(extracted_boxes[:, 0] + shift_y, 0, image.shape[0])  # Shift y1
    # extracted_boxes[:, 2] = np.clip(extracted_boxes[:, 2] + shift_y, 0, image.shape[0])  # Shift y2
    
    # # Increase box size by 10%
    # for i in range(extracted_boxes.shape[0]):
    #     y1, x1, y2, x2 = extracted_boxes[i]
    #     box_width = x2 - x1
    #     box_height = y2 - y1
        
    #     # Increase size by 10%
    #     new_width = box_width * 1.15
    #     new_height = box_height * 1.1
        
    #     # Calculate the new coordinates
    #     new_x1 = x1 - (new_width - box_width) / 2
    #     new_y1 = y1 - (new_height - box_height) / 2
    #     new_x2 = x2 + (new_width - box_width) / 2
    #     new_y2 = y2 + (new_height - box_height) / 2
        
    #     # Ensure the new coordinates are within the image boundaries
    #     new_x1 = np.clip(new_x1, 0, image.shape[1])
    #     new_y1 = np.clip(new_y1, 0, image.shape[0])
    #     new_x2 = np.clip(new_x2, 0, image.shape[1])
    #     new_y2 = np.clip(new_y2, 0, image.shape[0])
        
    #     extracted_boxes[i] = [new_y1, new_x1, new_y2, new_x2]

    # Store the bounding boxes in the cache
    bbox_cache[image_id] = extracted_boxes

    #print("Extracted boxes:", extracted_boxes)
    return extracted_boxes


In [7]:
class ColonDataset(Dataset):
    def __init__(self, image_path, mask_path, image_size):
        self.image_path = image_path
        self.mask_path = mask_path
        self.image_size = image_size
        
        # TODO: use ResizeLongestSide and pad to square
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, index):
        idx = self.image_path[index].split('images\\')[1].split('.png')[0]

        image = cv2.imread(self.image_path[index])
        gt = cv2.imread(self.mask_path[index])
        gt = cv2.cvtColor(gt, cv2.COLOR_BGR2GRAY) / 255
        gt = gt.astype('float32')

        #bbox_arr = extract_bboxes(gt, 1)
        #to do: extract bounding box from yolo
        bbox_arr = extract_bboxes(image, 1)

        gt_resized = cv2.resize(gt, (1024, 1024), cv2.INTER_NEAREST)
        gt_resized = torch.as_tensor(gt_resized > 0).long()
        
        gt = torch.from_numpy(gt)
        gt_binary_mask = torch.as_tensor(gt > 0).long()

        transform = ResizeLongestSide(self.image_size)
        input_image = transform.apply_image(image)
        input_image =  cv2.resize(input_image, (1024, 1024), cv2.INTER_CUBIC)
        input_image= self.to_tensor(input_image)
        
        # input_image= self.normalize(input_image)
#         print(input_image.shape)
#         plt.figure()
#         plt.imshow(input_image[0])
#         print('before preprcoess', torch.max(input_image[0]), torch.min(input_image[0]))
        # input_image = sam.preprocess(input_image.to('cuda:0')).detach().cpu()
#         print('after preprcoess', torch.max(input_image[0]), torch.min(input_image[0]))
#         input_image = cv2.resize(input_image.numpy(), (1024, 1024), cv2.INTER_CUBIC)

#         plt.figure()
#         plt.imshow(input_image[0])
        
        original_image_size = image.shape[:2]
        input_size = tuple(input_image.shape[-2:])
        
        return input_image, np.array(bbox_arr), gt_binary_mask, gt_resized, original_image_size, input_size
    

def my_collate(batch):
    
    images, bboxes, masks, gt_resized, original_image_size, input_size = zip(*batch)
    images = torch.stack(images, dim=0)
    gt_resized = torch.stack(gt_resized, dim=0)
    
    masks = [m for m in masks]
    bboxes = [m for m in bboxes]
    original_image_size = [m for m in original_image_size]
    input_size = [m for m in input_size]
    
    return images, bboxes, masks, gt_resized, original_image_size, input_size


In [8]:


val_dataset = ColonDataset(X_test, y_test, sam.image_encoder.img_size)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True, collate_fn = my_collate)

In [9]:
model_path = "C:/Users/39327/CV/polyp/SAM-CVC/New folder/All datasets-20240524T154817Z-001/All datasets/SAM Finetune Enc Dec"

sam.prompt_encoder.load_state_dict(torch.load(os.path.join(model_path, "prompt_enc_best_dice_model_DL.pth")))
sam.image_encoder.load_state_dict(torch.load(os.path.join(model_path, "img_enc_best_dice_model_DL.pth")))
sam.mask_decoder.load_state_dict(torch.load(os.path.join(model_path, "dec_best_dice_model_DL.pth")))
sam.eval()

with torch.no_grad():
    batch_dice = []
    batch_gd = []
    batch_iou = []

    for batch in val_dataloader:

        img, bbox, mask, gt_resized, original_image_size, input_size = batch[0], batch[1], batch[2], batch[3], batch[4], batch[5]

        dice = DiceMetric()
        gd =  GeneralizedDiceScore()
        iou = MeanIoU()

        for i in range(len(mask)):
            image_embedding = sam.image_encoder(img[i].unsqueeze(0).to(device))

            orig_x, orig_y =  original_image_size[i][0], original_image_size[i][1]
            col_x1, col_x2 = bbox[i][:,1] * 1024/orig_y, bbox[i][:,3]* 1024/orig_y
            col_y1, col_y2 = bbox[i][:,0]* 1024/orig_x, bbox[i][:,2]* 1024/orig_x

            box = np.array([col_x1, col_y1, col_x2, col_y2]).transpose()

            num_masks = box.shape[0]
            box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
            sparse_embeddings, dense_embeddings = sam.prompt_encoder(
              points=None,
              boxes= box_torch,
              masks = None
            )

            low_res_masks, iou_predictions = sam.mask_decoder(
              image_embeddings=image_embedding,
              image_pe=sam.prompt_encoder.get_dense_pe(),
              sparse_prompt_embeddings=sparse_embeddings,
              dense_prompt_embeddings=dense_embeddings,
              multimask_output=False
            )

            upscaled_masks = sam.postprocess_masks(low_res_masks, input_size[i], original_image_size[i])

            binary_mask = torch.sigmoid(upscaled_masks.detach().cpu())
            binary_mask =  (binary_mask>0.5).float()

            gt_binary_mask = mask[i].detach().cpu()

            if binary_mask.size()[0] > 1:
                binary_mask = torch.unsqueeze(torch.sum(binary_mask, 0) / binary_mask.size()[0],0)

            dice.reset()
            gd.reset()
            iou.reset()

            dice(binary_mask[0,:], gt_binary_mask.unsqueeze(0))
            gd(binary_mask[0,:], gt_binary_mask.unsqueeze(0))
            iou(binary_mask[0,:], gt_binary_mask.unsqueeze(0))
            final_dice = dice.aggregate().numpy()[0]
            final_gd = gd.aggregate().numpy()[0]
            final_iou = iou.aggregate().numpy()[0]
            batch_dice.append(final_dice)
            batch_gd.append(final_gd)
            batch_iou.append(final_iou)


    print(f'Mean val dice: {sum(batch_dice) / len(batch_dice)}')
    print(f'Mean val gd: {sum(batch_gd) / len(batch_gd)}')
    print(f'Mean val iou: {sum(batch_iou) / len(batch_iou)}')


0: 576x640 1 polyp, 72.7ms
Speed: 3.0ms preprocess, 72.7ms inference, 69.3ms postprocess per image at shape (1, 3, 576, 640)

0: 576x640 1 polyp, 39.5ms
Speed: 4.0ms preprocess, 39.5ms inference, 2.0ms postprocess per image at shape (1, 3, 576, 640)

0: 576x640 1 polyp, 40.0ms
Speed: 3.0ms preprocess, 40.0ms inference, 2.0ms postprocess per image at shape (1, 3, 576, 640)

0: 576x640 1 polyp, 39.0ms
Speed: 3.5ms preprocess, 39.0ms inference, 2.0ms postprocess per image at shape (1, 3, 576, 640)

0: 576x640 1 polyp, 39.5ms
Speed: 4.4ms preprocess, 39.5ms inference, 2.0ms postprocess per image at shape (1, 3, 576, 640)

0: 576x640 1 polyp, 36.5ms
Speed: 4.0ms preprocess, 36.5ms inference, 3.0ms postprocess per image at shape (1, 3, 576, 640)

0: 576x640 1 polyp, 36.6ms
Speed: 4.0ms preprocess, 36.6ms inference, 2.0ms postprocess per image at shape (1, 3, 576, 640)

0: 576x640 1 polyp, 36.0ms
Speed: 3.0ms preprocess, 36.0ms inference, 3.0ms postprocess per image at shape (1, 3, 576, 640)