In [1]:
import pandas as pd
import numpy as np

In [2]:
fname = 'dataset_omop.pkl'

In [3]:
import pickle as pkl

df = pkl.load(open(fname, 'rb'))

In [4]:
# df['measurement_datetime'] = df['measurement_datetime'].apply(pd.to_datetime)

In [5]:
df.iloc[:10]

Unnamed: 0,measurement_datetime,target,super_target,person_id,BP diastolic,BP systolic,Body temperature,Glasgow coma scale,Heart rate,Mean blood pressure,...,race_UNABLE TO OBTAIN,race_UNKNOWN/NOT SPECIFIED,race_WHITE,race_WHITE - BRAZILIAN,race_WHITE - EASTERN EUROPEAN,race_WHITE - OTHER EUROPEAN,race_WHITE - RUSSIAN,race_NaN,age,Respiratory rate avg h-2
0,2182-07-31 04:11:00,0,1,62063393,,,,,,,...,0,0,1,0,0,0,0,0,60.8,
1,2182-07-31 04:45:00,0,1,62063393,,,,,,,...,0,0,1,0,0,0,0,0,60.8,
2,2188-11-12 10:00:00,0,1,62063368,88.0,146.0,36.388901,15.0,,107.333,...,0,0,0,0,0,0,0,0,23.9,
3,2188-11-12 10:30:00,0,1,62063368,95.0,140.0,,,134.0,110.0,...,0,0,0,0,0,0,0,0,23.9,38.0
4,2188-11-12 10:45:00,0,1,62063368,97.0,144.0,,,134.0,112.667,...,0,0,0,0,0,0,0,0,23.9,36.0
5,2188-11-12 11:00:00,0,1,62063368,91.0,140.0,,,134.0,107.333,...,0,0,0,0,0,0,0,0,23.9,37.666667
6,2188-11-12 12:00:00,0,1,62063368,98.0,139.0,97.599998,,135.0,111.667,...,0,0,0,0,0,0,0,0,23.9,38.25
7,2188-11-12 13:00:00,0,1,62063368,108.0,158.0,36.666698,15.0,134.0,124.667,...,0,0,0,0,0,0,0,0,23.9,39.0
8,2188-11-12 13:03:00,0,1,62063368,,,36.7,,,,...,0,0,0,0,0,0,0,0,23.9,35.0
9,2188-11-12 13:30:00,0,1,62063368,105.0,159.0,,,134.0,123.0,...,0,0,0,0,0,0,0,0,23.9,35.0


In [6]:
variable_names = df.columns.values
mat = np.asarray(df)

In [7]:
from datetime import timedelta
from fleming_lib.preprocessing import fill_last_upto

# Fill missing values with last one (up to h time prior to measurement)
df = df.groupby('person_id', group_keys=False).apply(fill_last_upto, h=timedelta(hours=24))

# Sort index
df.sort_index(inplace=True)

[INFO] adding /home/paulroujansky/git/DataForGood/batch4_diafoirus_fleming to sys.path


In [14]:
# Extracting dataframe for each patient
list_df = dict()
patients_id = []

for patient_id, sub_df in df.groupby('person_id'):
    list_df[patient_id] = sub_df
    patients_id.append(patient_id)

In [15]:
import warnings
from datetime import timedelta

