In [16]:
# Used for generating csv summary files for all model metrics
# from npz files.
#
# For specifically CoMIGHT power, use `make_csv_summary_for_comightpower_from_npz.py`.

from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd

from sklearn.metrics import roc_auc_score

n_dims_2_ = 6


def make_csv_over_nsamples(
    root_dir,
    sim_name,
    n_samples_list,
    n_dims_1,
    n_repeats,
    param_name,
    model_name,
):
    # generate CSV file for varying n-samples models
    results = defaultdict(list)
    for n_samples in n_samples_list:
        for idx in range(n_repeats):
            output_fname = (
                root_dir
                / "output"
                / model_name
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2_}_{idx}.npz"
            )
            # print(output_fname)
            # print(output_fname.exists())

            # print(output_fname.exists())
            # Load data from the compressed npz file
            try:
                loaded_data = np.load(output_fname)
            except Exception as e:
                print(e, output_fname)
                continue
            # Extract variables with the same names
            idx = loaded_data["idx"]
            n_samples_ = loaded_data["n_samples"]
            n_dims_1_ = loaded_data["n_dims_1"]
            sim_name = loaded_data["sim_type"]
            # threshold = loaded_data["threshold"]

            results["idx"].append(idx)
            results["n_samples"].append(n_samples_)
            results["n_dims_1"].append(n_dims_1_)
            results["sim_type"].append(sim_name)
            results["model"].append(model_name)
            if param_name == "sas98":
                sas98 = loaded_data["sas98"]
                results["sas98"].append(sas98)

            elif param_name == "cdcorr_pvalue":
                # print(dict(loaded_data).keys())
                cdcorr_pvalue = loaded_data["cdcorr_pvalue"]
                results["cdcorr_pvalue"].append(cdcorr_pvalue)
            elif param_name == "pvalue":
                pvalue = loaded_data["pvalue"]
                results["pvalue"].append(pvalue)
            elif param_name == "cmi":
                mi = loaded_data["cmi"]
                results["cmi"].append(mi)

                if "comight" in model_name:
                    try:
                        I_XZ_Y = loaded_data["I_XZ_Y"]
                        I_Z_Y = loaded_data["I_Z_Y"]
                        results["I_XZ_Y"].append(I_XZ_Y)
                        results["I_Z_Y"].append(I_Z_Y)
                    except Exception as e:
                        try:
                            I_XZ_Y = loaded_data["I_X1X2_Y"]
                            I_Z_Y = loaded_data["I_X1_Y"]
                            results["I_XZ_Y"].append(I_XZ_Y)
                            results["I_Z_Y"].append(I_Z_Y)
                        except Exception as e:
                            print(e)
            elif param_name == "auc":
                y_score = loaded_data["posterior_arr"]
                y_true = loaded_data["y"]
                n_trees, n_samples, n_classes = y_score.shape
                y_score_avg = np.nanmean(y_score, axis=0)
                y_score_binary = y_score_avg[:, 1]
                nan_rows = np.isnan(y_score_binary)
                y_score_binary = y_score_binary[~nan_rows]
                y_true = y_true[~nan_rows]
                auc = roc_auc_score(y_true, y_score_binary)
                results["auc"].append(auc)
                # results["threshold"].append(threshold)

    df = pd.DataFrame(results)

    # Melt the DataFrame to reshape it
    df_melted = pd.melt(
        df,
        id_vars=["n_samples", "sim_type", "model"],
        value_vars=(
            [param_name]
            if param_name in ["sas98", "auc"] or "comight" not in model_name
            else [param_name, "I_XZ_Y", "I_Z_Y"]
        ),
        var_name="metric",
        value_name="metric_value",
    )

    # Convert "sim_type" to categorical type
    df_melted["sim_type"] = df_melted["sim_type"].astype(str)
    df_melted["model"] = df_melted["model"].astype(str)
    # df_melted["n_dims"] = df_melted["n_dims"].astype(int)
    df_melted["n_samples"] = df_melted["n_samples"].astype(int)
    df_melted["metric_value"] = df_melted["metric_value"].astype(float)
    return df_melted


