In [1]:
import datetime
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pkg.data import load_network_palette, load_unmatched
from pkg.io import FIG_PATH, get_environment_variables
from pkg.io import glue as default_glue
from pkg.io import savefig
from pkg.plot import set_theme
from pkg.stats import binom_2samp, stochastic_block_test
from scipy.stats import binom
from seaborn.utils import relative_luminance
from tqdm.autonotebook import tqdm

_, _, DISPLAY_FIGS = get_environment_variables()

FILENAME = "sbm_block_power"

FIG_PATH = FIG_PATH / FILENAME


def glue(name, var, **kwargs):
    default_glue(name, var, FILENAME, **kwargs)


def gluefig(name, fig, **kwargs):
    savefig(name, foldername=FILENAME, **kwargs)

    glue(name, fig, figure=True)

    if not DISPLAY_FIGS:
        plt.close()


t0 = time.time()
set_theme()
rng = np.random.default_rng(8888)

network_palette, NETWORK_KEY = load_network_palette()
neutral_color = sns.color_palette("Set2")[2]

GROUP_KEY = "celltype_discrete"

left_adj, left_nodes = load_unmatched(side="left")
right_adj, right_nodes = load_unmatched(side="right")

left_labels = left_nodes[GROUP_KEY].values
right_labels = right_nodes[GROUP_KEY].values

Environment variables:
   RESAVE_DATA: True
   RERUN_SIMS: False
   DISPLAY_FIGS: False



In [2]:

stat, pvalue, misc = stochastic_block_test(
    left_adj,
    right_adj,
    labels1=left_labels,
    labels2=right_labels,
)
glue("pvalue", pvalue, form="pvalue")
n_tests = misc["n_tests"]
glue("n_tests", n_tests)

In [3]:

possible1 = misc["possible1"]
possible2 = misc["possible2"]
probs1 = misc["probabilities1"]
probs2 = misc["probabilities2"]

In [4]:


method = "score"
index = possible1.index
n_sims = 100
effect_scale = 0.8
count = 0
rows = []
pbar = tqdm(total=n_sims * len(index) ** 2)
for source_group in index:
    for target_group in index:
        p1 = probs1.loc[source_group, target_group]
        p2 = probs2.loc[source_group, target_group]
        mean_p = (p1 + p2) / 2
        n1 = possible1.loc[source_group, target_group]
        n2 = possible2.loc[source_group, target_group]

        for sim in range(n_sims):
            edges1 = binom.rvs(n1, mean_p, random_state=rng)
            edges2 = binom.rvs(n2, effect_scale * mean_p, random_state=rng)
            stat, pvalue = binom_2samp(edges1, n1, edges2, n2, method=method)
            rows.append(
                {
                    "source_group": source_group,
                    "target_group": target_group,
                    "n1": n1,
                    "n2": n2,
                    "mean_p": mean_p,
                    "sim": sim,
                    "stat": stat,
                    "pvalue": pvalue,
                }
            )
            pbar.update(1)
pbar.close()

results = pd.DataFrame(rows)

  0%|          | 0/32400 [00:00<?, ?it/s]

In [5]:


def compute_power(pvalues, alpha=0.05):
    return np.mean(pvalues < alpha)


power_results = (
    results.groupby(["source_group", "target_group"])["pvalue"]
    .apply(compute_power)
    .rename("power")
    .reset_index()
)

power_results

Unnamed: 0,source_group,target_group,power
0,Ascending,Ascending,0.26
1,Ascending,CN,0.06
2,Ascending,DN$^{\mathrm{SEZ}}$,0.09
3,Ascending,DN$^{\mathrm{VNC}}$,0.05
4,Ascending,KC,0.00
...,...,...,...
319,Sensory,PN$^{\mathrm{Somato}}$,0.00
320,Sensory,Pre-DN$^{\mathrm{SEZ}}$,0.13
321,Sensory,Pre-DN$^{\mathrm{VNC}}$,0.07
322,Sensory,RGN,0.23


In [6]:
square_power = power_results.pivot(
    index="source_group", columns="target_group", values="power"
)
square_power