class Dataloader():
    
    def __init__(self):
        pass
    
    def load_data(self, df):
        self.df = df
        self.measurement_datetime = df['measurement_datetime']
        self.start_dt = self.measurement_datetime.iloc[0]
        self.end_dt = self.measurement_datetime.iloc[-1]
        
        self.variables = df.columns.values
        
    def make_timeline(self, start_dt=None, end_dt=None, step=timedelta(days=1), window=timedelta(days=1)):
        
        if start_dt is None:
            start_dt = self.start_dt
        if end_dt is None:
            end_dt = self.end_dt
            
        if start_dt > self.end_dt:
            raise ValueError('`start_date` cannot be greater than {}'.format(self.end_dt))
        if end_dt < self.start_dt:
            raise ValueError('`end_date` cannot be smaller than {}'.format(self.start_dt))
        if start_dt > end_dt:
            raise ValueError('`end_date` should be greater than `start_date`')
        
        timeline = []
        
        self.step = step
        self.window = window
        
        delta = end_dt - start_dt

        i = 0
        t = start_dt
        while t < end_dt:
            t = start_dt + i * step
            timeline.append(t)
            i += 1
            
        self.timeline = timeline
        self.n_times = len(self.timeline)
        
    def get_time(self, i):
        if i > self.n_times:
            raise ValueError('Cannot fetch time index {} (max {}).'.format(i, self.n_times - 1))
        return self.timeline[i]
        
    def _make_batch(self, i):
        time = self.get_time(i)
        j = 1
        mask = (self.measurement_datetime > time - j * self.window) & (self.measurement_datetime <= time)
        batch = self.df[mask] 
        if len(batch.index) == 0:
            # If no data is fetched, we go back in time to fetch some until we do.
            while True:
                j += 1
                mask = (self.measurement_datetime > time - j * self.window) & (self.measurement_datetime <= time)
                new_batch = self.df[mask]
                if len(new_batch.index) > 0:
                    warnings.warn('No data between {} and {}. Going back {}.'.format(time, time - j * self.window, self.window))
                    batch = new_batch
                    break
        return batch
    
    def build_batches(self):
        self.batches = [self._make_batch(i) for i in range(self.n_times)]
        self.n_batches = len(self.batches)
        
    def get_batch(self, i):
        if i > self.n_batches:
            raise ValueError('Cannot fetch batch {} (max {}).'.format(i, self.n_batches - 1))
        return self.batches[i]
    
    def batch_to_matrix(self, batches):
        if not isinstance(batches, list):
            return batches.as_matrix()
        else:
            return [batch.as_matrix() for batch in self.batches]

In [16]:
dataloaders = dict()

for patient_id, sub_df in list_df.items():
    
    dataloader = Dataloader()

    # Load dataset
    dataloader.load_data(list_df[patient_id])

    # Build timeline
    step = timedelta(days=1)  # get batch every 'step'
    window = timedelta(days=1)  # batch size of 'window'
    dataloader.make_timeline(step=step, window=window)

    # Build batches
    dataloader.build_batches()
    print('n_batches: {}'.format(dataloader.n_times))
    
    dataloaders[patient_id] = dataloader

n_batches: 12
n_batches: 5
n_batches: 2


In [18]:
dataloaders[patients_id[0]].get_batch(1)

Unnamed: 0,measurement_datetime,target,super_target,person_id,BP diastolic,BP systolic,Body temperature,Glasgow coma scale,Heart rate,Mean blood pressure,...,race_UNABLE TO OBTAIN,race_UNKNOWN/NOT SPECIFIED,race_WHITE,race_WHITE - BRAZILIAN,race_WHITE - EASTERN EUROPEAN,race_WHITE - OTHER EUROPEAN,race_WHITE - RUSSIAN,race_NaN,age,Respiratory rate avg h-2
3,2188-11-12 10:30:00,0,1,62063368,95,140,36.3889,15,134,110,...,0,0,0,0,0,0,0,0,23.9,38
4,2188-11-12 10:45:00,0,1,62063368,97,144,36.3889,15,134,112.667,...,0,0,0,0,0,0,0,0,23.9,36
5,2188-11-12 11:00:00,0,1,62063368,91,140,36.3889,15,134,107.333,...,0,0,0,0,0,0,0,0,23.9,37.6667
6,2188-11-12 12:00:00,0,1,62063368,98,139,97.6,15,135,111.667,...,0,0,0,0,0,0,0,0,23.9,38.25
7,2188-11-12 13:00:00,0,1,62063368,108,158,36.6667,15,134,124.667,...,0,0,0,0,0,0,0,0,23.9,39
8,2188-11-12 13:03:00,0,1,62063368,108,158,36.7,15,134,124.667,...,0,0,0,0,0,0,0,0,23.9,35
9,2188-11-12 13:30:00,0,1,62063368,105,159,36.7,15,134,123,...,0,0,0,0,0,0,0,0,23.9,35
10,2188-11-12 14:00:00,0,1,62063368,107,157,36.7,15,137,123.667,...,0,0,0,0,0,0,0,0,23.9,33.6667
11,2188-11-12 14:45:00,0,1,62063368,101,149,36.7,15,147,117,...,0,0,0,0,0,0,0,0,23.9,32.3333
12,2188-11-12 15:00:00,0,1,62063368,100,147,36.7,15,146,115.667,...,0,0,0,0,0,0,0,0,23.9,34.75