def make_csv_over_ndims1(
    root_dir,
    sim_name,
    n_dims_list,
    n_samples,
    n_repeats,
    param_name,
    model_name,
):
    # generate CSV file for varying n-samples models
    results = defaultdict(list)
    for n_dims_1 in n_dims_list:
        for idx in range(n_repeats):
            output_fname = (
                root_dir
                / "output"
                / model_name
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2_}_{idx}.npz"
            )
            # print(output_fname)
            # print(output_fname.exists())

            # print(output_fname.exists())
            # Load data from the compressed npz file
            try:
                loaded_data = np.load(output_fname)
            except Exception as e:
                print(e, output_fname)
                continue

            # Extract variables with the same names
            idx = loaded_data["idx"]
            n_samples_ = loaded_data["n_samples"]
            n_dims_1_ = loaded_data["n_dims_1"]
            sim_name = loaded_data["sim_type"]
            # threshold = loaded_data["threshold"]

            results["idx"].append(idx)
            results["n_samples"].append(n_samples_)
            results["n_dims_1"].append(n_dims_1_)
            results["sim_type"].append(sim_name)
            results["model"].append(model_name)

            if param_name == "sas98":
                sas98 = loaded_data["sas98"]
                results["sas98"].append(sas98)
            elif param_name == "cdcorr_pvalue":
                # print(dict(loaded_data).keys())
                cdcorr_pvalue = loaded_data["cdcorr_pvalue"]
                results["cdcorr_pvalue"].append(cdcorr_pvalue)
            elif param_name == "cmi":
                mi = loaded_data["cmi"]
                results["cmi"].append(mi)

                if "comight" in model_name:
                    try:
                        I_XZ_Y = loaded_data["I_XZ_Y"]
                        I_Z_Y = loaded_data["I_Z_Y"]
                        results["I_XZ_Y"].append(I_XZ_Y)
                        results["I_Z_Y"].append(I_Z_Y)
                    except Exception as e:
                        try:
                            I_XZ_Y = loaded_data["I_X1X2_Y"]
                            I_Z_Y = loaded_data["I_X1_Y"]
                            results["I_XZ_Y"].append(I_XZ_Y)
                            results["I_Z_Y"].append(I_Z_Y)
                        except Exception as e:
                            print(e)
            elif param_name == "auc":
                y_score = loaded_data["posterior_arr"]
                y_true = loaded_data["y"]
                n_trees, n_samples, n_classes = y_score.shape
                y_score_avg = np.nanmean(y_score, axis=0)
                y_score_binary = y_score_avg[:, 1]
                nan_rows = np.isnan(y_score_binary)
                y_score_binary = y_score_binary[~nan_rows]
                y_true = y_true[~nan_rows]
                auc = roc_auc_score(y_true, y_score_binary)
                results["auc"].append(auc)
                # results["threshold"].append(threshold)

    df = pd.DataFrame(results)
    print(df.head())
    # print('\n\n HERE!', param_name == "sas98" or  "comight" not in model_name)
    # Melt the DataFrame to reshape it
    df_melted = pd.melt(
        df,
        id_vars=["n_dims_1", "sim_type", "model"],
        value_vars=(
            [param_name]
            if param_name in ["sas98", "auc"] or "comight" not in model_name
            else [param_name, "I_XZ_Y", "I_Z_Y"]
        ),
        var_name="metric",
        value_name="metric_value",
    )

    # Convert "sim_type" to categorical type
    df_melted["sim_type"] = df_melted["sim_type"].astype(str)
    df_melted["model"] = df_melted["model"].astype(str)
    df_melted["n_dims_1"] = df_melted["n_dims_1"].astype(int)
    df_melted["metric_value"] = df_melted["metric_value"].astype(float)
    return df_melted


