In [None]:
import csv
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from hyppo.ksample import KSample
from scipy.stats import truncnorm
from skimage.filters import threshold_otsu
from tqdm import tqdm

In [None]:
def test(samples, labels, binarize, average):

    if binarize:
        threshold = threshold_otsu(samples)
        samples = samples > threshold

    samples = [samples[labels == group, :] for group in np.unique(labels)]

    if average:
        samples = [np.mean(sample, axis=1) for sample in samples]

    # Run MGC
    try:
        stat, pvalue, *_ = KSample("Dcorr").test(*samples, reps=10000, workers=-1)
    except ValueError:
        stat, pvalue = np.nan, 1
    return stat, pvalue

In [None]:
def twin_truncnorm(mu_1, sigma_1, mu_2, sigma_2, n_subjects, n_vertices=10):

    # Initialize distributions
    upper = 1
    lower = -1
    x1 = truncnorm(
        (lower - mu_1) / sigma_1, (upper - mu_1) / sigma_1, loc=mu_1, scale=sigma_1
    )
    x2 = truncnorm(
        (lower - mu_2) / sigma_2, (upper - mu_2) / sigma_2, loc=mu_2, scale=sigma_2
    )

    # Sample distributions
    samples = []
    labels = []
    for label in range(2):
        for _ in range(int(n_subjects)):
            labels.append(label)
            if label == 0:
                samples.append(x1.rvs(n_vertices))
            if label == 1:
                samples.append(x2.rvs(n_vertices))

    return np.array(samples), np.array(labels)

In [None]:
def main(binarize, average, f, dist_params, n_subjects):
    samples, labels = twin_truncnorm(n_subjects=n_subjects, **dist_params[f])
    stat, pvalue = test(samples, labels, binarize, average)

    n_groups = len(np.unique(labels))
    sample_size = n_groups * n_subjects

    return (
        sample_size,
        stat,
        pvalue,
    )

In [None]:
dist_params = {
    "equal": dict(mu_1=0, sigma_1=0.25, mu_2=0, sigma_2=0.25),
    "same_mean": dict(mu_1=0, sigma_1=0.25, mu_2=0, sigma_2=0.5),
    "diff_mean": dict(mu_1=-0.075, sigma_1=0.25, mu_2=0.075, sigma_2=0.25),
}

In [None]:
binarize = [True, False]
average = [True, False]
n_subjects = np.linspace(5, 50, 10)
functions = list(dist_params.keys())

n_iterations = range(50)
parameters = product(binarize, average, n_subjects, functions, n_iterations)

out = []
for binarize_, average_, n_subjects_, f, _ in tqdm(list(parameters)):
    sample_size, stat, pvalue = main(binarize_, average_, f, dist_params, n_subjects_)
    out.append([binarize_, average_, f, sample_size, stat, pvalue])

In [None]:
filename = "../../results/community_simulation.csv"
columns = ["binarize", "average", "distribution", "sample_size", "stat", "pvalue"]

with open(filename, "w") as outfile:
    writer = csv.writer(outfile)
    writer.writerow(columns)
    writer.writerows(out)

In [None]:
def get_method_name(row):
    if row.average:
        if row.binarize:
            return "Average Connectivity"
        else:
            return "Average Edge Weight"
    else:
        if row.binarize:
            return "Multivariate Binary"
        else:
            return "Multivariate Weighted"

In [None]:
df = pd.read_csv("../../results/block_simulation_dcorr.csv")
df["method"] = df.apply(get_method_name, axis="columns")
df = df.drop(["binarize", "average"], axis="columns")
df["reject"] = df["pvalue"] < 0.05
df.head()

In [None]:
fig, axs = plt.subplots(ncols=3, sharex=True, sharey=False, figsize=(8, 3.5))

sns.lineplot(
    data=df.query("distribution == 'equal'"),
    x="sample_size",
    y="reject",
    hue="method",
    legend=False,
    ax=axs[0],
)
axs[0].set_box_aspect(1)
axs[0].set(
    xlabel="Sample Size",
    ylabel="False Positive Rate",
    title="Same Distribution",
    ylim=(-0.05, 1.05),
)

sns.lineplot(
    data=df.query("distribution == 'same_mean'"),
    x="sample_size",
    y="reject",
    hue="method",
    legend=True,
    ax=axs[1],
)
axs[1].set_box_aspect(1)
axs[1].set(
    xlabel="Sample Size",
    ylabel="True Positive Rate",
    title="Same Mean",
    ylim=(-0.05, 1.05),
)

sns.lineplot(
    data=df.query("distribution == 'diff_mean'"),
    x="sample_size",
    y="reject",
    hue="method",
    legend=False,
    ax=axs[2],
)
axs[2].set_box_aspect(1)
axs[2].set(
    xlabel="Sample Size",
    ylabel="True Positive Rate",
    title="Different Mean",
    ylim=(-0.05, 1.05),
)

plt.tight_layout()
axs[1].legend(
    loc="upper center",
    bbox_to_anchor=(0.5, -0.25),
    fancybox=True,
    shadow=True,
    ncol=2,
)
plt.show()