In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scoring  # the scoring module
from datetime import date, datetime, timedelta
from matplotlib.dates import DateFormatter
import matplotlib.dates as mdates  # Corrected import statement 
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error, r2_score, mean_absolute_error, root_mean_squared_error

from firefly_optimizer import *
from plot_results import *
from transition_probability_estimation import *
from Data_synthesize import *
from rw_data_processing import *

%load_ext autoreload
%autoreload 
plt.style.use(r"./rw_visualization.mplstyle")


In [None]:
# import warnings filter
from warnings import simplefilter
# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)

# Color
current_palette = seaborn.color_palette()
current_palette


# 1. Predefine the Taiwan source outbreak infection days

In [36]:
def transform_dates_to_days(dates):
    sorted_dates = sorted(dates)
    days = [0]
    for i in range(1, len(sorted_dates)):
        days.append((sorted_dates[i] - sorted_dates[0]).days)
    return days

In [37]:
# Local positive cases confirmed dates
sources_confirmed_dates = ["2020/1/28", "2020/1/30", "2020/2/3", "2020/2/19", "2020/2/23", "2020/2/28", "2020/3/5", "2020/3/13", "2020/3/18", "2020/3/18", "2020/3/19", "2020/3/20", "2020/3/22", "2020/3/20", "2020/3/24", "2020/3/26", "2020/3/26", "2020/3/28", "2020/3/28", "2020/3/29", "2020/3/31", "2020/3/31", "2020/4/2", "2020/4/2", "2020/4/3", "2020/4/4", "2020/4/8", "2020/4/12"]
infection_dates = [datetime.strptime(date, "%Y/%m/%d") for date in sources_confirmed_dates]

# Generate 1000 duplicate infection dates
infection_dates_list = []
infection_days_list = []
for i in range(1100):
    infection_dates_list.append(infection_dates)
    infection_days = transform_dates_to_days(infection_dates)
    infection_days_list.append(infection_days)

with open('./variable/infection_days_list.pkl', 'wb') as f:
    pickle.dump(infection_days_list, f)

## NOTE: Please run data_synthesis.sh with mode taiwan_first_outbreak

# 2. Load Taiwan first outbreak population-level data

In [38]:
# Load data
tw_data_path = Path('./data/structured_course_of_disease_data/figshare_taiwan_covid.xlsx')
tw_summary_data = pd.read_excel(tw_data_path, sheet_name='Summary')

# First confirmed case date to the last confirmed case date
tw_summary_data = tw_summary_data.loc[\
     (tw_summary_data['announce_date'] >= pd.to_datetime('2020-01-27')) &\
     (pd.to_datetime('2020-08-01') >= tw_summary_data['announce_date'])] 

# # Data time range of our structured course of disease dataset with contact tracing information 
# # (Wu, Yu-Heng, and Torbjörn EM Nordling. "A structured course of disease dataset with contact tracing information in Taiwan for COVID-19 modelling." Scientific Data 11.1 (2024): 821.)
# tw_summary_data = tw_summary_data.loc[\
#      (tw_summary_data['announce_date'] >= pd.to_datetime('2020-01-21')) &\
#      (pd.to_datetime('2020-11-9') >= tw_summary_data['announce_date'])] 

# # First outbreak
# tw_summary_data = tw_summary_data.loc[\
#      (tw_summary_data['announce_date'] >= pd.to_datetime('2020-02-16')) &\
#      (pd.to_datetime('2020-04-11') >= tw_summary_data['announce_date'])] 

# # One day before the first death to the end of the first outbreak
# tw_summary_data = tw_summary_data.loc[\
#      (tw_summary_data['announce_date'] >= pd.to_datetime('2020-02-14')) &\
#      (pd.to_datetime('2020-04-11') >= tw_summary_data['announce_date'])] 

In [39]:
# Local local positive cases
number_of_local_positive_case = tw_summary_data['number_of_local_positive_cases'].to_numpy()
nan_index = np.isnan(number_of_local_positive_case)
# Remove all NaN values
number_of_local_positive_case = number_of_local_positive_case[~nan_index]
number_of_local_positive_case = np.flip(number_of_local_positive_case.astype(np.int32),0)

In [40]:
# Local unknown source positive cases
number_of_unknown_positive_cases = tw_summary_data['number_of_unknown_positive_cases'].\
    to_numpy()
# Remove all NaN values
number_of_unknown_positive_cases = number_of_unknown_positive_cases[~nan_index]
number_of_unknown_positive_cases = np.flip(number_of_unknown_positive_cases.astype(int),0)

In [41]:
date_time = tw_summary_data['announce_date'].to_numpy()
# Remove all NaN values
date_time = date_time[~nan_index]
date_time = np.flip(date_time,0)