if __name__ == "__main__":
    root_dir = Path("/Users/spanda/Documents/")
    # root_dir = Path('/home/hao/')
    # output_dir = Path('/data/adam/')
    output_dir = root_dir

    n_repeats = 100

    n_samples_list = [2**x for x in range(8, 11)]
    n_dims_list = [2**x - 6 for x in range(3, 11)]
    # n_dims_1 = 1024 - 6
    n_dims_1 = 512 - 6
    # n_dims_1 = 4096 - 6
    print(n_samples_list)

    # Choose one of the parametr names
    param_name = "sas98"
    # param_name = "cdcorr_pvalue"
    # param_name = "cmi"
    # param_name = "auc"
    # param_name = "pvalue"

    if param_name == "sas98":
        models = [
            "comight",
            # "comight-perm",
            # "knn",
            # "knn_viewone",
            # "knn_viewtwo",
            #    'might_viewone', 'might_viewtwo'
        ]
    elif param_name == "cmi":
        models = [
            "comight-cmi",
            "ksg",
        ]
    elif param_name == "cdcorr_pvalue":
        models = ["cdcorr"]
    elif param_name == "auc":
        models = ["comight"]
    # for model_name in ["coleman_pvalues"]:

    sim_names = ["mean_shiftv4", "multi_modalv2", "multi_equal"]
    for sim_name in sim_names:
        for model_name in models:
            n_dims_1 = 512 - 6
            # save the dataframe to a csv file over n-samples
            df = make_csv_over_nsamples(
                root_dir,
                sim_name,
                n_samples_list,
                n_dims_1,
                n_repeats,
                param_name=param_name,
                model_name=model_name,
            )
            df.to_csv(
                output_dir
                / "output"
                / f"results_vs_nsamples_{sim_name}_{model_name}_{param_name}_{n_dims_1}_{n_repeats}.csv",
                index=False,
            )

            # Save the dataframe over varying ndims
            n_samples = 4096
            # n_samples = 1024
            n_samples = 512
            print(n_dims_list)

            # save the dataframe to a csv file over n-dims
            df = make_csv_over_ndims1(
                root_dir,
                sim_name,
                n_dims_list,
                n_samples,
                n_repeats,
                param_name=param_name,
                model_name=model_name,
            )
            df.to_csv(
                output_dir
                / "output"
                / f"results_vs_ndims_{sim_name}_{model_name}_{param_name}_{n_samples}_{n_repeats}.csv",
                index=False,
            )

[256, 512, 1024]
[2, 10, 26, 58, 122, 250, 506, 1018]
  idx n_samples n_dims_1      sim_type    model       sas98
0   0       512        2  mean_shiftv4  comight   0.1484375
1   1       512        2  mean_shiftv4  comight  0.15234375
2   2       512        2  mean_shiftv4  comight  0.30078125
3   3       512        2  mean_shiftv4  comight   0.1953125
4   4       512        2  mean_shiftv4  comight  0.25390625
[2, 10, 26, 58, 122, 250, 506, 1018]
  idx n_samples n_dims_1       sim_type    model       sas98
0   0       512        2  multi_modalv2  comight  0.29296875
1   1       512        2  multi_modalv2  comight        0.25
2   2       512        2  multi_modalv2  comight   0.2421875
3   3       512        2  multi_modalv2  comight     0.21875
4   4       512        2  multi_modalv2  comight   0.3046875
[2, 10, 26, 58, 122, 250, 506, 1018]
  idx n_samples n_dims_1     sim_type    model       sas98
0   0       512        2  multi_equal  comight   0.0078125
1   1       512        2  mu

In [21]:
# Used for generating csv summary files for CoMIGHT power
# from npz files.
#
# For other metrics, use `make_csv_summary_from_npz.py`.

from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd
from numpy.testing import assert_array_equal
from sklearn.metrics import roc_curve
from sktree.stats.utils import (METRIC_FUNCTIONS, POSITIVE_METRICS,
                                 _compute_null_distribution_coleman,
                                 _mutual_information)

n_dims_2_ = 6


def _estimate_threshold(y_true, y_score, target_specificity=0.98, pos_label=1):
    # Compute ROC curve
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=pos_label)

    # Find the threshold corresponding to the target specificity
    index = np.argmax(fpr >= (1 - target_specificity))
    threshold_at_specificity = thresholds[index]

    return threshold_at_specificity


def sensitivity_at_specificity(
    y_true, y_score, target_specificity=0.98, pos_label=1, threshold=None
):
    n_trees, n_samples, n_classes = y_score.shape

    # Compute nan-averaged y_score along the trees axis
    y_score_avg = np.nanmean(y_score, axis=0)

    # Extract true labels and nan-averaged predicted scores for the positive class
    y_true = y_true.ravel()
    y_score_binary = y_score_avg[:, 1]

    # Identify rows with NaN values in y_score_binary
    nan_rows = np.isnan(y_score_binary)

    # Remove NaN rows from y_score_binary and y_true
    y_score_binary = y_score_binary[~nan_rows]
    y_true = y_true[~nan_rows]

    if threshold is None:
        # Find the threshold corresponding to the target specificity
        threshold_at_specificity = _estimate_threshold(
            y_true, y_score_binary, target_specificity=0.98, pos_label=1
        )
    else:
        threshold_at_specificity = threshold

    # Use the threshold to classify predictions
    y_pred_at_specificity = (y_score_binary >= threshold_at_specificity).astype(int)

    # Compute sensitivity at the chosen specificity
    sensitivity = np.sum((y_pred_at_specificity == 1) & (y_true == 1)) / np.sum(
        y_true == 1
    )

    return sensitivity


