In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display
import numpy as np
import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')

In [2]:
# !pip install PyDMD

In [3]:
sys.path.append("..")

from lib import utils as U
from lib.ehr.dataset import load_dataset

In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

In [5]:

output_dir = 'mimic_artefacts'
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [6]:
with U.modified_environ(DATA_DIR=DATA_DIR):
    m3_dataset = load_dataset('M3')
    m4_dataset = load_dataset('M4')
   

In [7]:
from lib.ehr.coding_scheme import DxCCS, DxFlatCCS, DxICD9, DxICD10
from lib.ehr import Subject_JAX
from lib.ehr import StaticInfoFlags

%load_ext autoreload
%autoreload 2

In [8]:
from lib.ehr import OutcomeExtractor, SurvivalOutcomeExtractor
outcome_class = SurvivalOutcomeExtractor

In [9]:
code_scheme = {
    'dx': DxCCS(), # other options 
    'outcome': SurvivalOutcomeExtractor('dx_flatccs_filter_v1')
}

In [10]:
static_info_flags = StaticInfoFlags(gender=True, age=True)

m3_interface = Subject_JAX.from_dataset(m3_dataset, 
                                        code_scheme=code_scheme, 
                                        static_info_flags=static_info_flags,
                                       data_max_size_gb=1)
m4_interface = Subject_JAX.from_dataset(m4_dataset, 
                                        code_scheme=code_scheme, 
                                        static_info_flags=static_info_flags,
                                       data_max_size_gb=1)

In [11]:

def subject_outcome_acquisition_sequence(interface, subject_adms, resolution=7):
    current_time = resolution
    X = [subject_adms[0].outcome[0]]
    for adm in subject_adms[1:]:
        new_outcome = np.maximum(X[-1], adm.outcome[0])
        if adm.admission_time <= current_time:
            X[-1] = new_outcome
            continue

        while current_time < adm.admission_time:
            current_time += resolution
            X.append(new_outcome)

    return X

def outcome_acquisition_sequence(interface, resolution=7):
    return {
        i: subject_outcome_acquisition_sequence(interface, adms, resolution)
        for i, adms in interface.items()
    }

def outcome_acquisition_transition(interface, resolution=7):
    sequence = outcome_acquisition_sequence(interface, resolution)
    X_prev, X_next = [], []
    for i, seq in sequence.items():
        X_prev.extend(seq[:-1])
        X_next.extend(seq[1:])
    return X_prev, X_next, sequence

In [51]:
X_prev, X_next, ack_seq = outcome_acquisition_transition(m3_interface, 30)
X_prev = np.stack(X_prev, axis=1)
X_next = np.stack(X_next, axis=1)

In [45]:
X_prev.shape

In [28]:
from pydmd import MrDMD
from pydmd import DMD
from pydmd import DMDBase
from pydmd.snapshots import Snapshots
from pydmd.utils import compute_tlsq
from pydmd.plotter import plot_eigs

In [25]:
class BatchDMD(DMD):
    def fit(self, X_prev, X_next):
        """
        Compute the Dynamic Modes Decomposition to the input data.

        :param X: the input snapshots.
        :type X: numpy.ndarray or iterable
        """
        self._reset()

        self._snapshots_holder = Snapshots(X_prev)

        n_samples = self.snapshots.shape[1]

        X, Y = compute_tlsq(X_prev, X_next, self._tlsq_rank)
        self._svd_modes, _, _ = self.operator.compute_operator(X_prev, X_next)

        # Default timesteps
        self._set_initial_time_dictionary(
            {"t0": 0, "tend": n_samples - 1, "dt": 1}
        )

        self._b = self._compute_amplitudes()

        return self
    
    
   

In [None]:
import matplotlib.pyplot as plt

def make_plot(X, figsize=(12, 8), title=''):
    """
    Plot of the data X
    """
    plt.figure(figsize=figsize)
    plt.title(title)
    x = np.linspace(0, X.shape[1], X.shape[1])
    t = np.linspace(0, X.shape[0], X.shape[0])
    X = np.real(X)
    CS = plt.pcolor(x, t, X)
    cbar = plt.colorbar(CS)
    plt.xlabel('Space')
    plt.ylabel('Time')
    plt.show()

    


make_plot(X_prev.T)

In [83]:

first_dmd = BatchDMD(svd_rank=100, opt=True, sorted_eigs='abs', tikhonov_regularization=1e-3)
first_dmd.fit(1-X_prev, 1-X_next)

In [84]:
plot_eigs(first_dmd, show_axes=True, show_unit_circle=True)

In [85]:
X_next_recons = first_dmd.predict(1-X_prev)
X_next_recons.shape

In [86]:
constant_one_mask = (X_prev == X_next) & (X_prev == 1)
constant_zero_mask = (X_prev == X_next) & (X_prev == 0)
transition_mask = X_prev != X_next
freq = np.mean(1-X_next, axis=1)


np.mean(constant_one_mask), np.mean(constant_zero_mask), np.mean(transition_mask)

In [87]:
error = np.abs((1-X_next) - X_next_recons.real)
transition_error = np.sum(transition_mask * error, axis=1) / np.sum(transition_mask, axis=1)
transition_error
# mean_transition_error = np.mean(, axis=1)
# error

In [89]:
import matplotlib.pyplot as plt

for mode in first_dmd.modes.T[:3]:
    plt.plot(mode.real)
    plt.title("Modes")
plt.show()

for dynamic in first_dmd.dynamics:
    plt.plot(dynamic.real[:12])
    plt.title("Dynamics")
plt.show()

