# Preamble

This code produces the multiomics related figure in the main text for illustrating applications of our method.

In [None]:
notebook_name = "application_multiomics"

In [None]:
n_jobs = 64

# Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline


# # disable parallelization for BLAS and co.
# from corals.threads import set_threads_for_external_libraries
# set_threads_for_external_libraries(n_threads=16)

# general
import re
import collections
import pickle
import warnings 
import joblib
import pathlib

# data
import numpy as np
import pandas as pd
import h5py

# ml / stats
import sklearn
import scipy.stats

# plotting
import matplotlib.pyplot as plt

# init matplotlib defaults
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run -m rpy2.situation

In [None]:
import sklearn.manifold
import sklearn.impute
import sklearn.pipeline

In [None]:
from matplotlib.collections import LineCollection

In [None]:
import coralsarticle.data.applications.multiomics
from coralsarticle.data.utils import preprocess
from corals.correlation.utils import preprocess_X

In [None]:
import corals.correlation.topk
import corals.correlation.topkdiff

from corals.correlation.topk.default import cor_topk
from corals.correlation.topkdiff.default import cor_topkdiff

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt 
from matplotlib.lines import Line2D

# Data and functions

In [None]:
data_preg = coralsarticle.data.applications.multiomics.load_pregnancy_multiomics_data()

In [None]:
feature_groups = sorted(coralsarticle.data.applications.multiomics.pregnancy_multiomics_subset_info.keys())
pregnancy_multiomics_subset_info = coralsarticle.data.applications.multiomics.pregnancy_multiomics_subset_info

## Masks and preprocessing

We are 
* throwing out columns with only one value because those cause NAs which the BallTree in the top-k algorithm can't handle
* throwing out duplicate columns because those seem not to be selected anyway when the min-difference between timepoint correlations is at least 0.75

In [None]:
# data
data_tp3 = data_preg[data_preg["timepoint"] == 3][feature_groups].values[:,:]
data_tp4 = data_preg[data_preg["timepoint"] == 4][feature_groups].values[:,:]

In [None]:
# kick out everything that may result in NaNs

data_tp3_msk_nunique2 = coralsarticle.data.utils.mask_min_nunique(data_tp3, 2)
data_tp4_msk_nunique2 = coralsarticle.data.utils.mask_min_nunique(data_tp4, 2)
data_tp34_msk_nunique2 = data_tp3_msk_nunique2 & data_tp4_msk_nunique2

data_tp3 = data_tp3[:, data_tp34_msk_nunique2]
data_tp4 = data_tp4[:, data_tp34_msk_nunique2]

In [None]:
# get masks for duplicate values (we'll add them back in later)
data_tp34_msk_unique, data_tp34_msk_unique_inverse = coralsarticle.data.utils.mask_unique(np.concatenate([data_tp3, data_tp4], axis=0), return_inverse=True)

# disable this make for now
# data_tp34_msk_unique[:] = 1

data_tp34_msk_unique.sum()

In [None]:
# test inverse
d = data_tp3[:,:]
msk_unique, msk_unique_reverse = coralsarticle.data.utils.mask_unique(d, return_inverse=True)

assert np.allclose(
    d,
    d[:,msk_unique][:,msk_unique_reverse])

# Top-k

## Top-k per timepoint

**Note**: This takes a bit longer than in the benchmarks because we are selecting 10% (for historical reasons) of the top correlations rather than 0.1%.

In [None]:
topk_ratio = 0.10 
approximation_factor = 5

In [None]:
%%time
data_tp3_topk_cor, data_tp3_topk_idx = cor_topk(
    data_tp3[:,data_tp34_msk_unique], 
    k=data_tp34_msk_unique.sum()**2 * topk_ratio, 
    n_jobs=n_jobs, 
    correlation_type="spearman", 
    approximation_factor=approximation_factor,
    symmetrize=True)

In [None]:
# check min and max correlations
min_cor = np.min(np.abs(data_tp3_topk_cor))
max_cor = np.max(np.abs(data_tp3_topk_cor))
print(min_cor, max_cor)

In [None]:
%%time
data_tp4_topk_cor, data_tp4_topk_idx = cor_topk(
    data_tp4[:,data_tp34_msk_unique], 
    k=data_tp34_msk_unique.sum()**2 * topk_ratio, 
    n_jobs=64, 
    correlation_type="spearman", 
    approximation_factor=approximation_factor,
    symmetrize=True)