def _estimate_sas98(y, posterior_arr, threshold=None, target_specificity=0.98):
    # Compute nan-averaged y_score along the trees axis
    y_score_avg = np.nanmean(posterior_arr, axis=0)

    # Extract true labels and nan-averaged predicted scores for the positive class
    y_true = y.ravel()
    y_score_binary = y_score_avg[:, 1]

    # Identify rows with NaN values in y_score_binary
    nan_rows = np.isnan(y_score_binary)

    # Remove NaN rows from y_score_binary and y_true
    y_score_binary = y_score_binary[~nan_rows]
    y_true = y_true[~nan_rows]

    threshold_at_specificity = _estimate_threshold(
        y_true, y_score_binary, target_specificity=0.98, pos_label=1
    )

    # generate S@S98 from posterior array
    sas98 = sensitivity_at_specificity(
        y,
        posterior_arr,
        target_specificity=target_specificity,
        threshold=threshold_at_specificity,
    )
    return sas98


def _estimate_pvalue(
    y,
    orig_forest_proba,
    perm_forest_proba,
    metric,
    n_repeats,
    seed,
    n_jobs,
    **metric_kwargs,
):
    metric_func = METRIC_FUNCTIONS[metric]
    y = y[:, np.newaxis]
    print(y.shape, orig_forest_proba.shape, perm_forest_proba.shape)
    metric_star, metric_star_pi = _compute_null_distribution_coleman(
        y,
        orig_forest_proba,
        perm_forest_proba,
        metric,
        n_repeats=n_repeats,
        seed=seed,
        n_jobs=n_jobs,
        **metric_kwargs,
    )

    y_pred_proba_orig = np.nanmean(orig_forest_proba, axis=0)
    y_pred_proba_perm = np.nanmean(perm_forest_proba, axis=0)
    observe_stat = metric_func(y, y_pred_proba_orig, **metric_kwargs)
    permute_stat = metric_func(y, y_pred_proba_perm, **metric_kwargs)

    # metric^\pi - metric = observed test statistic, which under the
    # null is normally distributed around 0
    observe_test_stat = permute_stat - observe_stat

    # metric^\pi_j - metric_j, which is centered at 0
    null_dist = metric_star_pi - metric_star

    # compute pvalue
    if metric in POSITIVE_METRICS:
        pvalue = (1 + (null_dist <= observe_test_stat).sum()) / (1 + n_repeats)
    else:
        pvalue = (1 + (null_dist >= observe_test_stat).sum()) / (1 + n_repeats)
    return pvalue


