In [1]:
import pandas as pd
import numpy as np
import pathlib
from tqdm.auto import tqdm

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
from torch_geometric import seed_everything

import ray

In [7]:
fps = !ls "/sc-projects/sc-proj-ukb-cvd/results/models/NeptuneLogger"

In [13]:
fp_df = pd.DataFrame(fps)

In [28]:
fp_df["type"] = fp_df[0].str.split("-").str[0]
fp_df["id"] = fp_df[0].str.split("-").str[1]

In [31]:
df_clean = fp_df.query("type=='MET'").query("id==id").assign(id = lambda x: x.id.astype(int)).query("id<=3600")

In [34]:
df_clean["fp_full"] = "/sc-projects/sc-proj-ukb-cvd/results/models/NeptuneLogger/" + df_clean[0]

In [37]:
import shutil
for p in tqdm(df_clean["fp_full"].to_list()):
    shutil.rmtree(p)

  0%|          | 0/733 [00:00<?, ?it/s]

In [2]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

/sc-projects/sc-proj-ukb-cvd


In [3]:
output_path

'/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data'

## Get Data

In [3]:
from hydra import compose, initialize
from omegaconf import OmegaConf
hydra.core.global_hydra.GlobalHydra().clear()

initialize(config_path="../../ehrgraphs/config")
args = compose(config_name="config", overrides=["datamodule.partition=0", 
                                               "head.kwargs.num_layers=1",
                                                "datamodule.batch_size=1024",
                                               #"setup.data.min_record_counts=100",
                                                "setup.use_data_artifact_if_available=False",
                                                "datamodule/covariates='no_covariates'",
                                                "model=identity",
                                                "datamodule.use_top_n_phecodes=1683",
                                                "datamodule.t0_mode=recruitment"
                                               ])
print(OmegaConf.to_yaml(args))

setup:
  entity: cardiors
  project: RecordGraphs
  group: null
  name: null
  data_root:
    charite-hpc: /sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/artifacts
    eils-hpc: /data/analysis/ag-reils/ag-reils-shared/cardioRS/data/2_datasets_pre/211110_anewbeginning/artifacts
  use_data_artifact_if_available: true
  data:
    drop_shortcut_edges: true
    drop_individuals_without_gp: false
    min_record_counts: 0
  data_identifier: WandBGraphDataNoShortcuts256:latest
  tags:
  - full_data
head:
  model_type: MLP
  dropout: 0.2
  kwargs:
    num_hidden: 256
    num_layers: 1
    detach_clf: false
    initial_dropout: 0.0
datamodule:
  covariates: []
  sampler:
    sampler_type: DummySampler
  batch_size: 1024
  partition: 0
  num_workers: 4
  label_definition:
    all_cause_death: true
    phecodes: true
    custom: []
  t0_mode: recruitment
  use_top_n_phecodes: 1683
  edge_weight_threshold: 0.1
  min_edge_type_fraction: 0.001
  buffer_years: 0.0
  filter_input

In [4]:
def extract_records_events_times(args):
    
    records_list = []
    outcomes_list = []
    
    # prepare extraction
    datamodule, _, _ = setup_training(args)
    
    record_cols = datamodule.record_cols
    label_cols = list(datamodule.label_mapping.keys())
    
    for s in tqdm(["train", "valid", "test"]):
        eids = datamodule.eids[s]
        
        if s=="train":  dataset = datamodule.train_dataloader(shuffle=False, drop_last=False).dataset
        if s=="valid":  dataset = datamodule.val_dataloader().dataset
        if s=="test":  dataset = datamodule.test_dataloader().dataset

        # extract records
        records_temp = pd.DataFrame.sparse.from_spmatrix(dataset.records, index=eids, columns=[f"{c}" for c in record_cols]).rename_axis("eid")
        records_list.append(records_temp)

        # extract exclusion & events
        exclusions_df = pd.DataFrame.sparse.from_spmatrix(dataset.exclusions, index=eids, columns=[f"{c}_prev" for c in label_cols]).rename_axis("eid")
        events_df = pd.DataFrame.sparse.from_spmatrix(dataset.labels_events, index=eids, columns=[f"{c}_event" for c in label_cols]).rename_axis("eid")

        times = dataset.labels_times.todense()
        censorings = dataset.censorings

        no_event_idxs = times == 0
        times[no_event_idxs] = censorings[:, None].repeat(repeats=times.shape[1], axis=1)[no_event_idxs]

        times_df = pd.DataFrame(data=times, index=eids, columns=[f"{c}_time" for c in label_cols]).rename_axis("eid")

        outcomes_temp = pd.concat([exclusions_df, events_df, times_df], axis=1)
        outcomes_list.append(outcomes_temp)
        
    records_df = pd.concat(records_list, axis=0)
    outcomes_df = pd.concat(outcomes_list, axis=0)
        
    return records_df, outcomes_df

