In [1]:
import os
HOME = os.environ["HOME"]
os.environ["CARDIAC_GWAS_REPO"] = CARDIAC_GWAS_REPO = f"{HOME}/01_repos/CardiacGWAS"
os.environ["CARDIAC_COMA_REPO"] = CARDIAC_COMA_REPO = f"{HOME}/01_repos/CardiacCOMA/"
os.environ["GWAS_REPO"] = GWAS_REPO = f"{HOME}/01_repos/GWAS_pipeline/"

MLRUNS_DIR = f"{CARDIAC_COMA_REPO}/mlruns"
#os.chdir(CARDIAC_COMA_REPO)

In [2]:
import mlflow
from mlflow.tracking import MlflowClient

import os, sys

import torch
import torch.nn.functional as F

from CardiacCOMA.config.cli_args import overwrite_config_items
from CardiacCOMA.config.load_config import load_yaml_config, to_dict
from CardiacCOMA.utils.helpers import get_datamodule, get_lightning_module
from CardiacCOMA.utils.mlflow_helpers import get_model_pretrained_weights
from CardiacCOMA.utils.CardioMesh.CardiacMesh import transform_mesh

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image, display, Markdown, clear_output

import pandas as pd
import shlex
from subprocess import check_output

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
from IPython import embed
sys.path.insert(0, '..')

from copy import deepcopy
from pprint import pprint

from typing import List
from tqdm import tqdm

import pandas as pd

import pyvista as pv
from ipywidgets import interact, interactive, fixed, interact_manual

from auxiliary import load_data
from auxiliary import get_model_pretrained_weights

In [3]:
good_runs_df = pd.read_csv(f"{CARDIAC_GWAS_REPO}/results/good_runs.csv")
run_ids = good_runs_df.run_id.to_list()

In [4]:
def load_data(b):
    global meshes, procrustes_transforms
    print("Loading mesh data...")
    meshes = pkl.load(open(f"{CARDIAC_COMA_REPO}/data/cardio/LV_meshes_at_ED_35k.pkl", "rb"))
    print("Mesh data loaded successfully.")
    
    print("Loading Procrustes transforms...")
    procrustes_transforms = pkl.load(open(f"{CARDIAC_COMA_REPO}/data/cardio/procrustes_transforms_35k.pkl", "rb"))
    print("Procrustes transform loaded successfully.")
    
button = widgets.Button(
    description='Load data',
    disabled=False,
    tooltip='Click me',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)

button.on_click(load_data)
button    

Button(description='Load data', icon='check', style=ButtonStyle(), tooltip='Click me')

Loading mesh data...
Mesh data loaded successfully.
Loading Procrustes transforms...
Procrustes transform loaded successfully.


In [5]:
import random
pv.set_plot_theme("document")

faces, _ = pkl.load(open(f"{CARDIAC_COMA_REPO}/data/cardio/faces_and_downsampling_mtx_frac_0.1_LV.pkl", "rb")).values()
faces = np.c_[np.ones(faces.shape[0]) * 3, faces].astype(int)

color_palette = list(pv.colors.color_names.values())
random.shuffle(color_palette)

# Select `run_id` / `z` variable

In [None]:
from PIL import Image

In [None]:
exp_id = '1'
STEP=0.01
run_id_w = widgets.Select(options=sorted(run_ids))
display(run_id_w)

filename = "latent_vector.csv"
df = pd.read_csv(f"{MLRUNS_DIR}/{exp_id}/{run_id}/artifacts/output/{filename}")
df = df.set_index("ID")

z_w = widgets.SelectionSlider(options=[f"z{str(i).zfill(3)}" for i in range(df.shape[1])])
display(z_w)

quantile_range_w=widgets.FloatSlider(min=0, max=0.99, step=STEP)
display(quantile_range_w)

button = widgets.Button(description="Plot mesh")

out = widgets.Output()


def plot_mesh(b):
    
    with out:

        z = z_w.value
        run_id = run_id_w.value
        quantile_range = quantile_range_w.value
        z_bounds = df.quantile([quantile_range_w.value, quantile_range_w.value+STEP])[z]                
        
        ids = list(df[ 
            (z_bounds[quantile_range_w.value] < df[z]) & (df[z] < z_bounds[quantile_range_w.value+STEP])
        ][z].index) 
        
        print(len(ids))
                
        clear_output()      
        print(z, run_id)
        manhattan_file = f"{CARDIAC_COMA_REPO}/mlruns/1/{run_id}/artifacts/GWAS_adj_10PCs/figures/GWAS__{z}__1_{run_id}__manhattan.png"
        qq_file = f"{CARDIAC_COMA_REPO}/mlruns/1/{run_id}/artifacts/GWAS_adj_10PCs/figures/GWAS__{z}__1_{run_id}__QQ-plot.png"
        
        # COMPUTE AVERAGE MESH                        
        avg_mesh = np.array([
            transform_mesh(
                meshes[str(id)], 
                **procrustes_transforms[str(id)]
            ) for id in ids
        ]).mean(axis=0)

                      
        pl = pv.Plotter(notebook=True, off_screen=False, polygon_smoothing=False)
        mesh = pv.PolyData(avg_mesh, faces)
        pl.add_mesh(mesh, show_edges=False, point_size=1.5, color=color_palette[0], opacity=0.5)
        pv.camera
        pl.show(interactive=True, interactive_update=True)
        
        display(Image.open(manhattan_file))
        # display(Image.open(qq_file))
        
