In [1]:
import os
import torch
import numpy as np
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
import matplotlib.pyplot as plt
import cv2
import torch
import xml.etree.ElementTree as ET


In [None]:
# !wget -q 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'

In [2]:
SAM_CHECKPOINT_PATH = os.path.join("weights", "sam_vit_h_4b8939.pth")
SAM_ENCODER_VERSION = "vit_h"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
XML_NAME = os.path.join("groundtruth", "tb_lungs.xml")


cuda


In [3]:


class SAM:
    def __init__(self):
        self.predictor = self._initialize_predictor()

    def _initialize_predictor(self):
        sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE)
        return SamPredictor(sam)

    def set_image(self, image: np.ndarray):
        self.predictor.set_image(image)

    def segment_no_jitter(self, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
        self.set_image(image)
        result_masks = []
        for box in xyxy:
            masks, scores, logits = self.predictor.predict(
                box=box,
                multimask_output=True
            )
            index = np.argmax(scores)
            result_masks.append(masks[index])
        result_array = np.array(result_masks)
        print('made it ')
        return {'left': np.transpose(np.nonzero(result_array[0])),
                'right': np.transpose(np.nonzero(result_array[1]))
                },  {'left': result_array[0],
                'right':  result_array[1]}



    # helper functions to show mask and box for SAM display
    @staticmethod
    def show_mask(mask, plt, color, random_color=False):
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([0.4])], axis=0)
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        plt.imshow(mask_image)

    @staticmethod
    #display the bounding box on the plot
    def show_box(box, ax):
        x0, y0 = box[0], box[1]
        w, h = box[2] - box[0], box[3] - box[1]
        ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))



In [None]:

def extract_rle_from_xml(xml_file_path, image_name, structure):
    tree = ET.parse(xml_file_path)
    root = tree.getroot()

    # Iterate through all 'image' elements in the XML
    for image in root.findall('image'):
        if image.get('name') == image_name:
            width = int(image.get('width'))
            height = int(image.get('height'))


            # If the image name matches, search for the structure
            for mask in image.findall('mask'):
                if mask.get('label') == structure:
                    mask_top = int(mask.get('top'))
                    mask_width = int(mask.get('width'))
                    mask_left = int(mask.get('left'))
                    return mask.get('rle'), width, height, mask_top, mask_width,mask_left
    return None


def cvat_rle_to_binary_image_mask(rle, left, top, width, img_h, img_w):
    # convert CVAT tight object RLE to COCO-style whole image mask
    rle = [int(x) for x in rle.split(',')]

    mask = np.zeros((img_h, img_w), dtype=np.uint8)
    value = 0
    offset = 0
    for rle_count in rle:
        while rle_count > 0:
            y, x = divmod(offset, width)
            mask[y + top][x + left] = value
            rle_count -= 1
            offset += 1
        value = 1 - value

    return mask



In [None]:
def dice_coefficient(mask1, mask2):
    """
    Calculate the Dice similarity coefficient between two masks.

    Parameters:
    - mask1: a binary 2D array (ground truth mask)
    - mask2: a binary 2D array (predicted mask)

    Returns:
    - dice: Dice coefficient as a float
    """
    # Flatten the masks
    mask1 = mask1.flatten()
    mask2 = mask2.flatten()

    # Calculate intersection and sum of pixels for each mask
    intersection = np.sum(mask1 * mask2)
    total_pixels_mask1 = np.sum(mask1)
    total_pixels_mask2 = np.sum(mask2)

    # Calculate the Dice coefficient
    dice = (2. * intersection) / (total_pixels_mask1 + total_pixels_mask2)

    return dice

In [4]:
def calculate_lung_bounding_boxes(image_shape, offset_factor=0.05, top_margin_factor=0.15):
    """
    Calculate bounding boxes for the left and right lung areas.

    Inputs:
    - image_shape: The shape of the image as (height, width).
    - offset_factor: How far from the midline the boxes start, as a fraction of the image width.
    - top_margin_factor: How far down from the top of the image the boxes start, as a fraction of the image height.

    Returns:
    - A tuple containing the bounding box for the left lung and the right lung.
      Each box is represented as [x_min, y_min, x_max, y_max].
    """
    height, width = image_shape
    box_width = width // 4
    box_height = height // 1.4

    midline = width // 2
    offset = int(width * offset_factor)
    top_margin = int(height * top_margin_factor)

    left_box = np.array([midline - box_width - offset, top_margin, midline - offset, top_margin + box_height])
    right_box = np.array([midline + offset, top_margin, midline + box_width + offset, top_margin + box_height])

    return  right_box, left_box


