# Imports and Helper Functions

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import os
import glob
import torch
import torchmetrics as tm

# Betti-Matching
import sys
sys.path.append('./Betti-Matching')
from BettiMatching import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Dice and IoU
def get_metrics(pred, gt):
    dice = tm.functional.dice(torch.from_numpy(pred).to(device), torch.from_numpy(gt).to(device), ignore_index=0)
    iou = tm.functional.classification.binary_jaccard_index(torch.from_numpy(pred).to(device), torch.from_numpy(gt).to(device))
    return dice.item(), iou.item()
    

# Evaluation


In [None]:
DATA_DIR = './test_data'
test_files = glob.glob(os.path.join(DATA_DIR, '*.png'))
test_files = [os.path.basename(f).replace('.png', '') for f in test_files]
skip_files = [
    # Only background pixels in these images
    '13_009','13_008','24_005','13_004','17_004','17_016','17_009','17_014','13_013',
    # These images do not work with mitometer
    '6_013','29_004','14_005','2_013','21_012','6_003','21_016','2_004','14_001','9_013',
    '21_013','2_014','2_008','2_003','21_008','24_016','14_011','2_009','24_012','6_004',
    '6_014','21_009','9_004','29_013','14_012','21_014','24_008','21_004'
]
test_files = [f for f in test_files if f not in skip_files]  # Filtering out files that cause issues in evaluation

In [None]:
PREDICTIONS_PATH = './Prediction'
GT_PATH = './test_gt'
pred_files = [f'{PREDICTIONS_PATH}/{i}.tif'.replace('_', '-') for i in test_files]
gt_files = [os.path.join(GT_PATH,
                        os.path.basename(f)).replace('tif', 'png').replace('-', '_') for f in pred_files]
img_files = [f.replace('test_gt', 'test_data') for f in gt_files]


### Dice + IoU

In [None]:
# Visualize the best examples
for i in range(len(pred_files)):
    pred = cv2.imread(pred_files[i], cv2.IMREAD_UNCHANGED) > 0
    gt = cv2.imread(gt_files[i], cv2.IMREAD_UNCHANGED)[..., 0] == 255
    img = cv2.imread(img_files[i], cv2.IMREAD_UNCHANGED)[..., 0]

    # Dice and IoU
    dice = tm.functional.dice(torch.from_numpy(pred).to(
        device), torch.from_numpy(gt).to(device), ignore_index=0)
    iou = tm.functional.classification.binary_jaccard_index(
        torch.from_numpy(pred).to(device), torch.from_numpy(gt).to(device))

    # Show the best examples
    if dice.item() > 0.9:
        print("Dice:", dice.item(), "IoU:", iou.item())

        # Plot
        plt.figure(figsize=(30, 10))
        plt.subplot(131)
        plt.title('image')
        plt.imshow(img)
        plt.subplot(132)
        plt.title('gt')
        plt.imshow(gt)
        plt.subplot(133)
        plt.title('pred')
        plt.imshow(pred)
        plt.show()


In [None]:
# Calculate metrics for all images and save to csv
preds = []
gts = []
dices = []
ious = []
images = []
ids = []

for i, (pred_file, gt_file, image_file) in enumerate(zip(pred_files, gt_files, img_files)):
    preds.append(pred_file)
    pred = cv2.imread(pred_file, cv2.IMREAD_UNCHANGED) > 0
    gts.append(gt_file)
    gt = cv2.imread(gt_file, cv2.IMREAD_UNCHANGED)[..., 0] == 255
    images.append(image_file)

    dice = tm.functional.dice(torch.from_numpy(pred).to(device), torch.from_numpy(gt).to(device), ignore_index=0)
    iou = tm.functional.classification.binary_jaccard_index(torch.from_numpy(pred).to(device), torch.from_numpy(gt).to(device))
    dices.append(dice.item())
    ious.append(iou.item())
    ids.append(i)
    # print(i, dice.item(), iou.item())
    print(f"{i+1}/{len(pred_files)}", end='\r')

df = pd.DataFrame({'id': ids, 'pred': preds, 'gt': gts,
                'dice': dices, 'iou': ious, 'image': images})
df.to_csv('neurips25_results.csv', index=False)


In [None]:
df[['dice', 'iou']].describe()

### Betti

In [None]:
# Calculate metrics for all images and save to csv
preds = []
gts = []
images = []
ids = []
betty_errors = []
betty_matching_losses = []

for i, (pred_file, gt_file, image_file) in enumerate(zip(pred_files, gt_files, img_files)):
    preds.append(pred_file)
    pred = cv2.imread(pred_file, cv2.IMREAD_UNCHANGED) == 255
    gts.append(gt_file)
    gt = cv2.imread(gt_file, cv2.IMREAD_UNCHANGED)[..., 0] == 255
    images.append(image_file)

    b=BettiMatching(pred, gt, filtration='superlevel')
    error, loss = b.Betti_number_error(), b.loss()
    betty_errors.append(error)
    betty_matching_losses.append(loss)    

    ids.append(i)
    print(f"{i+1}/{len(pred_files)}", end='\r')

df = pd.DataFrame({'id': ids, 'pred': preds, 'gt': gts,
                'betty_error': betty_errors, 'betty_matching_loss': betty_matching_losses,
                'image': images})
df.to_csv('betty_neurips25_results.csv', index=False)


In [None]:
df[['betty_error', 'betty_matching_loss']].describe()