# Description

(Please, take a look at the README.md file in this directory for instructions on how to run this notebook)

This notebook reads all gene correlations across all chromosomes and computes a single correlation matrix by assembling a big correlation matrix with all genes.

It has specicfic parameters for papermill (see under `Settings` below).

This notebook is not directly run. See README.md.

# Modules

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
from pathlib import Path

import numpy as np
from scipy.spatial.distance import squareform
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import conf
from entity import Gene
from correlations import (
    check_pos_def,
    compare_matrices,
    correct_corr_mat,
    adjust_non_pos_def,
)

# Settings

In [None]:
# a cohort name (it could be something like UK_BIOBANK, etc)
COHORT_NAME = None

# reference panel such as 1000G or GTEX_V8
REFERENCE_PANEL = "GTEX_V8"

# predictions models such as MASHR or ELASTIC_NET
EQTL_MODEL = "MASHR"

# output dir
OUTPUT_DIR_BASE = None

In [None]:
assert COHORT_NAME is not None and len(COHORT_NAME) > 0, "A cohort name must be given"

COHORT_NAME = COHORT_NAME.lower()
display(f"Cohort name: {COHORT_NAME}")

In [None]:
assert (
    REFERENCE_PANEL is not None and len(REFERENCE_PANEL) > 0
), "A reference panel must be given"

display(f"Reference panel: {REFERENCE_PANEL}")

In [None]:
assert (
    EQTL_MODEL is not None and len(EQTL_MODEL) > 0
), "A prediction/eQTL model must be given"

display(f"eQTL model: {EQTL_MODEL})")

In [None]:
assert (
    OUTPUT_DIR_BASE is not None and len(OUTPUT_DIR_BASE) > 0
), "Output directory path must be given"

OUTPUT_DIR_BASE = (
    Path(OUTPUT_DIR_BASE)
    / "gene_corrs"
    / COHORT_NAME
    # / REFERENCE_PANEL.lower()
    # / EQTL_MODEL.lower()
).resolve()

OUTPUT_DIR_BASE.mkdir(parents=True, exist_ok=True)

display(f"Using output dir base: {OUTPUT_DIR_BASE}")

In [None]:
INPUT_DIR = OUTPUT_DIR_BASE / "by_chr"

display(f"Gene correlations input dir: {INPUT_DIR}")
assert INPUT_DIR.exists()

# Load data

## Gene correlations

In [None]:
all_gene_corr_files = list(INPUT_DIR.glob("gene_corrs-chr*.pkl"))

In [None]:
# sort by chromosome
all_gene_corr_files = sorted(
    all_gene_corr_files, key=lambda x: int(x.name.split("-chr")[1].split(".pkl")[0])
)

In [None]:
len(all_gene_corr_files)

In [None]:
all_gene_corr_files

In [None]:
assert len(all_gene_corr_files) == 22

## Get common genes

In [None]:
gene_ids = set()
for f in all_gene_corr_files:
    chr_genes = pd.read_pickle(f).index.tolist()
    gene_ids.update(chr_genes)

In [None]:
display(len(gene_ids))

In [None]:
sorted(list(gene_ids))[:5]

## Gene info

In [None]:
genes_info = pd.read_pickle(OUTPUT_DIR_BASE / "genes_info.pkl")

In [None]:
genes_info.shape

In [None]:
genes_info.head()

In [None]:
# keep genes in correlation matrices only
genes_info = genes_info[genes_info["id"].isin(gene_ids)]

In [None]:
genes_info.shape

In [None]:
assert not genes_info.isna().any(None)

In [None]:
genes_info.dtypes

In [None]:
genes_info.head()

# Create full correlation matrix

In [None]:
genes_info = genes_info.sort_values(["chr", "start_position"])

In [None]:
genes_info

In [None]:
full_corr_matrix = pd.DataFrame(
    np.zeros((genes_info.shape[0], genes_info.shape[0])),
    index=genes_info["id"].tolist(),
    columns=genes_info["id"].tolist(),
)

In [None]:
assert full_corr_matrix.index.is_unique & full_corr_matrix.columns.is_unique

In [None]:
for chr_corr_file in all_gene_corr_files:
    print(chr_corr_file.name, flush=True, end="... ")

    # get correlation matrix for this chromosome
    corr_data = pd.read_pickle(chr_corr_file)

    # save gene correlation matrix
    full_corr_matrix.loc[corr_data.index, corr_data.columns] = corr_data

    # save inverse of Cholesky decomposition of gene correlation matrix
    # first, adjust correlation matrix if it is not positive definite
    is_pos_def = check_pos_def(corr_data)

    if is_pos_def:
        print("all good.", flush=True, end="\n")
    else:
        print("not positive definite, fixing... ", flush=True, end="")
        corr_data_adjusted = adjust_non_pos_def(corr_data)

        is_pos_def = check_pos_def(corr_data_adjusted)
        assert is_pos_def, "Could not adjust gene correlation matrix"

        print("fixed! comparing...", flush=True, end="\n")
        compare_matrices(corr_data, corr_data_adjusted)

        corr_data = corr_data_adjusted

        # save
        full_corr_matrix.loc[corr_data.index, corr_data.columns] = corr_data

    print("\n")