In [None]:
def evaluation_engine():
  sam = SAM()
  left_dice_score = []
  right_dice_score = []
  DATA = '/content/tb_annotated'
  for root, dirs,files in os.walk(DATA):
      for idx, file in enumerate(files):
          print(idx, file)
          if file.endswith('.png') or file.endswith('.JPG') or file.endswith('.jpeg') and 'checkpoint' not in file:
          # try:
              image_name = (os.path.join(root, file))
              image = cv2.imread(image_name)
              #file=file[:-4]

              left_box, right_box = calculate_lung_bounding_boxes(image.shape[:2])
              detections_masks_dict, detections_array_dict = sam.segment_no_jitter(image, [left_box, right_box])

              # Plotting
              fig, ax = plt.subplots(1)
              ax.imshow(image)  # Display the original image

              # Plot the masks and bounding boxes
              for box, mask in zip([left_box, right_box], ['left', 'right']):
                  mask_image = detections_array_dict[mask]
                  ax.imshow(mask_image, cmap='jet', alpha=0.3)  # Overlay the mask with some transparency
                  SAM.show_box(box, ax)  # Show the bounding box

              # Display the plot
              plt.title(f"Segmentation and Bounding Box for {file}")
              plt.show()

              gt_structures = ['left_lung', 'right_lung']
              ground_truth_name = file

              xml_file_path = XML_NAME

              # extract masks from ground truth rle file
              binary_masks = []
              for gt_struct in gt_structures:
                  rle, width, height, top, mask_width, mask_left = extract_rle_from_xml(xml_file_path, ground_truth_name, gt_struct)
                  print(rle)
                  # compute binary mask from rle
                  mask = cvat_rle_to_binary_image_mask(rle, mask_left, top, mask_width, height, width)
                  plt.imshow(mask)
                  plt.show()
                  binary_masks.append(mask)

              left_dice = dice_coefficient(binary_masks[0], detections_array_dict['left'])
              right_dice = dice_coefficient(binary_masks[1], detections_array_dict['right'])
              left_dice_score.append(left_dice)
              right_dice_score.append(right_dice)
  # Data preparation for plotting
  data = [left_dice_score, right_dice_score]
  labels = ['Left Lung', 'Right Lung']

  # Create box plots
  plt.figure(figsize=(10, 6))
  plt.boxplot(data, labels=labels)

  # Adding title and labels
  plt.title('Dice Scores for Left and Right Lungs')
  plt.ylabel('Dice Score')

  # Show the plot
  plt.show()

# evaluation_engine()

In [5]:
sam = SAM()


In [11]:
def resize_image_with_padding(img, desired_size=500):
    # Compute the new size, maintaining the aspect ratio
    ratio = float(desired_size) / max(img.shape)
    new_size = tuple([int(x * ratio) for x in img.shape[1::-1]])

    # Resize the image
    resized_img = cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)

    # Create a new image and place the resized image into the center
    new_img = np.zeros((desired_size, desired_size, img.shape[2]), dtype=np.uint8)
    x_offset = (desired_size - new_size[0]) // 2
    y_offset = (desired_size - new_size[1]) // 2
    new_img[y_offset:y_offset+new_size[1], x_offset:x_offset+new_size[0]] = resized_img

    return new_img



In [32]:

def store_masked_images():
    DATA = 'all_data'
    output_dir = 'contrast_then_sam'
    # os.makedirs(output_dir, exist_ok=True)

    for root, dirs, files in os.walk(DATA):
        for idx, file in enumerate(files):
            if file.endswith(('.png', '.jpg', '.jpeg')) and 'checkpoint' not in file:
                print(idx, file)
                image_name = os.path.join(root, file)
                
                image = cv2.imread(image_name)


                image = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
            
                # Split the Lab image into its channels
                l_channel, a_channel, b_channel = cv2.split(image)
            
                # Create a CLAHE object
                clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8,8))
            
                # Apply CLAHE to the L channel
                l_channel_clahe = clahe.apply(l_channel)
            
                # Merge the CLAHE enhanced L channel back with A and B channels
                lab_image_clahe = cv2.merge((l_channel_clahe, a_channel, b_channel))
            
                # Convert back to RGB color space
                image = cv2.cvtColor(lab_image_clahe, cv2.COLOR_Lab2RGB)

                
                left_box, right_box = calculate_lung_bounding_boxes(image.shape[:2])
                detections_masks_dict, detections_array_dict = sam.segment_no_jitter(image, [left_box, right_box])

                # Assuming masks are binary and of the same dimensions as the image
                left_mask = detections_array_dict['left']
                right_mask = detections_array_dict['right']

                # Find the most medial pixels for each lung
                left_medial_x = np.min(np.where(left_mask)[1])
                right_medial_x = np.max(np.where(right_mask)[1])

                # Calculate distance to move left lung mask
                distance_to_move = right_medial_x - left_medial_x - 1

                # Shift left lung mask to the right
                shifted_left_mask = np.roll(left_mask, shift=distance_to_move, axis=1)

                # Combine masks
                combined_mask = np.logical_or(shifted_left_mask, right_mask)
       
                # Find the bounding box of the combined mask
                rows = np.any(combined_mask, axis=1)
                cols = np.any(combined_mask, axis=0)
                ymin, ymax = np.where(rows)[0][[0, -1]]
                xmin, xmax = np.where(cols)[0][[0, -1]]


                # image = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
            
                # # Split the Lab image into its channels
                # l_channel, a_channel, b_channel = cv2.split(image)
            
                # # Create a CLAHE object
                # clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8,8))
            
                # # Apply CLAHE to the L channel
                # l_channel_clahe = clahe.apply(l_channel)
            
                # # Merge the CLAHE enhanced L channel back with A and B channels
                # lab_image_clahe = cv2.merge((l_channel_clahe, a_channel, b_channel))
            
                # # Convert back to RGB color space
                # image = cv2.cvtColor(lab_image_clahe, cv2.COLOR_Lab2RGB)

                
                # image = cv2.GaussianBlur(image,(51,51),0)

                # Crop the image
                # Create a new blank image and place the lungs within it
                new_image = np.zeros_like(image)
                new_image[ymin:ymax+1, xmin:xmax+1] = image[ymin:ymax+1, xmin:xmax+1] * combined_mask[ymin:ymax+1, xmin:xmax+1, None]
                cropped_image = new_image[ymin:ymax+1, xmin:xmax+1]


                # Add this part to the end of your for loop after you obtain the cropped_image
                # resized_cropped_image = resize_image_with_padding(cropped_image)
                resized_cropped_image = cv2.resize(cropped_image, (500,500))
                # plt.imshow(resized_cropped_image)  # Convert to RGB for displaying
                # # plt.title("Resized Cropped Image")
                # plt.show()

                output_image_path = os.path.join(output_dir, f"lungs_{file}")
                cv2.imwrite(output_image_path, resized_cropped_image)

                # Save the cropped image
                # output_image_path = os.path.join(output_dir, f"cropped_{file}")
                # cv2.imwrite(output_image_path, cropped_image)

                # Optionally display the cropped image
                # fig, ax = plt.subplots(1)
                # ax.imshow(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))  # Convert to RGB for displaying
                # plt.title(f"Cropped Segmentation for {file}")
                # plt.show()


# Call the function to process and store masked images
store_masked_images()

0 CHNCXR_0069_0.png
made it 
1 CHNCXR_0470_1.png
made it 
2 CHNCXR_0367_1.png
made it 
3 CHNCXR_0147_0.png
made it 
4 CHNCXR_0174_0.png
made it 
5 CHNCXR_0644_1.png
made it 
6 CHNCXR_0003_0.png
made it 
7 CHNCXR_0454_1.png
made it 
8 CHNCXR_0366_1.png
made it 
9 CHNCXR_0467_1.png
made it 
10 CHNCXR_0616_1.png
made it 
11 CHNCXR_0420_1.png
made it 
12 CHNCXR_0296_0.png
made it 
13 CHNCXR_0010_0.png
made it 
14 CHNCXR_0046_0.png
made it 
15 CHNCXR_0553_1.png
made it 
16 CHNCXR_0086_0.png
made it 
17 CHNCXR_0073_0.png
made it 
18 CHNCXR_0588_1.png
made it 
19 CHNCXR_0517_1.png
made it 
20 CHNCXR_0158_0.png
made it 
21 CHNCXR_0423_1.png
made it 
22 CHNCXR_0061_0.png
made it 
23 CHNCXR_0662_1.png
made it 
24 CHNCXR_0240_0.png
made it 
25 CHNCXR_0625_1.png
made it 
26 CHNCXR_0529_1.png
made it 
27 CHNCXR_0455_1.png
made it 
28 CHNCXR_0409_1.png
made it 
29 CHNCXR_0329_1.png
made it 
30 CHNCXR_0014_0.png
made it 
31 CHNCXR_0276_0.png
made it 
32 CHNCXR_0021_0.png
made it 
33 CHNCXR_0234_0.png