In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import os
from PIL import Image
from matplotlib import cm
import cv2
import matplotlib.pyplot as plt
import pandas as pd


  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [2]:
# --- 1. Load DINO-v2 Backbone ---

def get_dino_v2_backbone():
    # This returns a model directly, not a state_dict
    backbone = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")  # ✅ this is already a model
    return backbone


class DINOv2SegmentationModel(nn.Module):
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        self.num_classes = num_classes
        self.decoder = nn.Sequential(
            nn.Conv2d(768, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        features = self.backbone.get_intermediate_layers(x, n=1)[0]  # (B, N, 768)
        feat_size = int(features.shape[1] ** 0.5)
        features = features.permute(0, 2, 1).reshape(B, 768, feat_size, feat_size)  # (B, 768, h, w)
        out = self.decoder(features)  # (B, num_classes, h, w)
        out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)  # upscale to 1022x1022
        return out

In [3]:
def make_prediction(img_path):
    transform = transforms.Compose([
        transforms.Resize((1022, 1022)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    image = Image.open(img_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)  # (1, 3, 1022, 1022)
    image_np = np.array(Image.open(img_path).convert("RGB").resize((1022, 1022)))
    with torch.no_grad():
        output = model(input_tensor)  # (1, num_classes, 1022, 1022)
        pred_mask = output.argmax(dim=1).squeeze(0).cpu().numpy()  # (1022, 1022)
        probs = torch.softmax(output, dim=1)  # convert logits → probabilities
        conf, pred = torch.max(probs, dim=1)  # conf: confidence per pixel, pred: predicted class
        conf = conf.squeeze(0).cpu().numpy()  # (1022, 1022)
        pred_mask = pred.squeeze(0).cpu().numpy()  # (1022, 1022)
        threshold = 0.6  # choose based on your model's calibration
        high_conf_mask = np.where(conf >= threshold, pred_mask, 0)
        #print(np.unique(high_conf_mask))
        
    return image_np, high_conf_mask

In [4]:
def compute_metrics(pred_mask, true_mask, num_classes=2):
    """
    pred_mask: torch.Tensor or np.array (H, W) — predicted class per pixel
    true_mask: torch.Tensor or np.array (H, W) — ground truth class per pixel
    """
    if isinstance(pred_mask, torch.Tensor):
        pred_mask = pred_mask.cpu().numpy()
    if isinstance(true_mask, torch.Tensor):
        true_mask = true_mask.cpu().numpy()

    ious = []
    dices = []
    pixel_acc = np.mean(pred_mask == true_mask)

    for cls_1 in range(1,num_classes):
        pred_cls = (pred_mask == cls_1)
        true_cls = (true_mask == cls_1)

        intersection = np.logical_and(pred_cls, true_cls).sum()
        union = np.logical_or(pred_cls, true_cls).sum()
        iou = intersection / union if union > 0 else np.nan

        dice = (2 * intersection) / (pred_cls.sum() + true_cls.sum()) if (pred_cls.sum() + true_cls.sum()) > 0 else np.nan

        ious.append(iou)
        dices.append(dice)

    mean_iou = np.nanmean(ious)
    mean_dice = np.nanmean(dices)

    return {
        "pixel_accuracy": pixel_acc,
        "iou_per_class": ious,
        "mean_iou": mean_iou,
        "dice_per_class": dices,
        "mean_dice": mean_dice
    }

In [4]:
path = "/gladstone/finkbeiner/steve/work/data/npsad_data/monika/ALS/dino_v2_segmentation_may30.pth"  ## with new training Sep 12, 2025
path = "/gladstone/finkbeiner/steve/work/data/npsad_data/monika/ALS/dino_v2_segmentation_oct10.pth"
path =  "/gladstone/finkbeiner/steve/work/data/npsad_data/monika/ALS/dino_v2_segmentation_oct17.pth"

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone = get_dino_v2_backbone()
model = DINOv2SegmentationModel(backbone, num_classes=3)  # Update `num_classes` as needed
model.load_state_dict(torch.load(path, map_location=device))
model.to(device)
model.eval()  # 🔍 Important for inference

Using cache found in /home/mahirwar/.cache/torch/hub/facebookresearch_dinov2_main
xFormers not available
xFormers not available
  model.load_state_dict(torch.load(path, map_location=device))


DINOv2SegmentationModel(
  (backbone): DinoVisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0-11): 12 x NestedTensorBlock(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): MemEffAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): LayerScale()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False

In [6]:
dataset_test_location =  "/gladstone/finkbeiner/steve/work/data/npsad_data/monika/ALS/all_crops/"
test_folders = glob(os.path.join(dataset_test_location, "*/*/*.png"))

In [7]:
test_imgs_all = []
test_masks_all = []
#for i in range(len(test_folders)):
test_imgs = glob(os.path.join(dataset_test_location,"*/*/image", "*"))
test_masks = glob(os.path.join(dataset_test_location,"*/*/mask", "*"))
test_imgs_all.extend(test_imgs)
test_masks_all.extend(test_masks)
test_imgs_all.sort()
test_masks_all.sort()


In [8]:
len(test_imgs_all)

714

In [9]:
len(test_masks_all)

714

In [40]:
def overlay_masks(image, true_mask, pred_mask, save_path, alpha=0.4):
    """
    Overlay true and predicted segmentation masks on an RGB image.
    
    image: HxWx3 RGB (np.uint8 or float in [0,1])
    true_mask: HxW (int labels, 0 = background, >0 = object id)
    pred_mask: HxW (int labels, same format)
    alpha: transparency for masks
    """
    # Normalize image to [0,1] if needed
    if image.dtype == np.uint8:
        img = image.astype(np.float32) / 255.0
    else:
        img = image.copy()

    H, W = img.shape[:2]

    # Colormaps for GT and Pred
    cmap_true = cm.get_cmap("Set1", np.max(true_mask) + 1)  # distinct colors
    cmap_pred = cm.get_cmap("Set2", np.max(pred_mask) + 1)

    overlay_true = np.zeros((H, W, 4))  # RGBA
    overlay_pred = np.zeros((H, W, 4))

    # Colorize true mask
    if np.max(true_mask) > 0:
        overlay_true = cmap_true(true_mask)  # RGBA in [0,1]
        overlay_true[true_mask == 0] = (0,0,0,0)  # transparent bg
    
    # Colorize pred mask
    if np.max(pred_mask) > 0:
        overlay_pred = cmap_pred(pred_mask)
        overlay_pred[pred_mask == 0] = (0,0,0,0)

    # Combine: overlay GT in red tint, Pred in green tint
    blended_true = (1 - alpha) * img + alpha * overlay_true[..., :3]
    blended_pred = (1 - alpha) * img + alpha * overlay_pred[..., :3]

    # Show side by side
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img)
    axes[0].set_title("Original")
    axes[0].axis("off")

    axes[1].imshow(blended_true)
    axes[1].set_title("Ground Truth Mask")
    axes[1].axis("off")

    axes[2].imshow(blended_pred)
    axes[2].set_title("Predicted Mask")
    axes[2].axis("off")

    plt.tight_layout()
    #plt.show()
    plt.savefig(save_path)


In [41]:
for img_path, mask_path in zip(test_imgs_all,test_masks_all):
    img, pred_mask = make_prediction(img_path)
    mask = Image.open(mask_path).convert("P").resize((1022, 1022))
    true_mask = np.array(mask)//100
    true_ids = np.unique(true_mask)
    print(true_ids)
    pred_ids = np.unique(pred_mask)
    print(pred_ids)
    if len(pred_ids)>2:
        all_true_masks = (true_mask[np.newaxis] == true_ids[:,  np.newaxis, np.newaxis])
        all_pred_masks = (pred_mask[np.newaxis] == pred_ids[:,  np.newaxis, np.newaxis])
        break
        #save_path = "/gladstone/finkbeiner/steve/work/data/npsad_data/monika/ALS/seg_runs/2025-10-10_09-38-20/segmentation_results/"+img_path.split("/")[-1]
        #overlay_masks(img, all_true_masks[1], all_pred_masks[1],save_path, alpha=0.5)
        


[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 1]
[0 1]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 1]
[0 1]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 2]
[0 1]
[0 1]
[0 2]
[0 2]
[0 2]
[0 2]
[0 1 2]
[0 1 2]


