# Description

(Please, take a look at the README.md file in this directory for instructions on how to run this notebook)

This notebook reads all gene correlations across all chromosomes and computes a single correlation matrix by assembling a big correlation matrix with all genes.

# Modules

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from scipy.spatial.distance import squareform
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import conf
from entity import Gene

# Settings

In [None]:
# reference panel
REFERENCE_PANEL = "GTEX_V8"
# REFERENCE_PANEL = "1000G"

# prediction models
## mashr
EQTL_MODEL = "MASHR"
EQTL_MODEL_FILES_PREFIX = "mashr_"

# ## elastic net
# EQTL_MODEL = "ELASTIC_NET"
# EQTL_MODEL_FILES_PREFIX = "en_"

# make it read the prefix from conf.py
EQTL_MODEL_FILES_PREFIX = None

In [None]:
if EQTL_MODEL_FILES_PREFIX is None:
    EQTL_MODEL_FILES_PREFIX = conf.PHENOMEXCAN["PREDICTION_MODELS"][
        f"{EQTL_MODEL}_PREFIX"
    ]

In [None]:
display(f"Using eQTL model: {EQTL_MODEL} / {EQTL_MODEL_FILES_PREFIX}")

In [None]:
REFERENCE_PANEL_DIR = conf.PHENOMEXCAN["LD_BLOCKS"][f"{REFERENCE_PANEL}_GENOTYPE_DIR"]

In [None]:
display(f"Using reference panel folder: {str(REFERENCE_PANEL_DIR)}")

In [None]:
OUTPUT_DIR_BASE = (
    conf.PHENOMEXCAN["LD_BLOCKS"][f"GENE_CORRS_DIR"]
    / REFERENCE_PANEL.lower()
    / EQTL_MODEL.lower()
)
display(OUTPUT_DIR_BASE)
OUTPUT_DIR_BASE.mkdir(parents=True, exist_ok=True)

In [None]:
display(f"Using output dir base: {OUTPUT_DIR_BASE}")

In [None]:
INPUT_DIR = OUTPUT_DIR_BASE / "by_chr" / "corrected_positive_definite"
display(INPUT_DIR)
assert INPUT_DIR.exists()

# Load data

## Gene correlations

In [None]:
all_gene_corr_files = list(INPUT_DIR.glob("gene_corrs-chr*.pkl"))

In [None]:
# sort by chromosome
all_gene_corr_files = sorted(
    all_gene_corr_files, key=lambda x: int(x.name.split("-chr")[1].split(".pkl")[0])
)

In [None]:
len(all_gene_corr_files)

In [None]:
all_gene_corr_files

In [None]:
assert len(all_gene_corr_files) == 22

## MultiPLIER Z

In [None]:
multiplier_z_genes = pd.read_pickle(
    conf.MULTIPLIER["MODEL_Z_MATRIX_FILE"]
).index.tolist()

In [None]:
len(multiplier_z_genes)

In [None]:
multiplier_z_genes[:10]

## Get gene objects

In [None]:
multiplier_gene_obj = {
    gene_name: Gene(name=gene_name)
    for gene_name in multiplier_z_genes
    if gene_name in Gene.GENE_NAME_TO_ID_MAP
}

In [None]:
len(multiplier_gene_obj)

In [None]:
multiplier_gene_obj["GAS6"].ensembl_id

In [None]:
_gene_obj = list(multiplier_gene_obj.values())

genes_info = pd.DataFrame(
    {
        "name": [g.name for g in _gene_obj],
        "id": [g.ensembl_id for g in _gene_obj],
        "chr": [g.chromosome for g in _gene_obj],
        "start_position": [g.get_attribute("start_position") for g in _gene_obj],
    }
).dropna()

In [None]:
assert not genes_info.isna().any().any()

In [None]:
genes_info.dtypes

In [None]:
genes_info["chr"] = genes_info["chr"].apply(pd.to_numeric, downcast="integer")
genes_info["start_position"] = genes_info["start_position"].astype(
    int
)  # .apply(pd.to_numeric, downcast="signed")

In [None]:
genes_info.dtypes

In [None]:
genes_info.shape

In [None]:
genes_info.head()

In [None]:
assert not genes_info.isna().any().any()

# Create full correlation matrix

In [None]:
genes_info = genes_info.sort_values(["chr", "start_position"])

In [None]:
genes_info

In [None]:
full_corr_matrix = pd.DataFrame(
    np.zeros((genes_info.shape[0], genes_info.shape[0])),
    index=genes_info["id"].tolist(),
    columns=genes_info["id"].tolist(),
)

In [None]:
assert full_corr_matrix.index.is_unique & full_corr_matrix.columns.is_unique

In [None]:
for chr_corr_file in all_gene_corr_files:
    print(chr_corr_file.name, flush=True)

    corr_data = pd.read_pickle(chr_corr_file)
    full_corr_matrix.loc[corr_data.index, corr_data.columns] = corr_data

