In [4]:
# --------------------------------------------------------------------------------
#  Correctness Metrics Library
#--------------------------------------------------------------------------------

"""
This script contains several metrics for computing Concurrence metrics between datasets (custom vs. library implementations).
It also contains visualization of all metrics.

Usage:
  - Update `datasets` paths.
  - Run end-to-end or import functions as module.
"""


'\nThis script contains several metrics for computing Concurrence metrics between datasets (custom vs. library implementations).\nIt also contains visualization of all metrics.\n\nUsage:\n  - Update `datasets` paths.\n  - Run end-to-end or import functions as module.\n'

In [6]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.spatial.distance import (
    cosine as sk_cosine_distance,
    cityblock as sk_cityblock,
    mahalanobis as sk_mahalanobis,
    jensenshannon
)
from scipy.stats import pearsonr, entropy, wasserstein_distance
from scipy.linalg import sqrtm

from sklearn.decomposition import PCA


In [7]:
# --- User prompts ---
def prompt_paths(label):
    """
    Prompt user to enter CSV file paths (space-separated).
    Each CSV: samples × features.
    """
    while True:
        inp = input(f"Enter paths to {label} feature CSVs (space-separated): ").strip()
        tokens = inp.split()
        paths, invalid = [], []
        for t in tokens:
            p = Path(t)
            if p.is_file() and t.lower().endswith('.csv'):
                paths.append(p)
            else:
                invalid.append(t)
        if invalid:
            print(f"Invalid files: {', '.join(invalid)}")
        elif not paths:
            print("No valid CSV paths provided.")
        else:
            return paths

def ask_bool(prompt):
    ans = input(prompt + " (y/n): ").strip().lower()
    return ans in ('y','yes')

In [8]:
# --- Manual implementations ---
def cos_manual(a, b):
    num = np.dot(a, b)
    den = np.linalg.norm(a) * np.linalg.norm(b)
    return float(num/den) if den else 0.0

def pearson_manual(a, b):
    am, bm = a.mean(), b.mean()
    cov = np.mean((a-am)*(b-bm))
    den = a.std() * b.std()
    return float(cov/den) if den else 0.0

def manhattan_manual(a, b):
    return float(np.sum(np.abs(a-b)))

def mahalanobis_manual(a, b, inv_cov):
    d = a - b
    return float(np.sqrt(d.T.dot(inv_cov).dot(d)))

def jsd_manual(p, q, base=2.0):
    p = np.array(p, float)
    q = np.array(q, float)
    if p.sum()==0 or q.sum()==0:
        return 0.0
    pn = p/p.sum()
    qn = q/q.sum()
    m = 0.5*(pn+qn)
    return float(0.5*(entropy(pn, m, base=base) + entropy(qn, m, base=base)))

def emd_manual(a, b):
    return float(wasserstein_distance(np.sort(a), np.sort(b)))

# --- Library wrappers ---
def cos_scipy(a, b):       return float(1 - sk_cosine_distance(a, b))

def pearson_scipy(a, b):  return float(pearsonr(a, b)[0])

def manhattan_scipy(a, b):return float(sk_cityblock(a, b))

def mahalanobis_scipy(a,b,inv_cov): return float(sk_mahalanobis(a,b,inv_cov))

def jsd_scipy(p, q):
    pn = p/p.sum() if p.sum() else p
    qn = q/q.sum() if q.sum() else q
    return float(jensenshannon(pn, qn, base=2.0)**2)

def emd_scipy(a, b):    return float(wasserstein_distance(a,b))

def fid(a, b):
    mu1, mu2 = a.mean(0), b.mean(0)
    s1 = np.cov(a, rowvar=False)
    s2 = np.cov(b, rowvar=False)
    diff = mu1 - mu2
    covmean = sqrtm(s1.dot(s2))
    covmean = covmean.real if np.iscomplexobj(covmean) else covmean
    return float(diff.dot(diff) + np.trace(s1 + s2 - 2*covmean))

