Set path to where `mlruns` directory is located (usually, the `CardiacCOMA` repository)

In [1]:
CARDIAC_COMA_REPO = "/home/rodrigo/01_repos/CardiacCOMA/"

In [216]:
import mlflow
import os, sys

import torch
import torch.nn.functional as F

import os; os.chdir(CARDIAC_COMA_REPO)
from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image
from mlflow.tracking import MlflowClient

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

import model.Model3D
from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

from copy import deepcopy
from typing import List
from tqdm import tqdm
from IPython import embed

In [217]:
from functools import partial

In [218]:
import matplotlib.pyplot as plt
import seaborn as sns

In [219]:
from mlflow_helpers import \
    list_artifacts,\
    get_significant_loci,\
    get_metrics_cols, \
    get_params_cols, \
    get_runs_df, \
    get_good_runs,\
    summarize_loci_across_runs,\
    get_model_pretrained_weights

In [220]:
TRACKING_URI = f"file://{CARDIAC_COMA_REPO}/mlruns"
mlflow.set_tracking_uri(TRACKING_URI)

In [221]:
RUNS_CACHED="../CardiacGWAS/results/runs.csv"

# Select MLflow experiment

In [222]:
client = MlflowClient()

In [223]:
def experiment_selection_widget():
    
    options = [exp.name for exp in mlflow.list_experiments()]

    experiment_w = widgets.Select(
      options=options,
      value=options[1],
      description="Select MLflow experiment"
    )
    
    return experiment_w

exp_w = experiment_selection_widget()

@interact
def get_runs(exp_name=exp_w):  
  #try:
    global exp_id, runs_df
    exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
    _get_runs_df = partial(get_runs_df, sort_by=None)
    runs_df = _get_runs_df(exp_name=exp_name, only_finished=True)
    metrics, params = get_metrics_cols(runs_df), get_params_cols(runs_df)  
    # display(runs_df.loc[:, [*metrics, *params]].drop("params.platform", axis=1).head(10))    
  #except:
#    pass