In [None]:
full_corr_matrix.shape

In [None]:
full_corr_matrix

## Some checks

In [None]:
full_corr_matrix[full_corr_matrix > 1.0] = 1.0
np.fill_diagonal(full_corr_matrix.values, 1.0)

In [None]:
assert np.all(full_corr_matrix.to_numpy().diagonal() == 1.0)

In [None]:
# check that all genes have a value
assert not full_corr_matrix.isna().any().any()

In [None]:
_min_val = full_corr_matrix.min().min()
display(_min_val)
# assert _min_val >= 0.0

In [None]:
_max_val = full_corr_matrix.max().max()  # this will capture the 1.0 in the diagonal
display(_max_val)
assert _max_val <= 1.0

In [None]:
# check that matrix is positive definite
eigs = np.linalg.eigvals(full_corr_matrix.to_numpy())
assert np.all(eigs > 0)

In [None]:
# this should not fail
np.linalg.cholesky(np.linalg.inv(full_corr_matrix))

# Try to fit GLS and see if it works (with random data)

In [None]:
import statsmodels.api as sm

In [None]:
np.random.seed(0)

In [None]:
y = np.random.rand(full_corr_matrix.shape[0])

In [None]:
X = np.random.rand(full_corr_matrix.shape[0], 2)
X[:, 0] = 1

In [None]:
# this should not throw an exception: LinAlgError("Matrix is not positive definite")
# _gls_model = sm.GLS(y, X, sigma=np.identity(y.shape[0]))
_gls_model = sm.GLS(y, X, sigma=full_corr_matrix)

In [None]:
_gls_results = _gls_model.fit()

In [None]:
print(_gls_results.summary())

## Stats

In [None]:
full_corr_matrix_flat = full_corr_matrix.mask(
    np.triu(np.ones(full_corr_matrix.shape)).astype(bool)
).stack()

In [None]:
display(full_corr_matrix_flat.shape)
assert full_corr_matrix_flat.shape[0] == int(
    full_corr_matrix.shape[0] * (full_corr_matrix.shape[0] - 1) / 2
)

In [None]:
full_corr_matrix_flat[full_corr_matrix_flat == 1.0]

In [None]:
full_corr_matrix_flat.head()

In [None]:
full_corr_matrix_flat.describe().apply(str)

In [None]:
full_corr_matrix_flat_quantiles = full_corr_matrix_flat.quantile(np.arange(0, 1, 0.05))
display(full_corr_matrix_flat_quantiles)

## Plot: distribution

In [None]:
with sns.plotting_context("paper", font_scale=1.5):
    g = sns.displot(full_corr_matrix_flat, kde=True, height=7)
    g.ax.set_title("Distribution of gene correlation values in all chromosomes")

## Plot: heatmap

In [None]:
vmin_val = min(-0.05, full_corr_matrix_flat_quantiles[0.10])
vmax_val = max(0.05, full_corr_matrix_flat_quantiles[0.90])
display(f"{vmin_val} / {vmax_val}")

In [None]:
f, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(
    full_corr_matrix,
    xticklabels=False,
    yticklabels=False,
    square=True,
    vmin=vmin_val,
    vmax=vmax_val,
    cmap="YlGnBu",
    ax=ax,
)
ax.set_title("Gene correlations in all chromosomes")

# Save

## With ensemble ids

In [None]:
# output_file_name_template = conf.PHENOMEXCAN["LD_BLOCKS"][
#     "GENE_CORRS_FILE_NAME_TEMPLATES"
# ]["GENE_CORR_AVG"]

# output_file = OUTPUT_DIR_BASE / output_file_name_template.format(
#     prefix="",
#     suffix=f"-ssm_corrs-gene_ensembl_ids",
# )
# display(output_file)

In [None]:
# full_corr_matrix.to_pickle(output_file)

## With gene symbols

In [None]:
output_file_name_template = conf.PHENOMEXCAN["LD_BLOCKS"][
    "GENE_CORRS_FILE_NAME_TEMPLATES"
]["GENE_CORR_AVG"]

output_file = OUTPUT_DIR_BASE / output_file_name_template.format(
    prefix="",
    suffix=f"-gene_symbols",
)
display(output_file)

In [None]:
full_corr_matrix_gene_symbols = full_corr_matrix.rename(
    index=Gene.GENE_ID_TO_NAME_MAP, columns=Gene.GENE_ID_TO_NAME_MAP
)

In [None]:
assert full_corr_matrix_gene_symbols.index.is_unique

In [None]:
assert full_corr_matrix_gene_symbols.columns.is_unique

In [None]:
full_corr_matrix_gene_symbols.shape

In [None]:
full_corr_matrix_gene_symbols.head()

In [None]:
full_corr_matrix_gene_symbols.to_pickle(output_file)