In [None]:
# check min and max correlations
min_cor = np.min(np.abs(data_tp4_topk_cor))
max_cor = np.max(np.abs(data_tp4_topk_cor))
print(min_cor, max_cor)

In [None]:
%%time
# convert to sparse matrices
data_tp3_topk_cor_sparsematrix = scipy.sparse.csr_matrix((data_tp3_topk_cor, data_tp3_topk_idx), shape=[data_tp34_msk_unique.sum()] * 2)
data_tp4_topk_cor_sparsematrix = scipy.sparse.csr_matrix((data_tp4_topk_cor, data_tp4_topk_idx), shape=[data_tp34_msk_unique.sum()] * 2)

## Find most different

In [None]:
%%time
approximation_factor = 10
topk_ratio = 0.001
data_tp34_topk_diff_values_s, data_tp34_topk_diff_idx_s = cor_topkdiff(
    data_tp3[:,data_tp34_msk_unique], 
    data_tp4[:, data_tp34_msk_unique], 
    k=data_tp34_msk_unique.sum()**2*topk_ratio, 
    approximation_factor=approximation_factor,
    n_jobs=n_jobs,
    correlation_type="spearman", 
    symmetrize=True)

In [None]:
# check min and max correlations
min_cor = np.min(np.abs(data_tp34_topk_diff_values_s))
max_cor = np.max(np.abs(data_tp34_topk_diff_values_s))
print(min_cor, max_cor)

In [None]:
data_tp34_topk_diff_values, data_tp34_topk_diff_idx = data_tp34_topk_diff_values_s, data_tp34_topk_diff_idx_s 

In [None]:
%%time
# convert to sparse matrix
data_tp34_topk_diff_sparsematrix = scipy.sparse.csr_matrix((data_tp34_topk_diff_values, data_tp34_topk_diff_idx), shape=(data_tp34_msk_unique.sum(), data_tp34_msk_unique.sum()))

# Load embeddings

In [None]:
# load embeddings

path = pathlib.Path("../_out/applications/application_multiomics_prepare-embeddings")
prefix = "embedding___data_pregnancy___preprocessing_neg_n2___cor-spearman-direct___algorithm_tsne_v2___"

embeddings = []
select_embeddings = []
runtimes = []

for i_fg, feature_group in enumerate(feature_groups):
#     print(prefix, feature_group)
    
    filename = f"{prefix}featuregroup_{feature_group}.h5"
    print(filename)
    
    with h5py.File(path / filename, "r") as f:
        embeddings.append(f["embedding"][:])
        select_embeddings.append(f["mask"][:])
        runtimes.append(f["time"][()])

In [None]:
runtimes

In [None]:
# check embeddings
fig, axes = plt.subplots(1, 7, figsize=(7 * (6 + 1), 6))
for i, emb in enumerate(embeddings):
    ax = axes[i]
    ax.scatter(emb[:,0], emb[:,1], s=1)
    ax.set(title=feature_groups[i]) 

In [None]:
embeddings_merged = np.concatenate(embeddings)
e_max = np.max(np.abs(embeddings_merged), axis=0)
embeddings_max = [np.max(np.abs(e), axis=0) for e in embeddings]
embeddings_merged.shape

# Visualization

## Functions / Setup

In [None]:
feature_groups

In [None]:
group_offsets = [
    (3,0),          # cellfre rna
    (-1.5,-2.5),    # immune system
    (-3,0),         # metabolomics
    (0,3),          # microbiome
    (2.5,2.5),      # plasma_luminex
    (-2.5,2.5),     # plasma_somalogic
    (1.5,-2.5)]     # serum_luminex

group_scaling = [
    1, # cellfre rna
    1, # immune system
    1, # metabolomics
    1, # microbiome
    0.5, # plasma_luminex
    1, # plasma_somalogic
    0.5  # serum_luminex
]

In [None]:
def init_kwargs(kwargs, **defaults):

    if defaults is None:
        defaults = dict()

    if kwargs is not None:
        return {**defaults, **kwargs}
    else:
        return defaults

In [None]:
def remove_intraomic_edges(msk):
    
    offset = 0
    for fg, s in zip(feature_groups, select_embeddings):
        print(fg)

#         print("* m1")
        m1 = np.zeros(msk.shape, dtype=bool)
        m1[offset:offset+len(s),:] = 1

#         print("* m2")
        m2 = np.zeros(msk.shape, dtype=bool)
        m2[:,offset:offset+len(s)] = 1

#         print("* m1 * m2")
        m = m1 * m2

