# Distance-of-Distances Measures: Patient Comparison

This notebook extends the permutation-robust comparison by evaluating **all distance-of-distances measures** available in the project. For every patient, EEG band, and experimental phase we:

1. Build ultrametric, linkage, and condensed distance structures.
2. Compute the following measures across every phase pair:
   - Ultrametric Matrix Distance
   - Ultrametric Scaled Distance (log scaling)
   - Ultrametric Rank Correlation (converted to a distance)
   - Ultrametric Quantile RMSE (log scaling)
   - Ultrametric Distance (Permutation Robust)
   - Robinson–Foulds Distance
   - Cophenetic Correlation (distance form)
   - Baker's Gamma (distance form)
   - Fowlkes–Mallows Index (distance form)
3. Plot patient-wise heatmaps per band for each measure.


In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from lrgsglib import *
move_to_rootf(pathname='lrg_eegfc')

from lrg_eegfc.notebook import *
from lrg_eegfc.utils.datamanag.loaders import load_data_dict
from lrg_eegfc.utils.corrmat.structures import compute_structures_for_patient

from lrgsglib.utils.basic.linalg import (
    ultrametric_matrix_distance,
    ultrametric_scaled_distance,
    ultrametric_rank_correlation,
    ultrametric_quantile_rmse,
    ultrametric_distance_permutation_robust,
    tree_robinson_foulds_distance,
    tree_cophenetic_correlation,
    tree_baker_gamma,
    tree_fowlkes_mallows_index,
)

path_figs = Path('data') / 'figures' / 'DSTCMP_all_distance_measures'
path_figs.mkdir(parents=True, exist_ok=True)

phase_labels = list(PHASE_LABELS)
bands = list(BRAIN_BANDS_NAMES)
patients = list(PATIENTS_LIST)
correlation_protocol = dict(filter_type='abs', spectral_cleaning=False)

print(f'Patients: {patients}')
print(f'Phases: {phase_labels}')
print(f'Bands: {bands}')


In [None]:
data_dict, int_label_map = load_data_dict()
pin_labels_by_pat = {pat: int_label_map[pat]['label'] for pat in patients}
print('✓ Data loaded')


In [None]:
ultra_by_pat = {}
linkage_by_pat = {}
condensed_by_pat = {}

for pat in patients:
    print(f'Computing ultrametric structures for {pat} ...')
    U, Z, D = compute_structures_for_patient(
        data_dict,
        pat,
        int_label_map,
        bands=bands,
        phases=phase_labels,
        correlation_protocol=correlation_protocol,
    )
    ultra_by_pat[pat] = U
    linkage_by_pat[pat] = Z
    condensed_by_pat[pat] = D

print('✓ Ultrametric structures computed for all patients')


In [None]:
n_phases = len(phase_labels)


def compute_phase_distance_matrix(pat: str, band: str, diag_value: float, compute_fn):
    dm = np.full((n_phases, n_phases), np.nan, dtype=float)
    for i, pi in enumerate(phase_labels):
        for j, pj in enumerate(phase_labels):
            if i == j:
                dm[i, j] = diag_value
                continue
            try:
                value = compute_fn(pi, pj)
            except Exception as exc:
                print(f'[WARN] {pat} {band} {pi}->{pj}: {exc}')
                value = np.nan
            if value is None:
                dm[i, j] = np.nan
            else:
                val = float(value)
                dm[i, j] = val if np.isfinite(val) else np.nan
    return dm


