# EHRSHOT Dataset Preprocessing (S2P2 Paper)
This notebook includes code for preparing the EHRSHOT event sequence dataset from the raw [EHRSHOT dataset](https://som-shahlab.github.io/ehrshot-website/), where medical services and procedures are treated as marks, as identified by _Current Procedural Terminology_ (CPT-4) codes.

This version of dataset was originally used in evaluating the [State-Space Point Process (S2P2)](https://openreview.net/pdf?id=74SvE2GZwW) model. Note that we cannot distribute the raw data (or derivative dataset) under the terms of the original EHRSHOT dataset. The access to data can be applied [here](https://stanford.redivis.com/datasets/53gc-8rhx41kgt).

In [1]:
import pandas as pd
import numpy as np
from collections import defaultdict
import heapq
from tqdm import tqdm
from easy_tpp.utils import set_seed
import random
import json

### 0. Load data and check if it's complete

In [4]:
path_to_data_csv = '../data/EHRSHOT/EHRSHOT_ASSETS/data/ehrshot.csv'
path_to_splits_csv = '../data/EHRSHOT/EHRSHOT_ASSETS/splits/person_id_map.csv'
df_dataset = pd.read_csv(path_to_data_csv)
df_split = pd.read_csv(path_to_splits_csv)

  df_dataset = pd.read_csv(path_to_data_csv)


In [5]:
# check if the same data as the original repo: https://github.com/som-shahlab/ehrshot-benchmark/blob/main/ehrshot/stats.ipynb
print("# of events:", df_dataset.shape[0])
print("# of patients:", df_dataset['patient_id'].nunique())
print("# of visits:", df_dataset['visit_id'].nunique())
print("# of train patients", df_split[df_split['split'] == 'train']['omop_person_id'].nunique())
print("# of val patients", df_split[df_split['split'] == 'val']['omop_person_id'].nunique())
print("# of test patients", df_split[df_split['split'] == 'test']['omop_person_id'].nunique())

# of events: 41661637
# of patients: 6739
# of visits: 921499
# of train patients 2295
# of val patients 2232
# of test patients 2212


### 1. Get event times for visit occurrence

In [7]:
df_visit = df_dataset[df_dataset['omop_table'] == 'visit_occurrence']
df_visit.loc[:, 'start'] = pd.to_datetime(df_visit['start']).apply(lambda x: int(round(x.timestamp())))
df_visit_time = df_visit[['patient_id', 'start']].drop_duplicates(keep=False)
df_visit_time = df_visit_time.groupby(['patient_id'])['start'].apply(lambda x: sorted(list(set(x)))).reset_index(name='timestamp')
visit_dict = pd.Series(df_visit_time.timestamp.values, index=df_visit_time.patient_id).to_dict()
patient_visit = df_visit_time['patient_id'].to_numpy()

### 2. Get CPT4 codes that have at least 100 frequencies

In [14]:
df_cpt4 = df_dataset[df_dataset['code'].str.contains('CPT4', case=False, na=False)]
mark_val, mark_count = np.unique(df_cpt4.loc[:,'code'].to_numpy(), return_counts=True)

mark_mask = (mark_count >= 100)
print(f'Number of marks after filtering: {sum(mark_mask)}')

Number of marks after filtering: 668


In [15]:
mark_val = mark_val[mark_mask]
mark_val_set = set(mark_val)
df_cpt4_subset = df_cpt4[df_cpt4['code'].isin(mark_val_set)][['patient_id', 'start', 'code']]
df_cpt4_subset['start'] = pd.to_datetime(df_cpt4_subset.loc[:,'start']).apply(lambda x: int(round(x.timestamp())))
df_cpt4_subset['code'] = df_cpt4_subset['code'].astype('category').cat.codes
mark_val_subset, mark_count_subset = np.unique(df_cpt4_subset.loc[:,'code'].to_numpy(), return_counts=True)
mark_count_dict = dict(zip(mark_val_subset, mark_count_subset))
patient_cpt4 = df_cpt4_subset['patient_id'].unique()

### 3. Generate event sequences

In [16]:
def sample_event_times(real_event_time, std, size):
    sampled_times = np.random.normal(real_event_time, scale=std, size=size)
    # resample if not all non-negative, might be updated
    while not np.all(sampled_times > 0):
        sampled_times = np.random.normal(real_event_time, scale=std, size=size)
    return sampled_times

In [19]:
time_norm = 60 * 60  # in seconds
min_events = 5
max_marks_per_time = 10
padding_events = 668
all_sequences = []
idx = 0
set_seed(123)


for patient in tqdm(patient_cpt4):
    patient = int(patient)
    data = df_cpt4_subset[df_cpt4_subset.patient_id == patient]
    if len(data) < 5 or len(data.start.unique()) < 2:
        continue
    events = list(zip(data['start'], data['code']))
    sorted_unique_times = sorted(data.start.unique())
    if not len(np.diff(sorted_unique_times)):
        print(len(data))
        print(len(events))
    min_diff = min(np.diff(sorted_unique_times))  # minimum time between two consecutive events

    event_dict = defaultdict(list)
    base_time = int(sorted_unique_times[0])
    for t, m in events:
        event_dict[(t - base_time)/time_norm].append(m)

    std = min(min_diff/time_norm, 1) / 10  # std. for Normal distribution to jitter event times
    event_times = []
    event_marks = []
    for t in sorted_unique_times:
        t = (t - base_time)/time_norm
        v = event_dict[t]
        if len(v) > max_marks_per_time:  # choose mark by frequency
            h = []
            for mark in v:
                if len(h) < max_marks_per_time:
                    heapq.heappush(h, (mark_count_dict[mark], mark))
                else:
                    heapq.heappushpop(h, (mark_count_dict[mark], mark))
            v = [x[1] for x in h]
        else:
            np.random.shuffle(v)

        sampled_times = sample_event_times(t, std, min(max_marks_per_time, len(v)) - 1)
        times = sorted([t] + list(sampled_times))
        times = [float(t) for t in times]
        event_times.extend(times)
        event_marks.extend(v)
        assert len(v) <= max_marks_per_time
    assert(len(event_times) == len(event_marks))
    assert(min_events <= len(event_times))

    # padding the start and end of sequences to have padding events
    event_marks[0] = padding_events
    event_marks[-1] = padding_events

    all_sequences.append(
        {
            'dim_process': padding_events,
            'seq_idx': idx,
            'seq_len': len(event_times),
            'time_since_start': event_times,
            'time_since_last_event': [0] + [event_times[i+1] - event_times[i] for i in range(len(event_times) - 1)],
            'type_event': event_marks,
        }
    )
    idx += 1
print(len(all_sequences))  # 6183

100%|██████████| 6634/6634 [00:08<00:00, 786.48it/s] 

6183





In [20]:
test_pct, valid_pct, train_pct = 0.15, 0.15, 0.7
test_seqs, valid_seqs, train_seqs = [], [], []

random.shuffle(all_sequences)
for i, seq in enumerate(all_sequences):
    progress = (i + 1) / len(all_sequences)
    if progress <= test_pct:
        test_seqs.append(seq)
    elif progress <= test_pct + valid_pct:
        valid_seqs.append(seq)
    else:
        train_seqs.append(seq)

print(f'test: {len(test_seqs)}')
print(f'valid: {len(valid_seqs)}')
print(f'train: {len(train_seqs)}')

test: 927
valid: 927
train: 4329


In [21]:
# # Save results
# with open('./ehrshot_cpt4/train.json', 'w') as f:
#     json.dump(train_seqs, f)
#
# with open('./ehrshot_cpt4/dev.json', 'w') as f:
#     json.dump(valid_seqs, f)
#
# with open('./ehrshot_cpt4/test.json', 'w') as f:
#     json.dump(test_seqs, f)