In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip3 install monai torchio timm git+https://github.com/facebookresearch/segment-anything.git icecream slicerio

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-hd27y93g
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-hd27y93g
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Collecting torchio
  Downloading torchio-0.20.1-py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.7/50.7 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting timm
  Downloading timm-1.0.11-py3-none-any.whl.metadata (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.4/48.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting icecream
  Downloadi

In [30]:
# appending a path
sys.path.append('/content/drive/MyDrive/finetuneSAMmain')

import os
import torch
import json
from PIL import Image
import numpy as np
import csv
from torchvision import transforms
from utils.utils import inverse_normalize
from models.sam import sam_model_registry
from argparse import Namespace
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt

# Directory paths
image_dir = '/content/drive/MyDrive/correct_dataset/split/test_images'
ground_truth_dir = '/content/drive/MyDrive/correct_dataset/split/test_masks'
output_dir = '/content/drive/MyDrive/correct_dataset/split/finetune_testing_output_oct17'
mask_output_dir = os.path.join(output_dir, 'masks')
dsc_output_file = os.path.join(output_dir, 'dsc_results.csv')

# Create the output and mask output directories if they don't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
Path(mask_output_dir).mkdir(parents=True, exist_ok=True)

# Load model
checkpoint_dir = "/content/drive/MyDrive/correct_dataset/split/2D-SAM_vit_b_decoder_adapter_OCT-Ear_noprompt_Oct10"
args_path = f"{checkpoint_dir}/args.json"

with open(args_path, 'r') as f:
    args_dict = json.load(f)

args = Namespace(**args_dict)
sam_fine_tune = sam_model_registry[args.arch](args, checkpoint=os.path.join(args.dir_checkpoint, 'checkpoint_best.pth'), num_classes=args.num_cls)
sam_fine_tune = sam_fine_tune.to('cuda').eval()

def calculate_mean_std(image_folder):
    pixel_values = []

    for image_filename in os.listdir(image_folder):
        image_path = os.path.join(image_folder, image_filename)
        if os.path.isfile(image_path) and image_filename.endswith(('.png', '.jpg', '.jpeg')):
            # Open the image and convert to grayscale
            image = Image.open(image_path).convert('L')
            image_array = np.array(image) / 255.0  # Normalize to range [0, 1]
            pixel_values.extend(image_array.flatten())

    pixel_values = np.array(pixel_values)
    mean = np.mean(pixel_values)
    std = np.std(pixel_values)

    return mean, std

# Function to evaluate a single image slice
def evaluate_1_slice(image_path, model, mean, std):
    img = Image.open(image_path)
    Pil_img = img.copy()

    img = transforms.Resize((1024, 1024))(img)
    transform_img = transforms.Compose([transforms.ToTensor()])
    img = transform_img(img)
    imgs = torch.unsqueeze(transforms.Normalize(mean=[mean], std=[std])(img), 0).cuda()

    with torch.no_grad():
        img_emb = model.image_encoder(imgs)
        sparse_emb, dense_emb = model.prompt_encoder(points=None, boxes=None, masks=None)
        pred, _ = model.mask_decoder(
            image_embeddings=img_emb,
            image_pe=model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_emb,
            dense_prompt_embeddings=dense_emb,
            multimask_output=True,
        )
        pred = pred.argmax(dim=1)

    return pred, Pil_img

# Define subdirectory paths for comparison images and colorized predictions
comparison_dir = os.path.join(output_dir, 'comparison')
colorized_predictions_dir = os.path.join(output_dir, 'colorized_predictions')

# Create the subdirectories if they don't exist
Path(comparison_dir).mkdir(parents=True, exist_ok=True)
Path(colorized_predictions_dir).mkdir(parents=True, exist_ok=True)

# Prepare CSV file to store DSC, accuracy, precision, recall, and IoU for each image
with open(dsc_output_file, mode='w', newline='') as dsc_file:
    # Header with metrics for each class (0 to num_classes - 1)
    dsc_writer = csv.writer(dsc_file)
    headers = ['Image Name']
    num_classes = 4  # Assuming 4 classes (0: background, 1-3: actual classes)
    for cls in range(num_classes):
        headers += [
            f'Class {cls} DSC', f'Class {cls} Accuracy', f'Class {cls} Precision',
            f'Class {cls} Recall', f'Class {cls} IoU'
        ]
    dsc_writer.writerow(headers)

    # Initialize accumulators for combined averages
    total_metrics = {
        'dsc': np.zeros(num_classes),
        'accuracy': np.zeros(num_classes),
        'precision': np.zeros(num_classes),
        'recall': np.zeros(num_classes),
        'iou': np.zeros(num_classes)
    }

    cumulative_sum_metrics = {
        'dsc': 0.0,
        'accuracy': 0.0,
        'precision': 0.0,
        'recall': 0.0,
        'iou': 0.0
    }
    image_count = 0  # To calculate the average later

    mean, std = calculate_mean_std(image_dir)

    # Iterate through all images in the directory
    for image_filename in tqdm(os.listdir(image_dir)):
        image_path = os.path.join(image_dir, image_filename)
        ground_truth_path = os.path.join(ground_truth_dir, image_filename.replace('.jpg', '.png'))

        if os.path.exists(ground_truth_path):
            # Evaluate the image slice
            pred_1, ori_img = evaluate_1_slice(image_path, sam_fine_tune, mean, std)

            # Convert predicted mask to a PIL image
            mask_pred_1 = ((pred_1).cpu()).float()
            pil_mask1 = Image.fromarray(np.array(mask_pred_1[0], dtype=np.uint8), 'L').resize(ori_img.size, resample=Image.NEAREST)

            # Save predicted mask
            mask_img_filename = os.path.join(mask_output_dir, f'{os.path.splitext(image_filename)[0]}' + '.png')
            pil_mask1.save(mask_img_filename)

            # Load ground truth mask and resize to match prediction size
            ground_truth_img = Image.open(ground_truth_path).convert('L').resize(ori_img.size, resample=Image.NEAREST)
            ground_truth_display = np.array(ground_truth_img)

            # Initialize a list to store metrics for the current image
            image_metrics = [image_filename]

            # Calculate DSC, accuracy, precision, recall, and IoU for each class
            mask_display = np.array(pil_mask1)

            for cls in range(num_classes):
                pred_binary = (mask_display == cls).astype(float)
                gt_binary = (ground_truth_display == cls).astype(float)

                # Calculate True Positives, False Positives, False Negatives, and True Negatives
                true_positive = np.sum(pred_binary * gt_binary)
                false_positive = np.sum(pred_binary * (1 - gt_binary))
                false_negative = np.sum((1 - pred_binary) * gt_binary)
                true_negative = np.sum((1 - pred_binary) * (1 - gt_binary))

                # Calculate DSC (Dice Coefficient)
                union = 2 * true_positive + false_positive + false_negative
                dsc = (2 * true_positive / union) if union > 0 else 1.0

                # Calculate Accuracy
                total_pixels = true_positive + false_positive + false_negative + true_negative
                accuracy = (true_positive + true_negative) / total_pixels if total_pixels > 0 else 1.0

                # Calculate Precision
                precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 1.0

                # Calculate Recall
                recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 1.0

                # Calculate IoU (Jaccard Index)
                intersection = true_positive
                union = true_positive + false_positive + false_negative
                iou = intersection / union if union > 0 else 1.0

                # Append the metrics for the current class to the row
                image_metrics += [dsc, accuracy, precision, recall, iou]

                # Accumulate metrics for combined averages
                total_metrics['dsc'][cls] += dsc
                total_metrics['accuracy'][cls] += accuracy
                total_metrics['precision'][cls] += precision
                total_metrics['recall'][cls] += recall
                total_metrics['iou'][cls] += iou

                # Accumulate the sum of metrics for overall averages
                cumulative_sum_metrics['dsc'] += dsc
                cumulative_sum_metrics['accuracy'] += accuracy
                cumulative_sum_metrics['precision'] += precision
                cumulative_sum_metrics['recall'] += recall
                cumulative_sum_metrics['iou'] += iou

            # Write the metrics for this image (all classes) to the CSV
            dsc_writer.writerow(image_metrics)
            image_count += 1  # Increment image count for averaging

            # Visualization and overlay code
            # Overlay colors: Red for class 1, Green for class 2, Blue for class 3
            overlay_colors = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])

            # Create color overlays for the predicted mask
            pred_overlay = np.zeros((mask_display.shape[0], mask_display.shape[1], 3), dtype=np.float32)
            for cls in range(1, 4):  # Classes 1 to 3
                pred_overlay[mask_display == cls] = overlay_colors[cls - 1]

            # Blend the original image with the predicted mask overlay
            alpha = 0.5
            ori_img_array = np.array(ori_img).astype(np.float32) / 255.0
            pred_blend = (1 - alpha) * ori_img_array + alpha * pred_overlay

            # Save the RGB overlay of predicted masks in 'colorized_predictions' directory
            colorized_prediction_filename = os.path.join(colorized_predictions_dir, f'{os.path.splitext(image_filename)[0]}.png')
            plt.imsave(colorized_prediction_filename, pred_blend)

            # Comparison visualization: original, predicted, and ground truth
            gt_overlay = np.zeros((ground_truth_display.shape[0], ground_truth_display.shape[1], 3), dtype=np.float32)
            for cls in range(1, 4):
                gt_overlay[ground_truth_display == cls] = overlay_colors[cls - 1]

            gt_blend = (1 - alpha) * ori_img_array + alpha * gt_overlay

            # Side-by-side comparison of original, predicted, and ground truth
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))

            # Save the side-by-side comparison in the 'comparison' directory
            comparison_filename = os.path.join(comparison_dir, f'{os.path.splitext(image_filename)[0]}.png')
            axes[0].imshow(ori_img_array)
            axes[0].axis('off')

            axes[1].imshow(pred_blend)
            axes[1].axis('off')

            axes[2].imshow(gt_blend)
            axes[2].axis('off')

            plt.tight_layout()
            plt.savefig(comparison_filename)
            plt.close()

