# Description

**TODO UPDATE**

(Please, take a look at the README.md file in this directory for instructions on how to run this notebook)

This notebook reads all gene correlations across all chromosomes and computes a single correlation matrix by assembling a big correlation matrix with all genes.

# Modules

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
from scipy.spatial.distance import squareform
from scipy import sparse
import pandas as pd
from tqdm import tqdm

import conf
from utils import chunker
from entity import Gene

# Settings

In [None]:
# a cohort name (it could be something like UK_BIOBANK, etc)
COHORT_NAME = None

# reference panel such as 1000G or GTEX_V8
REFERENCE_PANEL = None

# predictions models such as MASHR or ELASTIC_NET
EQTL_MODEL = None

# This is one S-MultiXcan result file on the same target cohort
# Genes will be read from here to align the correlation matrices
SMULTIXCAN_FILE = None

LV_CODE = None

In [None]:
assert COHORT_NAME is not None and len(COHORT_NAME) > 0, "A cohort name must be given"

COHORT_NAME = COHORT_NAME.lower()
display(f"Cohort name: {COHORT_NAME}")

In [None]:
assert (
    REFERENCE_PANEL is not None and len(REFERENCE_PANEL) > 0
), "A reference panel must be given"

display(f"Reference panel: {REFERENCE_PANEL}")

In [None]:
assert (
    EQTL_MODEL is not None and len(EQTL_MODEL) > 0
), "A prediction/eQTL model must be given"

EQTL_MODEL_FILES_PREFIX = conf.PHENOMEXCAN["PREDICTION_MODELS"][f"{EQTL_MODEL}_PREFIX"]
display(f"eQTL model: {EQTL_MODEL}) / {EQTL_MODEL_FILES_PREFIX}")

In [None]:
assert (
    SMULTIXCAN_FILE is not None and len(SMULTIXCAN_FILE) > 0
), "An S-MultiXcan result file path must be given"
SMULTIXCAN_FILE = Path(SMULTIXCAN_FILE).resolve()
assert SMULTIXCAN_FILE.exists(), "S-MultiXcan result file does not exist"

display(f"S-MultiXcan file path: {str(SMULTIXCAN_FILE)}")

In [None]:
assert LV_CODE is not None and len(LV_CODE) > 0, "An LV code must be given"

display(f"LV code: {LV_CODE})")

In [None]:
OUTPUT_DIR_BASE = (
    conf.RESULTS["GLS"]
    / "gene_corrs"
    / "cohorts"
    / COHORT_NAME.lower()
    / REFERENCE_PANEL.lower()
    / EQTL_MODEL.lower()
)
OUTPUT_DIR_BASE.mkdir(parents=True, exist_ok=True)

display(f"Using output dir base: {OUTPUT_DIR_BASE}")

# Load data

## S-MultiXcan genes

In [None]:
smultixcan_df = pd.read_csv(SMULTIXCAN_FILE, sep="\t")

In [None]:
smultixcan_df.shape

In [None]:
smultixcan_df.head()

In [None]:
assert not smultixcan_df.isin([np.inf, -np.inf]).any().any()

In [None]:
# remove NaNs
smultixcan_df = smultixcan_df.dropna(subset=["pvalue"])
display(smultixcan_df.shape)

In [None]:
smultixcan_genes = set(smultixcan_df["gene_name"].tolist())

In [None]:
len(smultixcan_genes)

In [None]:
sorted(list(smultixcan_genes))[:5]

## Gene correlations

In [None]:
input_file = OUTPUT_DIR_BASE / "gene_corrs-symbols.pkl"
display(input_file)
assert input_file.exists()

In [None]:
# load correlation matrix
gene_corrs = pd.read_pickle(input_file)

In [None]:
gene_corrs.shape

In [None]:
gene_corrs.head()

## Define output dir (based on gene correlation's file)

In [None]:
# output file (hdf5)
output_dir = Path(input_file).with_suffix(".per_lv")
output_dir.mkdir(parents=True, exist_ok=True)

