## Imports and functions

In [None]:
! pip install --user ~/ml4c3

import os
import h5py
import socket
import pprint
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Union, Dict

import seaborn as sns
from matplotlib import pyplot as plt

from ml4c3.tensormap.TensorMap import TensorMap, update_tmaps
from ml4c3.definitions.globals import TENSOR_EXT
from ml4c3.definitions.sts import STS_DATA_CSV

pp = pprint.PrettyPrinter(indent=4)

%matplotlib inline

def get_fpaths(dirpath: str, ext: str = TENSOR_EXT) -> list:
    fpaths = []
    for root, dirs, files in os.walk(dirpath):
        for fname in files:
            if not fname.endswith(TENSOR_EXT):
                continue
            else:
                fpaths.append(os.path.join(root, fname))
    print(f"Found {len(fpaths)} {TENSOR_EXT} files at {dirpath}")
    return fpaths

def get_path_to_ecgs() -> str:
    """Check the hostname of the machine and return the appropriate path.
    If there is no match found, this function does not return anything, and
    the script ends up with a non-viable path prefix to HD5 files and will fail."""
    if "mithril" == socket.gethostname():
#         return "/storage/shared/ecg"
        return "/media/2tb/ecg"
    elif "anduril" == socket.gethostname():
        return "/storage/shared/ecg"
    elif "stultzlab" in socket.gethostname():
        return "/storage/shared/ecg"
    
print("Built functions!")

In [None]:
fpaths = get_fpaths(dirpath=os.path.join(get_path_to_ecgs(), "mgh"))

## Get list of STS MRNs (strings)

In [None]:
df = pd.read_csv(STS_DATA_CSV)
sts_mrns = df['medrecn'].to_list()
sts_mrns = [float(mrn) for mrn in sts_mrns]
print(f"Extracted {len(sts_mrns)} MRNs from {STS_DATA_CSV}")

## Isolate MRNs from HD5 paths and convert to set

In [None]:
# Isolate MRNs from fpaths
mrn_hd5s = []
for fpath in tqdm(fpaths):
    if "bad_mrn" not in fpath:
        mrn = os.path.split(fpath)[1].replace(TENSOR_EXT, "")
        mrn = float(mrn)
        mrn_hd5s.append(mrn)
mrn_hd5s = set(mrn_hd5s)

print(f"Isolated MRNs from {len(mrn_hd5s)} paths, convert to floats, and saved in a big set")

## Get intersect between MRNs from STS and MRNs from HD5 paths

In [None]:
# Get path prefix to HD5 data
path_prefix = os.path.split(fpath)[0]

# Iterate through STS MRNs and check if it is in ECG fpath list; if yes, append to list of paths
fpaths_matches = []
for mrn in tqdm(sts_mrns):
    if mrn in mrn_hd5s:
        fpath_match = os.path.join(path_prefix, str(int(mrn)) + TENSOR_EXT)
        fpaths_matches.append(fpath_match)

# Convert list to set then back to list to eliminate duplicates
fpaths_matches = list(set(fpaths_matches))

print(f"Found {len(fpaths_matches)} paths to ECGs that have an STS MRN")

## Define list several STS TMaps and build them

In [None]:
needed_tensor_maps = [
    "age",
    "classnyh",
    "chf",
]
tmaps = {}
for tmap_name in needed_tensor_maps:
    tmaps = update_tmaps(tmap_name=tmap_name, tmaps=tmaps)
    print(f"Successfully created tensor map {tmaps[tmap_name].name} with shape {tmaps[tmap_name].shape}")

tmaps = [tmaps[tm] for tm in tmaps if tmaps[tm].name in needed_tensor_maps]

## Initialize dict of empty lists in which to store tensors returned by TMaps, iterate through fpaths, use TMap to get tensors, and append to dict

In [None]:
tensors = dict()
for tm in needed_tensor_maps:
    tensors[tm] = []

for fpath in tqdm(fpaths_matches[0:5]):
    print("\n")
    with h5py.File(fpath, "r") as hf:
        print(fpath)
        for tm in tmaps:
            print(f"{tm.name}")
            try:
                tensor = tm.tensor_from_file(tm=tm, hd5=hf)
                print(f"\t{tensor}")
            except:
                print(f"\tFail!")