button.on_click(plot_mesh)
# displaying button and its output together
widgets.VBox([button,out])        

In [6]:
pl = pv.Plotter(notebook=True, off_screen=False, polygon_smoothing=False)

# Select genetic locus

In [7]:
gwas_harmonized_pattern = "data/other_gwas/preprocessed_files/{prefix}__{phenotype}.tsv"
gwas_selected_snps_pattern = "data/other_gwas/preprocessed_files/{prefix}__{phenotype}__selected_snps.tsv"
COMA_GWAS_SUMMARY = "results/gwas_loci_summary_across_runs.csv"
LOGP_PATH = "results/log10p_for_selected_snps_across_gwas.csv"

gwas_loci_summary_across_runs_df = pd.read_csv(COMA_GWAS_SUMMARY)

# get index of best locus/variable
idx = gwas_loci_summary_across_runs_df.groupby(["region"])["P"].transform(min) == gwas_loci_summary_across_runs_df["P"]

best_association_per_region = gwas_loci_summary_across_runs_df[idx].sort_values("region")
best_snps = set(best_association_per_region.SNP)

regions = { 
    f"{assoc[1].region} ({assoc[1].P:.1e})": assoc[1].region 
    for assoc in best_association_per_region.sort_values("P").iterrows() 
} 

del regions["chr6_79 (4.5e-20)"]

_best_association_per_region = best_association_per_region.set_index("region")
# assoc = _best_association_per_region.loc[region]
#run_id, z_variable = assoc.run, assoc.pheno[-4:]

In [8]:
N_STEPS = 100
STEP = 1 / N_STEPS
quantiles = np.arange(N_STEPS+1)/N_STEPS

quantile_ranges = [(quantiles[i], quantiles[i+1]) for i, q in enumerate(quantiles[:-1])]

In [42]:
from PIL import Image
import imageio

def merge_pngs(pngs, output_png, how):
    # https://www.tutorialspoint.com/python_pillow/Python_pillow_merging_images.htm
    
    # Read images    
    images = [Image.open(png) for png in pngs]    
    
    x_sizes = [image.size[0] for image in images]
    y_sizes = [image.size[1] for image in images]
    
    if how == "vertically":      
      y_size = sum(y_sizes)  
      x_size = images[0].size[0]      
      y_sizes.insert(0, 0)      
      y_positions = np.cumsum(y_sizes[:-1])    
      positions = [(0, y_position) for y_position in y_positions]
    
    elif how == "horizontally":
      x_size = sum(x_sizes)      
      y_size = images[0].size[1]      
      x_sizes.insert(0, 0)      
      x_positions = np.cumsum(x_sizes[:-1])    
      positions = [(x_position, 0) for x_position in x_positions]
    
    
    new_image = Image.new(
        mode='RGB',
        size=(x_size, y_size),
        color=(250, 250, 250)
    )
    
    for i, image in enumerate(images):        
        new_image.paste(image, positions[i])
        
    new_image.save(output_png, "PNG")

In [63]:
exp_id = '1'
# region = "chr17_27"
region = "chr12_69"
region = "chr6_78"
region = "chr2_108"
region = "chr12_19"
assoc = _best_association_per_region.loc[region]
run_id, z = assoc.run, assoc.pheno[-4:]
filename = "latent_vector.csv"
df = pd.read_csv(f"{MLRUNS_DIR}/{exp_id}/{run_id}/artifacts/output/{filename}")
df = df.set_index("ID")

q_ranges = [(0.00, 0.01), (0.05, 0.10), (0.45, 0.55), (0.90, 0.95), (0.99, 1.0)]
filenames = []