display(output_dir)

## MultiPLIER Z

In [None]:
multiplier_z = pd.read_pickle(conf.MULTIPLIER["MODEL_Z_MATRIX_FILE"])

In [None]:
multiplier_z.shape

In [None]:
multiplier_z.head()

## Common genes

In [None]:
common_genes = sorted(
    list(
        smultixcan_genes.intersection(multiplier_z.index).intersection(gene_corrs.index)
    )
)

In [None]:
len(common_genes)

In [None]:
common_genes[:5]

# Compute inverse correlation matrix for each LV

In [None]:
def exists_df(base_filename):
    full_filepath = output_dir / (base_filename + ".npz")

    return full_filepath.exists()

In [None]:
def store_df(nparray, base_filename):
    if base_filename in ("metadata", "gene_names"):
        np.savez_compressed(output_dir / (base_filename + ".npz"), data=nparray)
    else:
        sparse.save_npz(
            output_dir / (base_filename + ".npz"),
            sparse.csc_matrix(nparray),
            compressed=False,
        )

In [None]:
def compute_chol_inv(lv_codes):
    for lv_code in lv_codes:
        corr_mat_sub = pd.DataFrame(
            np.identity(len(common_genes)),
            index=common_genes.copy(),
            columns=common_genes.copy(),
        )

        lv_data = multiplier_z[lv_code]
        lv_nonzero_genes = lv_data[lv_data > 0].index
        lv_nonzero_genes = lv_nonzero_genes.intersection(corr_mat_sub.index)

        corr_mat_sub.loc[lv_nonzero_genes, lv_nonzero_genes] = gene_corrs.loc[
            lv_nonzero_genes, lv_nonzero_genes
        ]

        chol_mat = np.linalg.cholesky(corr_mat_sub)
        chol_inv = np.linalg.inv(chol_mat)

        store_df(chol_inv, lv_code)

In [None]:
# divide LVs in chunks for parallel processing
# lvs_chunks = list(chunker(list(multiplier_z.columns), 50))
lvs_chunks = [[LV_CODE]]

In [None]:
# metadata
if not exists_df("metadata"):
    metadata = np.array([REFERENCE_PANEL, EQTL_MODEL])
    store_df(metadata, "metadata")
else:
    display("Metadata file already exists")

# gene names
if not exists_df("gene_names"):
    gene_names = np.array(common_genes)
    store_df(gene_names, "gene_names")
else:
    display("Gene names file already exists")

# pbar = tqdm(total=multiplier_z.columns.shape[0])

with ProcessPoolExecutor(max_workers=conf.GENERAL["N_JOBS"]) as executor, tqdm(
    total=len(lvs_chunks), ncols=100
) as pbar:
    tasks = [executor.submit(compute_chol_inv, chunk) for chunk in lvs_chunks]
    for future in as_completed(tasks):
        res = future.result()
        pbar.update(1)

## Some checks

In [None]:
def load_df(base_filename):
    full_filepath = output_dir / (base_filename + ".npz")

    if base_filename in ("metadata", "gene_names"):
        return np.load(full_filepath)["data"]
    else:
        return sparse.load_npz(full_filepath).toarray()

In [None]:
_genes = load_df("gene_names")

In [None]:
display(len(_genes))
assert len(_genes) == len(common_genes)

In [None]:
_metadata = load_df("metadata")

In [None]:
display(_metadata)
assert _metadata[0] == REFERENCE_PANEL
assert _metadata[1] == EQTL_MODEL

In [None]:
# lv1_inv = load_df("LV1")

In [None]:
# lv2_inv = load_df("LV2")

In [None]:
# lv_last_inv = load_df("LV987")
lv_last_inv = load_df(LV_CODE)

In [None]:
# assert lv1_inv.shape == lv2_inv.shape

In [None]:
# assert not np.allclose(lv1_inv, lv2_inv)

In [None]:
# assert not np.allclose(lv1_inv, lv_last_inv)

In [None]:
# assert not np.allclose(lv2_inv, lv_last_inv)