interactive(children=(Select(description='Select MLflow experiment', index=1, options=('Default', 'Cardiac - E…
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
File ~/anaconda3/envs/cardio/lib/python3.9/site-packages/ipywidgets/widgets/interaction.py:257, in interactive.update(self, *args)
    255     value = widget.get_interact_value()
    256     self.kwargs[widget._kwarg] = value
--> 257 self.result = self.f(**self.kwargs)
    258 show_inline_matplotlib_plots()
    259 if self.auto_display and self.result is not None:

File /tmp/ipykernel_83945/3749540716.py:21, in get_runs(exp_name)
     19 exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
     20 _get_runs_df = partial(get_runs_df, sort_by=None)
---> 21 runs_df = _get_runs_df(exp_name=exp_name, only_finished=True)
     22 metrics, params = get_metrics_cols(runs_df), get_params_cols(runs_df)  

File ~/01_repos/Cardia

In [None]:
Retrieve run data from MLflow for the chosen experiment:

In [238]:
exp_name = "Cardiac - ED"
if not os.path.exists(RUNS_CACHED):
    exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
    _get_runs_df = partial(get_runs_df, sort_by=None)
    runs_df = _get_runs_df(exp_name=exp_name, only_finished=True)
    metrics, params = get_metrics_cols(runs_df), get_params_cols(runs_df)
    runs_df.to_csv(RUNS_CACHED)
else:
    runs_df = pd.read_csv(RUNS_CACHED)
    runs_df = runs_df.set_index(["experiment_id", "run_id"])

Retrieve `run_id`'s where average MSE is less than a threshold (in mm$^2$) for the test dataset:

In [239]:
RECON_LOSS_THRES = 1. # performance threshold for MSE mm2.
run_ids = sorted([x for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])

Number of latent variables tested $n_z=\sum_r \dim(\textbf{z}_r)$:

In [240]:
runs_df.filter(items=run_ids, axis='index').filter(like="param", axis='columns')["params.latent_dim"].sum()

384

In [241]:
run_ids_w = widgets.Select(description="Choose run:", options={x[:10]: x for x in run_ids})
display(run_ids_w)
run_id = run_ids_w.value
run_info = runs_df.loc[run_id].to_dict()
artifact_uri = run_info["artifact_uri"].replace("file://", "")

Select(description='Choose run:', options={(1, '0285fa2356fd454e88e3c30d6b63f163'): (1, '0285fa2356fd454e88e3c…


In [242]:
LOCUS_NAMES = {
      "chr2_108": "TTN",
      "chr6_78": "PLN",
      "chr17_27": "GOSR2",
      "chr5_103": "CREBRF*",
      "chr12_69": "TBX5",
      "chr21_10": "NCSTNP1*",
      "chr1_124": "CHTOP*",
      "chr10_69": "RBM20",
      "chr12_19": "CCDC91*",
      "chr6_20": "HFE*",
      "chr11_2": "LSP1*"
}   

In [243]:
def get_significant_loci(
    runs_df,
    experiment_id, run_id, 
    p_threshold=1e-8, 
    client=mlflow.tracking.MlflowClient()
) -> pd.DataFrame:
    
    '''    
    Returns a DataFrame with the loci that have a stronger p-value than a given threshold
    '''
    
    def get_phenoname(path):        
        filename = os.path.basename(path)
        phenoname = filename.split("__")[0]
        return phenoname
        
    run_info = runs_df.loc[(experiment_id, run_id)].to_dict()
    artifact_uri = run_info["artifact_uri"].replace("file://", "")    
           
    gwas_dir_summaries = os.path.join(artifact_uri, "GWAS/summaries")
    # gwas_dir_summaries = os.path.join(artifact_uri, "GWAS_adj_10PCs/summaries")
    
    try:
      summaries_fileinfo = [ os.path.join(gwas_dir_summaries, x) for x in  os.listdir(gwas_dir_summaries) ]
    except:
      summaries_fileinfo = []
    
    # summaries_fileinfo = client._tracking_client.list_artifacts(run_id, path="GWAS_adj_10PCs/summaries")
    # summaries_fileinfo = client._tracking_client.list_artifacts(run_id, path="GWAS/summaries")    
        
    if len(summaries_fileinfo) == 0:
        return pd.DataFrame(columns=["run", "pheno", "region"])
    
    # region_summaries = {get_phenoname(x.path): os.path.join(artifact_uri, x.path) for x in summaries_fileinfo}
    region_summaries = {get_phenoname(x): os.path.join(artifact_uri, x) for x in summaries_fileinfo}
    dfs = [pd.read_csv(path).assign(pheno=pheno) for pheno, path in region_summaries.items()]
    
    df = pd.concat(dfs)
    df['locus_name'] = df.apply(lambda row: LOCUS_NAMES.get(row["region"], "Unnamed"), axis=1)
    df = df.set_index(["pheno", "region"])    
    
    df_filtered = df[df.P < p_threshold]
    
    #print(df_filtered)
    
    return df_filtered.sort_values(by="P")


def summarize_loci_across_runs(runs_df: pd.DataFrame):

    '''
    Parameters: run_ids
    Return: pd.DataFrame with ["count", "min_P"].
    '''

    # run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])
    run_ids = sorted([x[1] for x in runs_df.index])

    all_signif_loci = []
    
    for run_id in tqdm(run_ids):
        signif_loci_df = \
            get_significant_loci(runs_df, experiment_id=1, run_id=run_id).\
            assign(run=run_id).\
            reset_index().\
            set_index(["run", "pheno", "region"]
        )                
        all_signif_loci.append(signif_loci_df)        
      
    all_signif_loci = pd.concat(all_signif_loci)    
    return all_signif_loci

    df = all_signif_loci.\
      groupby(["region", "locus_name"]).\
      aggregate({"CHR":"count", "P": "min"}).\
      rename({"CHR":"count", "P":"min_P"}, axis=1).\
      sort_values("count", ascending=False)    
    
    return df

In [253]:
all_signif_loci_df.reset_index().assign(exp_id="1").set_index(["exp_id", "run"]).filter(items=run_ids, axis='index')

ValueError: cannot handle a non-unique multi-index!

In [244]:
#runs_df_ = runs_df[runs_df["metrics.recon_loss"].astype(float) < 0.4]
ORDER_BY = {"by":"count", "ascending":False}
ORDER_BY = {"by":"min_P", "ascending":True}
ORDER_BY = {"by":"-log10(min_P)", "ascending":False}

runs_df_ = runs_df[runs_df["params.w_kl"].astype(float) == 0]


all_signif_loci_df = summarize_loci_across_runs(runs_df)
loci_summary_df = all_signif_loci_df.\
      groupby(["region", "locus_name"]).\
      aggregate({"CHR":"count", "P": "min"}).\
      rename({"CHR":"count", "P":"min_P"}, axis=1).\
      sort_values("count", ascending=False).\
      reset_index().\
      set_index("region")
    
loci_summary_df.drop(["chr6_79", "chr6_20", "chr6_24", "chr6_25", "chr6_26"], inplace=True)
    
loci_summary_df["-log10(min_P)"] = loci_summary_df.apply(lambda row: -np.log10(row["min_P"]), axis=1)
loci_summary_df.sort_values(**ORDER_BY, axis=0).head(15)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 231/231 [00:07<00:00, 28.94it/s]


Unnamed: 0_level_0,locus_name,count,min_P,-log10(min_P)
region,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
chr6_78,PLN,186,1.794734e-22,21.746
chr17_27,GOSR2,76,3.926449e-16,15.406
chr2_108,TTN,321,5.321083e-14,13.274
chr12_19,CCDC91*,9,7.585776e-13,12.12
chr11_2,LSP1*,12,1.648162e-12,11.783
chr12_69,TBX5,162,5.520774e-12,11.258
chr21_10,NCSTNP1*,15,8.729714e-12,11.059
chr12_67,Unnamed,30,1.119438e-11,10.951
chr5_103,CREBRF*,44,1.794734e-11,10.746
chr2_23,Unnamed,5,7.030723e-11,10.153


In [194]:
filter_for_threshold = []
for exp_id, run_id in run_ids:
    try:
        run_df = all_signif_loci_df.loc[run_id]
        run_df = run_df.assign(run_id=run_id)
    except KeyError:
        print(f"Run {run_id} does not have significant loci.")
        pass
    
    filter_for_threshold.append(run_df)
    
filter_for_threshold = pd.concat(filter_for_threshold).reset_index().set_index(["run_id", "pheno", "region"])

Run 71631622f0194243837b100bdce5f911 does not have significant loci.


In [196]:
good_runs = [x for x in all_signif_loci_df.index if x[0] in [y[1] for y in run_ids]]
all_signif_loci_df.filter(items=good_runs, axis="index")

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,index,CHR,SNP,BP,AF,a_0,a_1,BETA,SE,T,P,locus_name
run,pheno,region,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
0285fa2356fd454e88e3c30d6b63f163,z007,chr6_78,,6.0,rs11153730,118667522.0,0.49203,T,C,-0.076241,0.007928,-9.6162,7.311391e-22,PLN
0285fa2356fd454e88e3c30d6b63f163,z007,chr6_79,,6.0,rs10872167,118988362.0,0.46056,A,G,-0.071901,0.007982,-9.0084,2.202926e-19,PLN
0285fa2356fd454e88e3c30d6b63f163,z012,chr2_108,,2.0,rs2042995,179558366.0,0.22166,T,C,-0.071122,0.009542,-7.4534,9.332543e-14,TTN
0285fa2356fd454e88e3c30d6b63f163,z004,chr2_108,,2.0,rs2042995,179558366.0,0.22166,T,C,0.069439,0.009544,7.2755,3.531832e-13,TTN
0285fa2356fd454e88e3c30d6b63f163,z006,chr2_108,,2.0,rs2042995,179558366.0,0.22166,T,C,-0.069385,0.009546,-7.2685,3.723917e-13,TTN
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ff7c594accb74ad7947621a6b3e2527a,z002,chr6_79,,6.0,rs11756440,118993642.0,0.47505,C,A,0.046588,0.007990,5.8307,5.573141e-09,PLN
ff7c594accb74ad7947621a6b3e2527a,z011,chr17_27,,17.0,rs117953218,45244074.0,0.13850,T,C,-0.067495,0.011645,-5.7959,6.861199e-09,GOSR2
ff7c594accb74ad7947621a6b3e2527a,z007,chr17_27,,17.0,rs11570508,45228560.0,0.22405,C,A,-0.055592,0.009600,-5.7911,7.059924e-09,GOSR2
ff7c594accb74ad7947621a6b3e2527a,z006,chr6_79,,6.0,rs11756440,118993642.0,0.47505,C,A,-0.045982,0.007994,-5.7521,8.892011e-09,PLN


In [197]:
loci_summary_df = all_signif_loci_df.reset_index().\
      drop("index", axis=1).\
      groupby(["region", "locus_name", "run"]).\
      aggregate({"CHR":"count", "P": "min"}).\
      rename({"CHR":"count", "P":"min_P"}, axis=1).\
      sort_values("count", ascending=False).\
      sort_values("min_P", ascending=True)

In [198]:
region_w=widgets.Select(options=set([x[0] for x in loci_summary_df.index]))
region_w

Select(options=('chr5_40', 'chr6_2', 'chr1_77', 'chr3_63', 'chr21_10', 'chr7_19', 'chr1_15', 'chr1_61', 'chr12…


In [81]:
@interact
def examine_locus(region=widgets.Select(options=set([x[0] for x in loci_summary_df.index]))):
    print(loci_summary_df.loc[region])

interactive(children=(Select(description='region', options=('chr5_40', 'chr6_2', 'chr1_77', 'chr3_63', 'chr21_…
                                             count         min_P
locus_name run                                                  
Unnamed    b6a7eedfc0e84b8b9d1f099ed5f158c4      1  7.133456e-09


In [82]:
def summarize_loci_across_runs(runs_df: pd.DataFrame):

   '''
   Parameters: run_ids
   Return: pd.DataFrame with ["count", "min_P"].
   '''

   # run_ids = sorted([x[1] for x in runs_df[runs_df["metrics.test_recon_loss"] < RECON_LOSS_THRES].index])
   run_ids = sorted([x[1] for x in runs_df.index])

   all_signif_loci = pd.concat([
     get_significant_loci(runs_df, "1", run).\
       assign(run=run).\
       reset_index().\
       set_index(["run", "pheno", "region"]) 
     for run in run_ids
   ])
   
   return all_signif_loci

In [83]:
kk = summarize_loci_across_runs(runs_df).reset_index().drop("index", axis=1)
kk.pheno = kk.apply(lambda x: f"1_{x.run[:5]}_{x.pheno}", axis=1)

KeyError: '1'

In [84]:
z_corr = pd.read_csv("data/cardio/corr_z_vs_indices.csv").set_index("phenotype")

In [85]:
corrs = []

for index, row in pp.sort_values(by="region").iterrows():
    try:
        corrs.append(list(z_corr.loc[row.pheno]))
    except:
        corrs.append([pd.NA]*4)        

NameError: name 'pp' is not defined

In [None]:
corrs_df = pd.DataFrame(corrs, columns=["LVEDV_corr", "LVM_corr", "RVEDV_corr", "LVSph_corr"])
corrs_df.set_index(pp.index)

In [None]:
kk_grouped = pd.concat([kk, corrs_df.abs()], axis=1).groupby("region")

In [None]:
from functools import partial
mean_f = partial(pd.Series.mean, skipna = True)
std_f = partial(pd.Series.std, skipna = True)

In [None]:
counts = kk_grouped.agg("count")["LVEDV_corr"]

In [None]:
phenos =  ["LVEDV", "LVM", "RVEDV", "LVSph"]
corr_per_locus = kk_grouped.aggregate(func={f"{pheno}_corr": [mean_f, std_f] for pheno in phenos})

In [None]:
corr_per_locus["counts"] = counts

In [None]:
corr_per_locus.sort_values(by="counts", ascending=False)

# Statistics on the GWAS loci counts

In [None]:
signif_loci_dfs = {}
dd = []

def loci_count(run_df):
    from collections import Counter
    return dict(Counter([x[1] for x in run_df.index]))

for run in runs_df.index:
    
    try:     
      
      pp = get_significant_loci(runs_df[runs_df["metrics.val_recon_loss"] < 2], exp_id, run[1]) #.sort_values(by=["CHR", "BP"], axis=0)
      n_distinct_loci = len(loci_cnt.keys())
      n_hits_with_duplication = sum(loci_cnt.values())
      
      ff = [  run[1], 
         runs_df.loc[run, "metrics.test_recon_loss"], 
         runs_df.loc[run, "metrics.test_kld_loss"], 
         runs_df.loc[run, "params.latent_dim"], 
         runs_df.loc[run, "params.w_kl"],
         n_distinct_loci, 
         n_hits_with_duplication, 
         n_hits_with_duplication / n_distinct_loci             
      ]
      
      signif_loci_dfs[run[1]] = pp
      loci_cnt = loci_count(signif_loci_dfs[run[1]])
      dd.append(ff)
    except:
      pass

kk = pd.DataFrame(dd)

kk.columns = [
    "run_id",
    "test_mse",
    "kld",    
    "lat_dim",
    "w_kl",
    "n_loci",
    "n_loci_dupl",
    "ratio"    
]

In [None]:
interact(
    lambda xcol, ycol: sns.boxplot(x=xcol, y=ycol, data=kk),
    xcol = widgets.Select(options=kk.columns),
    ycol = widgets.Select(options=kk.columns)
);

In [None]:
@interact
def show_signif_loci(run_id=run_ids_w):
    return get_significant_loci(runs_df, exp_id, run_id)