# Preamble

This code produces the multiomics related figure in the main text for illustrating applications of our method. 

**Note:** This is a condensed version of `03.01_application_multiomics.ipynb` using cached data that is provided in the manuscript for convenience. To use the provided data see the `README.md`. To produce the cached data, run both `02_application_multiomics_prepare-embeddings.ipynb` and `03.01_application_multiomics.ipynb`.

In [None]:
notebook_name = "application___multiomics"

# Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline


# # disable parallelization for BLAS and co.
# from corals.threads import set_threads_for_external_libraries
# set_threads_for_external_libraries(n_threads=16)

# general
import re
import collections
import pickle
import warnings 
import joblib
import pathlib

# data
import numpy as np
import pandas as pd
import h5py

# ml / stats
import sklearn
import scipy.stats

# plotting
import matplotlib.pyplot as plt

# init matplotlib defaults
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [None]:
%run -m rpy2.situation

In [None]:
import sklearn.manifold
import sklearn.impute
import sklearn.pipeline

In [None]:
from matplotlib.collections import LineCollection

In [None]:
import coralsarticle.data.applications.multiomics
from coralsarticle.data.utils import preprocess
from corals.correlation.utils import preprocess_X

In [None]:
import corals.correlation.topk
import corals.correlation.topkdiff

from corals.correlation.topk.default import cor_topk
from corals.correlation.topkdiff.default import cor_topkdiff

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt 
from matplotlib.lines import Line2D

# Load data

In [None]:
def get_file(file_name, out_dir=f"../_out/{notebook_name}"):
    out_dir = pathlib.Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    return out_dir / file_name

In [None]:
feature_groups = sorted(coralsarticle.data.applications.multiomics.pregnancy_multiomics_subset_info.keys())
pregnancy_multiomics_subset_info = coralsarticle.data.applications.multiomics.pregnancy_multiomics_subset_info

In [None]:
with open(get_file("plot_adj1.pickle"), "rb") as f:
    adj1 = pickle.load(f)

In [None]:
with open(get_file("plot_adj1_diff34.pickle"), "rb") as f:
    adj1_diff34 = pickle.load(f)

In [None]:
# load embeddings

path = pathlib.Path("../_out/application___multiomics___prepare-embeddings")
prefix = "embedding___data_pregnancy___preprocessing_neg_n2___cor-spearman-direct___algorithm_tsne_v2___"

embeddings = []
select_embeddings = []
runtimes = []

for i_fg, feature_group in enumerate(feature_groups):
#     print(prefix, feature_group)
    
    filename = f"{prefix}featuregroup_{feature_group}.h5"
    print(filename)
    
    with h5py.File(path / filename, "r") as f:
        embeddings.append(f["embedding"][:])
        select_embeddings.append(f["mask"][:])
        runtimes.append(f["time"][()])

In [None]:
embeddings_merged = np.concatenate(embeddings)
e_max = np.max(np.abs(embeddings_merged), axis=0)
embeddings_max = [np.max(np.abs(e), axis=0) for e in embeddings]
embeddings_merged.shape

# Functions and variables

In [None]:
group_offsets = [
    (3,0),          # cellfre rna
    (-1.5,-2.5),    # immune system
    (-3,0),         # metabolomics
    (0,3),          # microbiome
    (2.5,2.5),      # plasma_luminex
    (-2.5,2.5),     # plasma_somalogic
    (1.5,-2.5)]     # serum_luminex

group_scaling = [
    1, # cellfre rna
    1, # immune system
    1, # metabolomics
    1, # microbiome
    0.5, # plasma_luminex
    1, # plasma_somalogic
    0.5  # serum_luminex
]

In [None]:
nodes_style1=dict(
    linewidths=0,
    s=0.5)

nodes_style2=dict(
    linewidths=0,
    s=1)

edges_style_single =dict(
    color="grey",
    alpha=0.05,
    linewidth=0.01)

edges_style4_under =dict(
    color="silver",
    alpha=0.5,
    linewidth=0.01)

edges_style4_over =dict(
    color=[0.2] * 3,
    alpha=0.05,
    linewidth=0.02)

In [None]:
def init_kwargs(kwargs, **defaults):

    if defaults is None:
        defaults = dict()

    if kwargs is not None:
        return {**defaults, **kwargs}
    else:
        return defaults

