# Description

This notebook reads all gene correlations across all tissues and computes a single correlation matrix.

# Modules

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from scipy.spatial.distance import squareform
import pandas as pd

# from tqdm import tqdm

import conf
from entity import Gene

# Settings

In [None]:
# mashr
EQTL_MODEL = "MASHR"
EQTL_MODEL_FILES_PREFIX = "mashr_"

# # elastic net
# EQTL_MODEL = "ELASTIC_NET"
# EQTL_MODEL_FILES_PREFIX = "en_"

EQTL_MODEL_FILES_PREFIX = None

In [None]:
if EQTL_MODEL_FILES_PREFIX is None:
    EQTL_MODEL_FILES_PREFIX = conf.PHENOMEXCAN["PREDICTION_MODELS_PREFIXES"][EQTL_MODEL]

In [None]:
display(f"Using eQTL model: {EQTL_MODEL} / {EQTL_MODEL_FILES_PREFIX}")

In [None]:
INPUT_DIR = (
    conf.PHENOMEXCAN["LD_BLOCKS"]["BASE_DIR"] / f"{EQTL_MODEL_FILES_PREFIX}gene_corrs"
)
display(INPUT_DIR)

# Load data

## Gene correlations

In [None]:
all_gene_corr_files = list(INPUT_DIR.rglob("*.pkl"))

In [None]:
len(all_gene_corr_files)

In [None]:
all_gene_corr_files[:5]

In [None]:
assert len(all_gene_corr_files) == 22 * 49

In [None]:
all_gene_corr_files_df = pd.DataFrame({"corr_file": [f for f in all_gene_corr_files]})

In [None]:
all_gene_corr_files_df = all_gene_corr_files_df.assign(
    file_name=all_gene_corr_files_df["corr_file"].apply(lambda x: x.name)
)

In [None]:
all_gene_corr_files_df = all_gene_corr_files_df.assign(
    tissue=all_gene_corr_files_df["file_name"].apply(
        lambda x: x.split("-chr")[0].split("gene_corrs-")[1]
    )
)

In [None]:
all_gene_corr_files_df = all_gene_corr_files_df.assign(
    chromosome=all_gene_corr_files_df["file_name"].apply(
        lambda x: int(x.split("-chr")[1].split(".")[0])
    )
)

In [None]:
assert all_gene_corr_files_df["tissue"].unique().shape[0] == 49

In [None]:
assert all_gene_corr_files_df["chromosome"].unique().shape[0] == 22
assert set(all_gene_corr_files_df["chromosome"]) == set(range(1, 23))

In [None]:
all_gene_corr_files_df.shape

In [None]:
all_gene_corr_files_df.head()

## MultiPLIER Z

In [None]:
multiplier_z_genes = pd.read_pickle(
    conf.MULTIPLIER["MODEL_Z_MATRIX_FILE"]
).index.tolist()

In [None]:
len(multiplier_z_genes)

In [None]:
multiplier_z_genes[:10]

## Get gene objects

In [None]:
multiplier_gene_obj = {
    gene_name: Gene(name=gene_name)
    for gene_name in multiplier_z_genes
    if gene_name in Gene.GENE_NAME_TO_ID_MAP
}

In [None]:
len(multiplier_gene_obj)

In [None]:
multiplier_gene_obj["GAS6"].ensembl_id

In [None]:
_gene_obj = list(multiplier_gene_obj.values())

genes_info = pd.DataFrame(
    {
        "name": [g.name for g in _gene_obj],
        "id": [g.ensembl_id for g in _gene_obj],
        "chr": [g.chromosome for g in _gene_obj],
    }
).dropna()

In [None]:
genes_info.shape

In [None]:
genes_info.head()

## Get tissues names

In [None]:
db_files = list(conf.PHENOMEXCAN["PREDICTION_MODELS"][EQTL_MODEL].glob("*.db"))

In [None]:
assert len(db_files) == 49

In [None]:
tissues = [str(f).split(EQTL_MODEL_FILES_PREFIX)[1].split(".db")[0] for f in db_files]

In [None]:
tissues[:5]

# Average correlations per chromosome

In [None]:
corrs_per_chr = {}