# --- Compute metrics strategies ---
def compute_metrics_nosample(a, b):
    """
    Use full arrays without subsampling.
    """
    mu_a, mu_b = a.mean(0), b.mean(0)
    cov_a = np.cov(a, rowvar=False) + 1e-6*np.eye(a.shape[1])
    inv_cov = np.linalg.pinv(cov_a)
    return {
        'Cosine_manual': cos_manual(mu_a,mu_b), 'Cosine_scipy': cos_scipy(mu_a,mu_b),
        'Pearson_manual': pearson_manual(mu_a,mu_b), 'Pearson_scipy': pearson_scipy(mu_a,mu_b),
        'Manhattan_manual': manhattan_manual(mu_a,mu_b), 'Manhattan_scipy': manhattan_scipy(mu_a,mu_b),
        'Mahalanobis_manual': mahalanobis_manual(mu_a,mu_b,inv_cov), 'Mahalanobis_scipy': mahalanobis_scipy(mu_a,mu_b,inv_cov),
        'JSD_manual': jsd_manual(mu_a,mu_b), 'JSD_scipy': jsd_scipy(mu_a,mu_b),
        'EMD_manual': emd_manual(mu_a,mu_b), 'EMD_scipy': emd_scipy(mu_a,mu_b),
        'FID': fid(a, b)
    }

def compute_metrics_subsample(a, b):
    """
    Subsample both arrays to the same size once.
    """
    n = min(a.shape[0], b.shape[0])
    if a.shape[0] != b.shape[0]:
        idx_a = np.random.choice(a.shape[0], n, replace=False)
        idx_b = np.random.choice(b.shape[0], n, replace=False)
        a, b = a[idx_a], b[idx_b]
    return compute_metrics_nosample(a, b)


def compute_metrics_bootstrap(a, b, reps=5):
    """
    Repeat subsampling 'reps' times and average metrics.
    """
    n = min(a.shape[0], b.shape[0])
    metrics_list = []
    for _ in range(reps):
        idx_a = np.random.choice(a.shape[0], n, replace=False)
        idx_b = np.random.choice(b.shape[0], n, replace=False)
        metrics_list.append(compute_metrics_nosample(a[idx_a], b[idx_b]))
    # average across reps
    return {k: np.mean([m[k] for m in metrics_list]) for k in metrics_list[0]}

# --- Pairwise & annotation ---
def pairwise(features, self_compare=False):
    names = list(features)
    rows = []
    for i, ni in enumerate(names):
        for j, nj in enumerate(names):
            if j <= i and not self_compare:
                continue
            m = METRIC_FCN(features[ni], features[nj])
            m.update({'A': ni, 'B': nj})
            rows.append(m)
    return pd.DataFrame(rows)

def annotate_heatmap(im, ax=None, data=None, fmt=".2f", text_colors=("black","white"), threshold=None):
    if ax is None:
        ax = plt.gca()
    if data is None:
        data = im.get_array()
    if threshold is None:
        threshold = np.nanmax(data) / 2
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            val = data[i, j]
            txt = fmt % val if np.isfinite(val) else 'N/A'
            col = text_colors[1] if np.isfinite(val) and val > threshold else text_colors[0]
            ax.text(j, i, txt, ha='center', va='center', color=col)