#         print("* msk = 0")
        msk[m] = 0

        offset += len(s)

In [None]:
def draw_network(edges=None, draw_legend=False, verbose=0, nodes_style=None, edges_style=None, ax=None):
    
    if ax is None:
        ax = plt.gca()
    
    # draw edges
    if edges is not None:
        
        if not isinstance(edges, list):
            edges = [edges]
        
        if edges_style is None:
            edges_style = [{}]
        else:
            if not isinstance(edges_style, list):
                edges_style = [edges_style] * len(edges)
                
        for i, (e, s) in enumerate(zip(edges, edges_style)):
            draw_edges(e, ax=ax, zorder=-1000 + i, **s)
    
    if verbose > 2:
        print("plot nodes")
        
    
    if nodes_style is None:
        nodes_style = {}
    if not isinstance(nodes_style, list):
        nodes_style = [nodes_style] * len(fg)
        
    for i_fg, feature_group in enumerate(feature_groups): 
        color = pregnancy_multiomics_subset_info[feature_group]["color"]
        e = embeddings[i_fg].copy()
        e[:,0] = e[:,0] / embeddings_max[i_fg][0] * group_scaling[i_fg] + group_offsets[i_fg][0]
        e[:,1] = e[:,1] / embeddings_max[i_fg][1] * group_scaling[i_fg] + group_offsets[i_fg][1]

        ax.scatter(
            e[:,0], e[:,1], 
            **init_kwargs(
                nodes_style[i_fg], 
                s=50 if draw_legend else 1, 
                label=feature_group, 
                zorder=-100, 
                c=color))

    if draw_legend:
        ax.legend()
    
    ax.axis("off");
    if verbose > 2:
        print("finalizing figure")

In [None]:
def draw_edges(edges, ax=None, verbose=0, **edges_kwargs):
    
    if ax is None:
        ax = plt.gca()
    
    msk_offsets = np.insert(np.cumsum([len(e) for e in select_embeddings]), 0,0)
    
    for i_rg, row_group in enumerate(feature_groups): 

        if verbose > 0:
            print(row_group)

        # prepare embeddings

        e_rg = embeddings[i_rg].copy()

        # normalize embeddings
        e_rg[:,0] = e_rg[:,0] / embeddings_max[i_rg][0] * group_scaling[i_rg] + group_offsets[i_rg][0]
        e_rg[:,1] = e_rg[:,1] / embeddings_max[i_rg][1] * group_scaling[i_rg] + group_offsets[i_rg][1]

        msk_rows = np.zeros(edges.shape[1], dtype=bool)
        msk_rows[msk_offsets[i_rg]:msk_offsets[i_rg + 1]] = select_embeddings[i_rg]

        edges_rows = edges[msk_rows, :]

        for i_cg, col_group in enumerate(feature_groups): 

            if i_cg < i_rg:
                if verbose > 1:
                    print("  *", col_group)

                e_cg = embeddings[i_cg].copy()
                e_cg[:,0] = e_cg[:,0] / embeddings_max[i_cg][0] * group_scaling[i_cg] + group_offsets[i_cg][0]
                e_cg[:,1] = e_cg[:,1] / embeddings_max[i_cg][1] * group_scaling[i_cg] + group_offsets[i_cg][1]

                if verbose > 2:
                    print(f"    * row vars: {e_rg.shape[0]:10d} / {embeddings[i_rg].shape[0]:10d}")
                    print(f"    * col vars: {e_cg.shape[0]:10d} / {embeddings[i_cg].shape[0]:10d}")

                msk_cols = np.zeros(edges.shape[1], dtype=bool)
                msk_cols[msk_offsets[i_cg]:msk_offsets[i_cg + 1]] = select_embeddings[i_cg]

                edges_rows_cols = edges_rows[:, msk_cols]

                # coordinates for edges
                coo_rg = np.repeat(e_rg, e_cg.shape[0], axis=0)
                coo_cg = np.tile(e_cg.transpose(), e_rg.shape[0]).transpose()
                coo = np.swapaxes(np.stack((coo_rg, coo_cg), axis=2), -1, -2)

                coo = coo[edges_rows_cols.astype(bool).A.flatten()]

                if verbose > 2:
                    print("    * draw")

                line_collection = LineCollection(
                    coo,
                    **init_kwargs(
                        edges_kwargs, 
                        color="grey",
                        alpha=0.1,
                        linewidth=0.5)
                    )
                
                ax.add_collection(line_collection)