for chr_num in range(1, 23):
    print(f"Chromosome: {chr_num}", flush=True)

    chr_files = all_gene_corr_files_df[all_gene_corr_files_df["chromosome"] == chr_num]
    print(f"Number of corrs files: {chr_files.shape}")

    multiplier_genes_in_chr = genes_info[genes_info["chr"] == str(chr_num)]
    print(f"Number of MultiPLIER genes: {multiplier_genes_in_chr.shape}")

    # create final dataframe with corrs for this chr
    chr_df = pd.DataFrame(
        data=0.0,
        index=multiplier_genes_in_chr["id"],
        columns=multiplier_genes_in_chr["id"],
    )

    print("Reading corrs per tissue", flush=True)
    for idx, tissue_corrs in chr_files.iterrows():
        tissue_corrs_df = pd.read_pickle(tissue_corrs["corr_file"])

        nan_values = tissue_corrs_df.isna()
        if nan_values.any().any():
            print(
                f"  WARNING ({tissue_corrs['tissue']}): has NaN values ({nan_values.sum().sum()})"
            )
            tissue_corrs_df = tissue_corrs_df.fillna(0.0)

        # align
        tissue_corrs_df = tissue_corrs_df.loc[chr_df.index, chr_df.columns]

        chr_df = chr_df + tissue_corrs_df
    #         chr_df = chr_df.where(chr_df.abs() > tissue_corrs_df.abs(), tissue_corrs_df).fillna(chr_df)

    chr_df = chr_df / float(chr_files.shape[0])
    chr_df_flat = pd.Series(squareform(chr_df.values, checks=False))
    display(chr_df_flat.describe())

    corrs_per_chr[chr_num] = chr_df

    print("\n")

# Create full gene correlation matrix

This matrix has all genes in MultiPLIER Z

In [None]:
gene_corrs_df = pd.DataFrame(data=0.0, index=genes_info["id"], columns=genes_info["id"])

In [None]:
gene_corrs_df.shape

In [None]:
gene_corrs_df.head()

In [None]:
for chr_num, chr_data in corrs_per_chr.items():
    chr_data = chr_data.reindex(
        index=gene_corrs_df.index, columns=gene_corrs_df.columns
    )
    gene_corrs_df = gene_corrs_df + chr_data.fillna(0.0)

In [None]:
gene_corrs_df = gene_corrs_df.astype(np.float32)

In [None]:
gene_corrs_df.head()

In [None]:
assert np.all(gene_corrs_df.values.diagonal() == 1.0)

## Stats

In [None]:
_gene_corrs_flat = squareform(gene_corrs_df.values, checks=False)

In [None]:
pd.Series(_gene_corrs_flat).describe()

# Plot

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

In [None]:
genes_order = genes_info.sort_values("chr")["id"].tolist()

In [None]:
cmap = ListedColormap(["w", "r"])

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
ax.matshow(
    gene_corrs_df.loc[genes_order, genes_order].values, vmin=-0.05, vmax=0.05, cmap=cmap
)
ax.set_xlabel("Genes")
ax.set_ylabel("Genes")
ax.set_xticks([])
ax.set_yticks([])

# Testing

In [None]:
# COL4A1 and COL4A2
gene1 = "ENSG00000187498"
gene2 = "ENSG00000134871"

gene_corrs_df.loc[gene1, gene2]

In [None]:
_genes_files = all_gene_corr_files_df[all_gene_corr_files_df["chromosome"] == 13][
    "corr_file"
].tolist()
assert len(_genes_files) == 49

In [None]:
_gene_values = []
for f in _genes_files:
    gene1_gene2_corr = pd.read_pickle(f).loc[gene1, gene2]
    _gene_values.append(gene1_gene2_corr)

In [None]:
_gene_values = np.array(_gene_values)
assert _gene_values.shape[0] == 49

In [None]:
display(_gene_values.mean())
assert gene_corrs_df.loc[gene1, gene2] == _gene_values.mean()

# Save

## With ensemble ids

In [None]:
output_file = conf.PHENOMEXCAN["LD_BLOCKS"][EQTL_MODEL]["GENE_IDS_CORR_AVG"]
display(output_file)

In [None]:
gene_corrs_df.to_pickle(output_file)

## With gene symbols

In [None]:
output_file = conf.PHENOMEXCAN["LD_BLOCKS"][EQTL_MODEL]["GENE_NAMES_CORR_AVG"]
display(output_file)

In [None]:
gene_corrs_gene_names_df = gene_corrs_df.rename(
    index=Gene.GENE_ID_TO_NAME_MAP, columns=Gene.GENE_ID_TO_NAME_MAP
)

In [None]:
assert gene_corrs_gene_names_df.index.is_unique

In [None]:
assert gene_corrs_gene_names_df.columns.is_unique

In [None]:
gene_corrs_gene_names_df.shape

In [None]:
gene_corrs_gene_names_df.head()

In [None]:
gene_corrs_gene_names_df.to_pickle(output_file)