In [None]:
!pip uninstall pytorch-lightning plotly -qy
!pip install pytorch-lightning pytorch_forecasting==0.9.0 plotly kaleido -q

import pytorch_forecasting
import os
import warnings
import plotly.graph_objects as go

warnings.filterwarnings("ignore")

import matplotlib
import matplotlib.pyplot as plt
#plt.rcParams["figure.figsize"] = (16,10)

from pathlib import Path
import pandas as pd
import numpy as np
import torch
import copy
import pickle
from tqdm.auto import tqdm
import itertools
import datetime
import re
import random
import collections
import math
import statistics
import gc
import ast

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.data import GroupNormalizer, EncoderNormalizer
from pytorch_forecasting.metrics import PoissonLoss, QuantileLoss, SMAPE, CrossEntropy, RMSE, BetaDistributionLoss, MAPE
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.data.encoders import NaNLabelEncoder

from optuna.visualization import plot_intermediate_values
from optuna.visualization import plot_optimization_history

pl.seed_everything(42)


In [None]:
path_df = ''
df = pd.read_csv(path_df).dropna(subset=['cycle', 'cycle_day'])


df['sex'] = df['sex'].apply(lambda x : str(x)).astype('category')  # .map({'F' : '0', 'M':'1'})
df['cycle'] = df['cycle'].apply(lambda x : str(x)).astype('category')
df['age'] = df['age'].apply(lambda x : str(x)).astype('category')
df['cycle_day'] = df['cycle_day'].apply(lambda x : int(x)).astype('int')
df['gcsf'] = df['gcsf'].apply(lambda x : str(x)).astype('category')
df['transfusion'] = df['transfusion'].apply(lambda x : str(x)).astype('category')
df['danger'] = df['danger'].astype('category')
df['cd34'] = df['cd34'].astype(float) 
df['pbscc'] = df['pbscc'].astype(float) 
df['timestamp'] = df['timestamp'].astype(int)

In [None]:

target = 'ANC'
encoder = EncoderNormalizer()
loss = SMAPE()
    
max_encoder_length = 0
for x,y in df.groupby(['no']):
    if max_encoder_length < len(y):
        max_encoder_length = len(y)
print(max_encoder_length)

seen = [x for x in df['no'].unique() if int(x) not in range(601,703)]
internal_val = random.sample(df['no'].unique().tolist(), 50)
max_prediction_length = 30

In [None]:
total_data = TimeSeriesDataSet(
    df,
    time_idx="timestamp", 
    target=target, 
    group_ids=["no"], 
    min_encoder_length=min_encoder_length, 
    max_encoder_length=max_encoder_length, 
    max_prediction_length=max_prediction_length, 
    static_categoricals=["sex"], 
    time_varying_known_categoricals=['cycle'], 
    time_varying_known_reals=['timestamp','cycle_day', 'cd34', 'pbscc'], 
    time_varying_unknown_categoricals=['gcsf', 'transfusion'], 
    categorical_encoders={'cycle': NaNLabelEncoder(add_nan=True), 'no': NaNLabelEncoder(add_nan=True)},
    time_varying_unknown_reals=[x for x in df.columns[5:] if x not in ['gcsf', 'transfusion','cd34', 'pbscc']],
    add_target_scales=True, 
    add_encoder_length=True,
    allow_missing_timesteps = True,
)

training = TimeSeriesDataSet.from_dataset(total_data,
                                          df[lambda x: x.no.isin([x for x in seen if x not in internal_val])])

validation = TimeSeriesDataSet.from_dataset(total_data,
                                            df[lambda x: x.no.isin(for_test)],#[lambda x: ~x.cycle.isin(['target'])], #'H1','H2'
                                            predict=True, stop_randomization=True)
unseen = TimeSeriesDataSet.from_dataset(total_data,
                                        df[lambda x: ~x.no.isin(seen)],#[lambda x: ~x.cycle.isin(['target'])], #'H1','H2'
                                        predict=True, stop_randomization=True)

train_dataloader = training.to_dataloader(train=True, batch_size=128,pin_memory=True)
val_dataloader = validation.to_dataloader(train=False, batch_size=256,pin_memory=True)
unseen_dataloader = unseen.to_dataloader(train=False, batch_size=256,pin_memory=True)

print(len(train_dataloader), len(val_dataloader), len(unseen_dataloader))

In [None]:
import tensorflow as tf 
import tensorboard as tb 
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

# create study
study = optimize_hyperparameters(
    train_dataloader,
    val_dataloader,
    model_path="",
    n_trials=100,
    max_epochs=200,
    loss=loss,
    gradient_clip_val_range=(0.01, 1.0),
    hidden_size_range=(8,128),
    hidden_continuous_size_range=(8, 64),
    attention_head_size_range=(1, 8),
    learning_rate_range=(0.001, 0.1),
    dropout_range=(0.1, 0.3),
    trainer_kwargs=dict(limit_train_batches=30),
    reduce_on_plateau_patience=3,
    use_learning_rate_finder=False,  
)

# save study results - also we can resume tuning at a later point in time
with open("test_study.pkl", "wb") as fout:
    pickle.dump(study, fout)

# show best hyperparameters
print(study.best_trial.params)

In [None]:
plot_optimization_history(study)

In [None]:
plot_intermediate_values(study)

In [None]:
best_model_path = os.path.join(f'./optuna_test/trial_{study.best_trial.number}', os.listdir(f'./optuna_test/trial_{study.best_trial.number}')[0])
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
pickle.dump(best_tft, open('best_tft.p', 'wb'))

In [None]:
# raw predictions are a dictionary from which all kind of information including quantiles can be extracted
raw_predictions, x_data = best_tft.predict(val_dataloader, mode='raw', return_x=True)

pickle.dump( x_data, open( "x_data.p", "wb" ))
pickle.dump( raw_predictions, open( "raw_predictions.p", "wb" ))

In [None]:
interpretation = best_tft.interpret_output(raw_predictions, reduction="sum")
pickle.dump(interpretation, open( "interpretation.p", "wb" ))
best_tft.plot_interpretation(interpretation)

In [None]:
## unseen data
new_raw_predictions, new_x = best_tft.predict(unseen_dataloader, mode='raw', return_x=True)

pickle.dump(new_raw_predictions, open("new_raw_predictions.p", "wb"))
pickle.dump(new_x, open("new_x.p", "wb"))