Set path to where `mlruns` directory is located (usually, the `CardiacCOMA` repository)

In [None]:
CARDIAC_COMA_REPO = "/home/rodrigo/01_repos/CardiacCOMA/"
CARDIAC_GWAS_REPO = "/home/rodrigo/01_repos/CardiacGWAS/"

In [None]:
import mlflow
import os, sys

import torch
import torch.nn.functional as F

import os; os.chdir(CARDIAC_COMA_REPO)
from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image
from mlflow.tracking import MlflowClient

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

import model.Model3D
from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

from copy import deepcopy
from typing import List
from tqdm import tqdm
from IPython import embed

In [None]:
from functools import partial

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
from mlflow_helpers import \
    list_artifacts,\
    get_significant_loci,\
    get_metrics_cols, \
    get_params_cols, \
    get_runs_df, \
    get_good_runs,\
    summarize_loci_across_runs,\
    get_model_pretrained_weights

In [None]:
TRACKING_URI = f"file://{CARDIAC_COMA_REPO}/mlruns"
mlflow.set_tracking_uri(TRACKING_URI)

In [None]:
RUNS_CACHED = "../CardiacGWAS/results/runs.csv"

# Select MLflow experiment

In [None]:
client = MlflowClient()

In [None]:
def experiment_selection_widget():
    
    options = [exp.name for exp in mlflow.list_experiments()]

    experiment_w = widgets.Select(
      options=options,
      value=options[1],
      description="Select MLflow experiment"
    )
    
    return experiment_w

exp_w = experiment_selection_widget()

@interact
def get_runs(exp_name=exp_w):  
  #try:
    global exp_id, runs_df
    exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
    _get_runs_df = partial(get_runs_df, sort_by=None)
    runs_df = _get_runs_df(exp_name=exp_name, only_finished=True)
    metrics, params = get_metrics_cols(runs_df), get_params_cols(runs_df)  
    # display(runs_df.loc[:, [*metrics, *params]].drop("params.platform", axis=1).head(10))    
  #except:
#    pass

Retrieve run data from MLflow for the chosen experiment:

In [None]:
exp_name = "Cardiac - ED"
if not os.path.exists(RUNS_CACHED):
    exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
    _get_runs_df = partial(get_runs_df, sort_by=None)
    runs_df = _get_runs_df(exp_name=exp_name, only_finished=True)
    metrics, params = get_metrics_cols(runs_df), get_params_cols(runs_df)
    runs_df.to_csv(RUNS_CACHED)
else:
    runs_df = pd.read_csv(RUNS_CACHED)
    runs_df = runs_df.set_index(["experiment_id", "run_id"])

Retrieve `run_id`'s where average MSE is less than a threshold (in mm$^2$) for the test dataset:

In [None]:
RECON_LOSS_THRES = 2. # performance threshold for MSE mm2.
run_ids = sorted([x for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])

Number of runs, $R$, and number of latent variables tested $n_z=\sum_{r=1}^{R} \dim(\textbf{z}_r)$:

In [None]:
n_runs = len(run_ids)
n_z_total = runs_df.filter(items=run_ids, axis='index').filter(like="param", axis='columns')["params.latent_dim"].sum()
print(f"Number of runs: {n_runs}")
print(f"Total number of latent variables: {n_z_total}")

___

### Widget for selecting run

In [None]:
run_ids_w = widgets.Select(description="Choose run:", options={x[:10]: x for x in run_ids})
display(run_ids_w)
run_id = run_ids_w.value
run_info = runs_df.loc[run_id].to_dict()
artifact_uri = run_info["artifact_uri"].replace("file://", "")

___

Assign gene names to different regions:

In [None]:
import loci_mapping

In [None]:
from loci_mapping import LOCUS_TO_REGION, REGION_TO_LOCUS, LOCI_TO_DROP