In [None]:
number_of_death_cases = tw_summary_data['dead'].to_numpy()
# Remove all NaN values
number_of_death_cases = number_of_death_cases[~nan_index]
number_of_death_cases = np.flip(number_of_death_cases.astype(int),0)

# ID 19, local, confirmed date 02-03, death date 02-15
# ID 27, local, confirmed date 02-23, death date 03-20
# ID 34, local, confirmed date 02-28, death date 03-29
# ID 101, abroad, confirmed date 03-19, death date 04-10
# ID 108, abroad, confirmed date 03-19, death date 03-29
# ID 170, abroad, confirmed date 03-23, death date 03-29
# ID 197, abroad, confirmed date 03-24, death date 05-11

# Only keep the first three cases since only they are local cases
index = np.where(number_of_death_cases==5)[0][0]
number_of_death_cases[index::] = 3


In [43]:
def insert_missing_date(date, number_of_case):
    """
    Insert the missing data by copying the data from the previous day.

    Input
    -----
    date: Numpy array of available date information with part of the dates missing
    number_of_case: Numpy array of number of cases, e.g. abroad positive cases, recovered cases, or death cases

    Output
    ------
    date_full: Numpy array of available date information with all the dates
    number_of_case_full: Numpy array of number of cases, e.g. abroad positive cases, recovered cases, or death cases
    """
    date_full = np.arange(date[0], date[-1]+np.timedelta64(1, 'D'), dtype='datetime64[D]')
    index = 0
    number_of_case_full = []
    for i in date_full:
        index = np.where(i == date)
        if np.shape(index)[1] == 1: # Existed date
            number_of_case_full = np.append(number_of_case_full, number_of_case[index[0][0]])
            index_temp = index
        elif np.shape(index)[1] == 0: # Missing date
            number_of_case_full = np.append(number_of_case_full, number_of_case[index_temp[0][0]])
        elif np.shape(index)[1] == 2: # Date exist but duplicate
            number_of_case_full = np.append(number_of_case_full, number_of_case[index[0][-1]])
            index_temp = index
        else:
            print('Error')
            print(i)
    return number_of_case_full, date_full

In [44]:
number_of_local_positive_case_full, date_full = insert_missing_date(date_time, number_of_local_positive_case)
number_of_unknown_positive_cases_full, _ = insert_missing_date(date_time, number_of_unknown_positive_cases)
number_of_death_cases_full, _ = insert_missing_date(date_time, number_of_death_cases)

In [45]:
daily_tw_local_confirmed_cases = np.diff(number_of_local_positive_case_full) + np.diff(number_of_unknown_positive_cases_full)
daily_tw_local_deaths = np.diff(number_of_death_cases_full)

# Padding to make all arrays 365 days long
daily_tw_local_confirmed_cases = np.pad(daily_tw_local_confirmed_cases, (0, 365 - len(daily_tw_local_confirmed_cases)), 'constant')
daily_tw_local_deaths = np.pad(daily_tw_local_deaths, (0, 365 - len(daily_tw_local_deaths)), 'constant')
daily_tw_date_time = np.arange(date_full[0], date_full[0] + np.timedelta64(365, 'D'), dtype='datetime64[D]')

In [None]:
# "The symptom onset date of the first patient identified was Dec 1, 2019." Huang, Chaolin, et al. "Clinical features of patients infected with 2019 novel coronavirus in Wuhan, China." The lancet 395.10223 (2020): 497-506.
print('Largest possible time shift: ', daily_tw_date_time[0]-np.datetime64('2019-12-01'))

# 3. Load CovSyn data and transforme it to population-level data

In [None]:
data_path = Path('./synthetic_data_results_taiwan_first_outbreak')
# Load 1000 data with at least 1 spreads happened.
demographic_data_list_all, social_data_list_all, course_of_disease_data_list_all, contact_data_list_all, transmission_digraph_all = \
    load_synthetic_data(data_path, return_len=1000, memory_limit=1e9, min_case_num=len(sources_confirmed_dates)+1)

In [48]:
time_limit = 365
data_len = len(demographic_data_list_all)
daily_infected_cases_matrix = np.zeros((data_len, time_limit))
daily_confirmed_cases_matrix = np.zeros((data_len, time_limit))
daily_recoverd_cases_matrix = np.zeros((data_len, time_limit))
daily_deaths_matrix = np.zeros((data_len, time_limit))
for i in range(data_len):
    course_of_disease_data = course_of_disease_data_list_all[i]
    contact_data = contact_data_list_all[i]
    daily_susceptible_population, daily_infected_cases, daily_contagious_cases, daily_symptomatic_cases, \
                daily_confirmed_cases, daily_tested_cases, daily_suspected_cases, daily_isolation_cases, daily_critically_ill_cases, daily_recovered_cases, \
                daily_deaths = transform_course_object_to_population_data(course_of_disease_data,
                                                                            contact_data,
                                                                            time_limit=time_limit-1,
                                                                            population_size=23008366)
    daily_infected_cases_matrix[i, :] = daily_infected_cases
    daily_confirmed_cases_matrix[i, :] = daily_confirmed_cases
    daily_recoverd_cases_matrix[i, :] = daily_recovered_cases
    daily_deaths_matrix[i, :] = daily_deaths


