In [1]:
# Experiment 2 Script for COVID-19 Forecasting
# Change path to the root directory of the project
import os
os.chdir("../../")

# Description: This script contains the code for the second experiment in the project, 
# forecasting COVID-19 MVBeds using various RNN models and hyperparameter tuning with Simulated Annealing.

# Imports for handling data
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
from itertools import cycle

# Imports for machine learning
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from sklearn.metrics import mean_absolute_error as mae, mean_squared_error as mse

# Imports for visualization
import plotly.express as px
import plotly.graph_objects as go

# Progress bar
from tqdm.autonotebook import tqdm
tqdm.pandas()

# Local imports for data loaders and models
from src.utils import plotting_utils
from src.dl.dataloaders import TimeSeriesDataModule
from src.dl.multivariate_models import SingleStepRNNConfig, SingleStepRNNModel, Seq2SeqConfig, Seq2SeqModel, RNNConfig
from src.transforms.target_transformations import AutoStationaryTransformer

# Set seeds for reproducibility
pl.seed_everything(42)
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

torch.set_float32_matmul_precision('high')

# Set default plotly template
import plotly.io as pio
pio.templates.default = "plotly_white"

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
import logging

# Set logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


Global seed set to 42


In [2]:
# Utility Functions

def format_plot(fig, legends=None, xlabel="Time", ylabel="Value", title="", font_size=15):
    if legends:
        names = cycle(legends)
        fig.for_each_trace(lambda t: t.update(name=next(names)))
    fig.update_layout(
        autosize=False,
        width=900,
        height=500,
        title_text=title,
        title={"x": 0.5, "xanchor": "center", "yanchor": "top"},
        titlefont={"size": 20},
        legend_title=None,
        legend=dict(
            font=dict(size=font_size),
            orientation="h",
            yanchor="bottom",
            y=0.98,
            xanchor="right",
            x=1,
        ),
        yaxis=dict(
            title_text=ylabel,
            titlefont=dict(size=font_size),
            tickfont=dict(size=font_size),
        ),
        xaxis=dict(
            title_text=xlabel,
            titlefont=dict(size=font_size),
            tickfont=dict(size=font_size),
        ),
    )
    return fig

def mase(actual, predicted, insample_actual):
    mae_insample = np.mean(np.abs(np.diff(insample_actual)))
    mae_outsample = np.mean(np.abs(actual - predicted))
    return mae_outsample / mae_insample

def forecast_bias(actual, predicted):
    return np.mean(predicted - actual)

def plot_forecast(pred_df, forecast_columns, selected_area, forecast_display_names=None, save_path=None):
    if forecast_display_names is None:
        forecast_display_names = forecast_columns
    else:
        assert len(forecast_columns) == len(forecast_display_names)

    mask = ~pred_df[forecast_columns[0]].isnull()
    colors = px.colors.qualitative.D3  # Use a colorblind-friendly palette
    act_color = colors[0]
    colors = cycle(colors[1:])

    fig = go.Figure()

    # Actual data plot
    fig.add_trace(
        go.Scatter(
            x=pred_df[mask].index,
            y=pred_df[mask].covidOccupiedMVBeds,
            mode="lines+markers",
            marker=dict(size=8, opacity=0.7, symbol='circle'),
            line=dict(color=act_color, width=3),
            name="Actual COVID-19 MVBeds trends",
        )
    )

    # Predicted data plot
    line_styles = ["solid", "dash", "dot", "dashdot"]
    markers = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up']
    for col, display_col, line_style, marker in zip(forecast_columns, forecast_display_names, cycle(line_styles), cycle(markers)):
        fig.add_trace(
            go.Scatter(
                x=pred_df[mask].index,
                y=pred_df.loc[mask, col],
                mode="lines+markers",
                marker=dict(size=6, symbol=marker),
                line=dict(color=next(colors), width=2, dash=line_style),
                name=display_col,
            )
        )

    fig.update_layout(
        title=f"COVID-19 MVBeds Forecast Comparison for {selected_area}",
        title_font=dict(size=20),
        xaxis_title="Date",
        yaxis_title="COVID-19 MVBeds",
        xaxis=dict(title_font=dict(size=15), tickfont=dict(size=12)),
        yaxis=dict(title_font=dict(size=15), tickfont=dict(size=12)),
        legend=dict(
            font=dict(size=10),
            orientation="h",
            yanchor="top",
            y=1.02,  # Move the legend down to separate it from the title
            xanchor="center",
            x=0.5
        ),
        template="plotly_white",
        plot_bgcolor="rgba(0,0,0,0)",
        margin=dict(l=10, r=40, t=80, b=40),  # Adjust left margin for better alignment
        width=1200,
        height=600
    )

    if save_path:
        pio.write_image(fig, save_path)
    return fig


