# Charting_module

In [None]:
from plot_module import interval_estimation_plot
from model_complex import InfluenzaData

epid_data = InfluenzaData(0, 2019, 40, 2020, 20)
city = 'Russia'
method = 'mcmc'
type = 'total'
save_path='./'
epsilon=1000

interval_estimation_plot(
    epid_data,
    city, 
    method, 
    type, 
    save_path, 
    epsilon
)

Multiprocess sampling (4 chains in 4 jobs)
DEMetropolisZ: [alpha, beta]
  "accept": np.exp(accept),
  "accept": np.exp(accept),
  "accept": np.exp(accept),
  "accept": np.exp(accept),
Sampling 4 chains for 2_500 tune and 500 draw iterations (10_000 + 2_000 draws total) took 9 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Sampling: [sim]


<Figure size 500x500 with 0 Axes>

<Figure size 500x500 with 0 Axes>

In [None]:
from plot_module import calibration_plot
from model_complex import InfluenzaData

city = 'Russia'
epid_data = InfluenzaData(city, 2019, 40, 2020, 20)
method = 'annealing'
type = 'total'
save_path='./'
epsilon=1000

calibration_plot(
    epid_data,
    city, 
    method, 
    type, 
    save_path=save_path, 
    epsilon=epsilon, 
    is_prevalence_plot=True, 
    is_recovered_plot=True
)



<Figure size 640x480 with 0 Axes>

In [None]:
from datetime import timedelta

from plot_module import forecast_plot
from model_complex import InfluenzaData

city = 'Russia'
epid_data = InfluenzaData(city, 2024, 40, 2025, 6)
forecast_duration = timedelta(weeks=8)
method = 'abc'
type = 'total'
save_path='./'
epsilon=300

forecast_plot(
    forecast_duration, 
    epid_data,
    city,
    method, 
    type, 
    save_path, 
    epsilon
)

Initializing SMC sampler...
Sampling 6 chains in 6 jobs


      

The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details


<Figure size 640x480 with 0 Axes>

# ChartingModel

In [None]:
from model_complex import ModelParams, FactoryModel, Calibration

def calibration(
    epid_data,
    method,
    type,
):
    epid_data.get_wave_data(type=type)
    data = epid_data.get_data()

    model_params = ModelParams(
        alpha=[0],
        beta=[0],
        population_size=epid_data.get_rho() // 10,
        initial_infectious=[100],
    )

    model = FactoryModel.get_model(type)

    # при добавлении новых моделей, нужно эту часть обновлять, либо вывести в параметры функции
    if type == "age":
        model_params.initial_infectious = [100, 100]
        label = {0: "0-14 years", 1: "15+ years"}
    else:
        label = {0: "total"}
        
    color = {0: "blue", 1: "orange"}

    calibration = Calibration(model, data, model_params)

    # надо поменять при добавлении новых методов
    if method.lower() == "annealing":
        calibration.annealing_calibration()
    elif method.lower() == "abc":
        calibration.abc_calibration(epsilon=1000)
    elif method.lower() == "mcmc":
        calibration.mcmc_calibration(epsilon=1000)
    else:
        calibration.optuna_calibration()

    return model

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

from ..epid_results import prevalence_plot, recovered_plot


def calibration_plot(
    model,
    epid_data,
    method,
    type,
    plot_modes=[],
    save_path="./",
):
    duration = epid_data.get_duration()
    city = epid_data.get_city()
    plot_data = epid_data.prepare_for_plot()

    color = {0: "blue", 1: "orange"}


    # все ниже не должно меняться при добавлении новых моделей
    if epid_data.attrs["time_step"] == "week":
        func_to_get_newly_data = model.get_weekly_newly_infected_by_group
    else:
        func_to_get_newly_data = model.get_daily_newly_infected_by_group


    for ci_par in model.get_ci_params():
        model.simulate(params=ci_par, modeling_duration=duration)

        res = func_to_get_newly_data()

        for i in range(len(res)):
            plt.plot(res[i], lw=0.3, alpha=0.5, color=color[i])

    model.simulate(params=model.get_best_params(), modeling_duration=duration)

    res = func_to_get_newly_data()

    for i in range(len(res)):
        plt.plot(
            res[i],
            label=f"$R^2_{i}$: {round(r2_score(plot_data[:, i], res[i]),2)}",
            color=color[i],
        )
        plt.plot(plot_data[:, i], "--o", color=color[i])

    plt.title(f"{method.upper()}, {type.capitalize()}")
    plt.legend()

    plt.savefig(save_path + f"{city}_{method}_{type}.png", dpi=600)
    plt.savefig(save_path + f"{city}_{method}_{type}.pdf", dpi=600)
    plt.clf()


In [39]:
import requests
from io import StringIO
import pandas as pd

url = (
        "https://db.influenza.spb.ru/scripts/report/rmancgi.exe"
        + "?reportname=get_csv&id=aripcr&byear={}&bweek={}&eyear={}&eweek={}&district={}&auth={}"
)

begin_year = 2024
begin_week = 40
end_year = 2025
end_week = 8

data = requests.get(
    url.format(
        begin_year, 
        begin_week, 
        end_year, 
        end_week, 
        0, 
        "7e283896cf78e49c321dc60fab2850745a25215b621f600f648424d242a78c4a"
    )
).content.decode("utf-8")

cases_df = pd.read_csv(StringIO(data), sep="|")

In [41]:
cases_df

Unnamed: 0,YEAR,WEEK,REGION,DISTRICT,REGION_NAME,DISTRICT_NAME,ARI_TOTAL,ARI_0_2,ARI_3_6,ARI_7_14,...,POP_3_4,POP_15,POP_TOTAL,SWB_TOTAL,A_TOTAL,PDM_TOTAL,H3_TOTAL,B_TOTAL,SWBNC_TOTAL,NC_TOTAL
0,2024,40,0,0,РФ,РФ,408903,40194,64671,83100,...,1115465,45837361,54912793,9659,0,0,0,3,15097,657
1,2024,41,0,0,РФ,РФ,392284,41539,65410,79467,...,1112932,45840747,54905514,10012,0,0,1,3,12443,779
2,2024,42,0,0,РФ,РФ,380656,39802,64904,78292,...,1112932,45840747,54905514,9933,0,1,0,0,12372,752
3,2024,43,0,0,РФ,РФ,369626,38604,63402,73403,...,2053415,45980947,55086645,9796,3,0,0,3,11324,652
4,2024,44,0,0,РФ,РФ,332324,37733,58660,51460,...,1116902,45980947,55086645,10470,2,0,0,3,11812,570
5,2024,45,0,0,РФ,РФ,308735,34506,53851,52413,...,1116902,45980947,55086645,9454,3,0,3,8,17115,674
6,2024,46,0,0,РФ,РФ,337366,36357,56194,66277,...,1116902,45980947,55086645,10423,2,4,0,13,19391,809
7,2024,47,0,0,РФ,РФ,354857,37046,58603,76086,...,1116902,45980947,55086645,10715,8,2,1,6,13549,984
8,2024,48,0,0,РФ,РФ,373443,38424,62330,85662,...,1116902,45980947,55086645,11988,10,23,0,12,13409,1165
9,2024,49,0,0,РФ,РФ,386491,38973,62388,89586,...,1116902,45980947,55086645,12686,21,43,0,10,13186,1180


In [33]:
data.split("|")[-10:]

['1172324', '44679241', '52628283', '776', '2', '3', '5', '46', '', '\r\n\r\n']