def recompute_metric_n_samples(
    root_dir, sim_name, n_dims_1, n_dims_2, n_repeats, n_jobs=None, overwrite=False
):
    """Implement comight-power and comightperm-power over n_samples.

    Each will have a separate csv file.
    """
    output_model_name = "comight-power"
    n_samples_list = [2**x for x in range(8, 11)]

    fname = (
        f"results_vs_nsamples_{sim_name}_{output_model_name}_{n_dims_1}_{n_repeats}.csv"
    )
    output_file = root_dir / "output" / fname
    output_file.parent.mkdir(exist_ok=True, parents=True)

    if output_file.exists() and not overwrite:
        print(f"Output file: {output_file} exists")
        return

    # loop through directory and extract all the posteriors
    # for comight and comight-perm -> cmi_observed
    # then for comight-perm and its combinations -> cmi_permuted
    result = defaultdict(list)
    for idx in range(n_repeats):
        for n_samples in n_samples_list:
            comight_fname = (
                root_dir
                / "output"
                / "comight"
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2}_{idx}.npz"
            )
            comight_perm_fname = (
                root_dir
                / "output"
                / "comight-cmi"
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2}_{idx}.npz"
            )
            comight_data = np.load(comight_fname)
            comight_perm_data = np.load(comight_perm_fname)

            obs_posteriors = comight_data["posterior_arr"]
            obs_y = comight_data["y"]
            perm_posteriors = comight_perm_data["perm_posterior_arr"]
            perm_y = comight_perm_data["y"]

            # mutual information for both
            y_pred_proba = np.nanmean(obs_posteriors, axis=0)
            I_XZ_Y = _mutual_information(obs_y, y_pred_proba)
            y_pred_proba = np.nanmean(perm_posteriors, axis=0)
            I_Z_Y = _mutual_information(perm_y, y_pred_proba)

            assert_array_equal(obs_y, perm_y)
            # compute sas98 diffs
            sas98_obs = _estimate_sas98(obs_y, obs_posteriors)
            sas98_perm = _estimate_sas98(perm_y, perm_posteriors)

            result["sas98"].append(sas98_obs - sas98_perm)
            result["cmi"].append(I_XZ_Y - I_Z_Y)
            result["idx"].append(idx)
            result["n_samples"].append(n_samples)
            result["n_dims_1"].append(n_dims_1)
            result["n_dims_2"].append(n_dims_2_)

    df = pd.DataFrame(result)
    df.to_csv(output_file, index=False)

    # now we do the same for comight-permuted
    output_model_name = "comightperm-power"
    fname = (
        f"results_vs_nsamples_{sim_name}_{output_model_name}_{n_dims_1}_{n_repeats}.csv"
    )
    output_file = root_dir / "output" / fname
    if output_file.exists() and not overwrite:
        print(f"Output file: {output_file} exists")
        return

    result = defaultdict(list)
    for idx in range(n_repeats):
        for n_samples in n_samples_list:
            perm_idx = idx + 1 if idx <= n_repeats - 2 else 0
            comight_fname = (
                root_dir
                / "output"
                / "comight-cmi"
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2}_{idx}.npz"
            )
            comight_perm_fname = (
                root_dir
                / "output"
                / "comight-cmi"
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2}_{perm_idx}.npz"
            )
            comight_data = np.load(comight_fname)
            comight_perm_data = np.load(comight_perm_fname)

            obs_posteriors = comight_data["perm_posterior_arr"]
            obs_y = comight_data["y"]
            perm_posteriors = comight_perm_data["perm_posterior_arr"]
            perm_y = comight_perm_data["y"]

            assert_array_equal(obs_y, perm_y)

            # mutual information for both
            y_pred_proba = np.nanmean(obs_posteriors, axis=0)
            I_XZ_Y = _mutual_information(obs_y, y_pred_proba)

            y_pred_proba = np.nanmean(perm_posteriors, axis=0)
            I_Z_Y = _mutual_information(perm_y, y_pred_proba)

            # compute sas98 diffs
            sas98_obs = _estimate_sas98(obs_y, obs_posteriors)
            sas98_perm = _estimate_sas98(perm_y, perm_posteriors)

            result["sas98"].append(sas98_obs - sas98_perm)
            result["cmi"].append(I_XZ_Y - I_Z_Y)
            result["idx"].append(idx)
            result["n_samples"].append(n_samples)
            result["n_dims_1"].append(n_dims_1)
            result["n_dims_2"].append(n_dims_2_)

    df = pd.DataFrame(result)
    df.to_csv(output_file, index=False)


