<a href="https://colab.research.google.com/github/rekalantar/MedSegmentAnything_SAM_FineTune/blob/main/MedSegmentAnything_FineTuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Introduction**

SegmentAnything (SAM) is an innovative model architecture developed by the Facebook (Meta) research group for generating segmentation masks for a wide range of objects or regions within an image. It's designed to be flexible, capable of segmenting anything from everyday objects to specific structures in medical images. This makes it an ideal tool for many medical imaging tasks.

Fine-tuning SAM for medical imaging tasks generally involves a few steps:

*   **Loading and Preprocessing the Data:** The first step is to load your medicalimaging data, which often comes in specific formats like DICOM or NIfTI. Libraries such as pydicom or nibabel can be very useful for this. Preprocessing might include tasks such as reorienting the images, normalizing pixel intensities, and converting the images and masks into suitable formats for the model.
*   **Creating Bounding Box Prompts:** SAM uses bounding box prompts to guide the segmentation. These bounding boxes should roughly encapsulate the structure you want to segment. You can generate these bounding boxes based on your segmentation masks. Note that SAM accepts multiple bounding boxes, allowing for multi-object segmentation in a single forward pass.

*   **Preparing the Model and Processor:** You'll need to load the pre-trained SAM model and its associated processor. The processor is used to prepare your inputs and prompts for the model.

*   **Fine-Tuning the Model:** With your data and model ready, you can now fine-tune SAM on your specific task. This often involves running a training loop, computing the loss function (comparing the model's output to the ground truth mask), backpropagating the gradients, and updating the model's weights. SAM is trained to generate segmentation masks that match the ground truth as closely as possible.

*   **Evaluating the Model:** After training, you'll want to evaluate your model's performance on a validation set. This will give you an idea of how well your model is likely to perform on unseen data. You could use metrics such as the Dice coefficient or Intersection over Union (IoU) for evaluation.

*   **Inference:** With a trained model, you can perform segmentation on new medical images. This involves preparing the image and bounding box prompt, passing them through the model, and post-processing the output to obtain your final segmentation mask.

## Packages

In [None]:
!pip install -q monai
!pip install -q SimpleITK
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q natsort

In [None]:
import os
import errno
import glob
import monai
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import time
import SimpleITK as sitk
from statistics import mean
from torch.optim import Adam
from natsort import natsorted
import matplotlib.pyplot as plt
from transformers import SamModel,SamConfig
import matplotlib.patches as patches
from transformers import SamProcessor
from IPython.display import clear_output
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import threshold, normalize
from skimage.io import imread, imshow
import pandas as pd
import json
import math
from skimage import color
import cv2
import csv
from skimage import color

%matplotlib inline

from monai.transforms import (
    EnsureChannelFirstd,
    EnsureTyped,
    Compose,
    CropForegroundd,
    CopyItemsd,
    LoadImaged,
    CenterSpatialCropd,
    Invertd,
    OneOf,
    Orientationd,
    MapTransform,
    NormalizeIntensityd,
    RandSpatialCropSamplesd,
    CenterSpatialCropd,
    RandSpatialCropd,
    SpatialPadd,
    ScaleIntensityRanged,
    Spacingd,
    RepeatChanneld,
    ToTensord,
    Resized,
)

In [None]:
TYPE='_Balanced'
#TYPE=""
base_dir = './Procesado'+TYPE
datasets = ['train', 'valid','test']

orig_types = ['images', 'masks']

METHOD='SAM' #Method to apply(original image 'SAM' or detection 'SAM+DETR')

In [None]:
# Initialize dictionary for storing image and label paths
data_paths = {}

# Create directories and print the number of images and masks in each
for dataset in datasets:
    folder = 'images'
    for data_type in orig_types:
        # Construct the directory path
        dir_path = os.path.join(base_dir, f'{dataset}/{data_type}')
        print(dir_path)

        # Find images and labels in the directory
        files = sorted(glob.glob(os.path.join(dir_path, "*.jpg")))
        print(len(files))

        # Store the image and label paths in the dictionary
        data_paths[f'{dataset}/{data_type}'] = files

print('Number of training images', len(data_paths['train/'+folder]))
print('Number of validation images', len(data_paths['valid/'+folder]))
print('Number of test images', len(data_paths['test/'+folder]))

./Procesado_Balanced/train/images
1605
./Procesado_Balanced/train/masks
1605
./Procesado_Balanced/valid/images
810
./Procesado_Balanced/valid/masks
810
./Procesado_Balanced/test/images
107
./Procesado_Balanced/test/masks
107
Number of training images 1605
Number of validation images 810
Number of test images 107


Now we can use a processor instance to prepare the images and prompts for training. The expected image size for the SAM model is 1024x1024 and 3 channels. The target masks are of size 256x256.

In [None]:
# create an instance of the processor for image preprocessing
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
processor

SamProcessor:
- image_processor: SamImageProcessor {
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_pad": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "SamImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "pad_size": {
    "height": 1024,
    "width": 1024
  },
  "processor_class": "SamProcessor",
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "longest_edge": 1024
  }
}