In [49]:
# Create daily source infected cases array. This store the pattern of the source infection cases which I set in the CovSyn simulation.
np_source_confirmed_dates = np.zeros(len(sources_confirmed_dates), dtype='datetime64[D]')
for i, date_str in enumerate(sources_confirmed_dates):
    confirm_date = np.datetime64(datetime.strptime(date_str, '%Y/%m/%d').strftime('%Y-%m-%d'))
    np_source_confirmed_dates[i] = confirm_date

min_confirmed_date = np.min(np_source_confirmed_dates)  # Use np.min instead of min
daily_source_infected_cases = np.zeros(time_limit)
for i, np_source_confirmed_date in enumerate(np_source_confirmed_dates):
    # Convert timedelta to integer days
    index = int((np_source_confirmed_date - min_confirmed_date) / np.timedelta64(1, 'D'))
    if 0 <= index < time_limit:  # Add bounds check
        daily_source_infected_cases[index] += 1  # Increment the day's count, not the sequence position

In [None]:
print('Mean of total confirmed cases: ', np.mean(np.sum(daily_confirmed_cases_matrix, axis=1)))
print('Median of total confirmed cases: ', np.median(np.sum(daily_confirmed_cases_matrix, axis=1)))
print('Number of source cases: ', len(sources_confirmed_dates))
print('Total number of TW local confirmed cases (summary data): ', sum(daily_tw_local_confirmed_cases))


# 4. Optimize the day shift for each simulation

In [51]:
def shift_and_pad(array1, array2, time_shift):
    array1_padded = np.pad(array1, (time_shift, 0), 'constant', constant_values=0)
    array2_padded = np.pad(array2, (0, time_shift), 'constant', constant_values=array2[-1])

    return (array1_padded, array2_padded)

In [52]:
def optimize_time_shift(dates, cumulative_cases, cumulative_simulated_cases):
    """
    Find the optimal time shift between two arrays using a modified bisection search.
    Returns the time shift that minimizes the mean squared error.
    """
    # Apply bisection
    time_shift_L = 0
    time_shift_R = 57
    # time_shift_R = 77
    
    while time_shift_R - time_shift_L > 1:
        time_shift_M = (time_shift_L + time_shift_R) // 2
        
        # Compute losses for left, middle, and right points
        L_cumulative_cases, L_shifted = shift_and_pad(cumulative_cases, cumulative_simulated_cases, time_shift_L)
        M_cumulative_cases, M_shifted = shift_and_pad(cumulative_cases, cumulative_simulated_cases, time_shift_M)
        R_cumulative_cases, R_shifted = shift_and_pad(cumulative_cases, cumulative_simulated_cases, time_shift_R)
        
        L_loss = mean_squared_error(L_cumulative_cases, L_shifted)
        M_loss = mean_squared_error(M_cumulative_cases, M_shifted)
        R_loss = mean_squared_error(R_cumulative_cases, R_shifted)
        
        # Update search interval - this assumes a unimodal error function
        if L_loss < M_loss:
            time_shift_R = time_shift_M
        elif R_loss < M_loss:
            time_shift_L = time_shift_M
        else:
            # If middle is the best, narrow from both sides
            time_shift_L = (time_shift_L + time_shift_M) // 2
            time_shift_R = (time_shift_M + time_shift_R) // 2
    shifted_dates = np.arange(dates[0]-time_shift_L, dates[-1]+1)

    return (time_shift_M, shifted_dates, M_cumulative_cases, M_shifted, M_loss)

In [53]:
# Get mean value and 95% confidence interval for both daily and cumulative

# Daily confirmed cases
mean_daily_confirmed_cases = np.mean(daily_confirmed_cases_matrix, axis=0)

# Daily deaths
mean_daily_deaths = np.mean(daily_deaths_matrix, axis=0)

# Cumulative confirmed cases
cumulative_confirmed_cases_matrix = np.cumsum(daily_confirmed_cases_matrix, axis=1)
mean_cumulative_confirmed_cases = np.mean(cumulative_confirmed_cases_matrix, axis=0)

# Cumulative deaths
cumulative_deaths_matrix = np.cumsum(daily_deaths_matrix, axis=1)
mean_cumulative_deaths = np.mean(cumulative_deaths_matrix, axis=0)

