In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from plotly.express.colors import colorbrewer

from aust_covid.inputs import load_google_mob_year_df, load_raw_pop_data
PROJECT_PATH = Path().resolve().parent
DATA_PATH = PROJECT_PATH / 'data'

In [None]:
locations = [
    'school', 
    'home', 
    'work', 
    'other_locations',
]

In [None]:
raw_location_matrices = {i: pd.read_csv(DATA_PATH / f'{i}.csv', index_col=0).to_numpy() for i in locations}

In [None]:
# Collate raw data
raw_data_2021 = pd.read_csv(DATA_PATH / '2021_AU_Region_Mobility_Report.csv', index_col=8)
raw_data_2022 = pd.read_csv(DATA_PATH / '2022_AU_Region_Mobility_Report.csv', index_col=8)
raw_data = pd.concat([raw_data_2021, raw_data_2022])

# Extract state-level data, change index to date, get mobility column names
state_data = raw_data.loc[raw_data['sub_region_1'].notnull() & raw_data['sub_region_2'].isnull()]
state_data.index = pd.to_datetime(state_data.index)
change_str = '_percent_change_from_baseline'
mob_locs = [c for c in state_data.columns if change_str in c]

# Split non-WA from WA
non_wa_data = state_data.loc[state_data['sub_region_1'] != 'Western Australia']
wa_data = state_data.loc[state_data['sub_region_1'] == 'Western Australia', mob_locs]

# Add state population totals by state for non-WA states
state_pop_totals = load_raw_pop_data('31010do002_202206.xlsx').sum()
jurisdictions = set([j for j in state_data['sub_region_1'] if j != 'Australia'])
for juris in jurisdictions:
    non_wa_data.loc[non_wa_data['sub_region_1'] == juris, 'weights'] = state_pop_totals[juris]

# Calculate weighted average
state_averages = pd.DataFrame(columns=mob_locs)
for mob_loc in mob_locs:
    state_averages[mob_loc] = non_wa_data.groupby(non_wa_data.index).apply(
        lambda x: np.average(x[mob_loc], weights=x['weights']),
    )
    
# Relative change values
non_wa_relmob = 1.0 + state_averages * 1e-2
non_wa_relmob.columns = [c.replace(change_str, '_relative_change') for c in non_wa_relmob.columns]
wa_relmob = 1.0 + wa_data * 1e-2
wa_relmob.columns = [c.replace(change_str, '_relative_change') for c in wa_relmob.columns]

In [None]:
colours = colorbrewer.Accent
fig = make_subplots(rows=4, cols=2, subplot_titles=list(jurisdictions))
fig.update_layout(height=1500)
for j, juris in enumerate(jurisdictions):
    for l, mob_loc in enumerate(mob_locs):
        estimates = state_data[state_data['sub_region_1'] == juris][mob_loc]
        fig.add_trace(
            go.Scatter(x=estimates.index, y=estimates, name=mob_loc, line=dict(color=colours[l]), showlegend=j==0),
            row=j % 4 + 1, col=round(j / 7) + 1,
        )
fig

In [None]:
fig = make_subplots(rows=1, cols=2, subplot_titles=['Western Australia', 'Weighted average for rest of Australia'])
fig.update_layout(height=500)
for j, mob_data in enumerate([wa_relmob, non_wa_relmob]):
    for l, mob_loc in enumerate(mob_data.columns):
        estimates = mob_data[mob_loc]
        fig.add_trace(
            go.Scatter(x=estimates.index, y=estimates, name=mob_loc, line=dict(color=colours[l]), showlegend=j==0),
            row=1, col=j + 1,
        )
fig

In [None]:
mob_map = {
    'other_locations': 
        {
            'retail_and_recreation_relative_change': 0.34, 
            'grocery_and_pharmacy_relative_change': 0.33,
            'parks_relative_change': 0.0,
            'transit_stations_relative_change': 0.33,
            'workplaces_relative_change': 0.0,
            'residential_relative_change': 0.0,
        },
    'work':
        {
            'retail_and_recreation_relative_change': 0.0, 
            'grocery_and_pharmacy_relative_change': 0.0,
            'parks_relative_change': 0.0,
            'transit_stations_relative_change': 0.0,
            'workplaces_relative_change': 1.0,
            'residential_relative_change': 0.0,
        },  
}
patch_data = {
    'wa': wa_relmob,
    'non_wa': non_wa_relmob,
}

model_mob = pd.DataFrame(columns=pd.MultiIndex.from_product([patch_data.keys(), mob_map.keys()]))
for patch in patch_data.keys():
    for mob_loc in mob_map.keys():
        data = wa_relmob.assign(**mob_map[mob_loc]).mul(patch_data[patch]).sum(1)
        model_mob.loc[:, (patch, mob_loc)] = data

In [None]:
fig = make_subplots(rows=1, cols=2, subplot_titles=['Western Australia', 'Weighted average for rest of Australia'])
fig.update_layout(height=500)
for p, patch in enumerate(patch_data.keys()):
    for l, mob_loc in enumerate(mob_map.keys()):
        estimates = model_mob.loc[:, (patch, mob_loc)]
        fig.add_trace(
            go.Scatter(x=estimates.index, y=estimates, name=mob_loc, line=dict(color=colours[l]), showlegend=p==0),
            row=1, col=p + 1,
        )
fig