In [None]:
def get_bounding_box(ground_truth_map):
    '''
    This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model
    The padding is random int values between 5 and 20 pixels
    '''

    if len(np.unique(ground_truth_map)) > 1:

        # get bounding box from mask
        y_indices, x_indices = np.where(ground_truth_map > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)

        # add perturbation to bounding box coordinates
        H, W = ground_truth_map.shape
        x_min = max(0, x_min - np.random.randint(5, 20))
        x_max = min(W, x_max + np.random.randint(5, 20))
        y_min = max(0, y_min - np.random.randint(5, 20))
        y_max = min(H, y_max + np.random.randint(5, 20))

        bbox = [x_min, y_min, x_max, y_max]

        return bbox
    else:
        return [0, 0, 256, 256] # if there is no mask in the array, set bbox to image size

In [None]:
class SAMDataset(Dataset):
    def __init__(self, image_paths, mask_paths, processor):

        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.processor = processor
        self.transforms = transforms = Compose([

            # load .jpg files
            LoadImaged(keys=['img', 'label']),

            # add channel id to match PyTorch configurations
            EnsureChannelFirstd(keys=['img', 'label']),

            # rescale image and label

            Resized(keys=['img'], spatial_size=(1024, 1024)),

            Resized(keys=['label'], spatial_size=(256, 256)),

            # scale intensities to 0 and 255 to match the expected input intensity range
            ScaleIntensityRanged(keys=['img'], a_min=-1000, a_max=2000,
                        b_min=0.0, b_max=255.0, clip=True),

            ScaleIntensityRanged(keys=['label'], a_min=0, a_max=255,
                         b_min=0.0, b_max=1.0, clip=True),
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        # create a dict of images and labels to apply Monai's dictionary transforms
        data_dict = self.transforms({'img': image_path, 'label': mask_path})

        # squeeze extra dimensions
        image = data_dict['img'].squeeze()
        ground_truth_mask = data_dict['label'].squeeze()
        weights = np.array([0.2989, 0.5870, 0.1140])
        ground_truth_mask = np.sum(ground_truth_mask * weights[:, np.newaxis, np.newaxis], axis=0)

        # convert to int type for huggingface's models expected inputs
        image_rgb = image.astype(np.uint8)

        mask = np.zeros([256, 256])
        ind=(ground_truth_mask>0.1)
        mask[ind]=1

        prompt = get_bounding_box(mask)

        # prepare image and prompt for the model
        inputs = self.processor(image_rgb, input_boxes=[[prompt]], return_tensors="pt")

        # remove batch dimension which the processor adds by default
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}

        # add ground truth segmentation (ground truth image size is 256x256)
        inputs["ground_truth_mask"] = torch.from_numpy(mask.astype(np.int8))
        inputs["path_name"]=image_path

        return inputs

In [None]:
# create train and validation dataloaders
train_dataset = SAMDataset(image_paths=data_paths['train/'+folder], mask_paths=data_paths['train/masks'], processor=processor)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

val_dataset = SAMDataset(image_paths=data_paths['valid/'+folder], mask_paths=data_paths['valid/masks'], processor=processor)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)


Finally, we can visualize our processed data along with the bounding boxes:

In [None]:
example = train_dataset[1]
for k,v in example.items():
    if type(k) != str:
        print(k,v.shape)

