In [None]:
import torch
import numpy as np
import pandas as pd
import requests

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline, DeepAR
from pytorch_forecasting.data import GroupNormalizer

from pytorch_forecasting.metrics import PoissonLoss, QuantileLoss, SMAPE
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

In [None]:
URL = 'https://covid.ourworldindata.org/data/owid-covid-data.json'

In [None]:
r = requests.get(URL)
data_json = r.json()

In [None]:
df = pd.DataFrame.from_dict(data_json, 'index')

In [None]:
xs = df.drop(columns='data')
raw_ys = df.data

In [None]:
xs.shape

In [None]:
import matplotlib.pyplot as plt


In [None]:
nans = xs.isnull().sum()

In [None]:
# look at missing values
plt.bar(nans.index.values, nans.values)
plt.title("Missing Values in (Exogenous) Covariates")
plt.xticks(rotation = 90);

In [None]:
plt.bar(xs.groupby('continent').size().index, xs.groupby('continent').size())
plt.xticks(rotation = 90);

In [None]:
# plot histograms for all xs
from pandas.api.types import is_numeric_dtype

fig = plt.figure(figsize=(20,20))
# ax.set_title("Distribution of Covariates")

i = 0
for x in xs.columns:
    if(x == 'location'):
        continue
        
    i += 1
    ax = fig.add_subplot(4,4,i)
        
    if is_numeric_dtype(xs[x]):
        
        ax.set_title(x)
        ax.hist(xs[x],density=True)
    
    else:
        
        ax.bar(xs.groupby(x).size().index, xs.groupby(x).size())
        ax.set_title(x)
        
        
fig;

In [None]:
# init df with first country
ys = pd.DataFrame(raw_ys[0])
ys['location'] = raw_ys.index[0]
ys = ys.set_index('location')
ys = ys.reset_index()

# append new ones
for i in range(1,len(raw_ys)):
    new_ys = pd.DataFrame(raw_ys[i])
    new_ys['location'] = raw_ys.index[i]
    ys = ys.append(new_ys)




In [None]:
ys.info()

In [None]:
plt.bar(ys.groupby('tests_units').size().index, ys.groupby('tests_units').size())
plt.xticks(rotation = 90);

In [None]:
(df == 0).sum()

In [None]:
xs.columns, ys.columns

In [None]:
data = ys.join(xs, on='location',rsuffix="_x").drop(columns='location_x')

In [None]:
# should have the same number of rows
data.shape, ys.shape

In [None]:
# prepare data for consumption
# generously copied from https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/stallion.html for testing
# compute time index in days from t0
data['date'] = pd.to_datetime(data['date'])
t_zero = data['date'].min()
data['time_idx'] = (data['date'] - t_zero).dt.days

# fill world-level continent label
data["continent"] = data["continent"].fillna('Global')

# fill nans in test units
data["tests_units"] = data["tests_units"].fillna('NA')

# TODO: see if wwe might need additional feature
data["month"] = data.date.dt.month.astype(str).astype("category")
data["continent"] = data["continent"].astype("category")
data["tests_units"] = data["tests_units"].astype("category")


In [None]:
data.sample(10)

In [None]:
# just take subset of European countries
# data = data[(data["continent"] == "Europe")
#                 &
#             (data["location"].isin(data["location"].sample(10)))]

In [None]:
# assign new unique index 
data.index = range(0,data.shape[0])

In [None]:
data.isna().mean()['population']

In [None]:
ｍax_pred_length = 7 # predict at most two weeks
max_encoder_length = 60 # use at most 2 months as input
training_cutoff = data['time_idx'].max() - max_pred_length

targets = 'new_cases'

# see https://github.com/jdb78/pytorch-forecasting/issues/187#issuecomment-743797144
# simple imputation by replacing NaNs with 0
data = data.fillna({name: 0.0 for name in ['population','population_density','median_age','aged_65_older','aged_70_older',
                    'gdp_per_capita','cardiovasc_death_rate', 'diabetes_prevalence', 'handwashing_facilities', 
                    'hospital_beds_per_thousand', 'life_expectancy', 'human_development_index', 'extreme_poverty', 'female_smokers','male_smokers', targets]})


training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx='time_idx',
    target=targets,
    group_ids=['location'],
    min_encoder_length=int(max_encoder_length / 2),
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_pred_length,
    static_categoricals=['location', 'continent', 'tests_units'],
    static_reals = ['population','population_density','median_age','aged_65_older','aged_70_older',
                    'gdp_per_capita','cardiovasc_death_rate', 'diabetes_prevalence', 'handwashing_facilities', 
                    'hospital_beds_per_thousand', 'life_expectancy', 'human_development_index', 'extreme_poverty', 'female_smokers','male_smokers'],
    time_varying_known_categoricals=['month'],
    time_varying_known_reals=['time_idx', 
                              #'stringency_index', 'new_tests', unknown but could be used for conditional forecasts
                             ],
    target_normalizer=GroupNormalizer(groups=['location'], transformation="softplus"),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missings=True
)

