# Loading packages

In [1]:
import os
import sys
import argparse
import datetime
from functools import partial

import numpy as np
import torch
import optuna
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter

# load model
from latent_ode.trainer_glunet_interpol import LatentODEWrapper
from latent_ode.eval_glunet import test

# utils for darts
from utils.darts_training import print_callback
from utils.darts_dataset import SamplingDatasetDual, SamplingDatasetInferenceDual
from utils.darts_processing import load_data, reshuffle_data

# Loading and visualizing data

In [2]:
formatter, series, scalers = load_data(seed=0, 
                                       study_file=None, 
                                       dataset='hall',
                                       use_covs=True, 
                                       cov_type='dual',
                                       use_static_covs=True)

--------------------------------
Loading column definition...
Checking column definition...
Loading data...
Dropping columns / rows...
Checking for NA values...
Setting data types...
Dropping columns / rows...
Encoding data...
	Updated column definition:
		id: REAL_VALUED (ID)
		time: DATE (TIME)
		gl: REAL_VALUED (TARGET)
		Age: REAL_VALUED (STATIC_INPUT)
		BMI: REAL_VALUED (STATIC_INPUT)
		A1C: REAL_VALUED (STATIC_INPUT)
		FBG: REAL_VALUED (STATIC_INPUT)
		ogtt.2hr: REAL_VALUED (STATIC_INPUT)
		insulin: REAL_VALUED (STATIC_INPUT)
		hs.CRP: REAL_VALUED (STATIC_INPUT)
		Tchol: REAL_VALUED (STATIC_INPUT)
		Trg: REAL_VALUED (STATIC_INPUT)
		HDL: REAL_VALUED (STATIC_INPUT)
		LDL: REAL_VALUED (STATIC_INPUT)
		mean_glucose: REAL_VALUED (STATIC_INPUT)
		sd_glucose: REAL_VALUED (STATIC_INPUT)
		range_glucose: REAL_VALUED (STATIC_INPUT)
		min_glucose: REAL_VALUED (STATIC_INPUT)
		max_glucose: REAL_VALUED (STATIC_INPUT)
		quartile.25_glucose: REAL_VALUED (STATIC_INPUT)
		median_glucose: REAL_VA

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
for i, j in enumerate(random.sample(range(len(series['train']['target'])), 3)):
    s = scalers['target'].inverse_transform(series['train']['target'][j])
    id = scalers['static'].inverse_transform(series['train']['static'][j])
    id = id.values()[0, -1]
    s.plot(ax=axs[i])
    axs[i].set_title(f'Patient {int(id)}')
    axs[i].set_ylabel('Glucose (mg/dL)')
    axs[i].set_xlabel('Time (date)')
    if axs[i].get_legend() is not None:
        axs[i].get_legend().remove()

In [3]:
# create datasets
out_len = 12
in_len = 24
max_samples_per_ts = 100
dataset_train = SamplingDatasetDual(series['train']['target'],
                                    series['train']['future'],
                                    output_chunk_length=out_len,
                                    input_chunk_length=in_len,
                                    use_static_covariates=True,
                                    max_samples_per_ts=max_samples_per_ts,)
dataset_val = SamplingDatasetDual(series['val']['target'],
                                    series['val']['future'],   
                                    output_chunk_length=out_len,
                                    input_chunk_length=in_len,
                                    use_static_covariates=True,)
dataset_test = SamplingDatasetInferenceDual(target_series=series['test']['target'],
                                            covariates=series['test']['future'],
                                            input_chunk_length=in_len,
                                            output_chunk_length=out_len,
                                            use_static_covariates=True,
                                            array_output_only=True)
dataset_test_ood = SamplingDatasetInferenceDual(target_series=series['test_ood']['target'],
                                                covariates=series['test_ood']['future'],
                                                input_chunk_length=in_len,
                                                output_chunk_length=out_len,
                                                use_static_covariates=True,
                                                array_output_only=True)

In [5]:
dataset_train[0][0].shape

(24, 1)

In [20]:
# convert samples to series
import darts
import pandas as pd



In [21]:
out_len = 1
in_len =  48
max_samples_per_ts = 100



# Define a model

In [22]:
model = LatentODEWrapper(device = 'cuda',
                         latents = 5,
                         rec_dims = 50,
                         rec_layers = 3,
                         gen_layers = 3,
                         units = 300,
                         gru_units = 100)

# Train

In [23]:
model_path = 'output/model.ckpt'
writer = SummaryWriter('output/tensorboard')
model.fit(dataset_train,
          dataset_train,
          learning_rate = 1e-3,
          batch_size = 32,
          epochs = 100,
          num_samples = 2,
          device = 'cuda',
          model_path = model_path,
          trial = None,
          logger = writer,
          visualize=True,)

  dydt = (dydt / mag)


KeyboardInterrupt: 

In [24]:
datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

'20230426-164445'

In [1]:
# list folders in the directory plots
import os
import numpy as np
import imageio
folders = os.listdir('./plots')
# in each folder, grab list of files

m = {'20230426-165058': 'glucose_id',
     '20230426-165144': 'periodic_id',
     '20230426-204814': 'periodic_interpol',
     '20230426-204841': 'glucose_interpol',
     '20230426-204954': 'glucose_extrapol',}
for folder in folders:
    files = os.listdir(f'./plots/{folder}')
    files = sorted(files, key=lambda x: int(x.split('_')[1].split('.')[0]))
    select1 = np.linspace(0, 200, 60, dtype=int)
    select2 = np.linspace(200, len(files)-1, 60, dtype=int)
    files = [files[i] for i in select1] + [files[i] for i in select2]
    images = []
    for file in files:
        images.append(imageio.imread(f'./plots/{folder}/{file}'))
    imageio.mimsave(f'./plots/{folder}.gif', images, duration=0.1)



  images.append(imageio.imread(f'./plots/{folder}/{file}'))
