In [None]:
# @title Imports
# General imports.
import datetime
import os
import numpy as np
import matplotlib.pyplot as plt

# Library imports.
from dm_c19_modelling.evaluation import constants
from dm_c19_modelling.modelling.training import checkpointing
from dm_c19_modelling.modelling.training import dataset_factory
from dm_c19_modelling.modelling.training import model_factory

In [None]:
# @title Util functions.

def restore_state(base_checkpoint_path, checkpoint_name):
  restore_path = os.path.join(base_checkpoint_path, checkpoint_name)
  checkpointer = checkpointing.Checkpointer(
      directory="/tmp/",
      max_to_keep=1,
      restore_path=restore_path)
  state = checkpointer.get_experiment_state("")
  checkpointer.restore("")
  return state

def date_str_to_datetime_sequence(date_str_list):
  return np.array(
      [datetime.datetime.strptime(date_str, constants.DATE_FORMAT)
       for date_str in date_str_list])

def plot(dataset,
         forecast, forecast_aux_dict,
         validation_forecast, validation_forecast_aux_dict):

  if len(dataset.target_names) != 1:
    raise ValueError(
        "This function is designed to plot models with a single target.")

  # Plot for sites in order of accumulated incidence.
  site_id_order = np.argsort(dataset.targets[..., 0].sum(0))[::-1]

  # Get the dates for the input, validation, and evaluation ranges.
  evaluation_dates = date_str_to_datetime_sequence(dataset.evaluation_dates)
  num_evaluation_dates = len(evaluation_dates)
  input_dates = date_str_to_datetime_sequence(dataset.dates)
  input_dates_without_validation = input_dates[:-num_evaluation_dates]
  validation_dates = input_dates[-num_evaluation_dates:]
  input_and_evaluation_dates = np.concatenate(
      [input_dates, evaluation_dates], axis=0)

  # Plot targets.
  fig, axes = plt.subplots(4, 5, figsize=(18, 10), sharex=True)
  for site_id, ax in zip(site_id_order, axes.flatten()):

    site_name = dataset.sites[site_id]
    ax.set_title(f"Site {site_id} ({site_name})")

    # Plot the targets for the input dates.
    targets_for_input_dates = dataset.targets[:, site_id, 0]
    ax.plot(input_dates, targets_for_input_dates,
            linewidth=4, color="grey", alpha=0.5,
            label="Observed ground truth")

    # Plot the model forecast.
    ax.plot(evaluation_dates, forecast[:, site_id, 0], color="green",
            label="Forecast")

    # Plot the model forecast for the validation dates.
    ax.plot(validation_dates, validation_forecast[:, site_id, 0], color="red",
            label="Validation forecast")

    # Some models produce predictions for some of the last input dates,
    # which can be useful for forecasting.
    if "predictions_for_inputs" in validation_forecast_aux_dict:
      predictions_for_inputs = validation_forecast_aux_dict[
          "predictions_for_inputs"][:, site_id, 0]
      ax.plot(input_dates_without_validation[-predictions_for_inputs.shape[0]:],
              predictions_for_inputs, color="orange",
              label="Fit to last input dates")
    ax.tick_params(axis="x", rotation=45)

  axes[0, -1].legend(bbox_to_anchor=[1., 0.5], loc="center left")
  for ax in axes[:, 0]:
    ax.set_ylabel(dataset.target_names[0])
  fig.subplots_adjust(wspace=0.4, hspace=0.3)
  fig.suptitle("Target prediction")

  # For the SEIRLSTM model we can also plot more info.
  if model.__class__.__name__ == "SEIRLSTM":

    for title, ylabel, aux_data, dates in [
        # ODE state as a function of time.
        ("ODE State as function of time",
         "Fraction of population",
         forecast_aux_dict["full_seir_state_sequence"],
         input_and_evaluation_dates),
        # ODE parameters as a function of time.
        ("ODE Parameters as function of time",
         "Parameter value",
         forecast_aux_dict["full_seir_params_sequence"],
         input_and_evaluation_dates[:-1],  # No parameter for last date.
         )
    ]:

      fig, axes = plt.subplots(4, 5, figsize=(18, 10), sharex=True, sharey=True)
      for site_id, ax in zip(site_id_order, axes.flatten()):
        site_name = dataset.sites[site_id]
        ax.set_title(f"Site {site_id} ({site_name})")
        for k, variable in aux_data._asdict().items():
          ax.plot(dates, variable[:, site_id], label=k)
        ax.set_yscale("log")
        ax.tick_params(axis="x", rotation=45)

      axes[0, -1].legend(bbox_to_anchor=[1., 0.5], loc="center left")
      for ax in axes[:, 0]:
        ax.set_ylabel(ylabel)
      fig.subplots_adjust(wspace=0.1, hspace=0.3)
      fig.suptitle(title)


In [None]:
# Checkpoint name is one of:
# "latest_eval": Model at end of the first phase of training (Usually overfits).
# "best_eval": Early stopping using validation dates.
# "latest_fine_tune": "best_eval" fine-tuned on validation dates.
base_checkpoint_path = "/tmp/checkpoint_path/"
checkpoint_name = "latest_fine_tune"

# Load the state of the experiment.
state = restore_state(base_checkpoint_path, checkpoint_name)

# Load dataset.
dataset = dataset_factory.get_training_dataset(
    **state.build_info["dataset_factory_kwargs"])

# Load model.
model = model_factory.get_model(state.build_info["dataset_spec"],
                                **state.build_info["model_factory_kwargs"])

# Make a forecast prediction for dates after the "last_observation_date" of
# the dataset. This is the forecast that would be submitted to the
# forecast index for independent evaluation of the model, and as such we do not
# have access to the targets during the model design phase.
forecast, forecast_aux_dict = model.evaluate(
    state.model_state, dataset)

# Get another version of the dataset removing validation dates from inputs, this
# will allow us to visualize the performance of our model against validation
# dates for which we have access to targets during the model design phase, e.g.
# dates before the "last_observation_date".
(dataset_without_validation_dates, _
 ) = dataset_factory.remove_validation_dates(dataset)
validation_forecast, validation_forecast_aux_dict = model.evaluate(
    state.model_state, dataset_without_validation_dates)

plot(dataset,
     forecast, forecast_aux_dict,
     validation_forecast, validation_forecast_aux_dict)