# Cumulative Taiwan data
cumulative_tw_local_confirmed_cases = np.cumsum(daily_tw_local_confirmed_cases)
cumulative_tw_local_death = np.cumsum(daily_tw_local_deaths)

In [None]:
# Optimize time shift
optimal_t_shift, shifted_dates, shifted_cumulative_tw_local_confirmed_cases, shifted_mean_cumulative_confirmed_cases, mse_loss = optimize_time_shift(daily_tw_date_time, cumulative_tw_local_confirmed_cases, mean_cumulative_confirmed_cases)
print(f"Optimal time shift: {optimal_t_shift} days")
print("MSE loss: ", mse_loss)

In [55]:
# Shift the case matrix and get the mean and the 95 CI.

# Daily confirmed cases
shifted_daily_confirmed_cases_matrix = np.pad(daily_confirmed_cases_matrix, ((0, 0), (0, optimal_t_shift)), 'constant', constant_values=0)
shifted_mean_daily_confirmed_cases = np.mean(shifted_daily_confirmed_cases_matrix, axis=0)
shifted_lb_daily_confirmed_cases = np.percentile(shifted_daily_confirmed_cases_matrix, 2.5, axis=0)
shifted_ub_daily_confirmed_cases = np.percentile(shifted_daily_confirmed_cases_matrix, 97.5, axis=0)

# Daily deaths
shifted_daily_deaths_matrix = np.pad(daily_deaths_matrix, ((0, 0), (optimal_t_shift, 0)), 'constant', constant_values=0)
shifted_mean_daily_deaths = np.mean(shifted_daily_deaths_matrix, axis=0)
shifted_lb_daily_deaths = np.percentile(shifted_daily_deaths_matrix, 2.5, axis=0)
shifted_ub_daily_deaths = np.percentile(shifted_daily_deaths_matrix, 97.5, axis=0)

# Cumulative confirmed cases
shifted_cumulative_confirmed_cases_matrix = np.cumsum(shifted_daily_confirmed_cases_matrix, axis=1)
shifted_mean_cumulative_confirmed_cases = np.mean(shifted_cumulative_confirmed_cases_matrix, axis=0)
shifted_lb_cumulative_confirmed_cases = np.percentile(shifted_cumulative_confirmed_cases_matrix, 2.5, axis=0)
shifted_ub_cumulative_confirmed_cases = np.percentile(shifted_cumulative_confirmed_cases_matrix, 97.5, axis=0)

# Cumulative deaths
shifted_cumulative_deaths_matrix = np.cumsum(shifted_daily_deaths_matrix, axis=1)
shifted_mean_cumulative_deaths = np.mean(shifted_cumulative_deaths_matrix, axis=0)
shifted_lb_cumulative_deaths = np.percentile(shifted_cumulative_deaths_matrix, 2.5, axis=0) 
shifted_ub_cumulative_deaths = np.percentile(shifted_cumulative_deaths_matrix, 97.5, axis=0)

# Shift Taiwan data
shifted_daily_tw_local_confirmed_cases = np.pad(daily_tw_local_confirmed_cases, (optimal_t_shift, 0), 'constant', constant_values=0)
shifted_daily_tw_local_deaths = np.pad(daily_tw_local_deaths, (optimal_t_shift, 0), 'constant', constant_values=0)
shifted_cumulative_tw_local_deaths = np.cumsum(shifted_daily_tw_local_deaths)

# Shift the source infection cases
shifted_daily_source_infected_cases = np.pad(daily_source_infected_cases, (0, optimal_t_shift), 'constant', constant_values=0)
shifted_cumulative_source_infected_cases = np.cumsum(shifted_daily_source_infected_cases)

# 5. Calculate GOF

In [56]:
def gof(actual, predicted, t_type='daily'):
    if t_type == 'daily':
        pass
    elif t_type == 'weekly':
        window_size = 7
        actual = np.convolve(actual, np.ones(window_size)/window_size, mode='same')
        predicted = np.convolve(predicted, np.ones(window_size)/window_size, mode='same')
    elif t_type == 'monthly':
        window_size = 31 # odds number
        actual = np.convolve(actual, np.ones(window_size)/window_size, mode='same')
        predicted = np.convolve(predicted, np.ones(window_size)/window_size, mode='same')


    cumulative_actual = np.cumsum(actual)
    cumulative_predicted = np.cumsum(predicted)
    mae = mean_absolute_error(actual, predicted)
    mae_cumsum = mean_absolute_error(cumulative_actual, cumulative_predicted)
    mse = mean_squared_error(actual, predicted)
    mse_cumsum = mean_squared_error(cumulative_actual, cumulative_predicted)
    rmse = root_mean_squared_error(actual, predicted)
    rmse_cumsum = root_mean_squared_error(cumulative_actual, cumulative_predicted)
    nae = np.mean(np.abs(actual-predicted)/(actual+predicted))
    nae_cumsum = np.mean(np.abs(cumulative_actual-cumulative_predicted)/(cumulative_actual+cumulative_predicted))
    mape = mean_absolute_percentage_error(actual, predicted)
    mape_cumsum = mean_absolute_percentage_error(cumulative_actual, cumulative_predicted)
    r2 = r2_score(actual, predicted)
    r2_cumsum = r2_score(cumulative_actual, cumulative_predicted)


    return mae, mae_cumsum, mse, mse_cumsum, rmse, rmse_cumsum, nae, nae_cumsum, mape, mape_cumsum, r2, r2_cumsum