In [None]:
def get_significant_loci(
    runs_df,
    experiment_id, run_id, 
    p_threshold=5e-8, 
    client=mlflow.tracking.MlflowClient()
) -> pd.DataFrame:
    
    '''    
    Returns a DataFrame with the loci that have a stronger p-value than a given threshold
    '''
    
    def get_phenoname(path):        
        filename = os.path.basename(path)
        phenoname = filename.split("__")[0]
        return phenoname
        
    run_info = runs_df.loc[(experiment_id, run_id)].to_dict()
    artifact_uri = run_info["artifact_uri"].replace("file://", "")    
           
    gwas_dir_summaries = os.path.join(artifact_uri, "GWAS/summaries")
    # gwas_dir_summaries = os.path.join(artifact_uri, "GWAS_adj_10PCs/summaries")
    
    try:
        summaries_fileinfo = [ os.path.join(gwas_dir_summaries, x) for x in  os.listdir(gwas_dir_summaries) ]
    except:
        summaries_fileinfo = []
    
    # summaries_fileinfo = client._tracking_client.list_artifacts(run_id, path="GWAS_adj_10PCs/summaries")
    # summaries_fileinfo = client._tracking_client.list_artifacts(run_id, path="GWAS/summaries")    
        
    if len(summaries_fileinfo) == 0:
        return pd.DataFrame(columns=["run", "pheno", "region"])
    
    # region_summaries = {get_phenoname(x.path): os.path.join(artifact_uri, x.path) for x in summaries_fileinfo}
    region_summaries = {get_phenoname(x): os.path.join(artifact_uri, x) for x in summaries_fileinfo}
    dfs = [pd.read_csv(path).assign(pheno=pheno) for pheno, path in region_summaries.items()]
    
    df = pd.concat(dfs)
    df['locus_name'] = df.apply(lambda row: REGION_TO_LOCUS.get(row["region"], "Unnamed"), axis=1)
    df = df.set_index(["pheno", "region"])    
    
    df_filtered = df[df.P < p_threshold]
    
    return df_filtered.sort_values(by="P")


def summarize_loci_across_runs(runs_df: pd.DataFrame):

    '''
    Parameters: run_ids
    Return: pd.DataFrame with .
    '''

    # run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])
    run_ids = sorted([x[1] for x in runs_df.index])

    all_signif_loci = []
    
    for run_id in tqdm(run_ids):
        signif_loci_df = \
            get_significant_loci(runs_df, experiment_id=1, run_id=run_id).\
            assign(run=run_id).\
            reset_index().\
            set_index(["run", "pheno", "region"]
        )                
        all_signif_loci.append(signif_loci_df)        
      
    all_signif_loci = pd.concat(all_signif_loci)    
    return all_signif_loci

    df = all_signif_loci.\
      groupby(["region", "locus_name"]).\
      aggregate({"CHR":"count", "P": "min"}).\
      rename({"CHR":"count", "P":"min_P"}, axis=1).\
      sort_values("count", ascending=False)    
    
    return df

In [None]:
ORDER_BY = {"by":"count", "ascending":False}
ORDER_BY = {"by":"min_P", "ascending":True}
ORDER_BY = {"by":"-log10(min_P)", "ascending":False}

all_signif_loci_df = summarize_loci_across_runs(runs_df)

good_runs = [x for x in all_signif_loci_df.index if x[0] in [y[1] for y in run_ids]]
all_signif_loci_df = all_signif_loci_df.filter(items=good_runs, axis="index")

In [None]:
all_signif_loci_df

In [None]:
all_signif_loci_df = all_signif_loci_df.reset_index()

In [None]:
kk = all_signif_loci_df.assign(best_p_for_region=all_signif_loci_df.groupby(['region']).P.transform("min"))
kk = all_signif_loci_df[kk.P == kk.best_p_for_region]
kk.BP = kk.BP.astype(int)
kk.CHR = kk.CHR.astype(int)
kk = kk[~kk.region.isin(LOCI_TO_DROP)]
kk = kk.sort_values("P").head(50)

COLUMNS = ["CHR", "BP", "SNP", "region", "run", "pheno"]
kk = kk[COLUMNS].reset_index(drop=True)
kk.to_csv(f"{CARDIAC_GWAS_REPO}/results/best_z_for_loci.csv", index=False)
kk.head()

In [None]:
filter_for_threshold = []

for exp_id, run_id in run_ids:
    
    try:
        run_df = all_signif_loci_df.loc[run_id]
        run_df = run_df.assign(run_id=run_id)
    except KeyError:
        print(f"Run {run_id} does not have significant loci.")
        pass
    
    filter_for_threshold.append(run_df)
    
filter_for_threshold = pd.concat(filter_for_threshold).reset_index().set_index(["run_id", "pheno", "region"])

In [None]:
loci_summary_df = all_signif_loci_df.\
      reset_index().\
      drop("index", axis=1).\
      groupby(by=["region", "locus_name", "run"]).\
      aggregate({"CHR":"count", "P": "min"}).\
      rename({"CHR":"count", "P":"min_P"}, axis=1).\
      sort_values("count", ascending=False).\
      sort_values("min_P", ascending=True)

In [None]:
loci_summary_df

Retrieve the best _p_-value for each genetic locus:

In [None]:
loci_grouped_df = loci_summary_df.groupby(level=["region", "locus_name"])

In [None]:
min_P_df = loci_grouped_df.\
    min("min_P").\
    drop("count", axis=1).\
    sort_values("min_P")    

min_P_df.min_P = [f"${str(round(float(x[0]), 1))} \times 10^{{{x[1]}}}$" for x in min_P_df.min_P.astype(str).str.split("e")]

min_P_df.head(5)

