# Metrics on test dataset

In this notebook we evaluate all the different models trained previously and compare their metrics on a test dataset.

### Import libraries

In [1]:
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt
import os

os.chdir("/home/ubuntu/")

  from .autonotebook import tqdm as notebook_tqdm


## Load dataset

First we will load the test dataset from the small subset of the total dataset.

In [2]:
ds = load_dataset('json', data_files='/home/ubuntu/data/small_metadata.json')

test_dataset = ds["train"].filter(lambda example: example["split"] == "test")
del(ds)
print(f'Length of test dataset:', len(test_dataset))

Length of test dataset: 379


Convert to a SAM dataset class.

In [3]:
from utils.sam_dataset import SAMDataset
from transformers import SamProcessor
# Define dataset location folder
data_folder = "/home/ubuntu/data/"
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
sam_test_dataset = SAMDataset(dataset=test_dataset, processor=processor, data_folder=data_folder)

## Load models

We load all the models trained before.
* Model 1: SAM trained with topological and geometrical loss with box prompt. Best one is for $\lambda = 0.1$.
* Model 2: SAM trained with topological and geometrical loss with point prompt
* Model 3: SAM trained with geometrical loss with box prompt
* Model 4: SAM trained with geometrical loss with point prompt
* Model 5: MedSAM with box prompt
* Model 6: SAM ViT-base model with box prompt

In [4]:
from transformers import SamModel

model_1 = torch.load("/home/ubuntu/models/sam_experiments/topo-no-reg-lambda-0.1_geom-int150_box-prompt.pth")
for name, param in model_1.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

model_2 = torch.load("/home/ubuntu/models/sam_experiments/small_private/topo-no-reg_geom-int150_box-prompt.pth")
for name, param in model_2.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

model_3 = torch.load("/home/ubuntu/models/sam_experiments/dev280/topo-no-reg_geom-int150_box-prompt.pth")
for name, param in model_3.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

model_4 = torch.load("/home/ubuntu/models/sam_experiments/no-topo_geom-int150_box-prompt.pth")
for name, param in model_4.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

model_medsam = SamModel.from_pretrained("wanglab/medsam-vit-base")
for name, param in model_medsam.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

model_sam_base = SamModel.from_pretrained("facebook/sam-vit-base")
for name, param in model_sam_base.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

## Compute the metrics for every model

Given a model and a dataloader this function returns a dictionary with all the metrics for that model.