In [57]:
# First confirmed case date to the last confirmed case date
# start_date = np.datetime64('2020-01-27')

# The first date of the first simulated confirmed case
start_date = shifted_dates[0]
start_index = np.where(shifted_dates == start_date)[0][0]
# end_date = np.datetime64('2020-08-01')
end_date = start_date + 28*6
end_index = np.where(shifted_dates == end_date)[0][0]
target_time_range_indices = range(start_index, end_index)
gof_dates = shifted_dates[target_time_range_indices]

# Select appropriate actual data
actual_confirmed_cases = shifted_daily_tw_local_confirmed_cases[target_time_range_indices]
cumulative_actual_confirmed_cases = shifted_cumulative_tw_local_confirmed_cases[target_time_range_indices]
actual_deaths = shifted_daily_tw_local_deaths[target_time_range_indices]
cumulative_actual_deaths = shifted_cumulative_tw_local_deaths[target_time_range_indices]

# Select appropriate predicted mean data
predicted_confirmed_cases = shifted_mean_daily_confirmed_cases[target_time_range_indices]
cumulative_predicted_confirmed_cases = shifted_mean_cumulative_confirmed_cases[target_time_range_indices]
predicted_deaths = shifted_mean_daily_deaths[target_time_range_indices]
cumulative_predicted_deaths = shifted_mean_cumulative_deaths[target_time_range_indices]


In [58]:
# Select appropriate predicted data matrix
predicted_confirmed_cases_matrix = shifted_daily_confirmed_cases_matrix[:, target_time_range_indices]
cumulative_predicted_confirmed_cases_matrix = shifted_cumulative_confirmed_cases_matrix[:, target_time_range_indices]
predicted_deaths_matrix = shifted_daily_deaths_matrix[:, target_time_range_indices]
cumulative_predicted_deaths_matrix = shifted_cumulative_deaths_matrix[:, target_time_range_indices]

In [None]:
len(actual_confirmed_cases)

In [None]:
# Define data types, time periods, and metrics
data_types = ['Confirmed', 'Deaths']
t_types = ['daily', 'weekly', 'monthly']
metrics = ['MAE', 'MAE (Cumulative)', 'MSE', 'MSE (Cumulative)', 'RMSE', 'RMSE (Cumulative)', 'NAE', 'NAE (Cumulative)', 
           'MAPE', 'MAPE (Cumulative)', 'R2', 'R2 (Cumulative)']

# Create dictionary to store results
results_dict = {}