# create validation set (predict=True) which means to predict the last max_prediction_length points in time for each series
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 128  # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=8)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=8)

In [None]:
# calculate baseline mean absolute error, i.e. predict next value as the last available value from the history
actuals = torch.cat([y for x, y in iter(val_dataloader)])
baseline_predictions = Baseline().predict(val_dataloader)
(actuals - baseline_predictions).abs().mean().item()

In [None]:
actuals

In [None]:
# configure network and trainer
pl.seed_everything(42)
trainer = pl.Trainer(
    gpus=8,
    # clipping gradients is a hyperparameter and important to prevent divergance
    # of the gradient for recurrent neural networks
    gradient_clip_val=0.1,
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=0.03,
    hidden_size=16,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=1,
    dropout=0.1,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=8,  # set to <= hidden_size
    output_size=7,  # 7 quantiles by default
    loss=QuantileLoss(),
    # reduce learning rate if no improvement in validation loss after x epochs
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

In [None]:
# find optimal learning rate
res = trainer.tuner.lr_find(
    tft,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-2,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

In [None]:
# fit network
trainer.fit(
    tft,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
)

In [None]:
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

# create study
study = optimize_hyperparameters(
    train_dataloader,
    val_dataloader,
    model_path="optuna_test",
    n_trials=200,
    max_epochs=50,
    gradient_clip_val_range=(0.01, 1.0),
    hidden_size_range=(8, 128),
    hidden_continuous_size_range=(8, 128),
    attention_head_size_range=(1, 4),
    learning_rate_range=(0.001, 0.1),
    dropout_range=(0.1, 0.3),
    trainer_kwargs=dict(limit_train_batches=30),
    reduce_on_plateau_patience=4,
    use_learning_rate_finder=False,  # use Optuna to find ideal learning rate or use in-built learning rate finder
)

# save study results - also we can resume tuning at a later point in time
with open("test_study.pkl", "wb") as fout:
    pickle.dump(study, fout)

# show best hyperparameters
print(study.best_trial.params)

In [None]:
# load the best model according to the validation loss
# (given that we use early stopping, this is not necessarily the last epoch)
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

In [None]:
# calcualte mean absolute error on validation set
actuals = torch.cat([y for x, y in iter(val_dataloader)])
predictions = best_tft.predict(val_dataloader)
(actuals - predictions).abs().mean()

In [None]:
# raw predictions are a dictionary from which all kind of information including quantiles can be extracted
raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)

In [None]:
for idx in range(10):  # plot 10 examples
    best_tft.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True);

In [None]:
torch.save(best_tft.state_dict(), "20201213_tft_basemodel.pkl")

# 2nd Try - More Sophisticated Data Cleaning & Value Imputation

In [None]:
plt.hist(data.groupby('location').nunique()['date']);

In [None]:
plt.bar(data.groupby('date').nunique()['location'].index, data.groupby('date').nunique()['location']);

Given the high number of missing values, we might be tempted to retry model fitting without excessively sparse features (e.g. `handwashing_facilities`).
See [this article](https://www.wandb.com/articles/pytorch-lightning-with-weights-biases) to leverage wandb.

In [None]:
targets = 'new_cases'

reals = ['population','population_density','median_age','aged_65_older','aged_70_older',
                    'gdp_per_capita','cardiovasc_death_rate', 'diabetes_prevalence', 'handwashing_facilities', 
                    'hospital_beds_per_thousand', 'life_expectancy', 'human_development_index', 'extreme_poverty', 'female_smokers','male_smokers']

In [None]:
# check reals with most missing values
# remove reals with more than 50%  NaNs
data2 = data.copy()
high_missing = data2[reals].isna().mean()[data2[reals].isna().mean() > .5].index
print("removing: ", high_missing)
for idx in high_missing:
    reals.remove(idx)

In [None]:
reals

In [None]:
# the rest we interpolate

# we want to interpolate by country 
# and we'll try with linear

countries = data2['location'].unique()

for country in countries:
    data2[(data2['location'] == country)][reals] = data2[(data2['location'] == country)][reals].interpolate(method='linear', axis = 1)
    
data2[reals].isna().mean()

In [None]:
data3 = data.copy()

countries = data3['location'].unique()

for country in countries:
    data3[(data3['location'] == country)][reals] = data3[(data3['location'] == country)][reals].interpolate(method='akima', limit_direction='both', axis = 1)
    
data3[reals].isna().mean()

In [None]:
data3[data3['location'] == 'AFG'][reals]