#             if i_cg > 1:
#                 break

#         if i_rg > 1:
#             break
            
        if verbose > 0:
            print()

In [None]:
def expand(m):
    """Expands data to original data size (before dropping homogeneous value columns and duplicates)"""

    tmp = m[:,data_tp34_msk_unique_inverse][data_tp34_msk_unique_inverse,:]

    # pad with zeros
    tmp = scipy.sparse.csr_matrix((tmp.data, tmp.indices, tmp.indptr), shape=(tmp.shape[0], tmp.shape[1] + 1), copy=True)
    tmp = scipy.sparse.vstack([tmp, scipy.sparse.csr_matrix((1 , tmp.shape[1]))])

    # calculate index
    n = np.concatenate(select_embeddings).size
    idx = np.repeat(data_tp34_msk_nunique2.sum(), n)
    idx[data_tp34_msk_nunique2] = np.arange(data_tp34_msk_nunique2.sum(), dtype=int)

    # finalize matrix
    return tmp[idx,:][:,idx]

In [None]:
def get_file(file_name, out_dir=f"_out/applications/{notebook_name}"):
    out_dir = pathlib.Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    return out_dir / file_name

In [None]:
nodes_style1=dict(
    linewidths=0,
    s=0.5)

nodes_style2=dict(
    linewidths=0,
    s=1)

edges_style_single =dict(
    color="grey",
    alpha=0.05,
    linewidth=0.01)

edges_style4_under =dict(
    color="silver",
    alpha=0.5,
    linewidth=0.01)

edges_style4_over =dict(
    color=[0.2] * 3,
    alpha=0.05,
    linewidth=0.02)

## Prepare

In [None]:
# # look at embeddings
# fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=150)
# draw_network(nodes_style=[nodes_style1] * 1 + [nodes_style2] * 6, ax=ax)

### Trimester 3

In [None]:
%%time
adj1 = expand(data_tp3_topk_cor_sparsematrix)

In [None]:
%%time
adj1_diff34 = expand(data_tp34_topk_diff_sparsematrix).multiply(adj1.astype(bool))

In [None]:
%%time
# separate cellfree from other omics for visualization purposes

msk_cellfree = data_preg[feature_groups].columns.get_level_values(0) == "cellfree_rna"

diag = scipy.sparse.spdiags((~msk_cellfree).astype(int), 0, *adj1.shape)
adj1_remaining = (adj1.T * diag).T * diag
adj1_remaining.eliminate_zeros()

adj1_cellfree = adj1 - adj1_remaining
adj1_remaining.eliminate_zeros()

In [None]:
%%time
fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=300)
draw_network(
    nodes_style=[nodes_style1] * 1 + [nodes_style2] * 6,
    edges=[adj1_cellfree, adj1_remaining],
    edges_style=[edges_style_single, edges_style_single],
    ax=ax)
fig.savefig(get_file("network_pregnancy_t3_topk_monochrome.png"), bbox_inches='tight')

In [None]:
%%time
fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=300)
draw_network(
    nodes_style=[nodes_style1] * 1 + [nodes_style2] * 6,
    edges=[adj1, adj1_diff34],
#     edges_style=[edges_style4_under, edges_style_diff],
    edges_style=[edges_style4_under, edges_style4_over],
    ax=ax)
fig.savefig(get_file("network_pregnancy_t3_topk_diff.png"), bbox_inches='tight')
fig.savefig(get_file("multiomics_pregnancy_t3.png", out_dir="../_out/figures"), bbox_inches='tight')

### Postpartum

In [None]:
%%time
adj2 = expand(data_tp4_topk_cor_sparsematrix)

In [None]:
%%time
adj2_diff34 = scipy.sparse.csr_matrix.multiply(
    expand(data_tp34_topk_diff_sparsematrix), adj2.astype(bool))
adj2_diff34.eliminate_zeros()

In [None]:
%%time
diag = scipy.sparse.spdiags((~msk_cellfree).astype(int), 0, *adj2.shape)
adj2_remaining = (adj2.T * diag).T * diag
adj2_remaining.eliminate_zeros()

adj2_cellfree = adj2 - adj2_remaining
adj2_remaining.eliminate_zeros()

In [None]:
%%time
fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=300)
draw_network(
    nodes_style=[nodes_style1] * 1 + [nodes_style2] * 6,
    edges=[adj2_cellfree, adj2_remaining],
    edges_style=[edges_style_single, edges_style_single],
    ax=ax)
