In [None]:
import os
import sys
import numpy as np
import pints
import pints.plot
from scipy.integrate import odeint
import matplotlib.pyplot as plt
import pandas as pd
import math
import arviz as az
import seaborn as sns

cwd = os.getcwd()
model_path = os.path.abspath(os.path.join(cwd, os.pardir, 'models'))
sys.path.append(model_path)
from diamond_ODE_model import *

In [None]:
# Create instance of new model class
model = diamondODEModel()

In [None]:
df = pd.read_csv('../data/data_cases_symp.csv')
df = df.drop(['day_no', 'all'], axis=1)

# Convert onset_date to datetime for better x-axis formatting
df['onset_date'] = pd.to_datetime(df['onset_date'], format='%d-%b')
df['onset_date'] = df['onset_date'].dt.strftime('%d-%b')

# Plotting
plt.figure(figsize=(8, 5))
plt.plot(df['onset_date'], df['pass'], label='Passengers', marker='o')
plt.plot(df['onset_date'], df['crew'], label='Crew', marker='x')

# Formatting the plot
plt.xlabel('Date')
plt.ylabel('Number of Cases')
plt.title('COVID-19 Cases Over Time')
plt.xticks(rotation=45)
plt.legend()
plt.tight_layout()

# Display the plot
plt.show()

In [None]:
# problem = pints.MultiOutputProblem(model, )
data_time = np.array(df.index)
data_cases = df[['pass', 'crew']].values
problem = pints.MultiOutputProblem(model, data_time, data_cases)
log_likelihood = pints.GaussianLogLikelihood(problem)
# [bbarcp, cpp, b1, tpp, tcc, thetaa, thetap, chi]
log_prior = pints.UniformLogPrior(
    [0, 0.5, 0.8, 10, 15, 0, 0, 0.5, 0, 0],
    [6, 2, 1, 20, 25, 1, 1, 1, 5, 5]
)
log_posterior = pints.LogPosterior(log_likelihood, log_prior)

n_chains = 3
xs = pints.sample_initial_points(log_posterior, n_chains, parallel=True)

In [None]:
mcmc = pints.MCMCController(log_posterior, n_chains, xs, method=pints.HaarioBardenetACMC)

# Add stopping criterion
mcmc.set_max_iterations(100000)

# Start adapting after n iterations
mcmc.set_initial_phase_iterations(5000)

# Disable logging mode
mcmc.set_log_to_screen(True)
chains = mcmc.run()

In [None]:
pints.plot.trace(chains)
plt.show()

In [None]:
warmup = 20000
trimmed_chains = chains[:,warmup:,:]

# Reshape to combine the chains
combined_chains = trimmed_chains.reshape(-1, 10)
param_names = ['bbarcp', 'cpp', 'b1', 'tpp', 'tcc', 'thetaa', 'thetap', 'chi', 'sigmap', 'sigmac']
# Create a DataFrame from the combined_chains array with the parameter names as column headings
inferred_df = pd.DataFrame(combined_chains, columns=param_names)

# Save the DataFrame to a CSV file
inferred_df.to_csv('../data/inferred_parameters.csv', index=False)

In [None]:
inferred_df = pd.read_csv('../data/inferred_parameters.csv')

In [None]:
data = combined_chains

param_names = [r'$\bar{\beta}$', r'$c_{pp}$', r'$b_1$', r'$\tau_{pp}$', r'$\tau_{cc}$', 
               r'$\theta_a$', r'$\theta_p$', r'$\chi$', r'$\sigma_p$', r'$\sigma_c$']

# Set up the figure
fig, axes = plt.subplots(len(param_names), len(param_names), figsize=(15, 15))
plt.rcParams['font.family'] = 'Arial'

for i in range(len(param_names)):
    for j in range(len(param_names)):
        ax = axes[i, j]
        if i < j:
            # Skip the upper triangle
            ax.axis('off')
        elif i == j:
            # Plot histogram on the diagonal
            sns.histplot(data[:, i].flatten(), bins=30, kde=False, color='#31688e', ax=ax, linewidth=0)
            ax.set_title(param_names[i])
            ax.set_ylabel('')
        else:
            # Plot scatter plot in the lower triangle
            x_data = data[:, j].flatten()
            y_data = data[:, i].flatten()
            ax.scatter(x_data, y_data, color='#31688e', s=.01)
            ax.set_xlabel('')
            ax.set_ylabel('')

        # Adjust tick labels for clarity
        if i != len(param_names) - 1:
            ax.set_xticklabels([])
        if j != 0:
            ax.set_yticklabels([])
        if j > 0 and i != len(param_names) - 1:
            ax.set_xlabel('')
            ax.set_ylabel('')

plt.tight_layout()
plt.show()