In [1]:
interp = "cubic"
datafile = "../data/abbrev.json.bz2"
outfile = "../data/interp-abbrev-cubic.json.bz2"

In [2]:
import bz2
import json

In [3]:
import numpy as np
import pandas as pd
import torch

In [4]:
from tqdm import tqdm

In [5]:
import torchcde

In [6]:
def interp_data(file_name, interpolation="cubic"):
    with bz2.open(file_name, 'rt', encoding="utf-8") as f:
        covid19_data = json.load(f)

    interpolation_data = dict()
    interpolation_data['xdata'] = dict()
    interpolation_data['ydata'] = dict()

    time_index = covid19_data['time_index']
    time_index_df = pd.DataFrame({'time_index': time_index})
    patient_idx = covid19_data['info'].keys()
    interpolation_data['patient_list'] = list(patient_idx)

    for patient_id in tqdm(sorted(patient_idx)):
        x_array = []
        y_array = []

        observation_idx = covid19_data['info'][patient_id].keys()
        for observation_id in sorted(observation_idx):
            duration = covid19_data['outcome'][observation_id]['time']
            event = covid19_data['outcome'][observation_id]['outcome']
            y_array.append([duration, event])
            x = pd.DataFrame(covid19_data['data'][observation_id]).fillna(value=np.nan)
            x = pd.merge_ordered(time_index_df, x, left_on='time_index', right_on=0, fill_method=None)
            x = x.drop(['time_index', 0], axis=1)
            x = x.to_numpy()
            x_mask = (~torch.isnan(torch.Tensor(x))).cumsum(dim=0).cpu()
            x = pd.concat([pd.DataFrame(time_index), pd.DataFrame(x), pd.DataFrame(x_mask.numpy())], axis=1).to_numpy()
            x_array.append(x)

        x_array = torch.Tensor(x_array)
        y_array = torch.Tensor(y_array)

        if interpolation == "linear":
            x_array = torchcde.linear_interpolation_coeffs(x_array)
        else:
            x_array = torchcde.natural_cubic_coeffs(x_array)

        interpolation_data['xdata'][patient_id] = x_array.numpy().tolist()
        interpolation_data['ydata'][patient_id] = y_array.numpy().tolist()

    return interpolation_data

In [7]:
interp_data = interp_data(datafile, interpolation=interp)

100%|██████████| 30/30 [00:21<00:00,  1.38it/s]


In [8]:
with bz2.open(outfile, 'wt', encoding="utf-8") as f:
    json.dump(interp_data, f)