def plot_measure_across_patients(measure_key: str, measure_label: str, cbar_label: str):
    outdir = path_figs / measure_key
    outdir.mkdir(parents=True, exist_ok=True)

    for band in bands:
        vmax = 0.0
        for pat in patients:
            M = measure_results[measure_key][band].get(pat)
            if M is not None and np.isfinite(M).any():
                vmax = max(vmax, np.nanmax(M))
        if vmax == 0.0:
            print(f'[SKIP] No finite values for {measure_label} / {band}')
            continue

        fig, axes = plt.subplots(1, len(patients), figsize=(4 * len(patients) + 2, 4), constrained_layout=True)
        if len(patients) == 1:
            axes = [axes]

        cmap = plt.cm.viridis.copy()
        cmap.set_bad(color='lightgray')

        for ax, pat in zip(axes, patients):
            M = measure_results[measure_key][band][pat]
            im = ax.imshow(M, vmin=0.0, vmax=vmax, cmap=cmap, aspect='equal')
            ax.set_title(pat)
            ax.set_xticks(range(n_phases))
            ax.set_yticks(range(n_phases))
            ax.set_xticklabels(phase_labels, rotation=45, ha='right')
            ax.set_yticklabels(phase_labels)
            for spine in ax.spines.values():
                spine.set_visible(False)

        cbar = fig.colorbar(im, ax=axes, shrink=0.85)
        cbar.set_label(cbar_label)
        fig.suptitle(f'{measure_label} — {band} band', fontsize=14)

        outfile = outdir / f'{band}.png'
        fig.savefig(outfile, dpi=200, bbox_inches='tight')
        plt.show()
        print(f'Saved: {outfile}')


measure_results = {}
plot_specs = {}


In [None]:
# Ultrametric-based measures
measure_key = 'ultrametric_matrix_distance'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label='Ultrametric Matrix Distance',
    cbar='Ultrametric matrix distance (euclidean)',
)
for band in bands:
    for pat in patients:
        U_band = ultra_by_pat[pat][band]

        def compute_fn(pi, pj, U_band=U_band):
            U1 = U_band.get(pi)
            U2 = U_band.get(pj)
            if U1 is None or U2 is None or U1.shape != U2.shape:
                return np.nan
            return ultrametric_matrix_distance(U1, U2, metric='euclidean')

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f'Band {band}: ✓ Ultrametric Matrix Distance computed')
print('✓ Ultrametric Matrix Distance complete')

measure_key = 'ultrametric_scaled_distance_log'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label='Ultrametric Scaled Distance (log)',
    cbar='Ultrametric scaled distance (log)',
)
for band in bands:
    for pat in patients:
        U_band = ultra_by_pat[pat][band]

        def compute_fn(pi, pj, U_band=U_band):
            U1 = U_band.get(pi)
            U2 = U_band.get(pj)
            if U1 is None or U2 is None or U1.shape != U2.shape:
                return np.nan
            return ultrametric_scaled_distance(U1, U2, metric='euclidean', scale='log', normalize=True)

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f'Band {band}: ✓ Ultrametric Scaled Distance computed')
print('✓ Ultrametric Scaled Distance complete')

measure_key = 'ultrametric_rank_correlation_spearman'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label='Ultrametric Rank Correlation (1 - spearman)',
    cbar='Distance (1 - spearman correlation)',
)
for band in bands:
    for pat in patients:
        U_band = ultra_by_pat[pat][band]

        def compute_fn(pi, pj, U_band=U_band):
            U1 = U_band.get(pi)
            U2 = U_band.get(pj)
            if U1 is None or U2 is None or U1.shape != U2.shape:
                return np.nan
            corr = ultrametric_rank_correlation(U1, U2, method='spearman')
            return 1.0 - corr if corr is not None else np.nan

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f'Band {band}: ✓ Ultrametric Rank Correlation computed')
print('✓ Ultrametric Rank Correlation complete')

measure_key = 'ultrametric_quantile_rmse_log'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label='Ultrametric Quantile RMSE (log)',
    cbar='Ultrametric quantile RMSE (log scale)',
)
for band in bands:
    for pat in patients:
        U_band = ultra_by_pat[pat][band]

        def compute_fn(pi, pj, U_band=U_band):
            U1 = U_band.get(pi)
            U2 = U_band.get(pj)
            if U1 is None or U2 is None or U1.shape != U2.shape:
                return np.nan
            return ultrametric_quantile_rmse(U1, U2, scale='log')

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f'Band {band}: ✓ Ultrametric Quantile RMSE computed')
print('✓ Ultrametric Quantile RMSE complete')