In [None]:
loci_counts = loci_grouped_df.\
    count().\
    drop(LOCI_TO_DROP).\
    drop("min_P", axis=1).\
    sort_values('count', ascending=False) # / n_runs * 100

loci_counts.head(5)

In [None]:
def create_count_table_tex(tex_file):
    
    with open(tex_file, "wt") as table_f:    
        
        table_code = pd.merge(loci_counts, min_P_df, left_index=True, right_index=True).\
            reset_index().\
            rename({"locus_name": "locus", "min_P": "$p$-value"}, axis=1).\
            to_latex( 
                escape=False,
                index=False
            )
        
        table_code = table_code.replace("_", "\_")
        table_f.write()
    
    return table_code

In [None]:
create_count_table_tex(f"{CARDIAC_GWAS_REPO}/manuscript/tables/gwas_counts.tex")

In [None]:
region_w=widgets.Select(options=sorted(list(set([x[0] for x in loci_summary_df.index]))))

@interact
def examine_locus(region=region_w):
    display(loci_summary_df.loc[region])

In [None]:
def summarize_loci_across_runs(runs_df: pd.DataFrame):

   '''
   Parameters: run_ids
   Return: pd.DataFrame with ["count", "min_P"].
   '''

   # run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])
   run_ids = sorted([x[1] for x in runs_df.index])

   all_signif_loci = pd.concat([
     get_significant_loci(runs_df, "1", run).\
       assign(run=run).\
       reset_index().\
       set_index(["run", "pheno", "region"]) 
     for run in run_ids
   ])
   
   return all_signif_loci

In [None]:
kk = summarize_loci_across_runs(runs_df).reset_index().drop("index", axis=1)
#kk.pheno = kk.apply(lambda x: f"1_{x.run[:5]}_{x.pheno}", axis=1)

In [None]:
z_corr = pd.read_csv("data/cardio/corr_z_vs_indices.csv").set_index("phenotype")

In [None]:
corrs = []

for index, row in pp.sort_values(by="region").iterrows():
    try:
        corrs.append(list(z_corr.loc[row.pheno]))
    except:
        corrs.append([pd.NA]*4)        

In [None]:
corrs_df = pd.DataFrame(corrs, columns=["LVEDV_corr", "LVM_corr", "RVEDV_corr", "LVSph_corr"])
corrs_df.set_index(pp.index)

In [None]:
kk_grouped = pd.concat([kk, corrs_df.abs()], axis=1).groupby("region")

In [None]:
from functools import partial
mean_f = partial(pd.Series.mean, skipna = True)
std_f = partial(pd.Series.std, skipna = True)

In [None]:
counts = kk_grouped.agg("count")["LVEDV_corr"]

In [None]:
phenos =  ["LVEDV", "LVM", "RVEDV", "LVSph"]
corr_per_locus = kk_grouped.aggregate(func={f"{pheno}_corr": [mean_f, std_f] for pheno in phenos})

In [None]:
corr_per_locus["counts"] = counts

In [None]:
corr_per_locus.sort_values(by="counts", ascending=False)

# Statistics on the GWAS loci counts

In [None]:
signif_loci_dfs = {}
dd = []

def loci_count(run_df):
    from collections import Counter
    return dict(Counter([x[1] for x in run_df.index]))

for run in runs_df.index:
    
    try:     
      
      pp = get_significant_loci(runs_df[runs_df["metrics.val_recon_loss"] < 2], exp_id, run[1]) #.sort_values(by=["CHR", "BP"], axis=0)
      n_distinct_loci = len(loci_cnt.keys())
      n_hits_with_duplication = sum(loci_cnt.values())
      
      ff = [  run[1], 
         runs_df.loc[run, "metrics.test_recon_loss"], 
         runs_df.loc[run, "metrics.test_kld_loss"], 
         runs_df.loc[run, "params.latent_dim"], 
         runs_df.loc[run, "params.w_kl"],
         n_distinct_loci, 
         n_hits_with_duplication, 
         n_hits_with_duplication / n_distinct_loci             
      ]
      
      signif_loci_dfs[run[1]] = pp
      loci_cnt = loci_count(signif_loci_dfs[run[1]])
      dd.append(ff)
    except:
      pass

kk = pd.DataFrame(dd)

kk.columns = [
    "run_id",
    "test_mse",
    "kld",    
    "lat_dim",
    "w_kl",
    "n_loci",
    "n_loci_dupl",
    "ratio"    
]

In [None]:
interact(
    lambda xcol, ycol: sns.boxplot(x=xcol, y=ycol, data=kk),
    xcol = widgets.Select(options=kk.columns),
    ycol = widgets.Select(options=kk.columns)
);

In [None]:
@interact
def show_signif_loci(run_id=run_ids_w):
    return get_significant_loci(runs_df, exp_id, run_id)