def highlight_abs_min(s, props=""):
    return np.where(s == np.nanmin(np.abs(s.values)), props, "")


In [3]:
# Load and Prepare Data
data_path = Path("data/hos_data/merged_data.csv")
data = pd.read_csv(data_path).drop("Unnamed: 0", axis=1)
data["date"] = pd.to_datetime(data["date"])

# Select and Process Data
selected_area = "Midlands"
data_filtered = data[data["areaName"] == selected_area]

# Data Processing
data_filtered["date"] = pd.to_datetime(data_filtered["date"])
data_filtered.sort_values(by=["date", "areaName"], inplace=True)
data_filtered.drop(
    [
        "areaName",
        "cumAdmissions",
        "cumulative_confirmed",
        "cumulative_deceased",
        "population",
        "latitude",
        "longitude",
        "epi_week",
    ],
    axis=1,
    inplace=True,
)
data_filtered.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 888 entries, 4439 to 3552
Data columns (total 6 columns):
 #   Column               Non-Null Count  Dtype         
---  ------               --------------  -----         
 0   date                 888 non-null    datetime64[ns]
 1   covidOccupiedMVBeds  888 non-null    float64       
 2   hospitalCases        888 non-null    float64       
 3   newAdmissions        888 non-null    int64         
 4   new_confirmed        888 non-null    float64       
 5   new_deceased         888 non-null    float64       
dtypes: datetime64[ns](1), float64(4), int64(1)
memory usage: 48.6 KB


In [4]:
# Add rolling features
def add_rolling_features(df, window_size, columns, agg_funcs=None):
    if agg_funcs is None:
        agg_funcs = ["mean"]
    added_features = {}
    for column in columns:
        for func in agg_funcs:
            roll_col_name = f"{column}_rolling_{window_size}_{func}"
            df[roll_col_name] = df[column].rolling(window_size).agg(func)
            if column not in added_features:
                added_features[column] = []
            added_features[column].append(roll_col_name)
    df.dropna(inplace=True)
    return df, added_features

window_size = 7
columns_to_roll = ["hospitalCases", "newAdmissions", "new_confirmed", "new_deceased"]
agg_funcs = ["mean", "std"]

data_filtered, added_features = add_rolling_features(data_filtered, window_size, columns_to_roll, agg_funcs)

for column, features in added_features.items():
    logging.info(f"{column}: {', '.join(features)}")

2024-05-22 14:22:44,837 - INFO - hospitalCases: hospitalCases_rolling_7_mean, hospitalCases_rolling_7_std
2024-05-22 14:22:44,838 - INFO - newAdmissions: newAdmissions_rolling_7_mean, newAdmissions_rolling_7_std
2024-05-22 14:22:44,839 - INFO - new_confirmed: new_confirmed_rolling_7_mean, new_confirmed_rolling_7_std
2024-05-22 14:22:44,840 - INFO - new_deceased: new_deceased_rolling_7_mean, new_deceased_rolling_7_std


In [5]:
# Add time-lagged features
def add_lags(data, lags, features):
    added_features = []
    for feature in features:
        for lag in lags:
            new_feature = feature + f"_lag_{lag}"
            data[new_feature] = data[feature].shift(lag)
            added_features.append(new_feature)
    return data, added_features

lags = [1, 2, 3, 5, 7, 14, 21]
data_filtered, added_features = add_lags(data_filtered, lags, ["covidOccupiedMVBeds"])
data_filtered.dropna(inplace=True)