In [None]:
# Permutation-robust ultrametric distance
measure_key = 'ultrametric_distance_permutation_robust'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label='Ultrametric Distance (Permutation Robust)',
    cbar='Permutation-robust distance (euclidean)',
)
for band in bands:
    for pat in patients:
        Z_band = linkage_by_pat[pat][band]
        D_band = condensed_by_pat[pat][band]
        pin_labels = pin_labels_by_pat[pat]

        def compute_fn(pi, pj, Z_band=Z_band, D_band=D_band, pin_labels=pin_labels):
            Z1 = Z_band.get(pi)
            Z2 = Z_band.get(pj)
            D1 = D_band.get(pi)
            D2 = D_band.get(pj)
            if Z1 is None or Z2 is None or D1 is None or D2 is None:
                return np.nan
            return ultrametric_distance_permutation_robust(Z1, Z2, D1, D2, pin_labels, metric='euclidean')

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f'Band {band}: ✓ Permutation-robust distance computed')
print('✓ Permutation-robust distance complete')


In [None]:
# Tree-based measures (converted to distances)
measure_key = 'tree_robinson_foulds'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label='Robinson–Foulds Distance',
    cbar='Robinson–Foulds distance (normalized)',
)
for band in bands:
    for pat in patients:
        Z_band = linkage_by_pat[pat][band]

        def compute_fn(pi, pj, Z_band=Z_band):
            Z1 = Z_band.get(pi)
            Z2 = Z_band.get(pj)
            if Z1 is None or Z2 is None:
                return np.nan
            return tree_robinson_foulds_distance(Z1, Z2, normalized=True)

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f'Band {band}: ✓ Robinson–Foulds computed')
print('✓ Robinson–Foulds distance complete')

measure_key = 'tree_cophenetic_distance'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label='Cophenetic Correlation (distance)',
    cbar='Distance (1 - cophenetic correlation)',
)
for band in bands:
    for pat in patients:
        Z_band = linkage_by_pat[pat][band]

        def compute_fn(pi, pj, Z_band=Z_band):
            Z1 = Z_band.get(pi)
            Z2 = Z_band.get(pj)
            if Z1 is None or Z2 is None:
                return np.nan
            corr = tree_cophenetic_correlation(Z1, Z2)
            return 1.0 - corr if corr is not None else np.nan

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f'Band {band}: ✓ Cophenetic correlation computed')
print('✓ Cophenetic correlation distance complete')

measure_key = 'tree_baker_gamma_distance'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label="Baker's Gamma (distance)",
    cbar="Distance (1 - Baker's gamma)",
)
for band in bands:
    for pat in patients:
        Z_band = linkage_by_pat[pat][band]

        def compute_fn(pi, pj, Z_band=Z_band):
            Z1 = Z_band.get(pi)
            Z2 = Z_band.get(pj)
            if Z1 is None or Z2 is None:
                return np.nan
            gamma = tree_baker_gamma(Z1, Z2)
            return 1.0 - gamma if gamma is not None else np.nan

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f"Band {band}: ✓ Baker's gamma computed")
print("✓ Baker's gamma distance complete")

measure_key = 'tree_fowlkes_mallows_distance'
measure_results[measure_key] = {band: {} for band in bands}
plot_specs[measure_key] = dict(
    label='Fowlkes–Mallows Index (distance)',
    cbar='Distance (1 - Fowlkes–Mallows)',
)
for band in bands:
    for pat in patients:
        Z_band = linkage_by_pat[pat][band]

        def compute_fn(pi, pj, Z_band=Z_band):
            Z1 = Z_band.get(pi)
            Z2 = Z_band.get(pj)
            if Z1 is None or Z2 is None:
                return np.nan
            fm = tree_fowlkes_mallows_index(Z1, Z2)
            return 1.0 - fm if fm is not None else np.nan

        measure_results[measure_key][band][pat] = compute_phase_distance_matrix(pat, band, 0.0, compute_fn)
    print(f'Band {band}: ✓ Fowlkes–Mallows computed')
print('✓ Fowlkes–Mallows distance complete')


In [None]:
for measure_key, spec in plot_specs.items():
    print(f"\n=== {spec['label']} ===")
    plot_measure_across_patients(measure_key, spec['label'], spec['cbar'])

print('✓ All figures generated')