In [None]:
def draw_edges(edges, ax=None, verbose=0, **edges_kwargs):
    
    if ax is None:
        ax = plt.gca()
    
    msk_offsets = np.insert(np.cumsum([len(e) for e in select_embeddings]), 0,0)
    
    for i_rg, row_group in enumerate(feature_groups): 

        if verbose > 0:
            print(row_group)

        # prepare embeddings

        e_rg = embeddings[i_rg].copy()

        # normalize embeddings
        e_rg[:,0] = e_rg[:,0] / embeddings_max[i_rg][0] * group_scaling[i_rg] + group_offsets[i_rg][0]
        e_rg[:,1] = e_rg[:,1] / embeddings_max[i_rg][1] * group_scaling[i_rg] + group_offsets[i_rg][1]

        msk_rows = np.zeros(edges.shape[1], dtype=bool)
        msk_rows[msk_offsets[i_rg]:msk_offsets[i_rg + 1]] = select_embeddings[i_rg]

        edges_rows = edges[msk_rows, :]

        for i_cg, col_group in enumerate(feature_groups): 

            if i_cg < i_rg:
                if verbose > 1:
                    print("  *", col_group)

                e_cg = embeddings[i_cg].copy()
                e_cg[:,0] = e_cg[:,0] / embeddings_max[i_cg][0] * group_scaling[i_cg] + group_offsets[i_cg][0]
                e_cg[:,1] = e_cg[:,1] / embeddings_max[i_cg][1] * group_scaling[i_cg] + group_offsets[i_cg][1]

                if verbose > 2:
                    print(f"    * row vars: {e_rg.shape[0]:10d} / {embeddings[i_rg].shape[0]:10d}")
                    print(f"    * col vars: {e_cg.shape[0]:10d} / {embeddings[i_cg].shape[0]:10d}")

                msk_cols = np.zeros(edges.shape[1], dtype=bool)
                msk_cols[msk_offsets[i_cg]:msk_offsets[i_cg + 1]] = select_embeddings[i_cg]

                edges_rows_cols = edges_rows[:, msk_cols]

                # coordinates for edges
                coo_rg = np.repeat(e_rg, e_cg.shape[0], axis=0)
                coo_cg = np.tile(e_cg.transpose(), e_rg.shape[0]).transpose()
                coo = np.swapaxes(np.stack((coo_rg, coo_cg), axis=2), -1, -2)

                coo = coo[edges_rows_cols.astype(bool).A.flatten()]

                if verbose > 2:
                    print("    * draw")

                line_collection = LineCollection(
                    coo,
                    **init_kwargs(
                        edges_kwargs, 
                        color="grey",
                        alpha=0.1,
                        linewidth=0.5)
                    )
                
                ax.add_collection(line_collection)

#             if i_cg > 1:
#                 break

#         if i_rg > 1:
#             break
            
        if verbose > 0:
            print()

In [None]:
def draw_network(edges=None, draw_legend=False, verbose=0, nodes_style=None, edges_style=None, ax=None):
    
    if ax is None:
        ax = plt.gca()
    
    # draw edges
    if edges is not None:
        
        if not isinstance(edges, list):
            edges = [edges]
        
        if edges_style is None:
            edges_style = [{}]
        else:
            if not isinstance(edges_style, list):
                edges_style = [edges_style] * len(edges)
                
        for i, (e, s) in enumerate(zip(edges, edges_style)):
            draw_edges(e, ax=ax, zorder=-1000 + i, **s)
    
    if verbose > 2:
        print("plot nodes")
        
    
    if nodes_style is None:
        nodes_style = {}
    if not isinstance(nodes_style, list):
        nodes_style = [nodes_style] * len(fg)
        
    for i_fg, feature_group in enumerate(feature_groups): 
        color = pregnancy_multiomics_subset_info[feature_group]["color"]
        e = embeddings[i_fg].copy()
        e[:,0] = e[:,0] / embeddings_max[i_fg][0] * group_scaling[i_fg] + group_offsets[i_fg][0]
        e[:,1] = e[:,1] / embeddings_max[i_fg][1] * group_scaling[i_fg] + group_offsets[i_fg][1]

        ax.scatter(
            e[:,0], e[:,1], 
            **init_kwargs(
                nodes_style[i_fg], 
                s=50 if draw_legend else 1, 
                label=feature_group, 
                zorder=-100, 
                c=color))

    if draw_legend:
        ax.legend()
    
    ax.axis("off");
    if verbose > 2:
        print("finalizing figure")

# Plotting

In [None]:
%%time
fig, ax = plt.subplots(1,1, figsize=(4,3), dpi=300)
draw_network(
    nodes_style=[nodes_style1] * 1 + [nodes_style2] * 6,
    edges=[adj1, adj1_diff34],
    edges_style=[edges_style4_under, edges_style4_over],
    ax=ax)
ax.set_rasterization_zorder(-1)

# first legend ()
patchList = []
for fg in feature_groups:
        data_key = Line2D(
            [0], [0], marker="o",  color=(1,1,1,0), lw=1, markeredgewidth=.5,
            markerfacecolor=pregnancy_multiomics_subset_info[fg]["color"], 
            label=pregnancy_multiomics_subset_info[fg]["name_full"], 
            markersize=5)
        patchList.append(data_key)

# second legend ()
patchList2 = []
patchList2.append(Line2D(
    [0], [0], markerfacecolor=(1,1,1,0), 
    color=(0.5,0.5,0.5,1), lw=0.5, label="Correlations at 3rd trimester"))
patchList2.append(Line2D(
    [0], [0], markerfacecolor=(1,1,1,0), 
    color=(0,0,0,1), lw=1, label="Modified correlations after birth"))

legend1 = ax.legend(handles=patchList2, loc=(0.75,0.87), frameon=False, fontsize=5)
legend1.set_zorder(102)

ax.legend(handles=patchList, loc=(0.75,0.03), frameon=False, fontsize=5).set_zorder(102)
ax.add_artist(legend1)

# fig.savefig(
#     get_file("multiomics_pregnancy_t3_with-legend.pdf", out_dir="../_out/figures"), 
#     bbox_inches='tight')