# Loop through each combination of data type and time period
for t_type in t_types:
    for data_type in data_types:
        row_name = f"{t_type} {data_type}"
        results_dict[row_name] = {}  # Create a nested dictionary for each row
        
        # Select appropriate actual and predicted data
        if data_type == 'Confirmed':
            actual = actual_confirmed_cases
            predicted_matrix = predicted_confirmed_cases_matrix
        else:  # Deaths
            actual = actual_deaths
            predicted_matrix = predicted_deaths_matrix
        
        # Calculate goodness of fit with the updated function that returns all metrics
        maes = []
        maes_cumsum = []
        mses = []
        mses_cumsum = []
        rmses = []
        rmses_cumsum = []
        r2s = []
        r2s_cumsum = []
        for i in range(predicted_matrix.shape[0]):
            predicted = predicted_matrix[i, :]
            mae, mae_cumsum, mse, mse_cumsum, rmse, rmse_cumsum, nae, nae_cumsum, mape, mape_cumsum, r2, r2_cumsum = gof(actual, predicted, t_type=t_type)
            maes.append(mae)
            maes_cumsum.append(mae_cumsum)
            mses.append(mse)
            mses_cumsum.append(mse_cumsum)
            rmses.append(rmse)
            rmses_cumsum.append(rmse_cumsum)
            r2s.append(r2)
            r2s_cumsum.append(r2_cumsum)

        mae_mean = np.mean(maes)
        mae_ub = np.percentile(maes, 97.5)
        mae_lb = np.percentile(maes, 2.5)
        mae_cumsum_mean = np.mean(maes_cumsum)
        mae_cumsum_ub = np.percentile(maes_cumsum, 97.5)
        mae_cumsum_lb = np.percentile(maes_cumsum, 2.5)
        mse_mean = np.mean(mses)
        mse_ub = np.percentile(mses, 97.5)
        mse_lb = np.percentile(mses, 2.5)
        mse_cumsum_mean = np.mean(mses_cumsum)
        mse_cumsum_ub = np.percentile(mses_cumsum, 97.5)
        mse_cumsum_lb = np.percentile(mses_cumsum, 2.5)
        rmse_mean = np.mean(rmses)
        rmse_ub = np.percentile(rmses, 97.5)
        rmse_lb = np.percentile(rmses, 2.5)
        rmse_cumsum_mean = np.mean(rmses_cumsum)
        rmse_cumsum_ub = np.percentile(rmses_cumsum, 97.5)
        rmse_cumsum_lb = np.percentile(rmses_cumsum, 2.5)
        r2_mean = np.mean(r2s)
        r2_ub = np.percentile(r2s, 97.5)
        r2_lb = np.percentile(r2s, 2.5)
        r2_cumsum_mean = np.mean(r2s_cumsum)
        r2_cumsum_ub = np.percentile(r2s_cumsum, 97.5)
        r2_cumsum_lb = np.percentile(r2s_cumsum, 2.5)
        
        # Store results with metrics as column names
        results_dict[row_name]['MAE'] = mae_mean
        results_dict[row_name]['MAE ub'] = mae_ub
        results_dict[row_name]['MAE lb'] = mae_lb
        results_dict[row_name]['MAE (Cumulative)'] = mae_cumsum_mean
        results_dict[row_name]['MAE (Cumulative) ub'] = mae_cumsum_ub
        results_dict[row_name]['MAE (Cumulative) lb'] = mae_cumsum_lb
        results_dict[row_name]['MSE'] = mse_mean
        results_dict[row_name]['MSE ub'] = mse_ub
        results_dict[row_name]['MSE lb'] = mse_lb
        results_dict[row_name]['MSE (Cumulative)'] = mse_cumsum_mean
        results_dict[row_name]['MSE (Cumulative) ub'] = mse_cumsum_ub
        results_dict[row_name]['MSE (Cumulative) lb'] = mse_cumsum_lb
        results_dict[row_name]['RMSE'] = rmse_mean
        results_dict[row_name]['RMSE ub'] = rmse_ub
        results_dict[row_name]['RMSE lb'] = rmse_lb
        results_dict[row_name]['RMSE (Cumulative)'] = rmse_cumsum_mean
        results_dict[row_name]['RMSE (Cumulative) ub'] = rmse_cumsum_ub
        results_dict[row_name]['RMSE (Cumulative) lb'] = rmse_cumsum_lb
        results_dict[row_name]['R2'] = r2_mean
        results_dict[row_name]['R2 ub'] = r2_ub
        results_dict[row_name]['R2 lb'] = r2_lb
        results_dict[row_name]['R2 (Cumulative)'] = r2_cumsum_mean
        results_dict[row_name]['R2 (Cumulative) ub'] = r2_cumsum_ub
        results_dict[row_name]['R2 (Cumulative) lb'] = r2_cumsum_lb

# Create DataFrame with data types and time periods as rows, metrics as columns
results_df = pd.DataFrame.from_dict(results_dict, orient='index')

# Add a column for the row names (can be useful for further processing)
results_df = results_df.reset_index().rename(columns={'index': ''})

# Round all numeric values to 3 decimal places
numeric_columns = results_df.columns[1:]  # All columns except 'Data Type'
results_df[numeric_columns] = results_df[numeric_columns].round(2)

# Save to CSV
results_df.to_csv('goodness_of_fit_results.csv', index=False)

# Display the table
print(results_df)

# Shift result plot

In [61]:
gof_cumulative_source_infected_cases = shifted_cumulative_source_infected_cases[target_time_range_indices]

In [62]:
# Get 95% confidence interval
gof_cumulative_confirmed_cases_lb = shifted_lb_cumulative_confirmed_cases[target_time_range_indices]
gof_cumulative_confirmed_cases_ub = shifted_ub_cumulative_confirmed_cases[target_time_range_indices]
gof_cumulative_deaths_lb = shifted_lb_cumulative_deaths[target_time_range_indices]
gof_cumulative_deaths_ub = shifted_ub_cumulative_deaths[target_time_range_indices]

In [63]:
quantile_cumulative_confirmed_cases_dict = {0.025: np.quantile(shifted_cumulative_confirmed_cases_matrix, 0.025, axis=0)[target_time_range_indices],
                      0.975: np.quantile(shifted_cumulative_confirmed_cases_matrix, 0.975, axis=0)[target_time_range_indices]}