# Create temporal features
def create_temporal_features(df, date_column):
    df["month"] = df[date_column].dt.month
    df["day"] = df[date_column].dt.day
    df["day_of_week"] = df[date_column].dt.dayofweek
    return df

data_filtered = create_temporal_features(data_filtered, "date")
data_filtered = data_filtered.set_index("date")

In [6]:
# Load and process the SEIRD data
seird_data = pd.read_csv(f"reports/output/pinn_{selected_area}_output.csv")
seird_data["date"] = pd.to_datetime(seird_data["date"])
seird_data.set_index("date", inplace=True)

# Merge the two dataframes on the date index
merged_data = pd.merge(data_filtered, seird_data, left_index=True, right_index=True, how="inner")

# Drop rows with any missing values
merged_data.dropna(inplace=True)
merged_data.info()

<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 1584 entries, 2020-05-01 to 2021-05-31
Data columns (total 28 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   covidOccupiedMVBeds           1584 non-null   float64
 1   hospitalCases                 1584 non-null   float64
 2   newAdmissions                 1584 non-null   int64  
 3   new_confirmed                 1584 non-null   float64
 4   new_deceased                  1584 non-null   float64
 5   hospitalCases_rolling_7_mean  1584 non-null   float64
 6   hospitalCases_rolling_7_std   1584 non-null   float64
 7   newAdmissions_rolling_7_mean  1584 non-null   float64
 8   newAdmissions_rolling_7_std   1584 non-null   float64
 9   new_confirmed_rolling_7_mean  1584 non-null   float64
 10  new_confirmed_rolling_7_std   1584 non-null   float64
 11  new_deceased_rolling_7_mean   1584 non-null   float64
 12  new_deceased_rolling_7_std    1584 non-null 

In [6]:
# Load and process the SEIRD data
seird_data = pd.read_csv(f"reports/output/pinn_{selected_area}_output.csv")
seird_data["date"] = pd.to_datetime(seird_data["date"])
seird_data.set_index("date", inplace=True)

# Merge the two dataframes on the date index
merged_data = pd.merge(data_filtered, seird_data, left_index=True, right_index=True, how="inner")

# Drop rows with any missing values
merged_data.dropna(inplace=True)
merged_data.info()

<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 396 entries, 2020-05-01 to 2021-05-31
Data columns (total 28 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   covidOccupiedMVBeds           396 non-null    float64
 1   hospitalCases                 396 non-null    float64
 2   newAdmissions                 396 non-null    int64  
 3   new_confirmed                 396 non-null    float64
 4   new_deceased                  396 non-null    float64
 5   hospitalCases_rolling_7_mean  396 non-null    float64
 6   hospitalCases_rolling_7_std   396 non-null    float64
 7   newAdmissions_rolling_7_mean  396 non-null    float64
 8   newAdmissions_rolling_7_std   396 non-null    float64
 9   new_confirmed_rolling_7_mean  396 non-null    float64
 10  new_confirmed_rolling_7_std   396 non-null    float64
 11  new_deceased_rolling_7_mean   396 non-null    float64
 12  new_deceased_rolling_7_std    396 non-null   

In [7]:
# Set the target variable and make it stationary
target = "covidOccupiedMVBeds"
seasonal_period = 7
auto_stationary = AutoStationaryTransformer(seasonal_period=seasonal_period)
data_stat = auto_stationary.fit_transform(merged_data[[target]], freq="D")
merged_data[target] = data_stat.values

# Ensure the index is a DateTimeIndex
merged_data.index = pd.to_datetime(merged_data.index)

In [8]:
merged_data.columns

Index(['covidOccupiedMVBeds', 'hospitalCases', 'newAdmissions',
       'new_confirmed', 'new_deceased', 'hospitalCases_rolling_7_mean',
       'hospitalCases_rolling_7_std', 'newAdmissions_rolling_7_mean',
       'newAdmissions_rolling_7_std', 'new_confirmed_rolling_7_mean',
       'new_confirmed_rolling_7_std', 'new_deceased_rolling_7_mean',
       'new_deceased_rolling_7_std', 'covidOccupiedMVBeds_lag_1',
       'covidOccupiedMVBeds_lag_2', 'covidOccupiedMVBeds_lag_3',
       'covidOccupiedMVBeds_lag_5', 'covidOccupiedMVBeds_lag_7',
       'covidOccupiedMVBeds_lag_14', 'covidOccupiedMVBeds_lag_21', 'month',
       'day', 'day_of_week', 'susceptible', 'exposed', 'active_cases',
       'recovered', 'cumulative_deceased'],
      dtype='object')

In [9]:
# Filter data between the specified dates
start_date = "2020-05-01"
end_date = "2021-05-31"
merged_data = merged_data[start_date:end_date]

min_date = merged_data.index.min()
max_date = merged_data.index.max()

# Calculate the range of dates
date_range = max_date - min_date
logging.info(f"Data ranges from {min_date} to {max_date} ({date_range.days} days)")

# Calculate split points
total_days = date_range.days
train_end = min_date + pd.Timedelta(days=int(total_days * 0.70))
val_end = train_end + pd.Timedelta(days=int(total_days * 0.20))

# Split the data into training, validation, and testing sets
train = merged_data[merged_data.index <= train_end]
val = merged_data[(merged_data.index > train_end) & (merged_data.index <= val_end)]
test = merged_data[merged_data.index > val_end]

# Calculate the percentage of dates in each dataset
total_sample = len(merged_data)
train_sample = len(train) / total_sample * 100
val_sample = len(val) / total_sample * 100
test_sample = len(test) / total_sample * 100

print(f"Train: {train_sample:.2f}%, Validation: {val_sample:.2f}%, Test: {test_sample:.2f}%")
print(f"Train: {len(train)} samples, Validation: {len(val)} samples, Test: {len(test)} samples")
print(f"Max date in train: {train.index.max()}, Min date in validation: {val.index.min()}, Max date in test: {test.index.max()}")


2024-05-22 13:53:15,828 - INFO - Data ranges from 2020-05-01 00:00:00 to 2021-05-31 00:00:00 (395 days)


Train: 69.95%, Validation: 19.95%, Test: 10.10%
Train: 277 samples, Validation: 79 samples, Test: 40 samples
Max date in train: 2021-02-01 00:00:00, Min date in validation: 2021-02-02 00:00:00, Max date in test: 2021-05-31 00:00:00


In [10]:
train_dates = (train.index.min(), train.index.max())
val_dates = (val.index.min(), val.index.max())
test_dates = (test.index.min(), test.index.max())

print(f"Train dates: {train_dates}, Val dates: {val_dates}, Test dates: {test_dates}")

Train dates: (Timestamp('2020-05-01 00:00:00'), Timestamp('2021-02-01 00:00:00')), Val dates: (Timestamp('2021-02-02 00:00:00'), Timestamp('2021-04-21 00:00:00')), Test dates: (Timestamp('2021-04-22 00:00:00'), Timestamp('2021-05-31 00:00:00'))


In [11]:
features = [
    "covidOccupiedMVBeds",
    "hospitalCases_rolling_7_mean",
    "hospitalCases_rolling_7_std",
    "newAdmissions_rolling_7_mean",
    "newAdmissions_rolling_7_std",
    "new_confirmed_rolling_7_mean",
    "new_confirmed_rolling_7_std",
    "new_deceased_rolling_7_mean",
    "new_deceased_rolling_7_std",
    "covidOccupiedMVBeds_lag_1",
    "covidOccupiedMVBeds_lag_2",
    "covidOccupiedMVBeds_lag_3",
    "covidOccupiedMVBeds_lag_5",
    "covidOccupiedMVBeds_lag_7",
    "covidOccupiedMVBeds_lag_14",
    "covidOccupiedMVBeds_lag_21",
    "month",
    "day",
    "day_of_week",
    "susceptible",
    "exposed",
    "active_cases",
    "recovered",
    "cumulative_deceased",
]

In [12]:
train.head()

Unnamed: 0_level_0,covidOccupiedMVBeds,hospitalCases,newAdmissions,new_confirmed,new_deceased,hospitalCases_rolling_7_mean,hospitalCases_rolling_7_std,newAdmissions_rolling_7_mean,newAdmissions_rolling_7_std,new_confirmed_rolling_7_mean,...,covidOccupiedMVBeds_lag_14,covidOccupiedMVBeds_lag_21,month,day,day_of_week,susceptible,exposed,active_cases,recovered,cumulative_deceased
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2020-05-01,63.351164,1211.0,127,492.0,72.0,1341.571429,82.687881,152.714286,28.877409,443.0,...,320.0,280.0,5,1,4,6221338.5,5.704047,11255.908,24.35237,2814.158
2020-05-02,60.890383,1190.0,111,303.0,38.0,1313.857143,97.314659,150.285714,31.862801,432.714286,...,312.0,283.0,5,2,5,6220566.5,6897.2144,11504.0,1221.9509,2802.905
2020-05-03,59.896116,1223.0,166,317.0,54.0,1285.428571,89.184934,152.857143,32.369886,427.0,...,306.0,291.0,5,3,6,6222003.5,9310.359,11591.686,-2363.8323,2842.6924
2020-05-04,59.146535,1138.0,106,392.0,36.0,1246.142857,84.044489,149.857143,35.941752,409.714286,...,294.0,264.0,5,4,0,6221165.5,9295.368,11566.316,179.91927,2855.6702
2020-05-05,59.146535,1160.0,90,332.0,48.0,1213.428571,60.417594,136.428571,38.526429,391.0,...,279.0,303.0,5,5,1,6220008.0,9924.149,11582.3955,-614.2166,2924.508


In [13]:
# minmax scaling for the features and target variable

# Define the features to be scaled
# features = [col for col in train.columns if col not in ['susceptible', 'exposed', 'active_cases', 'recovered', 'cumulative_deceased']]

# Initialize the MinMaxScaler
scaler = MinMaxScaler()

# Fit the scaler on the training set
scaler.fit(train[features])

# Transform the training, validation, and test sets
train[features] = scaler.transform(train[features])
val[features] = scaler.transform(val[features])
test[features] = scaler.transform(test[features])


In [14]:
train.head()

Unnamed: 0_level_0,covidOccupiedMVBeds,hospitalCases,newAdmissions,new_confirmed,new_deceased,hospitalCases_rolling_7_mean,hospitalCases_rolling_7_std,newAdmissions_rolling_7_mean,newAdmissions_rolling_7_std,new_confirmed_rolling_7_mean,...,covidOccupiedMVBeds_lag_14,covidOccupiedMVBeds_lag_21,month,day,day_of_week,susceptible,exposed,active_cases,recovered,cumulative_deceased
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2020-05-01,0.600239,1211.0,127,492.0,72.0,0.312642,0.22074,0.312166,0.396367,0.052339,...,0.839474,0.768595,0.363636,0.0,0.666667,0.998003,0.0,0.079177,0.009127,0.00133
2020-05-02,0.576586,1190.0,111,303.0,38.0,0.306069,0.260041,0.307122,0.439019,0.051012,...,0.818421,0.77686,0.363636,0.033333,0.833333,0.995685,0.050487,0.081102,0.013704,0.0
2020-05-03,0.56703,1223.0,166,317.0,54.0,0.299326,0.238197,0.312463,0.446264,0.050274,...,0.802632,0.798898,0.363636,0.066667,1.0,1.0,0.068165,0.081783,0.0,0.004701
2020-05-04,0.559825,1138.0,106,392.0,36.0,0.290008,0.224385,0.306231,0.497295,0.048042,...,0.771053,0.724518,0.363636,0.1,0.0,0.997484,0.068055,0.081586,0.009721,0.006235
2020-05-05,0.559825,1160.0,90,332.0,48.0,0.282249,0.160901,0.278338,0.534222,0.045626,...,0.731579,0.831956,0.363636,0.133333,0.166667,0.994008,0.072662,0.081711,0.006686,0.014368
