In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import os
import numpy as np
from june.records import RecordReader
from june.hdf5_savers import generate_world_from_hdf5
from pandas.core.groupby.groupby import DataError
import matplotlib.dates as mdates
import datetime
from difflib import SequenceMatcher

In [None]:
def load_and_format_sim_infs(path, start_date="2020-10-01", end_date="2020-12-25"):
    print(f"loading {path}")
    try:
        rec = RecordReader(path) 
        df = rec.get_table_with_extras("infections", "infected_ids")
    except DataError:
        print("sim has not started yet. Discarded.")
        return None
    if rec.get_world_summary().index.max() < pd.to_datetime(end_date):
        print(f"run {path} not done yet. Discarded")
        return None
    
    a = df.loc[df.primary_activity_type=="school",:].groupby("timestamp").size()
    b = df.loc[df.primary_activity_type!="school",:].groupby("timestamp").size()
    sim_data = pd.concat([a,b], keys=["school", "other"], axis=1)
    return sim_data

In [None]:
def load_and_format_sim_hosps(path, start_date="2020-10-01", end_date="2020-12-25"):
    print(f"loading {path}")
    try:
        rec = RecordReader(path) 
        df = rec.get_table_with_extras("hospital_admissions", "patient_ids")
    except DataError:
        print("sim has not started yet. Discarded.")
        return None
    if rec.get_world_summary().index.max() < pd.to_datetime(end_date):
        print(f"run {path} not done yet. Discarded")
        return None
    
    a = df.loc[df.primary_activity_type=="school",:].groupby("timestamp").size()
    b = df.loc[df.primary_activity_type!="school",:].groupby("timestamp").size()
    sim_data = pd.concat([a,b], keys=["school", "other"], axis=1)
    return sim_data

In [None]:
mount = "/home/lheger/JUNE_germany_private/parameter_studies/results/" # path to your data

In [None]:
# load q68 runs schools open
q_68_i_so = pd.concat([load_and_format_sim_infs(mount + "q68_schools_open/" + file, 
                                                end_date="2021-03-19") \
                       for file in os.listdir(mount + "q68_schools_open/")
                      ], 
                      keys=[file for file in os.listdir(mount + "q68_schools_open/")]
                     )

In [None]:
# load q68 runs
q_68_i = pd.concat([load_and_format_sim_infs(mount + "q68_complete_second_wave/" + file, 
                                           end_date="2021-03-19") \
                  for file in os.listdir(mount + "q68_complete_second_wave/")
                 ], 
                 keys=[file for file in os.listdir(mount + "q68_complete_second_wave/")]
                )

In [None]:
# load q68 runs schools open
q_68_h_so = pd.concat([load_and_format_sim_hosps(mount + "q68_schools_open/" + file, 
                                                end_date="2021-03-19") \
                       for file in os.listdir(mount + "q68_schools_open/")
                      ], 
                      keys=[file for file in os.listdir(mount + "q68_schools_open/")]
                     )

In [None]:
# load q68 runs schools open
q_68_h = pd.concat([load_and_format_sim_hosps(mount + "q68_complete_second_wave/" + file, 
                                                end_date="2021-03-19") \
                       for file in os.listdir(mount + "q68_complete_second_wave/")
                      ], 
                      keys=[file for file in os.listdir(mount + "q68_complete_second_wave/")]
                     )

In [None]:
# load q68 runs schools open
q_68_i_np = pd.concat([load_and_format_sim_infs(mount + "q68_no_policies/" + file, 
                                                end_date="2021-03-19") \
                       for file in os.listdir(mount + "q68_no_policies/")
                      ], 
                      keys=[file for file in os.listdir(mount + "q68_no_policies/")]
                     )

In [None]:
# load q68 runs schools open
q_68_h_np = pd.concat([load_and_format_sim_hosps(mount + "q68_no_policies/" + file, 
                                                end_date="2021-03-19") \
                       for file in os.listdir(mount + "q68_no_policies/")
                      ], 
                      keys=[file for file in os.listdir(mount + "q68_no_policies/")]
                     )