# Calculate and print average metrics for all images
if image_count > 0:
    avg_metrics = {key: total_metrics[key] / image_count for key in total_metrics}
    print("Combined Averages for All Images:")
    for cls in range(num_classes):
        print(f"Class {cls} - DSC: {avg_metrics['dsc'][cls]:.4f}, Accuracy: {avg_metrics['accuracy'][cls]:.4f}, Precision: {avg_metrics['precision'][cls]:.4f}, Recall: {avg_metrics['recall'][cls]:.4f}, IoU: {avg_metrics['iou'][cls]:.4f}")

    # Calculate and print overall averages across all classes
    overall_avg_metrics = {key: cumulative_sum_metrics[key] / (image_count * num_classes) for key in cumulative_sum_metrics}
    print("\nOverall Averages Across All Classes for All Images:")
    print(f"DSC: {overall_avg_metrics['dsc']:.4f}, Accuracy: {overall_avg_metrics['accuracy']:.4f}, Precision: {overall_avg_metrics['precision']:.4f}, Recall: {overall_avg_metrics['recall']:.4f}, IoU: {overall_avg_metrics['iou']:.4f}")

print(f"All results saved to {output_dir}")
print(f"DSC results saved to {dsc_output_file}")


100%|██████████| 34/34 [00:40<00:00,  1.20s/it]

Combined Averages for All Images:
Class 0 - DSC: 0.9906, Accuracy: 0.9834, Precision: 0.9911, Recall: 0.9901, IoU: 0.9813
Class 1 - DSC: 0.8940, Accuracy: 0.9877, Precision: 0.8723, Recall: 0.9258, IoU: 0.8134
Class 2 - DSC: 0.7494, Accuracy: 0.9893, Precision: 0.7670, Recall: 0.7492, IoU: 0.6100
Class 3 - DSC: 0.6803, Accuracy: 0.9872, Precision: 0.7265, Recall: 0.6690, IoU: 0.5668

Overall Averages Across All Classes for All Images:
DSC: 0.8286, Accuracy: 0.9869, Precision: 0.8392, Recall: 0.8335, IoU: 0.7429
All results saved to /content/drive/MyDrive/correct_dataset/split/finetune_testing_output_oct17
DSC results saved to /content/drive/MyDrive/correct_dataset/split/finetune_testing_output_oct17/dsc_results.csv