In [5]:
from ehrgraphs.training import setup_training
seed_everything(0)

#args = compose(config_name="config", overrides=[f"datamodule.partition={partition}"])
records_df, outcomes_df = extract_records_events_times(args)

Using backend: pytorch


/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/artifacts/WandBGraphDataNoShortcuts256_220411.p
Generating train dataset...
Generating valid dataset...
Using edge types: ['Is a' 'ATC - RxNorm sec up' 'Tradename of' 'Has tradename' 'Constitutes'
 'RxNorm inverse is a' 'RxNorm ing of' 'Dose form group of' 'Has method'
 'Has asso morph' 'Has interprets' 'Interprets of' 'Is descendant of'
 'Is associated with' 'Is ancestor of' 'Asso morph of' 'Method of'
 'Interacts with' 'Is part of' 'Composed of']


  0%|          | 0/3 [00:00<?, ?it/s]

Generating test dataset...


## Write Records

In [6]:
records_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 502460 entries, 1000018 to 1917839
Columns: 68527 entries, OMOP_1000560 to OMOP_998415
dtypes: Sparse[float64, 0](68527)
memory usage: 372.7 MB


In [7]:
for c in tqdm(records_df.columns):
    records_df[c] = records_df[c].astype(bool).sparse.to_dense()

  0%|          | 0/68527 [00:00<?, ?it/s]

In [45]:
records_df = records_df.sort_index()

In [46]:
records_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 502460 entries, 1000018 to 6025198
Columns: 68527 entries, OMOP_1000560 to OMOP_998415
dtypes: bool(68527)
memory usage: 32.1 GB


In [47]:
records_df.reset_index().to_feather(f"{output_path}/baseline_records_220412.feather")

## Write Outcomes

In [10]:
for c in tqdm(outcomes_df.columns):
    if c.endswith("_prev") or c.endswith("_event"):
        outcomes_df[c] = outcomes_df[c].astype(bool).sparse.to_dense()
    if c.endswith("_time"):
        outcomes_df[c] = outcomes_df[c].astype(np.float32)

  0%|          | 0/5052 [00:00<?, ?it/s]

In [48]:
outcomes_df = outcomes_df.sort_index()

In [49]:
outcomes_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 502460 entries, 1000018 to 6025198
Columns: 5052 entries, OMOP_4306655_prev to phecode_240_time
dtypes: bool(3368), float32(1684)
memory usage: 4.7 GB


In [50]:
outcomes_df.reset_index().to_feather(f"{output_path}/baseline_outcomes_220412.feather")

### Outcomes long

In [51]:
endpoints = sorted(outcomes_df.columns.str.replace("_prev|_event|_time", "", regex=True).unique().tolist())

In [52]:
outcomes_long = pd.DataFrame()

In [53]:
outcomes_df_list = []
cols = ["prev", "event", "time"]
for e in tqdm(endpoints):
    temp = outcomes_df[[f"{e}_{c}" for c in cols]].assign(endpoint = e)
    temp.columns = cols + ["endpoint"]
    outcomes_df_list.append(temp)

  0%|          | 0/1684 [00:00<?, ?it/s]

In [54]:
outcomes_long = pd.concat(outcomes_df_list, axis=0)[["endpoint"] + cols].assign(endpoint = lambda x: x.endpoint.astype("category")).reset_index()

In [55]:
outcomes_long.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 846142640 entries, 0 to 846142639
Data columns (total 5 columns):
 #   Column    Dtype   
---  ------    -----   
 0   eid       int64   
 1   endpoint  category
 2   prev      bool    
 3   event     bool    
 4   time      float32 
dtypes: bool(2), category(1), float32(1), int64(1)
memory usage: 12.6 GB


In [56]:
outcomes_long.to_feather(f"{output_path}/baseline_outcomes_long_220412.feather")