xmin, ymin, xmax, ymax = get_bounding_box(example['ground_truth_mask'])

fig, axs = plt.subplots(1, 2)

axs[0].imshow(example['pixel_values'][1], cmap='gray')
axs[0].axis('off')

axs[1].imshow(example['ground_truth_mask'], cmap='copper')

# create a Rectangle patch for the bounding box
rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none')

# add the patch to the second Axes
axs[1].add_patch(rect)

axs[1].axis('off')

plt.tight_layout()
plt.show()

In order to finetune the model, we freeze the encoder weights from the pre-trained SAM model:

In [None]:
# load the pretrained weights for finetuning
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder (encoder weights are frozen)
for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        print(name)
        param.requires_grad_(False)

vision_encoder.pos_embed
vision_encoder.patch_embed.projection.weight
vision_encoder.patch_embed.projection.bias
vision_encoder.layers.0.layer_norm1.weight
vision_encoder.layers.0.layer_norm1.bias
vision_encoder.layers.0.attn.rel_pos_h
vision_encoder.layers.0.attn.rel_pos_w
vision_encoder.layers.0.attn.qkv.weight
vision_encoder.layers.0.attn.qkv.bias
vision_encoder.layers.0.attn.proj.weight
vision_encoder.layers.0.attn.proj.bias
vision_encoder.layers.0.layer_norm2.weight
vision_encoder.layers.0.layer_norm2.bias
vision_encoder.layers.0.mlp.lin1.weight
vision_encoder.layers.0.mlp.lin1.bias
vision_encoder.layers.0.mlp.lin2.weight
vision_encoder.layers.0.mlp.lin2.bias
vision_encoder.layers.1.layer_norm1.weight
vision_encoder.layers.1.layer_norm1.bias
vision_encoder.layers.1.attn.rel_pos_h
vision_encoder.layers.1.attn.rel_pos_w
vision_encoder.layers.1.attn.qkv.weight
vision_encoder.layers.1.attn.qkv.bias
vision_encoder.layers.1.attn.proj.weight
vision_encoder.layers.1.attn.proj.bias
vision_

In [None]:
torch.cuda.empty_cache()

**Train Model**

In [None]:
try:
    os.makedirs('./SAMResults')
    os.makedirs('./SAMResults/best_models')
except OSError as e:
    if e.errno != errno.EEXIST:
        raise

In [None]:
inicio = time.time()
# define training loop
num_epochs = 100

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# define optimizer
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

# define segmentation loss with sigmoid activation applied to predictions from the model
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

# track mean train and validation losses
mean_train_losses, mean_val_losses = [], []

# create an artibarily large starting validation loss value
best_val_loss = 100.0
best_val_epoch = 0