target_group,Ascending,CN,DN$^{\mathrm{SEZ}}$,DN$^{\mathrm{VNC}}$,KC,LHN,LN,MB-FBN,MB-FFN,MBIN,MBON,Other,PN,PN$^{\mathrm{Somato}}$,Pre-DN$^{\mathrm{SEZ}}$,Pre-DN$^{\mathrm{VNC}}$,RGN,Sensory
source_group,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
Ascending,0.26,0.06,0.09,0.05,0.0,0.01,0.0,0.01,0.03,0.0,0.0,0.1,0.1,0.11,0.06,0.1,0.01,0.01
CN,0.0,0.47,0.19,0.3,0.02,0.37,0.09,0.47,0.18,0.13,0.15,0.34,0.14,0.08,0.25,0.62,0.0,0.0
DN$^{\mathrm{SEZ}}$,0.02,0.1,0.61,0.23,0.0,0.11,0.21,0.22,0.06,0.02,0.0,0.24,0.54,0.13,0.17,0.29,0.05,0.11
DN$^{\mathrm{VNC}}$,0.13,0.08,0.21,0.67,0.0,0.07,0.0,0.11,0.15,0.01,0.0,0.13,0.06,0.16,0.09,0.35,0.08,0.0
KC,0.0,0.23,0.0,0.0,1.0,0.11,0.22,0.39,0.04,0.98,1.0,0.19,0.16,0.0,0.08,0.05,0.0,0.0
LHN,0.0,0.72,0.15,0.32,0.09,1.0,0.62,0.77,0.4,0.28,0.25,0.94,0.45,0.36,0.53,0.82,0.15,0.0
LN,0.0,0.22,0.35,0.12,0.32,0.85,0.78,0.24,0.08,0.1,0.1,0.5,0.97,0.07,0.23,0.26,0.11,0.54
MB-FBN,0.0,0.74,0.16,0.39,0.07,0.3,0.16,0.95,0.3,0.34,0.38,0.57,0.21,0.18,0.24,0.62,0.03,0.0
MB-FFN,0.0,0.27,0.18,0.17,0.03,0.4,0.06,0.4,0.22,0.15,0.15,0.41,0.15,0.15,0.17,0.52,0.08,0.0
MBIN,0.0,0.15,0.0,0.0,0.99,0.13,0.07,0.08,0.04,0.14,0.19,0.08,0.07,0.0,0.03,0.07,0.0,0.0


In [7]:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
im = sns.heatmap(
    square_power,
    square=True,
    cmap="RdBu_r",
    vmin=0,
    center=0,
    vmax=1,
    cbar_kws=dict(shrink=0.5, pad=0.1),
    ax=ax,
    # annot=misc['rejections']
)
ax.set(ylabel="Source group", xlabel="Target group")
cax = fig.get_axes()[1]
cax.set_title("Power @\n" + r"$\alpha=0.05$", pad=20)

significant = misc["rejections"]

colors = im.get_children()[0].get_facecolors()
K = square_power.shape[0]

# NOTE: the x's looked bad so I did this super hacky thing...
pad = 0.2
for idx, (is_significant, color) in enumerate(zip(significant.values.ravel(), colors)):
    if is_significant:
        i, j = np.unravel_index(idx, (K, K))
        # REF: seaborn heatmap
        lum = relative_luminance(color)
        text_color = ".15" if lum > 0.408 else "w"

        xs = [j + pad, j + 1 - pad]
        ys = [i + pad, i + 1 - pad]
        ax.plot(xs, ys, color=text_color, linewidth=3)
        xs = [j + 1 - pad, j + pad]
        ys = [i + pad, i + 1 - pad]
        ax.plot(xs, ys, color=text_color, linewidth=3)


gluefig("empirical_power_by_block", fig)

In [8]:
mean_possible = (possible1 + possible2) / 2

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
sns.heatmap(
    mean_possible,
    square=True,
    cmap="RdBu_r",
    vmin=0,
    center=0,
    cbar_kws=dict(shrink=0.5, pad=0.1),
    ax=ax,
)
ax.set(ylabel="Source group", xlabel="Target group")
cax = fig.get_axes()[1]
cax.set_title("# possible\nedges", pad=10)
gluefig("n_possible_by_block", fig)

In [9]:
elapsed = time.time() - t0
delta = datetime.timedelta(seconds=elapsed)
print(f"Script took {delta}")
print(f"Completed at {datetime.datetime.now()}")

Script took 0:00:13.387522
Completed at 2023-03-08 12:47:36.264703