We will compute many different metrics for every model:
| Metric     | Definition |
| ----- | ---------- |
| Intersection over Union `IoU`   | Formula = Intersection / Union |
| Dice Square Coefficient `DSC` | Formula: Dice = (2 * Intersection) / (Sum of squares of the masks' areas) |
| Surface Distance `SurfDist` | Recently discovered by Deepmind |
| Sensitivity `Sens` | Formula: Sensitivity = True Positives / (True Positives + False Negatives) |
| Specificity `Spec` | Formula: Specificity = True Negatives / (True Negatives + False Positives) |
| Hausdorff Distance `HausDist` | Quantifies the maximum distance between the contours of the predicted and ground truth masks |
| Average Precision `AP` | Computed by integrating the precision-recall curve |
| F1 Score `F1` | Formula: F1 Score = 2 * (Precision * Recall) / (Precision + Recall) |

In [5]:
import torch.nn.functional as F
from utils.metrics import iou, dsc, surfdist, sensitivity, specificity, hausdorff_dist, ap, f1_score
from statistics import mean
def metrics_calculation(model, test_dataloader, prompt, device, noised_prompt):
    """
    Compute all the metrics above for one model on a given dataset.

    Args:
        model (transformers.model): model to use
        test_dataloader (torch.utils.data.dataloader.DataLoader): dataloader for the test dataset
        prompt (string): "box" or "point" depending on how we want to do the inference
        device (string): device where we work
        noised_prompt (Bool): denoise prompt

    Returns:
        metrics_dict (dictionary): dictionary containing all array with metrics for every image
    """
    iou_values = []
    dsc_values = []
    # surfdist_values = []
    sens_values = []
    spec_values = []
    # hausdist_values = []
    ap_values = []
    f1_values = []

    # Set model to device
    model.to(device)
    # Set evaluation mode
    model.eval()

    # Run on all batches
    for batch in test_dataloader:
        # Get ground-truth masks
        ground_truth_masks = batch["ground_truth_mask"].float().unsqueeze(1).to(device)
        _, _, m_h, m_w = ground_truth_masks.shape

        with torch.no_grad():
            if noised_prompt:
                # Forward pass
                if prompt == "point":
                    outputs = model(pixel_values=batch["pixel_values"].to(device),
                                    input_points=batch["input_points"].to(device),
                                    multimask_output=False)
                elif prompt == "box":
                    outputs = model(pixel_values=batch["pixel_values"].to(device),
                                    input_boxes=batch["input_boxes"].to(device),
                                    multimask_output=False)
            elif not noised_prompt: 
                # Get middle point
                x = (batch["gt_box"][0] + batch["gt_box"][2])/2
                y = (batch["gt_box"][1] + batch["gt_box"][3])/2

                # Compute the inputs for the desired image
                inputs = processor(batch["original_image"], input_boxes=[[batch["gt_box"]]], 
                                input_points=[[[x.item(), y.item()]]], return_tensors="pt")  

                # Forward pass
                if prompt == "point":
                    outputs = model(pixel_values=inputs["pixel_values"].to(device),
                                    input_points=inputs["input_points"].to(device),
                                    multimask_output=False)
                elif prompt == "box":
                    outputs = model(pixel_values=inputs["pixel_values"].to(device),
                                    input_boxes=inputs["input_boxes"].to(device),
                                    multimask_output=False)
            # Get masks
            predicted_masks = outputs.pred_masks.to(device)
            # Masks post processing
            predicted_masks = F.interpolate(predicted_masks.squeeze(1), (1024, 1024), 
                                            mode="bilinear", align_corners=False)
            predicted_masks = predicted_masks[..., :992, :1024]
            predicted_masks = F.interpolate(predicted_masks, (m_h, m_w), 
                                            mode="bilinear", align_corners=False)
            
            # Apply sigmoid
            predicted_masks = torch.sigmoid(predicted_masks)

            # Compute metrics
            iou_values.append(iou(predicted_masks, ground_truth_masks))
            dsc_values.append(dsc(predicted_masks, ground_truth_masks))
            # surfdist_values.append(surfdist(predicted_masks, ground_truth_masks))
            sens_values.append(sensitivity(predicted_masks, ground_truth_masks))
            spec_values.append(specificity(predicted_masks, ground_truth_masks))
            # hausdist_values.append(hausdorff_dist(predicted_masks, ground_truth_masks))
            ap_values.append(ap(predicted_masks, ground_truth_masks))
            f1_values.append(f1_score(predicted_masks, ground_truth_masks))
    
    # Create dictionary with all values
    metrics_dict = {}
    metrics_dict["IoU"] = mean(iou_values)
    metrics_dict["DSC"] = mean(dsc_values)
    # metrics_dict["SurfDist"] = mean(surfdist_values)
    metrics_dict["Sens"] = mean(sens_values)
    metrics_dict["Spec"] = mean(spec_values)
    # metrics_dict["HausDist"] = mean(hausdist_values)
    metrics_dict["AP"] = mean(ap_values)
    metrics_dict["F1"] = mean(f1_values)

    return metrics_dict


In [14]:
import torch.utils.data as DataLoader

# Create dataloader for the dataset
sam_test_dataloader = DataLoader.DataLoader(sam_test_dataset, batch_size=1, shuffle=True)
device = "cuda"

# Compute metrics
metrics_model_1 = metrics_calculation(model = model_1.to(device), test_dataloader=sam_test_dataloader, prompt = "box", device = "cuda", noised_prompt = False)
metrics_model_2 = metrics_calculation(model = model_2.to(device), test_dataloader=sam_test_dataloader, prompt = "point", device = "cuda", noised_prompt = False)
metrics_model_3 = metrics_calculation(model = model_3.to(device), test_dataloader=sam_test_dataloader, prompt = "box", device = "cuda", noised_prompt = False)
metrics_model_4 = metrics_calculation(model = model_4.to(device), test_dataloader=sam_test_dataloader, prompt = "point", device = "cuda", noised_prompt = False)
metrics_medsam = metrics_calculation(model = model_medsam.to(device), test_dataloader=sam_test_dataloader, prompt = "box", device = "cuda", noised_prompt = False)
metrics_sam_base = metrics_calculation(model = model_sam_base.to(device), test_dataloader=sam_test_dataloader, prompt = "box", device = "cuda", noised_prompt = False)

### Save the results.

We will save the results on a `models_metrics_complete.csv` that can be later used to compare the models.

In [15]:
import pandas as pd

# Save information in lists
models_list = ["SAM trained with Topo. + Geom. Loss and box prompt on private and intestinal dataset",
               "SAM trained with Topo. + Geom. Loss and box prompt on private dataset",
               "SAM trained with Topo. + Geom. Loss and box prompt on private dev280 dataset",
               "SAM trained with Geom. Loss and box prompt on private and intestinal dataset",
               "MedSAM",
               "SAM Base"]
iou_list = [metrics_model_1["IoU"], metrics_model_2["IoU"], metrics_model_3["IoU"], metrics_model_4["IoU"], metrics_medsam["IoU"], metrics_sam_base["IoU"]]
dsc_list = [metrics_model_1["DSC"], metrics_model_2["DSC"], metrics_model_3["DSC"], metrics_model_4["DSC"], metrics_medsam["DSC"], metrics_sam_base["DSC"]]
# surfdist_list = [metrics_model_1["SurfDist"], metrics_model_2["SurfDist"], metrics_model_3["SurfDist"], metrics_model_4["SurfDist"]]
sens_list = [metrics_model_1["Sens"], metrics_model_2["Sens"], metrics_model_3["Sens"], metrics_model_4["Sens"], metrics_medsam["Sens"], metrics_sam_base["Sens"]]
spec_list = [metrics_model_1["Spec"], metrics_model_2["Spec"], metrics_model_3["Spec"], metrics_model_4["Spec"], metrics_medsam["Spec"], metrics_sam_base["Spec"]]
# hausdist_list = [metrics_model_1["HausDist"], metrics_model_2["HausDist"], metrics_model_3["HausDist"], metrics_model_4["HausDist"]]
ap_list = [metrics_model_1["AP"], metrics_model_2["AP"], metrics_model_3["AP"], metrics_model_4["AP"], metrics_medsam["AP"], metrics_sam_base["AP"]]
f1_list = [metrics_model_1["F1"], metrics_model_2["F1"], metrics_model_3["F1"], metrics_model_4["F1"], metrics_medsam["F1"], metrics_sam_base["F1"]]

# Save dataframe and later to .csv
df = pd.DataFrame(list(zip(models_list, iou_list, dsc_list, 
                           # surfdist_list,
                           sens_list, spec_list, 
                           # hausdist_list, 
                           ap_list, f1_list)),
                           columns =['Model', 'IoU', 'DSC', 
                                     # 'Surface Distance', 
                                     'Sensitivity', 'Specificity',
                                     # 'Hausdorff Distance', 
                                     'AP', 'F1'])

df.to_csv("/home/ubuntu/models_metrics_complete_.csv")