# set model to train mode for gradient updating
model.train()
for epoch in tqdm(range(num_epochs),desc="Épocas"):

    # create temporary list to record training losses
    epoch_losses = []
    for i, batch in enumerate(train_dataloader):

        # forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # backward pass (compute gradients of parameters w.r.t. loss)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

        # visualize training predictions every 50 iterations
        if i % 50 == 0:

            # clear jupyter cell output
            clear_output(wait=True)

            fig, axs = plt.subplots(1, 3)
            xmin, ymin, xmax, ymax = get_bounding_box(batch['ground_truth_mask'][0])
            rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r',
                                     facecolor='none')

            axs[0].set_title('input image')
            axs[0].imshow(batch["pixel_values"][0,1], cmap='gray')
            axs[0].axis('off')

            axs[1].set_title('ground truth mask')
            axs[1].imshow(batch['ground_truth_mask'][0], cmap='copper')
            axs[1].add_patch(rect)
            axs[1].axis('off')

            # apply sigmoid
            medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))

            # convert soft mask to hard mask
            medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze()
            medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

            axs[2].set_title('predicted mask')
            axs[2].imshow(medsam_seg, cmap='copper')
            axs[2].axis('off')

            plt.tight_layout()
            plt.show()

    # create temporary list to record validation losses
    val_losses = []

    # set model to eval mode for validation
    with torch.no_grad():
        for val_batch in val_dataloader:

            # forward pass
            outputs = model(pixel_values=val_batch["pixel_values"].to(device),
                      input_boxes=val_batch["input_boxes"].to(device),
                      multimask_output=False)

            # calculate val loss
            predicted_val_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            val_loss = seg_loss(predicted_val_masks, ground_truth_masks.unsqueeze(1))

            val_losses.append(val_loss.item())

        # visualize the last validation prediction
        fig, axs = plt.subplots(1, 3)
        xmin, ymin, xmax, ymax = get_bounding_box(val_batch['ground_truth_mask'][0])

        axs[0].set_title('input image')
        axs[0].imshow(val_batch["pixel_values"][0,1], cmap='gray')
        axs[0].axis('off')

        axs[1].set_title('ground truth mask')
        axs[1].imshow(val_batch['ground_truth_mask'][0], cmap='copper')
        #axs[1].add_patch(rect)
        axs[1].axis('off')

        # apply sigmoid
        medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))

        # convert soft mask to hard mask
        medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze()
        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

        axs[2].set_title('predicted mask')
        axs[2].imshow(medsam_seg, cmap='copper')
        axs[2].axis('off')

        plt.tight_layout()
        plt.show()

        # save the best weights and record the best performing epoch
        if mean(val_losses) < best_val_loss:
            torch.save(model.state_dict(), f"./SAMResults/best_models/best_weights2_balanced.pth")
            print(f"Model Was Saved! Current Best val loss {best_val_loss}")
            best_val_loss = mean(val_losses)
            best_val_epoch = epoch
        else:
            print("Model Was Not Saved!")

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

    mean_train_losses.append(mean(epoch_losses))
    mean_val_losses.append(mean(val_losses))
print("Entrenado en: ", round((time.time()-inicio)/60,0))

In [None]:
print(mean_train_losses)
print(mean_val_losses)