cumulative_confirmed_cases_wis_score, sharpness, calibration = scoring.weighted_interval_score(cumulative_actual_confirmed_cases, alphas=[0.05], 
                              q_dict=quantile_cumulative_confirmed_cases_dict)

In [64]:
quantile_cumulative_deaths_dict = {0.025: np.quantile(shifted_cumulative_deaths_matrix, 0.025, axis=0)[target_time_range_indices],
                      0.975: np.quantile(shifted_cumulative_deaths_matrix, 0.975, axis=0)[target_time_range_indices]}
cumulative_deaths_wis_score, death_sharpness, death_calibration = scoring.weighted_interval_score(np.cumsum(actual_deaths), alphas=[0.05], 
                              q_dict=quantile_cumulative_deaths_dict)

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(10, 9), gridspec_kw={'height_ratios': [2, 1]})

# Daily confirmed cases
axs[0].plot(gof_dates, cumulative_actual_confirmed_cases, ':', color=current_palette[0])
axs[0].plot(gof_dates, cumulative_actual_deaths, ':', color=current_palette[2])
# Create a mask for points where there's an increase in confirmed cases
is_case_increase = np.concatenate([[True], np.diff(cumulative_actual_confirmed_cases) > 0])
# Create a mask for points where there's an increase in deaths
is_death_increase = np.concatenate([[True], np.diff(cumulative_actual_deaths) > 0])
axs[0].plot(gof_dates[is_case_increase], cumulative_actual_confirmed_cases[is_case_increase], 
         '*', color=current_palette[0], markersize=8, label='Observed local confirmed cases')
axs[0].plot(gof_dates[is_death_increase], cumulative_actual_deaths[is_death_increase], 
         'X', color=current_palette[2], markersize=8, label='Observed local deaths')

# CovSyn daily confirmed cases
axs[0].plot(gof_dates, cumulative_predicted_confirmed_cases, color=current_palette[0], label='CovSyn local confirmed cases')
axs[0].fill_between(gof_dates, gof_cumulative_confirmed_cases_lb, gof_cumulative_confirmed_cases_ub, color=current_palette[0], alpha=0.3, label='95% CI for CovSyn cases')

# CovSyn daily death cases
axs[0].plot(gof_dates, cumulative_predicted_deaths, color=current_palette[2], label='CovSyn local deaths')
axs[0].fill_between(gof_dates, gof_cumulative_deaths_lb, gof_cumulative_deaths_ub, color=current_palette[2], alpha=0.3, label='95% CI for CovSyn deaths')

# Source infected cases
axs[0].plot(gof_dates, gof_cumulative_source_infected_cases, '--', color=current_palette[3], label='Source infected cases')

# Generate exact dates for major ticks every 28 days
major_tick_dates = []
current_date = start_date
plot_end_date = end_date
while current_date <= plot_end_date:
    major_tick_dates.append(current_date)
    current_date = current_date + timedelta(days=28)

# Generate weekly minor ticks between each major tick
minor_tick_dates = []
for i in range(len(major_tick_dates)-1):
    current = major_tick_dates[i] + timedelta(days=7)
    while current < major_tick_dates[i+1]:
        minor_tick_dates.append(current)
        current = current + timedelta(days=7)

# Convert dates to matplotlib's ordinal format
major_tick_locations = [mdates.date2num(d) for d in major_tick_dates]
minor_tick_locations = [mdates.date2num(d) for d in minor_tick_dates]

# Configure first subplot (axs[0])
axs[0].legend(numpoints=1, loc='best')

# Set major ticks but hide labels
axs[0].set_xticks(major_tick_locations)
# Don't set DateFormatter here - that would override our empty labels
axs[0].tick_params(axis='x', which='major', length=6)  # Make ticks visible

# Important: After setting ticks, hide the labels by setting them to empty strings
axs[0].set_xticklabels(['' for _ in range(len(major_tick_locations))])

# Set minor ticks
axs[0].set_xticks(minor_tick_locations, minor=True)
axs[0].tick_params(axis='x', which='minor', bottom=True)

# Add grid aligned with ticks
axs[0].grid(True, which='major', axis='both', linestyle='dotted', color='gray', alpha=0.7)

axs[0].set_ylabel('Number of cases')
axs[0].set_xlim(start_date - timedelta(days=2), plot_end_date + timedelta(days=2))
# axs[0].set_ylim([0, 4.1])

# WIS plot for second subplot (axs[1])
axs[1].plot(gof_dates, cumulative_confirmed_cases_wis_score, color=current_palette[0], label='Confirmed cases')
axs[1].plot(gof_dates, cumulative_deaths_wis_score, color=current_palette[2], label='Deaths')