fig.savefig(get_file("network_pregnancy_t4_topk_monochrome.png"), bbox_inches='tight')

In [None]:
%%time
fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=300)
draw_network(
    nodes_style=[nodes_style1] * 1 + [nodes_style2] * 6,
    edges=[adj2, adj2_diff34],
#     edges_style=[edges_style4_under, edges_style_diff],
    edges_style=[edges_style4_under, edges_style4_over],
    ax=ax)
fig.savefig(get_file("network_pregnancy_t4_topk_diff.png"), bbox_inches='tight')

### Legend

In [None]:
patchList = []
for fg in feature_groups:
        data_key = Line2D(
            [0], [0], marker="o",  color='w', 
            markerfacecolor=pregnancy_multiomics_subset_info[fg]["color"], 
            label=pregnancy_multiomics_subset_info[fg]["name_full"], 
            markersize=10)
        patchList.append(data_key)

fig, axes = plt.subplots(1,1, dpi=150, figsize=(2,2))
ax = axes
ax.legend(handles=patchList, loc="upper left", frameon=False)
ax.axis("off")
fig.savefig('../_out/figures/multiomics_legend.pdf', bbox_inches='tight')

## Final figure

In [None]:
%%time
fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=300)
draw_network(
    nodes_style=[nodes_style1] * 1 + [nodes_style2] * 6,
    edges=[adj1, adj1_diff34],
    edges_style=[edges_style4_under, edges_style4_over],
    ax=ax)
ax.set_rasterization_zorder(-1)

# first legend ()
patchList = []
for fg in feature_groups:
        data_key = Line2D(
            [0], [0], marker="o",  color=(1,1,1,0), lw=1, markeredgewidth=.5,
            markerfacecolor=pregnancy_multiomics_subset_info[fg]["color"], 
            label=pregnancy_multiomics_subset_info[fg]["name_full"], 
            markersize=5)
        patchList.append(data_key)

# second legend ()
patchList2 = []
patchList2.append(Line2D(
    [0], [0], markerfacecolor=(1,1,1,0), 
    color=(0.5,0.5,0.5,1), lw=0.5, label="Correlations at 3rd trimester"))
patchList2.append(Line2D(
    [0], [0], markerfacecolor=(1,1,1,0), 
    color=(0,0,0,1), lw=1, label="Modified correlations after birth"))

legend1 = ax.legend(handles=patchList2, loc=(0.75,0.87), frameon=False, fontsize=5)
legend1.set_zorder(102)

ax.legend(handles=patchList, loc=(0.75,0.03), frameon=False, fontsize=5).set_zorder(102)
ax.add_artist(legend1)

fig.savefig(
    get_file("multiomics_pregnancy_t3_with-legend.pdf", out_dir="../_out/figures"), 
    bbox_inches='tight')

# Edges of interest

## Functions