for q0, q1 in q_ranges:  
  
  z_bounds = df.quantile([q0, q1])[z]                
      
  ids = list(df[ 
      (z_bounds[q0] < df[z]) & (df[z] < z_bounds[q1])
  ][z].index) 
    
  scaled_meshes = []
  msd = []
  
  for id in ids:
      mesh = transform_mesh(
          meshes[str(id)], 
          **procrustes_transforms[str(id)]
      )
      
      # mean-squared deviation
      msd.append(np.array([np.sqrt((p**2).sum()) for p in mesh]).mean())
      scaled_meshes.append(mesh)
  
  scaled_meshes = np.array(scaled_meshes)
  msd = np.array(msd)
          
  avg_mesh = np.array([scaled_meshes[i] / msd[i] for i, _ in enumerate(scaled_meshes)]).mean(0)
  avg_mesh = avg_mesh * msd.mean()
    
  filename = f"{region}_{q0}-{q1}.png"
  print(filename)
  filenames.append(filename)  
    
  pv.set_plot_theme("document")
  pl = pv.Plotter(off_screen=True, notebook=False)
  
  pl.camera.position = (300, 0.0, 0.0)
  pl.camera.azimuth = 95
  
  mesh = pv.PolyData(avg_mesh, faces)
  pl.add_mesh(mesh, show_edges=False, point_size=1.5, color=color_palette[0], opacity=0.5)
  
  pl.screenshot(filename);
  

chr12_19_0.0-0.01.png
chr12_19_0.05-0.1.png
chr12_19_0.45-0.55.png
chr12_19_0.9-0.95.png
chr12_19_0.99-1.0.png


In [64]:
merge_pngs(filenames, output_png=f"{region}.png", how="horizontally")

In [51]:
exp_id = '1'
# region = "chr17_27"
region = "chr12_69"
assoc = _best_association_per_region.loc[region]
run_id, z = assoc.run, assoc.pheno[-4:]
filename = "latent_vector.csv"
df = pd.read_csv(f"{MLRUNS_DIR}/{exp_id}/{run_id}/artifacts/output/{filename}")
df = df.set_index("ID")

for q0, q1 in quantile_ranges:
  
    z_bounds = df.quantile([q0, q1])[z]                
    
    ids = list(df[ 
        (z_bounds[q0] < df[z]) & (df[z] < z_bounds[q1])
    ][z].index) 
    
    scaled_meshes = []
    msd = []
    
    for id in ids:
        mesh = transform_mesh(
            meshes[str(id)], 
            **procrustes_transforms[str(id)]
        )
        
        # mean-squared deviation
        msd.append(np.array([np.sqrt((p**2).sum()) for p in mesh]).mean())
        scaled_meshes.append(mesh)
    
    scaled_meshes = np.array(scaled_meshes)
    msd = np.array(msd)
            
    avg_mesh = np.array([scaled_meshes[i] / msd[i] for i, _ in enumerate(scaled_meshes)]).mean(0)
    avg_mesh = avg_mesh * msd.mean()
    
    pl = pv.Plotter(notebook=True, off_screen=False, polygon_smoothing=False)
    mesh = pv.PolyData(avg_mesh, faces)
    pl.add_mesh(mesh, show_edges=False, point_size=1.5, color=color_palette[0], opacity=0.5)
    
    clear_output()
    print()
    pl.show(interactive=True, interactive_update=True)




ViewInteractiveWidget(height=768, layout=Layout(height='auto', width='100%'), width=1024)

KeyboardInterrupt: 

In [None]:
exp_id = '1'
STEP = 0.01

region_w = widgets.Select(options=regions, description="Locus: \n",)
display(region_w)
region = region_w.value

# _best_association_per_region = best_association_per_region.set_index("region")
assoc = _best_association_per_region.loc[region]
run_id, z = assoc.run, assoc.pheno[-4:]

quantile_inf_w=widgets.FloatSlider(min=0.00, max=0.99, step=STEP)
display(quantile_inf_w)    

button = widgets.Button(description="Plot mesh")
out = widgets.Output()

def plot_mesh(b):
    
    filename = "latent_vector.csv"
    df = pd.read_csv(f"{MLRUNS_DIR}/{exp_id}/{run_id}/artifacts/output/{filename}")
    df = df.set_index("ID")
    
    # z_w = widgets.Select(options=[f"z{str(i).zfill(3)}" for i in range(df.shape[1])])    
    
    with out:

        quantile_range = quantile_inf_w.value
        z_bounds = df.quantile([quantile_range, STEP+quantile_range])[z]                
        
        ids = list(df[ 
            (z_bounds[quantile_range] < df[z]) & (df[z] < z_bounds[quantile_range+STEP])
        ][z].index)    
                
        # COMPUTE AVERAGE MESH                        
        avg_mesh = np.array([
            transform_mesh(
                meshes[str(id)], 
                **procrustes_transforms[str(id)]
            ) for id in ids
        ]).mean(axis=0)

        clear_output()        
        
        pl = pv.Plotter(notebook=True, off_screen=False, polygon_smoothing=False)
        mesh = pv.PolyData(avg_mesh, faces)
        pl.add_mesh(mesh, show_edges=False, point_size=1.5, color=color_palette[0], opacity=0.5)
        pl.show(interactive=True, interactive_update=True)
        
button.on_click(plot_mesh)
# displaying button and its output together
widgets.VBox([button,out])    