In [None]:
# properly named configs from q68_schools_open
named_configs = dict()
p = mount + "q68_schools_open/"
for run in os.listdir(p):
    with open(f"{p}{run}/config.yaml") as f:
        ff = yaml.load(f, Loader=yaml.FullLoader)
    try:
        named_configs[run] = ff["interaction"]["betas"]
    except KeyError:
        print(f"{run} does not contain proper configs, dismiss")

In [None]:
match = dict()
p = mount + "q68_complete_second_wave/"
for file in os.listdir(p):
    with open(f"{p}{file}/config.yaml") as f:
        ff = yaml.load(f, Loader=yaml.FullLoader)
        try:
            betas = ff["interaction"]["betas"]
            for config in named_configs:
                if betas == named_configs[config]:
                    match[config] = file
        except KeyError:
            print(f"{file} does not contain proper configs, dismiss")


In [None]:
q_68_i_so = q_68_i_so.loc[list(match.keys())]
q_68_h_so = q_68_h_so.loc[list(match.keys())]

In [None]:
q_68_i = q_68_i.rename(index={v:k for k,v in match.items()}, level=0)
q_68_h = q_68_h.rename(index={v:k for k,v in match.items()}, level=0)

In [None]:
q_68_i = q_68_i.sort_index()
q_68_i_so = q_68_i_so.sort_index()

q_68_h = q_68_h.sort_index()
q_68_h_so = q_68_h_so.sort_index()

q_68_i_np = q_68_i_np.sort_index()
q_68_h_np = q_68_i_np.sort_index()

In [None]:
q_68_h = q_68_h.fillna(0)
q_68_h_so = q_68_h_so.fillna(0)

q_68_i_np = q_68_i_np.fillna(0)
q_68_h_np = q_68_i_np.fillna(0)


In [None]:
q_68_i_norm_fac = q_68_i.groupby(level=0).sum()

In [None]:
fig, ax = plt.subplots()

err = q_68_h.groupby("timestamp").std()
err_so = q_68_h_so.groupby("timestamp").std()
err_np = q_68_h_np.groupby("timestamp").std()

y = q_68_h.groupby(level=1).mean()
y_np = q_68_h_np.groupby(level=1).mean()


q_68_h.groupby(level=1).mean().loc[:,"other"].rolling(7).mean().plot(ax=ax, 
                                                                     color="#1f77b4", 
                                                                     label="state policies", 
                                                                     lw=2)

q_68_h_np.groupby(level=1).mean().loc[:,"other"].rolling(7).mean().plot(ax=ax, 
                                                                        color="#ff7f0e", 
                                                                        label="no state policies", 
                                                                        lw=2)

ax.fill_between(y.loc[:,"other"].rolling(7).mean().index, 
                    y.loc[:,"other"].rolling(7).mean() - err.loc[:,"other"].rolling(7).mean(), 
                    y.loc[:,"other"].rolling(7).mean() + err.loc[:,"other"].rolling(7).mean(), 
                    alpha=0.2)

ax.fill_between(y_np.loc[:,"other"].rolling(7).mean().index, 
                    y_np.loc[:,"other"].rolling(7).mean() - err_np.loc[:,"other"].rolling(7).mean(), 
                    y_np.loc[:,"other"].rolling(7).mean() + err_np.loc[:,"other"].rolling(7).mean(), 
                    alpha=0.2)


ax.legend(fontsize=10, loc="upper right")

ax.set_ylabel("hospitalisations", fontsize=10)
ax.set_xlabel("")


ax.yaxis.set_tick_params(labelsize=10)

ax.yaxis.set_tick_params(labelsize=10)

plt.tight_layout()

In [None]:
fig, ax = plt.subplots()

err = q_68_h.groupby("timestamp").std()
err_so = q_68_h_so.groupby("timestamp").std()
err_np = q_68_h_np.groupby("timestamp").std()

y = q_68_h.groupby(level=1).mean()
y_so = q_68_h_so.groupby(level=1).mean()