[0.1893078379541914, 0.16077119806473872, 0.155571824702147, 0.15071400118022693, 0.1454905291212682, 0.13960010919986857, 0.1346950808046763, 0.1299981718494142, 0.12471443125020677, 0.12143349937189406, 0.11559977757967893, 0.1122417347453465, 0.1089594795325092, 0.10460114798441854, 0.1024728653215545, 0.09937414762758391, 0.09680039522432464, 0.09484490148746336, 0.09201044512686328, 0.08987316869872382, 0.08719571334922054, 0.08442066755621604, 0.0824760456322881, 0.08159185543981297, 0.08002666517210155, 0.07749425934111218, 0.07543943819598617, 0.07389088865381163, 0.07268091715013499, 0.07171095917900774, 0.0694904035868303, 0.0685374611634703, 0.06715506861143022, 0.06623564824879727, 0.06463676549935267, 0.06356901966522788, 0.06382684807911097, 0.06070357074618711, 0.06067648297901094, 0.05877307835397691, 0.05883645651125091, 0.05794646936785024, 0.0565747282958105, 0.05674248101926667, 0.054956769126226594, 0.05385955601837776, 0.05330570319731288, 0.054379924509755546, 0.

In [None]:
MODEL_PATH = './SAMResults/'+METHOD+TYPE
model.save_pretrained(MODEL_PATH)

In [None]:
# Define the model
modelo = SamModel(SamConfig(num_labels=2))
device = "cuda" if torch.cuda.is_available() else "cpu"
modelo.to(device)
# Specifies the path to the. pth file of the model
ruta_del_archivo = "./SAMResults/best_models/best_weights1_balanced.pth"

# Load the weights trained on the model
modelo.load_state_dict(torch.load(ruta_del_archivo,map_location=torch.device('cpu')))

# Set the model in evaluation mode (no training)
modelo.eval()

SamModel(
  (shared_image_embedding): SamPositionalEmbedding()
  (vision_encoder): SamVisionEncoder(
    (patch_embed): SamPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (layers): ModuleList(
      (0-11): 12 x SamVisionLayer(
        (layer_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): SamVisionAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (layer_norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): SamMLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (neck): SamVisionNeck(
      (conv1): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (layer_norm1): SamLayerNorm()
     

In [None]:
def dice_coefficient(mask1, mask2):
    intersection = np.logical_and(mask1, mask2)
    dice_val = (2.0 * intersection.sum()) / (mask1.sum() + mask2.sum())
    return round(dice_val,3)

In [None]:
df = pd.read_csv('./data_Balanced/DETR/resultados.csv')
df = df.fillna('NULL')
bboxes = [] #DETR
bboxes_gold = [] #GOLD
img_nombres = []

for index, row in df.iterrows():
    # Get the image name and prediction (if no prediction is taken the entire image)
    nombre_imagen = row['imagen']
    #DETR
    if row['predicion'] != 'NULL' :
        prediccion = [round(json.loads(row['predicion'])[0],0),
                      round(json.loads(row['predicion'])[1],0),
                      round(json.loads(row['predicion'])[2],0),
                      round(json.loads(row['predicion'])[3],0)]

    else:
        prediccion = [0, 0, 256, 256]
    #GOLD
    g_xmin = json.loads(row['true_box'])[0]
    g_ymin = json.loads(row['true_box'])[1]
    g_xmax = json.loads(row['true_box'])[2] + g_xmin
    g_ymax = json.loads(row['true_box'])[3] + g_ymin

    img_nombres.append(nombre_imagen)
    bboxes_gold.append([g_xmin,g_ymin,g_xmax,g_ymax])
    bboxes.append(prediccion)

In [None]:
bbox_dict=dict(zip(img_nombres,bboxes))
bbox_gold_dict=dict(zip(img_nombres,bboxes_gold))

In [None]:
print(bbox_dict.get('00015C.jpg'),'\n',bbox_gold_dict.get('00015C.jpg'))

[137.0, 62.0, 189.0, 115.0] 
 [65, 54, 109, 165]


In [None]:
#Function to obtain the bbox obtained by DETR for the SAM+DETR set and pass it as input_boxes
def get_boxfromDETR(name):
    return bbox_dict.get(name)

#Function to obtain the bbox for the SAM set and pass it as input_boxes
def get_goldbox(name):
    return bbox_gold_dict.get(name)

In [None]:
class SAM_Pred_Dataset(Dataset):
    def __init__(self, image_paths, mask_paths, processor):

        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.processor = processor
        self.transforms = transforms = Compose([

            # load .jpg files
            LoadImaged(keys=['img', 'label']),
            # add channel id to match PyTorch configurations
            EnsureChannelFirstd(keys=['img', 'label']),
            Resized(keys=['img'], spatial_size=(1024, 1024)),
            Resized(keys=['label'], spatial_size=(256, 256)),
            # scale intensities to 0 and 255 to match the expected input intensity range
            ScaleIntensityRanged(keys=['img'], a_min=-1000, a_max=2000,
                        b_min=0.0, b_max=255.0, clip=True),
            ScaleIntensityRanged(keys=['label'], a_min=0, a_max=255,
                         b_min=0.0, b_max=1.0, clip=True)
            ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        data_dict = self.transforms({'img': image_path, 'label': mask_path})

        image = data_dict['img'].squeeze()
        ground_truth_mask = data_dict['label'].squeeze()
        weights = np.array([0.2989, 0.5870, 0.1140])
        ground_truth_mask = np.sum(ground_truth_mask * weights[:, np.newaxis, np.newaxis], axis=0)
        image_rgb = image.astype(np.uint8)

        mask = np.zeros([256, 256])
        ind=(ground_truth_mask>0.1)
        mask[ind]=1

        name = image_path.split('/')[-1]
        if METHOD == 'SAM':
            prompt = get_goldbox(name)
        else:
            prompt = get_boxfromDETR(name)

        # prepare image and prompt for the model
        inputs = self.processor(image_rgb, input_boxes=[[prompt]], return_tensors="pt")

        # remove batch dimension which the processor adds by default
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}

        # add ground truth segmentation (ground truth image size is 256x256)
        inputs["ground_truth_mask"] = torch.from_numpy(mask.astype(np.int8))

        inputs["path_name"]=image_path

        return inputs

In [None]:
#GOLD bbox
METHOD ='SAM'
test_dataset = SAM_Pred_Dataset(image_paths=data_paths['test/'+folder],
                                mask_paths=data_paths['test/masks'],
                                processor=processor)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
for n,batch in enumerate(test_dataloader):
    if n == 14:
        print(batch['path_name'])
        print(batch['input_boxes'][0][0].tolist())
        break

['./Procesado_Balanced/test/images/00015C.jpg']
[65.0, 54.0, 109.0, 165.0]


In [None]:
#DETR bbox
METHOD ='SAM+DETR'
test_dataset2 = SAM_Pred_Dataset(image_paths=data_paths['test/'+folder],
                                 mask_paths=data_paths['test/masks'],
                                 processor=processor)
test_dataloader2 = DataLoader(test_dataset2, batch_size=1, shuffle=False)

In [None]:
for n,batch in enumerate(test_dataloader2):
    if n == 14:
        print(batch['path_name'])
        print(batch['input_boxes'][0][0].tolist())
        break

['./Procesado_Balanced/test/images/00015C.jpg']
[137.0, 62.0, 189.0, 115.0]


**Inference**

In [None]:
METHOD = 'SAM+DETR'

In [None]:
base_dir = './Procesado'+TYPE
datasets = ['train', 'valid','test']

if METHOD == 'detect':
    types = ['DETR', 'masks']
else:
    types = orig_types

# Initialize dictionary for storing image and label paths
data_paths = {}

# Create directories and print the number of images and masks in each
for dataset in datasets:
    for data_type in types:
        # Construct the directory path
        dir_path = os.path.join(base_dir, f'{dataset}/{data_type}')
        print(dir_path)

        # Find images and labels in the directory
        files = sorted(glob.glob(os.path.join(dir_path, "*.jpg")))
        print(len(files))

        # Store the image and label paths in the dictionary
        data_paths[f'{dataset}/{data_type}'] = files

print('Number of training images', len(data_paths['train/'+types[0]]))
print('Number of validation images', len(data_paths['valid/'+types[0]]))
print('Number of test images', len(data_paths['test/'+types[0]]))

./Procesado_Balanced/train/images
1605
./Procesado_Balanced/train/masks
1605
./Procesado_Balanced/valid/images
810
./Procesado_Balanced/valid/masks
810
./Procesado_Balanced/test/images
107
./Procesado_Balanced/test/masks
107
Number of training images 1605
Number of validation images 810
Number of test images 107


In [None]:
# Predictions
# create train and validation dataloaders
train_dataset = SAM_Pred_Dataset(image_paths=data_paths['train/'+types[0]], mask_paths=data_paths['train/masks'], processor=processor)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)

val_dataset = SAM_Pred_Dataset(image_paths=data_paths['valid/'+types[0]], mask_paths=data_paths['valid/masks'], processor=processor)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# create test dataloader
test_dataset = SAM_Pred_Dataset(image_paths=data_paths['test/'+types[0]], mask_paths=data_paths['test/masks'], processor=processor)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
#To save the predictions of test
try:
    os.makedirs('./data'+TYPE+'/'+METHOD)
    os.makedirs('./data'+TYPE+'/'+METHOD+'/test')
    os.makedirs('./data'+TYPE+'/'+METHOD+'/test/Cáncer')
    os.makedirs('./data'+TYPE+'/'+METHOD+'/test/Control')
except OSError as e:
    if e.errno != errno.EEXIST:
        raise

In [None]:
orig_masks=[]
for r in data_paths['test/masks']:
    Yimg = imread(r,as_gray=True)
    mask=np.zeros([250, 250])
    ind =(Yimg>0.1)
    mask[ind]=1

    orig_masks.append(mask)

In [None]:
for batch in test_dataloader:
    print(batch['pixel_values'][0].shape)
    break

torch.Size([3, 1024, 1024])


In [None]:
inicio=time.time()
random.seed(123)

if METHOD=='detect':
    if TYPE != "":
        used_model= 'SAM_best_weights2_balanced'
    else:
        used_model= 'SAM_best_weights2'
else:
    if TYPE != "":
        used_model= 'SAM_best_weights1_balanced'
    else:
        used_model= 'SAM_best_weights1'

writepredsdetectDict = [] # Dict Results

# Iterate through test images
with torch.no_grad():
    for indice,batch in tqdm(enumerate(test_dataloader)):
        image_name = os.path.basename(str(batch['path_name'][0]))
        # forward pass
        outputs = modelo(pixel_values=batch["pixel_values"].cuda(),
                      input_boxes=batch["input_boxes"].cuda(),
                      multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().cuda()

        # apply sigmoid
        medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
        # convert soft mask to hard mask
        medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

        medsam_seg2 = cv2.resize(medsam_seg,(250,250))
        m_dice = dice_coefficient(medsam_seg2,orig_masks[indice])

        if 'C' in image_name:
            new_dir='./data'+TYPE+'/'+METHOD+'/test/Cáncer/'
            label='Cáncer'
        else:
            new_dir='./data'+TYPE+'/'+METHOD+'/test/Control/'
            label='Control'

        Ximg = imread(str(batch['path_name'][0]))
        Ximg = cv2.resize(Ximg, (256, 256))

        result=cv2.bitwise_and(Ximg,Ximg,mask=medsam_seg)
        result=cv2.resize(result,(250,250))
        result= cv2.cvtColor(result, cv2.COLOR_BGR2RGB)

        cv2.imwrite(new_dir+str(image_name), result)
        time.sleep(2)

        writepredsdetectDict.append({'modelo':used_model,
                                     'imagen':str(image_name),
                                     'set':'Test',
                                     'clase':label,
                                     'true_mask':batch["ground_truth_mask"][0],
                                     'predicion':medsam_seg,
                                     'DICE': m_dice
                                    })
print('TEST PREDICTIONS COMPLETE')

#------We proceed to store the results in file
file_name='resultados.csv'
archivo='./data'+TYPE+'/'+METHOD+'/'+file_name
if os.path.isfile(archivo):
    modo = 'a+'
else:
    modo = 'w'
with open(archivo, modo, newline='') as csvfile:
    fieldnames = ['modelo', 'imagen','set', 'clase','true_mask','predicion','DICE']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    if modo=='w':
        writer.writeheader()
    for d in writepredsdetectDict:
        writer.writerow(d)
print('*****END*****')
print('Tiempo de ejecución: ',round((time.time()-inicio)/60,0))

107it [04:14,  2.38s/it]

TEST PREDICTIONS COMPLETE
*****END*****
Tiempo de ejecución:  4.0





In [None]:
#To save the predictions of valid
try:
    os.makedirs('./data'+TYPE+'/'+METHOD+'/valid')
    os.makedirs('./data'+TYPE+'/'+METHOD+'/valid/Cáncer')
    os.makedirs('./data'+TYPE+'/'+METHOD+'/valid/Control')
except OSError as e:
    if e.errno != errno.EEXIST:
        raise

In [None]:
orig_masks_val=[]
for r in data_paths['valid/masks']:
    Yimg = imread(r,as_gray=True)
    mask=np.zeros([250, 250])
    ind =(Yimg>0.1)
    mask[ind]=1
    orig_masks_val.append(mask)

In [None]:
inicio=time.time()

random.seed(123)
writepredsdetectDict = [] # Dict Results

# Iterate through test images
with torch.no_grad():
    for indice,batch in tqdm(enumerate(val_dataloader)):
        image_name = os.path.basename(str(batch['path_name'][0]))
        # forward pass
        outputs = modelo(pixel_values=batch["pixel_values"].cuda(),
                      input_boxes=batch["input_boxes"].cuda(),
                      multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().cuda()

        # apply sigmoid
        medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
        # convert soft mask to hard mask
        medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

        medsam_seg2 = cv2.resize(medsam_seg,(250,250))
        m_dice = dice_coefficient(medsam_seg2,orig_masks_val[indice])

        if 'C' in image_name:
            new_dir='./data'+TYPE+'/'+METHOD+'/valid/Cáncer/'
            label='Cáncer'
        else:
            new_dir='./data'+TYPE+'/'+METHOD+'/valid/Control/'
            label='Control'

        Ximg = imread(str(batch['path_name'][0]))
        Ximg = cv2.resize(Ximg, (256, 256))

        result=cv2.bitwise_and(Ximg,Ximg,mask=medsam_seg)
        result=cv2.resize(result,(250,250))
        result= cv2.cvtColor(result, cv2.COLOR_BGR2RGB)

        cv2.imwrite(new_dir+str(image_name), result)
        time.sleep(2)

        writepredsdetectDict.append({'modelo':used_model,
                                     'imagen':str(image_name),
                                     'set':'Valid',
                                     'clase':label,
                                     'true_mask':batch["ground_truth_mask"][0],
                                     'predicion':medsam_seg,
                                     'DICE': m_dice
                                    })
print('VALID PREDICTIONS COMPLETE')

#------We proceed to store the results in file
file_name='resultados.csv'
archivo='./data'+TYPE+'/'+METHOD+'/'+file_name
if os.path.isfile(archivo):
    modo = 'a+'
else:
    modo = 'w'
with open(archivo, modo, newline='') as csvfile:
    fieldnames = ['modelo', 'imagen','set', 'clase','true_mask','predicion','DICE']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    if modo=='w':
        writer.writeheader()
    for d in writepredsdetectDict:
        writer.writerow(d)
print('*****END*****')
print('Tiempo de ejecución: ',round((time.time()-inicio)/60,0))

810it [32:08,  2.38s/it]


VALID PREDICTIONS COMPLETE
*****END*****
Tiempo de ejecución:  32.0


In [None]:
#To save the predictions of train
try:
    os.makedirs('./data'+TYPE+'/'+METHOD+'/train')
    os.makedirs('./data'+TYPE+'/'+METHOD+'/train/Cáncer')
    os.makedirs('./data'+TYPE+'/'+METHOD+'/train/Control')
except OSError as e:
    if e.errno != errno.EEXIST:
        raise

In [None]:
orig_masks_train=[]
for r in data_paths['train/masks']:
    Yimg = imread(r,as_gray=True)
    mask=np.zeros([250, 250])
    ind =(Yimg>0.1)
    mask[ind]=1
    orig_masks_train.append(mask)

In [None]:
inicio=time.time()

writepredsdetectDict = [] # Dict Results
random.seed(123)

# Iterate through test images
with torch.no_grad():
    for indice,batch in tqdm(enumerate(train_dataloader)):
        image_name = os.path.basename(str(batch['path_name'][0]))
        # forward pass
        outputs = modelo(pixel_values=batch["pixel_values"].cuda(),
                      input_boxes=batch["input_boxes"].cuda(),
                      multimask_output=False)

        # compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().cuda()

        # apply sigmoid
        medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
        # convert soft mask to hard mask
        medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

        medsam_seg2 = cv2.resize(medsam_seg,(250,250))
        m_dice = dice_coefficient(medsam_seg2,orig_masks_train[indice])

        if 'C' in image_name:
            new_dir='./data'+TYPE+'/'+METHOD+'/train/Cáncer/'
            label='Cáncer'
        else:
            new_dir='./data'+TYPE+'/'+METHOD+'/train/Control/'
            label='Control'

        Ximg = imread(str(batch['path_name'][0]))
        Ximg = cv2.resize(Ximg, (256, 256))

        result=cv2.bitwise_and(Ximg,Ximg,mask=medsam_seg)
        result=cv2.resize(result,(250,250))
        result= cv2.cvtColor(result, cv2.COLOR_BGR2RGB)

        cv2.imwrite(new_dir+str(image_name), result)
        time.sleep(2)

        writepredsdetectDict.append({'modelo':used_model,
                                     'imagen':str(image_name),
                                     'set':'Train',
                                     'clase':label,
                                     'true_mask':batch["ground_truth_mask"][0],
                                     'predicion':medsam_seg,
                                     'DICE': m_dice
                                    })
print('TRAIN PREDICTIONS COMPLETE')

#------We proceed to store the results in file
file_name='resultados.csv'
archivo='./data'+TYPE+'/'+METHOD+'/'+file_name
if os.path.isfile(archivo):
    modo = 'a+'
else:
    modo = 'w'
with open(archivo, modo, newline='') as csvfile:
    fieldnames = ['modelo', 'imagen','set', 'clase','true_mask','predicion','DICE']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    if modo=='w':
        writer.writeheader()
    for d in writepredsdetectDict:
        writer.writerow(d)
print('*****END*****')
print('Tiempo de ejecución: ',round((time.time()-inicio)/60,0))

1605it [1:03:40,  2.38s/it]


TRAIN PREDICTIONS COMPLETE
*****END*****
Tiempo de ejecución:  64.0