In [None]:
full_corr_matrix.shape

In [None]:
full_corr_matrix.head()

In [None]:
np.all(full_corr_matrix.to_numpy().diagonal() == 1.0)

## Some checks

In [None]:
_min_val = full_corr_matrix.min().min()
display(_min_val)
# assert _min_val >= -0.05

In [None]:
_max_val = full_corr_matrix.max().max()
display(_max_val)
# assert _max_val <= 1.05

## Positive definiteness

In some cases, even if the submatrices are adjusted, the whole one is not.

So here I check that again.

In [None]:
is_pos_def = check_pos_def(full_corr_matrix)

if is_pos_def:
    print("all good.", flush=True, end="\n")
else:
    print("not positive definite, fixing... ", flush=True, end="")
    corr_data_adjusted = adjust_non_pos_def(full_corr_matrix)

    is_pos_def = check_pos_def(corr_data_adjusted)
    assert is_pos_def, "Could not adjust gene correlation matrix"

    print("fixed! comparing...", flush=True, end="\n")
    compare_matrices(full_corr_matrix, corr_data_adjusted)

    full_corr_matrix = corr_data_adjusted

## Save

### Gene corrs with gene symbols

In [None]:
output_file = OUTPUT_DIR_BASE / "gene_corrs-symbols.pkl"
display(output_file)

In [None]:
gene_corrs = full_corr_matrix.rename(
    index=Gene.GENE_ID_TO_NAME_MAP, columns=Gene.GENE_ID_TO_NAME_MAP
)

In [None]:
assert not gene_corrs.isna().any(None)
assert not np.isinf(gene_corrs.to_numpy()).any()
assert not np.iscomplex(gene_corrs.to_numpy()).any()

In [None]:
assert gene_corrs.index.is_unique
assert gene_corrs.columns.is_unique

In [None]:
gene_corrs.shape

In [None]:
gene_corrs.head()

In [None]:
gene_corrs.to_pickle(output_file)

In [None]:
del gene_corrs

# Stats

In [None]:
full_corr_matrix_flat = full_corr_matrix.mask(
    np.triu(np.ones(full_corr_matrix.shape)).astype(bool)
).stack()

In [None]:
display(full_corr_matrix_flat.shape)
assert full_corr_matrix_flat.shape[0] == int(
    full_corr_matrix.shape[0] * (full_corr_matrix.shape[0] - 1) / 2
)

## On all correlations

In [None]:
_corr_mat = full_corr_matrix_flat

In [None]:
_corr_mat.shape

In [None]:
_corr_mat.head()

In [None]:
_corr_mat.describe().apply(str)

In [None]:
display(_corr_mat.quantile(np.arange(0, 1, 0.05)))

In [None]:
display(_corr_mat.quantile(np.arange(0, 0.001, 0.0001)))

In [None]:
display(_corr_mat.quantile(np.arange(0.999, 1.0, 0.0001)))

### Plot: distribution

In [None]:
with sns.plotting_context("paper", font_scale=1.5):
    g = sns.displot(_corr_mat, kde=True, height=7)
    g.ax.set_title("Distribution of gene correlation values in all chromosomes")

### Plot: heatmap

In [None]:
vmin_val = 0.0
vmax_val = max(0.05, _corr_mat.quantile(0.99))
display(f"{vmin_val} / {vmax_val}")

In [None]:
f, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(
    full_corr_matrix,
    xticklabels=False,
    yticklabels=False,
    square=True,
    vmin=vmin_val,
    vmax=vmax_val,
    cmap="rocket_r",
    ax=ax,
)
ax.set_title("Gene correlations in all chromosomes")

## On nonzero correlations

In [None]:
nonzero_corrs = full_corr_matrix_flat[full_corr_matrix_flat > 0.0]

In [None]:
_corr_mat = nonzero_corrs

In [None]:
_corr_mat.shape

In [None]:
_corr_mat.head()

In [None]:
_corr_mat.describe().apply(str)

In [None]:
display(_corr_mat.quantile(np.arange(0, 1, 0.05)))

In [None]:
display(_corr_mat.quantile(np.arange(0, 0.001, 0.0001)))

In [None]:
display(_corr_mat.quantile(np.arange(0.999, 1.0, 0.0001)))

### Plot: distribution

In [None]:
with sns.plotting_context("paper", font_scale=1.5):
    g = sns.displot(_corr_mat, kde=True, height=7)
    g.ax.set_title("Distribution of gene correlation values in all chromosomes")