In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from plotly.express.colors import colorbrewer

from aust_covid.inputs import load_google_mob_year_df, load_raw_pop_data
PROJECT_PATH = Path().resolve().parent
DATA_PATH = PROJECT_PATH / 'data'
CHANGE_STR = '_percent_change_from_baseline'
COLOURS = colorbrewer.Accent

In [None]:
model_mob_locs = [
    'school', 
    'home', 
    'work', 
    'other_locations',
]

In [None]:
raw_location_matrices = {i: pd.read_csv(DATA_PATH / f'{i}.csv', index_col=0).to_numpy() for i in model_mob_locs}

In [None]:
def get_raw_state_mobility():
    # Get raw data
    raw_data_2021 = pd.read_csv(DATA_PATH / '2021_AU_Region_Mobility_Report.csv', index_col=8)
    raw_data_2022 = pd.read_csv(DATA_PATH / '2022_AU_Region_Mobility_Report.csv', index_col=8)
    raw_data =  pd.concat([raw_data_2021, raw_data_2022])
    
    # Extract state data
    state_data = raw_data.loc[raw_data['sub_region_1'].notnull() & raw_data['sub_region_2'].isnull()]
    state_data.index = pd.to_datetime(state_data.index)
    return state_data

def get_non_wa_mob_averages(state_data, mob_locs, jurisdictions):
    # Non-WA mobility
    non_wa_data = state_data.loc[state_data['sub_region_1'] != 'Western Australia']

    # Add state population totals to dataframe
    state_pop_totals = load_raw_pop_data('31010do002_202206.xlsx').sum()
    for juris in jurisdictions:
        non_wa_data.loc[non_wa_data['sub_region_1'] == juris, 'weights'] = state_pop_totals[juris]

    # Weighted average calculation
    state_averages = pd.DataFrame(columns=mob_locs)
    for mob_loc in mob_locs:
        state_averages[mob_loc] = non_wa_data.groupby(non_wa_data.index).apply(
            lambda x: np.average(x[mob_loc], weights=x['weights']),
        )
    return state_averages

def get_constants_from_mobility(state_data):
    jurisdictions = set([j for j in state_data['sub_region_1'] if j != 'Australia'])
    mob_locs = [c for c in state_data.columns if CHANGE_STR in c]
    return jurisdictions, mob_locs

def get_relative_mobility(mobility_df):
    mobility_df.columns = [c.replace(CHANGE_STR, '') for c in mobility_df.columns]
    return 1.0 + mobility_df * 1e-2

def plot_state_mobility(state_data):
    fig = make_subplots(rows=4, cols=2, subplot_titles=list(jurisdictions))
    fig.update_layout(height=1500)
    for j, juris in enumerate(jurisdictions):
        for l, mob_loc in enumerate(mob_locs):
            estimates = state_data[state_data['sub_region_1'] == juris][mob_loc]
            fig.add_trace(
                go.Scatter(x=estimates.index, y=estimates, name=mob_loc.replace(CHANGE_STR, '').replace('_', ' '), line=dict(color=COLOURS[l]), showlegend=j==0),
                row=j % 4 + 1, col=round(j / 7) + 1,
            )
    return fig

def plot_averaged_mobility(wa_relmob, non_wa_relmob):
    fig = make_subplots(rows=1, cols=2, subplot_titles=['Western Australia', 'weighted average for rest of Australia'])
    fig.update_layout(height=500)
    for j, mob_data in enumerate([wa_relmob, non_wa_relmob]):
        for l, mob_loc in enumerate(mob_data.columns):
            estimates = mob_data[mob_loc]
            fig.add_trace(
                go.Scatter(x=estimates.index, y=estimates, name=mob_loc.replace('_', ' '), line=dict(color=COLOURS[l]), showlegend=j==0),
                row=1, col=j + 1,
            )
    return fig

def map_mobility_locations(patch_data, mob_map):
    model_mob = pd.DataFrame(columns=pd.MultiIndex.from_product([patch_data.keys(), mob_map.keys()]))
    for patch in patch_data.keys():
        for mob_loc in mob_map.keys():
            data = patch_data[patch].assign(**mob_map[mob_loc]).mul(patch_data[patch]).sum(1)
            model_mob.loc[:, (patch, mob_loc)] = data
    return model_mob

In [None]:
state_data = get_raw_state_mobility()
jurisdictions, mob_locs = get_constants_from_mobility(state_data)
wa_data = state_data.loc[state_data['sub_region_1'] == 'Western Australia', mob_locs]
state_averages = get_non_wa_mob_averages(state_data, mob_locs, jurisdictions)
non_wa_relmob = get_relative_mobility(state_averages)
wa_relmob = get_relative_mobility(wa_data)
plot_state_mobility(state_data)

In [None]:
plot_averaged_mobility(wa_relmob, non_wa_relmob)

In [None]:
mob_map = {
    'other_locations': 
        {
            'retail_and_recreation': 0.34, 
            'grocery_and_pharmacy': 0.33,
            'parks': 0.0,
            'transit_stations': 0.33,
            'workplaces': 0.0,
            'residential': 0.0,
        },
    'work':
        {
            'retail_and_recreation': 0.0, 
            'grocery_and_pharmacy': 0.0,
            'parks': 0.0,
            'transit_stations': 0.0,
            'workplaces': 1.0,
            'residential': 0.0,
        },  
}
patch_data = {
    'wa': wa_relmob,
    'non_wa': non_wa_relmob,
}

model_mob = map_mobility_locations(patch_data, mob_map)
smoothed_model_mob = model_mob.rolling(7).mean().dropna()

In [None]:
fig = make_subplots(rows=1, cols=2, subplot_titles=['Western Australia', 'weighted average for rest of Australia'])
fig.update_layout(height=500)
for p, patch in enumerate(patch_data.keys()):
    for l, mob_loc in enumerate(mob_map.keys()):
        estimates = model_mob.loc[:, (patch, mob_loc)]
        fig.add_trace(
            go.Scatter(x=estimates.index, y=estimates, name=mob_loc.replace('_', ' '), line=dict(color=COLOURS[l]), showlegend=p==0),
            row=1, col=p + 1,
        )
        estimates = smoothed_model_mob.loc[:, (patch, mob_loc)]
        fig.add_trace(
            go.Scatter(x=estimates.index, y=estimates, name=f'smoothed_{mob_loc}'.replace('_', ' '), line=dict(color=COLOURS[l + 2]), showlegend=p==0),
            row=1, col=p + 1,
        )
fig