def recompute_metric_n_dims(
    root_dir, sim_name, n_samples, n_dims_2, n_repeats, n_jobs=None, overwrite=False
):
    output_model_name = "comight-power"
    n_dims_list = [2**i - 6 for i in range(3, 11)]

    fname = (
        f"results_vs_ndims_{sim_name}_{output_model_name}_{n_samples}_{n_repeats}.csv"
    )
    output_file = root_dir / "output" / fname
    output_file.parent.mkdir(exist_ok=True, parents=True)

    if output_file.exists() and not overwrite:
        print(f"Output file: {output_file} exists")
        return

    # loop through directory and extract all the posteriors
    # for comight and comight-perm -> cmi_observed
    # then for comight-perm and its combinations -> cmi_permuted
    result = defaultdict(list)
    for idx in range(n_repeats):
        for n_dims_1 in n_dims_list:
            comight_fname = (
                root_dir
                / "output"
                / "comight"
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2}_{idx}.npz"
            )
            comight_perm_fname = (
                root_dir
                / "output"
                / "comight-cmi"
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2}_{idx}.npz"
            )
            comight_data = np.load(comight_fname)
            comight_perm_data = np.load(comight_perm_fname)

            obs_posteriors = comight_data["posterior_arr"]
            obs_y = comight_data["y"]
            perm_posteriors = comight_perm_data["perm_posterior_arr"]
            perm_y = comight_perm_data["y"]

            # mutual information for both
            y_pred_proba = np.nanmean(obs_posteriors, axis=0)
            I_XZ_Y = _mutual_information(obs_y, y_pred_proba)

            y_pred_proba = np.nanmean(perm_posteriors, axis=0)
            I_Z_Y = _mutual_information(perm_y, y_pred_proba)

            assert_array_equal(obs_y, perm_y)

            # compute sas98 diffs
            sas98_obs = _estimate_sas98(obs_y, obs_posteriors)
            sas98_perm = _estimate_sas98(perm_y, perm_posteriors)

            result["sas98"].append(sas98_obs - sas98_perm)
            result["cmi"].append(I_XZ_Y - I_Z_Y)
            result["idx"].append(idx)
            result["n_samples"].append(n_samples)
            result["n_dims_1"].append(n_dims_1)
            result["n_dims_2"].append(n_dims_2_)

    df = pd.DataFrame(result)
    df.to_csv(output_file, index=False)

    # now we do the same for comight-permuted
    output_model_name = "comightperm-power"
    fname = (
        f"results_vs_ndims_{sim_name}_{output_model_name}_{n_samples}_{n_repeats}.csv"
    )
    output_file = root_dir / "output" / fname
    if output_file.exists() and not overwrite:
        print(f"Output file: {output_file} exists")
        return

    result = defaultdict(list)
    for idx in range(n_repeats):
        for n_dims_1 in n_dims_list:
            perm_idx = idx + 1 if idx <= n_repeats - 2 else 0
            comight_fname = (
                root_dir
                / "output"
                / "comight-cmi"
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2}_{idx}.npz"
            )
            comight_perm_fname = (
                root_dir
                / "output"
                / "comight-cmi"
                / sim_name
                / f"{sim_name}_{n_samples}_{n_dims_1}_{n_dims_2}_{perm_idx}.npz"
            )
            comight_data = np.load(comight_fname)
            comight_perm_data = np.load(comight_perm_fname)

            obs_posteriors = comight_data["perm_posterior_arr"]
            obs_y = comight_data["y"]
            perm_posteriors = comight_perm_data["perm_posterior_arr"]
            perm_y = comight_perm_data["y"]

            assert_array_equal(obs_y, perm_y)

            # mutual information for both
            y_pred_proba = np.nanmean(obs_posteriors, axis=0)
            I_XZ_Y = _mutual_information(obs_y, y_pred_proba)

            y_pred_proba = np.nanmean(perm_posteriors, axis=0)
            I_Z_Y = _mutual_information(perm_y, y_pred_proba)

            # compute sas98 diffs
            sas98_obs = _estimate_sas98(obs_y, obs_posteriors)
            sas98_perm = _estimate_sas98(perm_y, perm_posteriors)

            result["sas98"].append(sas98_obs - sas98_perm)
            result["cmi"].append(I_XZ_Y - I_Z_Y)
            result["idx"].append(idx)
            result["n_samples"].append(n_samples)
            result["n_dims_1"].append(n_dims_1)
            result["n_dims_2"].append(n_dims_2_)

    df = pd.DataFrame(result)
    df.to_csv(output_file, index=False)


if __name__ == "__main__":
    # root_dir = Path("/Volumes/Extreme Pro/cancer")
    # root_dir = Path('/home/hao/')
    # output_dir = Path('/data/adam/')
    root_dir = Path("/Users/spanda/Documents/")
    output_dir = root_dir

    sim_names = ["mean_shiftv4", "multi_modalv2", "multi_equal"]

    # n_dims_1 = 4096 - 6
    n_dims_1 = 512 - 6
    n_dims_2 = 6

    # n_samples = 1024
    n_samples = 512
    n_repeats = 100
    n_jobs = -1

    for sim_name in sim_names:
        recompute_metric_n_samples(
            root_dir,
            sim_name,
            n_dims_1,
            n_dims_2,
            n_repeats,
            n_jobs=n_jobs,
            overwrite=True,
        )

        recompute_metric_n_dims(
            root_dir,
            sim_name,
            n_samples,
            n_dims_2,
            n_repeats,
            n_jobs=n_jobs,
            overwrite=True,
        )