In [None]:
def save_edges(name, cor1, cor2, adj):
    
    with pd.ExcelWriter(get_file(f'edges___{name}.xlsx'), engine='xlsxwriter') as writer:

        stats = []

        for fg1_i in range(7):

            dfs = []
            for fg2_i in range(7):

                if fg1_i < fg2_i:

                    # get feature group names / identifiers
                    fg1 = feature_groups[fg1_i]
                    fg2 = feature_groups[fg2_i]

                    # masks to select features for each feature group
                    msk_fg1 = (data_preg[feature_groups].columns.get_level_values(0) == fg1) & np.concatenate(select_embeddings)
                    msk_fg2 = (data_preg[feature_groups].columns.get_level_values(0) == fg2) & np.concatenate(select_embeddings)

                    # get variable names based on feature masks
                    varname1 = data_preg[feature_groups].columns[msk_fg1]
                    varname2 = data_preg[feature_groups].columns[msk_fg2]

                    varidx1 = np.arange(data_preg[feature_groups].shape[1])[msk_fg1]
                    varidx2 = np.arange(data_preg[feature_groups].shape[1])[msk_fg2]

                    # mask and correlations for feature group pair
                    values = adj[msk_fg1,:][:,msk_fg2].A
                    msk = values.astype(bool)
                    c1 = cor1[msk_fg1,:][:,msk_fg2].A
                    c2 = cor2[msk_fg1,:][:,msk_fg2].A

                    # index based on mask
                    idx = np.where(msk)

                    # log
                    stats.append((feature_groups[fg1_i], feature_groups[fg2_i], msk.sum()))
                    print(f"{feature_groups[fg1_i]:20s} {feature_groups[fg2_i]:20s}: {msk.sum():10d}")

                    # data frame
                    df = pd.DataFrame(collections.OrderedDict([
                        ("modality1",     [varname1[r][0] for r,c in zip(*idx)]),
                        ("var_glob_idx1", [varidx1[r] for r,c in zip(*idx)]),
                        ("var_loc_idx1",  [int(re.sub("_.*$", "", varname1[r][1])) for r,c in zip(*idx)]),
                        ("var_name1",     [re.sub("^.*?_", "", varname1[r][1]) for r,c in zip(*idx)]),
                        ("modality2",     [varname2[c][0] for r,c in zip(*idx)]),
                        ("var_glob_idx2", [varidx2[c] for r,c in zip(*idx)]),
                        ("var_loc_idx2",  [int(re.sub("_.*$", "", varname2[c][1])) for r,c in zip(*idx)]),
                        ("var_name2",     [re.sub("^.*?_", "", varname2[c][1]) for r,c in zip(*idx)]),
                        ("cor1",          [c1[r, c] for r,c in zip(*idx)]),
#                         ("cor2",          [c2[r, c] for r,c in zip(*idx)]),
                        ("cor2",          [c1[r, c] - values[r, c] for r,c in zip(*idx)]),
#                         ("cor_absdiff",   [np.abs(c1[r, c] - c2[r, c]) for r,c in zip(*idx)])
                        ("cor_absdiff",   [np.abs(values[r, c]) for r,c in zip(*idx)])
                    ]))
                    dfs.append(df)
            print()

            # excel    
            if len(dfs) > 0:
                sheet_name = f"{feature_groups[fg1_i]} ({len(dfs)})"

                dfs = pd.concat(dfs)

                dfs.to_excel(writer, sheet_name=sheet_name, index=False)                  
                workbook  = writer.book

                worksheet = writer.sheets[sheet_name]
                for k, v in pregnancy_multiomics_subset_info.items():
                    f = workbook.add_format()
                    f.set_bg_color(v["color"])
                    worksheet.conditional_format(
                        f'A1:A{dfs.shape[0] + 1}', 
                        {
                            'type': 'cell', 
                            'criteria': "equal to", 
                            "value": f'"{k}"', 
                            'format': f
                        }
                    )
                    worksheet.conditional_format(
                        f'E1:E{dfs.shape[0] + 1}', 
                        {
                            'type': 'cell', 
                            'criteria': "equal to", 
                            "value": f'"{k}"', 
                            'format': f
                        }
                    )
                worksheet.set_column(0, 2, 20)
                worksheet.set_column(3, 3, 40)
                worksheet.set_column(4, 6, 20)
                worksheet.set_column(7, 7, 40)
                worksheet.set_column(8, 10, 20)

        # stats sheet
        sheet_name = f"_stats"
        df_stats = pd.DataFrame.from_records(
            stats, 
            columns=["modality1", "modality2", "number of edges"])
        df_stats.to_excel(writer, sheet_name=sheet_name, index=False)  
        worksheet = writer.sheets[sheet_name] 
        worksheet.set_column(0, 3, 20)
        

## Prepare

Define cutoffs to reduce the correlations to investigate.

In [None]:
%%time
adj1_gt = adj1.copy()
adj1_gt.data[np.abs(adj1_gt.data) <= 0.8] = 0
adj1_gt.eliminate_zeros()

In [None]:
%%time
adj1_diff34_gt = scipy.sparse.csr_matrix.multiply(
    expand(data_tp34_topk_diff_sparsematrix), adj1_gt.astype(bool))
adj1_diff34_gt.eliminate_zeros()

In [None]:
%%time
adj2_gt = adj2.copy()
adj2_gt.data[np.abs(adj2_gt.data) <= 0.8] = 0
adj2_gt.eliminate_zeros()

In [None]:
%%time
adj2_diff34_gt = scipy.sparse.csr_matrix.multiply(
    expand(data_tp34_topk_diff_sparsematrix), adj2_gt.astype(bool))
adj2_diff34_gt.eliminate_zeros()

## Output

In [None]:
save_edges("tp3_tp4_gt0.8", adj1, adj2, adj1_diff34_gt)

In [None]:
save_edges("tp4_tp3_gt0.8", adj2, adj1, -adj2_diff34_gt)

In [None]:
save_edges("tp3_tp4_full", adj1, adj2, adj1_diff34)

In [None]:
save_edges("tp4_tp3_full", adj2, adj1, -adj2_diff34)