In [13]:
class_maps= {1:"glial", 2:"tdp43"}

def boxes_overlap(boxA, boxB):
    xA1, yA1, xA2, yA2 = boxA
    xB1, yB1, xB2, yB2 = boxB

    # Compute overlap in each dimension
    x_overlap = max(0, min(xA2, xB2) - max(xA1, xB1))
    y_overlap = max(0, min(yA2, yB2) - max(yA1, yB1))

    # If both overlaps are positive → boxes intersect
    return x_overlap > 0 and y_overlap > 0

total_true = 0
total_matched = 0

for img_path, mask_path in zip(test_imgs_all,test_masks_all):
    img, pred_mask = make_prediction(img_path)
    mask = Image.open(mask_path).convert("P").resize((1022, 1022))
    true_mask = np.array(mask)//100

    class_ids = np.unique(pred_mask)
    class_ids = class_ids[class_ids != 0]
    true_ids = np.unique(true_mask)
    true_ids= true_ids[true_ids != 0]
    
    true_boxes = []
    pred_boxes = []
    count = {'gt':{'glial':0, 'tdp43':0},'pred':{'glial':0, 'tdp43':0}}
    #color_mask = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    for true_id in true_ids:
        binary = (true_mask == true_id).astype(np.uint8)
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        sorted_with_area = sorted(
        [(c, cv2.contourArea(c)) for c in contours],
        key=lambda x: x[1],
        reverse=True)
        
        for cnt, area in sorted_with_area:
            if area>2000:
                x, y, w, h = cv2.boundingRect(cnt)
                cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 0), 3)
                cv2.putText(img, f"true: {class_maps[true_id]}", (x, y - 5),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
            
                true_boxes.append([x, y,x + w, y + h])
                print(true_id)
                if true_id==1:
                    count['gt']['glial']+=1
                if true_id==2:
                    count['gt']['tdp43']+=1
    # Draw bounding boxes
    for class_id in class_ids:
        # Create binary mask for this class
        binary = (pred_mask == class_id).astype(np.uint8)

        # Find contours for that class
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        #contours = list(map(lambda t: t[0], sorted([(contour, cv2.contourArea(contour)) for contour in contours], key=lambda t: -t[1])))

        sorted_with_area = sorted(
        [(c, cv2.contourArea(c)) for c in contours],
        key=lambda x: x[1],
        reverse=True
        
    )
        
        # Draw bounding box for each contour
        for contour, area in sorted_with_area:
            if area>2000:
                #print(area)
                x, y, w, h = cv2.boundingRect(contour)
                pred_boxes.append([x, y,x + w, y + h])
                cv2.rectangle(img, (x, y), (x + w, y + h), (255, 0, 0), 4)
                cv2.putText(img, f"pred: {class_maps[class_id]}", (x, y - 30),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                if class_id==1:
                    count['pred']['glial']+=1
                if class_id==2:
                    count['pred']['tdp43']+=1
                
    save_path = "/gladstone/finkbeiner/steve/work/data/npsad_data/monika/ALS/seg_runs/2025-10-10_09-38-20/segmentation_bbox_formatted/"+img_path.split("/")[-1]

    # Display the result
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.title("Ground truth and prediction bounding boxes")
    plt.axis('off')
    plt.figtext(0.5, 0.02, 'GT: [glial-'+str(count['gt']['glial'])+ ",tdp43- "+str(count['gt']['tdp43'])+"]"
                 '; pred: [glial-'+str(count['pred']['glial'])+ ", tdp43- "+str(count['pred']['tdp43'])+"]",
            horizontalalignment='center',  # Center the text horizontally
            fontsize=10,
            color='red')    
    #plt.show()
    plt.savefig(save_path)
    plt.close()

    total_true = total_true + len(true_boxes)
    matched = 0
    for true_box in true_boxes:
        for pred_box in pred_boxes:
            intersect = boxes_overlap(true_box, pred_box)
            if intersect==True:
                matched=matched+1
    
    total_matched  =  total_matched + matched
    #break


2
2
2
2
2
2
1
2
2
2
2
2
2
2
2
2
2
1
1
2
2
2
2
2
1
2
2
1
2
2
2
2
2
2
2
2
2
2
1
1
2
2
2
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
2
1
2
2
2
1
2
2
2
2
2
2
2
2
1
1
1
2
2
2
2
2
2
2
2
2
2
2
2
1
1
2
1
2
2
2
1
2
2
2
1
2
1
1
2
2
2
2
1
1
2
2
2
2
2
2
2
1
1
2
2
2
1
2
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
1
1
1
1
1
1
1
1
1
1
2
2
1
1
1
2
2
2
2
2
2
1
2
2
2
2
2
2
1
2
2
2
1
1
2
1
2
1
2
1
2
2
1
1
2
1
2
2
2
2
1
1
1
1
1
1
2
2
2
1
1
2
2
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
2
2
1
2
1
2
2
2
1
1
1
1
1
1
1
2
2
1
2
1
1
1
2
2
2
1
2
2
1
1
1
1
1
1
1
2
1
1
2
1
1
1
1
1
1
2
1
2
2
1
2
1
1
1
1
2
2
1
1
1
1
1
1
2
1
1
1
1
1
1
1
2
1
2
1
1
1
1
2
2
2
1
1
1
1
1
1
2
2
1
1
2
1
1
1
2
1
2
1
1
2
1
2
1
1
2
1
1
2
2
2
2
1
2
2
2
2
2
2
2
1
1
2
2
1
1
1
1
1
1
1
1
1
1
2
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
2
1
2
1
1
2
1
1
1
2
2
1
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
1
1
2
1
1
2
1
1
2
1
1


In [14]:
total_matched

650

In [15]:
total_matched

650

In [16]:
total_true

799

In [17]:
total_matched/total_true

0.8135168961201502

In [15]:
total_matched/total_true

0.7683089214380826

In [44]:
total_matched/total_true

0.8135168961201502