# RKHS Loss Demonstration

Shows identical diagrams have zero loss, different diagrams have large loss.

In [None]:
import sys
from pathlib import Path
import os

start_dir = Path(os.getcwd()).resolve()
project_root = None

for parent in [start_dir, start_dir.parent, start_dir.parent.parent, start_dir.parent.parent.parent]:
    if (parent / "src" / "__init__.py").exists():
        project_root = parent
        break

if project_root is None:
    project_root = start_dir.parent.parent

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import numpy as np
import gudhi
from scipy.ndimage import distance_transform_edt
from src import TopologicalRKHSLoss, gudhi_persistence_to_vpd_vector

def compute_persistence(mask):
    dist_fg = distance_transform_edt(mask > 0.5)
    dist_bg = distance_transform_edt(1.0 - (mask > 0.5))
    dist = dist_fg - dist_bg
    dmin, dmax = dist.min(), dist.max()
    if dmax > dmin:
        dist = (dist - dmin) / (dmax - dmin)
    filtration = 1.0 - dist
    cubical = gudhi.CubicalComplex(dimensions=list(mask.shape), top_dimensional_cells=filtration.flatten())
    return cubical.persistence()

grid_size = 50
loss_fn = TopologicalRKHSLoss(grid_size=grid_size, random_state=14)

y, x = np.ogrid[:64, :64]
mask1 = np.zeros((64, 64), dtype=np.float64)
mask1[(x - 20)**2 + (y - 20)**2 <= 100] = 1
mask1[(x - 44)**2 + (y - 44)**2 <= 144] = 1

pairs1 = compute_persistence(mask1)
vpd1 = gudhi_persistence_to_vpd_vector(pairs1, grid_size=grid_size, dimension=0) + gudhi_persistence_to_vpd_vector(pairs1, grid_size=grid_size, dimension=1)

In [8]:
diff_identical = vpd1 - vpd1
loss_identical = loss_fn(diff_identical)
print(f"Loss for identical diagrams: {loss_identical:.6f}")

Loss for identical diagrams: 0.000000


In [9]:
y, x = np.ogrid[:64, :64]
mask2 = np.zeros((64, 64), dtype=np.float64)
mask2[(x - 32)**2 + (y - 32)**2 <= 400] = 1

pairs2 = compute_persistence(mask2)
vpd2 = gudhi_persistence_to_vpd_vector(pairs2, grid_size=grid_size, dimension=0) + gudhi_persistence_to_vpd_vector(pairs2, grid_size=grid_size, dimension=1)

diff_different = vpd2 - vpd1
loss_different = loss_fn(diff_different)
print(f"Loss for different diagrams: {loss_different:.6f}")

Loss for different diagrams: 1.641248