q_68_h.groupby(level=1).mean().loc[:,"other"].rolling(7).mean().plot(ax=ax, 
                                                                     color="#1f77b4", 
                                                                     label="state policies", 
                                                                     lw=2)

q_68_h_so.groupby(level=1).mean().loc[:,"other"].rolling(7).mean().plot(ax=ax, 
                                                                        color="#ff7f0e", 
                                                                        label="no school closures", 
                                                                        lw=2)

ax.fill_between(y.loc[:,"other"].rolling(7).mean().index, 
                    y.loc[:,"other"].rolling(7).mean() - err.loc[:,"other"].rolling(7).mean(), 
                    y.loc[:,"other"].rolling(7).mean() + err.loc[:,"other"].rolling(7).mean(), 
                    alpha=0.2)

ax.fill_between(y_so.loc[:,"other"].rolling(7).mean().index, 
                    y_so.loc[:,"other"].rolling(7).mean() - err_so.loc[:,"other"].rolling(7).mean(), 
                    y_so.loc[:,"other"].rolling(7).mean() + err_so.loc[:,"other"].rolling(7).mean(), 
                    alpha=0.2)


ax.legend(fontsize=10, loc="upper left")

ax.set_ylabel("hospitalisations", fontsize=10)
ax.set_xlabel("")


ax.yaxis.set_tick_params(labelsize=10)

ax.yaxis.set_tick_params(labelsize=10)

plt.tight_layout()

# Make plots that require infection per super area

In [None]:
def load_and_format_sim_infs_per_super_area(path, start_date="2020-10-01", end_date="2020-12-25"):
    print(f"loading {path}")
    def preprocess_simdata(df):
        return (df
                .groupby(["timestamp","name_super_area"])
                .size()
                #.sum(level=1)
               )
    try:
        df = RecordReader(path).get_table_with_extras("infections", "infected_ids")
    except DataError:
        print("sim has not started yet. Discarded.")
        return None
    if df.timestamp.max() < pd.to_datetime(end_date):
        print(f"run {path} not done yet. Discarded")
        return None
    return preprocess_simdata(df)

def load_and_format_target_infs_per_super_area():    
    targets = pd.read_csv(mount + "/home/lheger/june_fitting/data/infektionen.csv")
    targets = (targets
               .loc[targets.bundesland=="Rheinland-Pfalz"]
               .drop(["_id", "ags5", "ags2", "bundesland"], axis=1))

    targets = (targets
               .rename(columns={col:col[1:] for col in targets.columns if col[0]=="d"})
               .loc[targets.loc[:,"variable"].str.contains("kr_inf_a") 
                    & ~targets.loc[:,"variable"].str.contains("kr_inf_aktiv"),:]
              )
    targets = (targets
               .groupby(["kreis", "variable"])
               .sum()
               .T)

    targets = (targets
               .reindex(pd.to_datetime(targets.index))
              )

    #targets = targets.rename(columns={col:rename_agegroup_columns(col) for col in targets.columns.droplevel()}, level=1)
    return targets

In [None]:
def find_most_similar_super_area_name(old_super_area_name, 
                                      new_super_area_names
                                     ):
    if old_super_area_name not in new_super_area_names:
        similarity_score = 0
        new_super_area_name = None
        for candidate in new_super_area_names:
            candidate_similarity_score = SequenceMatcher(None, old_super_area_name, candidate).ratio()
            if candidate_similarity_score > similarity_score:
                similarity_score = candidate_similarity_score
                new_super_area_name = candidate
    else:
        new_super_area_name = old_super_area_name
    
    if new_super_area_name is None:
        raise ValueError("the new super area name cannot be None")
    return new_super_area_name

In [None]:
world = generate_world_from_hdf5("../data/world_rlp.hdf5")

In [None]:
# load the names of the super areas since we need to match the codes to the names to compare sim to real data
df = pd.read_csv("/home/lheger/JUNE_germany_private/data/geography/super_area_coordinates.csv")
df = df.drop(["latitude", "longitude"], axis=1)
super_area_name_lookup = dict(zip(df.super_area, df.super_area_name))
# fixing a bug in the data. the kreis kaiserslautern and the kreisfreie stadt kaiserslautern
# carry the same name. the kreisfreie stadt should be stadt kaiserslautern
super_area_name_lookup["D07312"] = "Stadt Kaiserslautern"