In [9]:
# --- Main routine ---
def main():
    print("*** Correctness Metrics Interactive Runner ***")

    # Prompt for feature files
    patient_paths   = prompt_paths('patient')
    synthetic_paths = prompt_paths('synthetic')

    # Choose sampling method
    method = input("Choose sampling method: 1) No sampling 2) Subsample once 3) Bootstrap aggregate: ").strip()
    global METRIC_FCN
    if method == '1':
        METRIC_FCN = compute_metrics_nosample
    elif method == '2':
        METRIC_FCN = compute_metrics_subsample
    elif method == '3':
        METRIC_FCN = compute_metrics_bootstrap
    else:
        print("Invalid choice, defaulting to subsample once.")
        METRIC_FCN = compute_metrics_subsample

    # Comparison choices
    do_pvss = ask_bool("Compare patient vs synthetic?")
    do_wp   = ask_bool("Within patients?")
    do_ws   = ask_bool("Within synthetics?")

    # Output directory
    out_dir = Path(input("Output dir [results]: ").strip() or "results")
    out_dir.mkdir(parents=True, exist_ok=True)

    # Load features
    patients   = {p.stem: pd.read_csv(p).values for p in patient_paths}
    synthetics = {s.stem: pd.read_csv(s).values for s in synthetic_paths}

    # Optional raw PCA plot
    if ask_bool("Plot raw features via PCA?"):
        all_data = np.vstack(list(patients.values()) + list(synthetics.values()))
        labels   = ([name for name, arr in patients.items() for _ in range(arr.shape[0])] +
                    [name for name, arr in synthetics.items() for _ in range(arr.shape[0])])
        pca = PCA(n_components=2)
        pcs = pca.fit_transform(all_data)
        fig, ax = plt.subplots()
        for ds in list(patients) + list(synthetics):
            idx = [i for i, lab in enumerate(labels) if lab == ds]
            ax.scatter(pcs[idx,0], pcs[idx,1], label=ds, s=10)
        ax.legend(); ax.set_title("Raw feature PCA")
        fig.tight_layout(); fig.savefig(out_dir/'raw_pca.png'); plt.close(fig)
        print(f"Saved raw PCA: {out_dir/'raw_pca.png'}")

    # Optional raw feature histograms
    if ask_bool("Plot raw feature histograms?"):
        for ds, arr in {**patients, **synthetics}.items():
            fig, ax = plt.subplots()
            ax.hist(arr.flatten(), bins=50, edgecolor='black')
            ax.set_title(f"Feature histogram: {ds}")
            ax.set_xlabel('Value'); ax.set_ylabel('Count')
            fig.tight_layout(); fig.savefig(out_dir/f"{ds}_feature_hist.png"); plt.close(fig)
            print(f"Saved feature histogram: {out_dir/f'{ds}_feature_hist.png'}")

    # Compute metrics summary
    dfs = []
    if do_pvss:
        for pn, pa in patients.items():
            for sn, sa in synthetics.items():
                df = pd.DataFrame([METRIC_FCN(pa, sa)])
                df['A'], df['B'] = pn, sn
                dfs.append(df)
    if do_wp:
        dfs.append(pairwise(patients, self_compare=False))
    if do_ws:
        dfs.append(pairwise(synthetics, self_compare=False))
    if not dfs:
        print("No comparisons selected. Exiting.")
        return
    summary = pd.concat(dfs, ignore_index=True)

    # Prepare output dirs
    metrics      = [c for c in summary.columns if c not in ('A','B')]
    tables_dir   = out_dir / 'tables';   tables_dir.mkdir(exist_ok=True)
    heatmap_dir  = out_dir / 'heatmaps'; heatmap_dir.mkdir(exist_ok=True)
    bar_dir      = out_dir / 'barplots'; bar_dir.mkdir(exist_ok=True)
    hist_dir     = out_dir / 'histograms'; hist_dir.mkdir(exist_ok=True)

    # Save tables
    for _, row in summary.iterrows():
        df_row = pd.DataFrame([row[metrics].to_dict()])
        fname  = f"{row.A}_vs_{row.B}.csv"
        df_row.to_csv(tables_dir / fname, index=False)
        print(f"\nMetrics for {row.A} vs {row.B}:\n{df_row.to_string(index=False)}")

    # Visualization choices
    print("\nChoose visualization modes:\n1) Bar plots  2) Heatmaps  3) Histograms")
    choices = [c.strip() for c in input("Enter choices (e.g. 1,2,3): ").split(',')]

    # 1) Bar plots
    if '1' in choices:
        for m in metrics:
            pivot = summary.pivot(index='A', columns='B', values=m)
            ax = pivot.plot(kind='bar', title=m)
            ax.set_ylabel(m)
            ax.figure.tight_layout(); ax.figure.savefig(bar_dir/f"{m}_bar.png"); plt.close(ax.figure)
            print(f"Saved bar plot: {bar_dir}/{m}_bar.png")

    # 2) Heatmaps
    if '2' in choices:
        def make_heatmap(df, rows, cols, metric, tag):
            mat = df.pivot(index='A', columns='B', values=metric).reindex(index=rows, columns=cols)
            fig, ax = plt.subplots(); im = ax.imshow(mat.values, interpolation='nearest')
            ax.set_xticks(range(len(cols))); ax.set_xticklabels(cols, rotation=45, ha='right')
            ax.set_yticks(range(len(rows))); ax.set_yticklabels(rows)
            ax.set_title(f"{metric} ({tag})")
            annotate_heatmap(im, ax, data=mat.values)
            fig.tight_layout(); fig.savefig(heatmap_dir/f"{metric}_{tag}_heatmap.png"); plt.close(fig)
            print(f"Saved heatmap: {heatmap_dir}/{metric}_{tag}_heatmap.png")

        if do_pvss:
            df_ps = summary[summary['A'].isin(patients) & summary['B'].isin(synthetics)]
            for m in metrics: make_heatmap(df_ps, list(patients), list(synthetics), m, 'pvss')
        if do_wp:
            df_wp = summary[summary['A'].isin(patients) & summary['B'].isin(patients)]
            for m in metrics: make_heatmap(df_wp, list(patients), list(patients), m, 'within_patients')
        if do_ws:
            df_ws = summary[summary['A'].isin(synthetics) & summary['B'].isin(synthetics)]
            for m in metrics: make_heatmap(df_ws, list(synthetics), list(synthetics), m, 'within_synthetics')

    # 3) Histograms
    if '3' in choices:
        for m in metrics:
            vals = summary[m].replace([np.inf, -np.inf], np.nan).dropna()
            if vals.empty:
                print(f"Skipping histogram for {m}")
                continue
            fig, ax = plt.subplots(); ax.hist(vals, bins=10, edgecolor='black')
            ax.set_title(f"Histogram: {m}"); ax.set_xlabel(m); ax.set_ylabel('Count')
            fig.tight_layout(); fig.savefig(hist_dir/f"{m}_hist.png"); plt.close(fig)
            print(f"Saved histogram: {hist_dir}/{m}_hist.png")

if __name__ == '__main__':
    main()


*** Correctness Metrics Interactive Runner ***
Enter paths to patient feature CSVs (space-separated): InBreast_handcrafted.csv MIAS_handcrafted.csv 
Enter paths to synthetic feature CSVs (space-separated): MSYNTH_handcrafted.csv Mammo_medigan_handcrafted.csv
Choose sampling method: 1) No sampling 2) Subsample once 3) Bootstrap aggregate: 1
Compare patient vs synthetic? (y/n): y
Within patients? (y/n): n
Within synthetics? (y/n): n
Output dir [results]: Output
Plot raw features via PCA? (y/n): y
Saved raw PCA: Output/raw_pca.png
Plot raw feature histograms? (y/n): y
Saved feature histogram: Output/InBreast_handcrafted_feature_hist.png
Saved feature histogram: Output/MIAS_handcrafted_feature_hist.png
Saved feature histogram: Output/MSYNTH_handcrafted_feature_hist.png
Saved feature histogram: Output/Mammo_medigan_handcrafted_feature_hist.png

Metrics for InBreast_handcrafted vs MSYNTH_handcrafted:
 Cosine_manual  Cosine_scipy  Pearson_manual  Pearson_scipy  Manhattan_manual  Manhattan_sci