In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from aust_covid.mobility import get_non_wa_mob_averages, get_constants_from_mobility, get_relative_mobility, map_mobility_locations
from aust_covid.plotting import plot_state_mobility, plot_processed_mobility
from aust_covid.inputs import load_raw_pop_data, get_raw_state_mobility
PROJECT_PATH = Path().resolve().parent
DATA_PATH = PROJECT_PATH / 'data'
CHANGE_STR = '_percent_change_from_baseline'

In [None]:
model_mob_locs = [
    'school', 
    'home', 
    'work', 
    'other_locations',
]
raw_location_matrices = {i: pd.read_csv(DATA_PATH / f'{i}.csv', index_col=0).to_numpy() for i in model_mob_locs}

In [None]:
state_data = get_raw_state_mobility()
jurisdictions, mob_locs = get_constants_from_mobility(state_data)
wa_data = state_data.loc[state_data['sub_region_1'] == 'Western Australia', mob_locs]
state_averages = get_non_wa_mob_averages(state_data, mob_locs, jurisdictions)
non_wa_relmob = get_relative_mobility(state_averages)
wa_relmob = get_relative_mobility(wa_data)
mob_map = {
    'other_locations': 
        {
            'retail_and_recreation': 0.34, 
            'grocery_and_pharmacy': 0.33,
            'parks': 0.0,
            'transit_stations': 0.33,
            'workplaces': 0.0,
            'residential': 0.0,
        },
    'work':
        {
            'retail_and_recreation': 0.0, 
            'grocery_and_pharmacy': 0.0,
            'parks': 0.0,
            'transit_stations': 0.0,
            'workplaces': 1.0,
            'residential': 0.0,
        },  
}
model_mob = map_mobility_locations(wa_relmob, non_wa_relmob, mob_map)
smoothed_model_mob = model_mob.rolling(7).mean().dropna()

In [None]:
# plot_state_mobility(state_data, jurisdictions, mob_locs)

In [None]:
# plot_processed_mobility(model_mob, smoothed_model_mob)

In [None]:
from summer2.functions.time import get_piecewise_scalar_function, get_linear_interpolation_function

In [None]:
mobility_effect_func = get_linear_interpolation_function(model.get_epoch().dti_to_index(mobility_effect.index), mobility_effect.to_numpy())