In [None]:
pop_per_sa_lookup = {super_area_name_lookup[sa.name]:len(sa.people) for sa in world.super_areas}

In [None]:
q68_i_sa = pd.concat([load_and_format_sim_infs_per_super_area(mount + "q68_complete_second_wave/" + file, 
                                                             end_date="2021-03-19") \
                     for file in os.listdir(mount + "q68_complete_second_wave/")
                    ], 
                    keys=[file for file in os.listdir(mount + "q68_complete_second_wave/")]
                   )

In [None]:
# set super area codes to kreis names and name levels of index
q68_i_sa = q68_i_sa.rename(index={ags:super_area_name_lookup[ags] for ags in q68_i_sa.index.levels[2]}, level=2)
q68_i_sa = q68_i_sa.reindex(q68_i_sa.index.set_names(["run", "timestamp", "name_super_area"]))

In [None]:
# load targets and sum over age groups
targets = load_and_format_target_infs_per_super_area()
targets.index.name = "timestamp"
targets = targets.unstack().groupby(["kreis", "timestamp"]).sum()

# rename Kreise in targets to match june naming convention
new_idx = {old_idx:find_most_similar_super_area_name(old_idx, 
                                                     list(pop_per_sa_lookup.keys()))
           for old_idx in targets.index.levels[0]}
targets = targets.rename(index=new_idx, level=0)
targets = targets.sort_index()

# select total infections per kreis at eval date
target_infs_per_sa = targets.groupby("kreis").cumsum().loc[:,pd.to_datetime("2021-03-19")]

In [None]:
yerr = target_infs_per_sa*0.05 # schott

In [None]:
xerr = (q68_i_sa
        .groupby(["timestamp","name_super_area"])
        .std()
        .groupby("name_super_area")
        .apply(lambda x:np.sqrt(np.sum(x**2)))
       )

In [None]:
def make_scatterplot(sims, target):
    yerr = target*0.05
    xerr = (sims
        .groupby(["timestamp","name_super_area"])
        .std()
        .groupby("name_super_area")
        .apply(lambda x:np.sqrt(np.sum(x**2)))
       )
    
    x = sims.groupby(["timestamp","name_super_area"]).mean().groupby("name_super_area").sum().to_numpy()
    y = target.to_numpy()
    #plt.scatter(x,y)
    plt.errorbar(x,y,yerr=yerr,xerr=xerr,fmt='o',capsize=3, capthick=2, elinewidth=1,markersize=5)
    plt.ylabel("data")
    plt.xlabel("simulation")
    plt.tight_layout()
    return plt.gcf()

In [None]:
def calculate_corr(sims,targets):
    a = sims.groupby(["run","name_super_area"]).sum().unstack().to_numpy()
    b = targets.to_numpy()
    corrs = []
    for i in range(22):
        corrs.append(np.corrcoef(a[i,:],b)[0,1])
    return np.array(corrs).mean(),np.array(corrs).std()

In [None]:
f = make_scatterplot(q68_i_sa, target_infs_per_sa)

In [None]:
norm_targets = dict()
norm_sim = dict()
for k in pop_per_sa_lookup:
    norm_fac = pop_per_sa_lookup[k]/10000
    norm_targets[k] = target_infs_per_sa.loc[k]/norm_fac
    norm_sim[k] = q68_i_sa.loc[:,:,k]/norm_fac

In [None]:
q68_i_sa_norm = pd.DataFrame(norm_sim).stack()
idx = q68_i_sa_norm.index
q68_i_sa_norm = q68_i_sa_norm.reindex(idx.set_names("name_super_area",level=-1))
targets_i_sa_norm = pd.Series(norm_targets)

In [None]:
f2 = make_scatterplot(q68_i_sa_norm,targets_i_sa_norm)