# Configure second subplot (axs[1]) with same x-axis settings
axs[1].set_xticks(major_tick_locations)
axs[1].xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
axs[1].tick_params(axis='x', rotation=30)

# Set minor ticks
axs[1].set_xticks(minor_tick_locations, minor=True)
axs[1].tick_params(axis='x', which='minor', bottom=True)

# Add grid aligned with ticks
axs[1].grid(True, which='major', axis='both', linestyle='dotted', color='gray', alpha=0.7)

axs[1].set_ylabel('Weighted interval score')
# axs[1].set_xlabel('Date')
axs[1].set_xlim(start_date - timedelta(days=2), plot_end_date + timedelta(days=2))
axs[1].legend()

# Set horizontal alignment for tick labels - only for the second subplot
plt.setp(axs[1].get_xticklabels(), ha='right')

plt.tight_layout()

df = pd.DataFrame({'Dates': gof_dates, 'WIS':cumulative_confirmed_cases_wis_score, 
                   'WIS death': cumulative_deaths_wis_score, 'Actual data': cumulative_actual_confirmed_cases})
# df.to_csv('wis_score.csv')
# plt.savefig("RW2025_Covsyn_population_data_wis.pdf")

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(9.5, 6))

# Daily confirmed cases
axs.plot(gof_dates, cumulative_actual_confirmed_cases, ':', color=current_palette[0])
axs.plot(gof_dates, cumulative_actual_deaths, ':', color=current_palette[2])
# Create a mask for points where there's an increase in confirmed cases
is_case_increase = np.concatenate([[True], np.diff(cumulative_actual_confirmed_cases) > 0])
# Create a mask for points where there's an increase in deaths
is_death_increase = np.concatenate([[True], np.diff(cumulative_actual_deaths) > 0])
axs.plot(gof_dates[is_case_increase], cumulative_actual_confirmed_cases[is_case_increase], 
         '*', color=current_palette[0], markersize=8, label='Observed local confirmed cases')
axs.plot(gof_dates[is_death_increase], cumulative_actual_deaths[is_death_increase], 
         'X', color=current_palette[2], markersize=8, label='Observed local deaths')

# CovSyn daily confirmed cases
axs.plot(gof_dates, cumulative_predicted_confirmed_cases, color=current_palette[0], label='CovSyn local confirmed cases')
axs.fill_between(gof_dates, gof_cumulative_confirmed_cases_lb, gof_cumulative_confirmed_cases_ub, color=current_palette[0], alpha=0.3, label='95% CI for CovSyn cases')

# CovSyn daily death cases
axs.plot(gof_dates, cumulative_predicted_deaths, color=current_palette[2], label='CovSyn local deaths')
axs.fill_between(gof_dates, gof_cumulative_deaths_lb, gof_cumulative_deaths_ub, color=current_palette[2], alpha=0.3, label='95% CI for CovSyn deaths')

# Plot the soure infected cases
axs.plot(gof_dates, gof_cumulative_source_infected_cases, '--', color=current_palette[3], label='Source infected cases')

# Generate exact dates for major ticks every 28 days
major_tick_dates = []
current_date = start_date
# plot_end_date = np.datetime64('2020-08-10')
plot_end_date = end_date
while current_date <= plot_end_date:
    major_tick_dates.append(current_date)
    current_date = current_date + timedelta(days=28)

# Generate weekly minor ticks between each major tick
minor_tick_dates = []
for i in range(len(major_tick_dates)-1):
    current = major_tick_dates[i] + timedelta(days=7)
    while current < major_tick_dates[i+1]:
        minor_tick_dates.append(current)
        current = current + timedelta(days=7)

# Convert dates to matplotlib's ordinal format
major_tick_locations = [mdates.date2num(d) for d in major_tick_dates]
minor_tick_locations = [mdates.date2num(d) for d in minor_tick_dates]

# Configure first subplot (axs)
axs.legend(numpoints=1, loc='best')

# Configure second subplot (axs[1]) with same x-axis settings
axs.set_xticks(major_tick_locations)
axs.xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
axs.tick_params(axis='x', rotation=30)

# Important: After setting ticks, hide the labels by setting them to empty strings
# axs.set_xticklabels(['' for _ in range(len(major_tick_locations))])

# Set minor ticks
axs.set_xticks(minor_tick_locations, minor=True)
axs.tick_params(axis='x', which='minor', bottom=True)

plt.setp(axs.get_xticklabels(), ha='right')

# Add grid aligned with ticks
axs.grid(True, which='major', axis='both', linestyle='dotted', color='gray', alpha=0.7)

axs.set_ylabel('Cumulative number of cases')
axs.set_xlim(start_date - timedelta(days=2), plot_end_date + timedelta(days=2))
# axs.set_yscale('log')



plt.tight_layout()
# plt.savefig("RW2025_Covsyn_population_data.pdf")