In [None]:
# data
import numpy as np
import pandas as pd
import geopandas as gpd
import scipy
import datetime
import matplotlib.dates as mdates

# custom data
import wastewater as ww

# plotting
import matplotlib.pylab as plt
import seaborn as sns

# machine learning
import sklearn 
from sklearn.linear_model import LinearRegression

# custom machine learning
from wastewater.ml_utils import Dataset, RandomIntercepts, RandomEffects, bootstrap, predict_kfold
from xgboost import XGBRegressor

import contextily as ctx
from mpl_toolkits.axes_grid1 import make_axes_locatable
import boto3
import contextily as ctx


In [None]:
# dataframes location
data_folder = 's3://jbc-staging-data-wip/jbc-wip/01. Data/01. Raw/Waste Water/ww_users/anna/prevalence_estimation/'

In [None]:
# select variables for the model
vars_disease = ['sars_cov2_gc_l_mean']
vars_population = ['catch_in_cis_prop', 'cis_population']
vars_network = ['catchment_area', 'suspended_solids_mg_l', 'ammonia_mg_l', 'ophosph_mg_l', 'sample_ph_pre_ansis']
vars_sampling = ['compo_frac', 'reception_delay']
vars_lab = ['sars_below_lod', 'sars_below_loq', 'control_gc_l_mean']

vars_ww = vars_disease + vars_network + vars_sampling + vars_lab
vars_all = vars_disease + vars_population + vars_network + vars_sampling + vars_lab

## Model performance

In [None]:
df = pd.read_csv(data_folder+'input_df/df.csv', parse_dates=['date'], dayfirst=True)
df_total = pd.read_csv(data_folder+'input_df/df_total.csv', parse_dates=['date'], dayfirst=True)

In [None]:
# Metrics
compute_mae = lambda y, pred: np.abs(pred-y).mean()
compute_95_error = lambda y, pred: np.percentile(np.abs(pred-y), 95)

In [None]:
# cross validation
# Score per CIS region
dict_pred = dict()
dataset = Dataset(df.set_index(['CIS20CD', 'date']), vars_all, 'median_prob', input_offset=0.001)
x, y = dataset.prepare_no_split() # what does this mean
dict_pred = predict_kfold(XGBRegressor(), x, y, n_splits=20)

In [None]:
# plot the fit per site with MAE
fig, axes = plt.subplots(14,6, figsize=(12,20), sharex=True, sharey=True)
a = 0

for cis, ax in zip(dict_pred.reset_index().CIS20CD.unique(), axes.flatten()):
    
    ycis = y.reset_index().query('CIS20CD==@cis').median_prob
    ypred = dict_pred.reset_index().query('CIS20CD==@cis').median_prob
    xaxisval = y.reset_index().query('CIS20CD==@cis').date
    
    ax.plot(xaxisval, ycis, color='C0', label='CIS')
    ax.plot(xaxisval, ypred, color='C1', label='predictions')

    # score as title
    ax.set_title(f'{compute_mae(ycis, ypred):.2f}')
    ax.tick_params(axis='x', labelrotation=45, labelright=True)

plt.tight_layout()
#plt.savefig('figures/fit_predict_mae_log10.png', dpi=300)

In [None]:
def get_scores_cis(x, y, dict_pred):
    
    x = x.drop(x.columns,axis=1)
    x = x.assign(estimates = dict_pred.values)

    cis_all = x.reset_index().CIS20CD.unique()

    score = pd.DataFrame(index=cis_all, columns=['mae_log10'])
    for cis in cis_all:

        ycis = y.reset_index().query('CIS20CD==@cis').median_prob
        ypred = x.reset_index().query('CIS20CD==@cis').estimates
        score.loc[cis, 'mae_log10'] = compute_mae(ycis, ypred)
        
    return score

score = get_scores_cis(x, y, dict_pred)
score

In [None]:
# load CIS boundaries
gdf = gpd.read_file('s3://jbc-staging-data-wip/jbc-wip/Reference/Covid Infection Survey December 2020 UK BUC/Covid_Infection_Survey__December_2020__UK_BUC.shp')

# add the scores
gdf = gdf.set_index('CIS20CD').join(score).reset_index()
gdf = gdf.set_crs('epsg:27700')

# model fit
fig, ax = plt.subplots(figsize=(8,8))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="3%", pad=0.2)
ax = gdf.plot(ax=ax, column='mae_log10', edgecolor='k', cmap='summer')
ax.axis('off')
# thanks Mark for this :)
ctx.add_basemap(ax=ax, crs=27700, source='https://api.mapbox.com/styles/v1/msricketts-jbc/ckninsxcj0oba17mo04t8knqn/tiles/256/{z}/{x}/{y}@2x?access_token=pk.eyJ1IjoibXNyaWNrZXR0cy1qYmMiLCJhIjoiY2tjZDZvZmdrMGJpazJzcHVmZ3RvNmpuOSJ9.F5q4Qdh_Uv32o6TY3EhNQg')
#cx.add_basemap(ax, source=cx.providers.Stamen.TerrainBackground)
sm = plt.cm.ScalarMappable(cmap='summer', norm=plt.Normalize(vmin=gdf.mae_log10.min(), vmax=gdf.mae_log10.max()))
cbr = fig.colorbar(sm, cax=cax, label='MAE (log10 %)')

#cbr.ax.yaxis.set_ticks([0,1]) 
#cbr.ax.yaxis.set_ticklabels(['All grab','All composite']) 
#plt.savefig('figures/map_mae_large_log10.png', dpi=300)

#### Contribution of variables

In [None]:
# drop 1 variable and recalculate the scores to see how much that variable matters

x, y = Dataset(df.set_index(['CIS20CD', 'date']), vars_all, 'median_prob', input_offset=0.001).prepare_no_split()
cis_all = x.reset_index().CIS20CD.unique()
scores_vars = pd.DataFrame(index=cis_all, columns=vars_all)

#for var in vars_all:
    
#    vars_sel = vars_all.copy()
#    vars_sel.remove(var)
#    print(len(vars_sel))
    
    # run the model on all but one variable
#    dataset = Dataset(df.set_index(['CIS20CD', 'date']), vars_sel, 'median_prob', input_offset=0.001)
#    x, y = dataset.prepare_no_split()
#    dict_pred = predict_kfold(XGBRegressor(), x, y, n_splits=20)
    
    # calculate model predictions per cis
#    scores_vars.loc[:, var] = get_scores_cis(x, y, dict_pred)

#scores_vars.reset_index().rename(columns={'index':'CIS20CD'}).to_csv('results/scores_vars.csv', index=False)
#scores_overall.reset_index().rename(columns={'index':'CIS20CD'}).to_csv('results/scores_overall.csv', index=False)

scores_vars = pd.read_csv(data_folder+'results/scores_vars.csv')
#scores_overall = pd.read_csv('results/scores_overall.csv')
scores_vars

In [None]:
# plot scores by removed variable

plt.figure(figsize=(8,4))
sns.boxplot(data=scores_vars.melt(id_vars='CIS20CD'), y='variable', x='value', color='skyblue')
plt.ylabel('Variable removed')
plt.xlabel('MAE (log10 %)')
plt.tight_layout()
#plt.savefig('figures/removing_variables.png', dpi=300)

In [None]:
# create columns with subtracted scores
#change_in = scores_compare.copy()
#rel_change_in = scores_compare.copy()

#for col in scores_compare.columns:
#    change_in[col] = scores_compare[col] - scores_compare['all_vars']
#    rel_change_in[col] = ((scores_compare[col] - scores_compare['all_vars']) / scores_compare['all_vars']) * 100
#change_in = change_in.drop('all_vars',axis=1).reset_index().melt(id_vars='CIS20CD').sort_values('CIS20CD')
#rel_change_in = rel_change_in.drop('all_vars',axis=1).reset_index().melt(id_vars='CIS20CD').sort_values('CIS20CD')
#rel_change_in

# plot the change in score per CIS
#fig, axes = plt.subplots(14,6, figsize=(12,36), sharex=True, sharey=True)
#a = 0

#for cis, ax in zip(scores_compare.index, axes.flatten()):
    
#    vrs = rel_change_in.query('CIS20CD==@cis').variable
#    val = rel_change_in.query('CIS20CD==@cis').value
    
#    mask1 = val < 0
#    mask2 = val >= 0

#    ax.barh(vrs[mask1], val[mask1], color = 'green')
#    ax.barh(vrs[mask2], val[mask2], color = 'red')
    
#    ax.invert_yaxis()
    
    # add the MAE
#    ax.set_title(np.round(scores_compare.loc[cis, 'all_vars'],2))
    
#plt.tight_layout()
#plt.savefig('figures/variable_importance.png', dpi=300)

## Estimates using out-of-sample wastewater data

In [None]:
# df_total excluding initial nans in the cis data
start_date_train = pd.Timestamp('2020-09-03') # start of cis estimate availability
start_date_test = pd.Timestamp('2021-02-11') # end of sub-regional estimates

df_pred = df_total.query('date>=@start_date_train')
df_pred

In [None]:
# run model on out-of-sample wastewater data

dataset = Dataset(df_pred.set_index(['CIS20CD', 'date']), vars_all, 'median_prob', input_offset=0.001)
x_train, x_test, y_train, y_test = dataset.temporal_split(start_date_test=start_date_test)
dict_pred_test = XGBRegressor().fit(x_train, y_train).predict(x_test)


In [None]:
# run over seen values as well
dataset = Dataset(df.set_index(['CIS20CD', 'date']), vars_all, 'median_prob', input_offset=0.001)
x, y = dataset.prepare_no_split() # what does this mean
dict_training_period = predict_kfold(XGBRegressor(), x, y, n_splits=20)

x_test['predictions'] = dict_pred_test
x_train['predictions'] = dict_training_period
x_train

In [None]:
fig, axes = plt.subplots(14,6, figsize=(12,20), sharex=True, sharey=True)
a = 0

for cis, ax in zip(x_test.reset_index().CIS20CD.unique(), axes.flatten()):
    
    # wastewater training predictions
    ytr = x_train.reset_index().query('CIS20CD==@cis').predictions
    xtr = x_train.reset_index().query('CIS20CD==@cis').date
    ax.plot(xtr, ytr, color='lightpink', label='training')
    
    # cis values for training
    ycis = y.reset_index().query('CIS20CD==@cis').median_prob
    xcis = y.reset_index().query('CIS20CD==@cis').date
    ax.plot(xcis, ycis, color='C0', label='CIS')
       
    # predictions
    ypred = x_test.reset_index().query('CIS20CD==@cis').predictions
    xpred = x_test.reset_index().query('CIS20CD==@cis').date
    ax.plot(xpred, ypred, color='C1', label='predictions')
    #ax.plot(df_cis_14day.percent_positive)
    
    ax.tick_params(axis='x', labelrotation=90)
    
    score.loc[cis, 'p95'] = compute_95_error(10**ycis, 10**ypred)
    score.loc[cis, 'p95_log10'] = compute_95_error(ycis, ypred)
    
    ax.set_title(cis)
    
plt.tight_layout()
#plt.savefig('figures/per_cis_prediction.png', dpi=300)

In [None]:

regions = ww.load_cis().rename(columns={'CIS.name':'CIS20CD'})[['CIS20CD', 'RGN19CD']].drop_duplicates().merge(ww.read_dataset('LOOKUP_REGISTER_ONS_GEO_CODE')\
                                        .set_index('region_code')[['region_name']].dropna().drop_duplicates().reset_index().rename(columns={'region_code': 'RGN19CD'}))
df_cis_daily = ww.load_cis(daily=True).rename(columns={'CIS.name':'CIS20CD'}).merge(regions).rename(columns={'region_name':'region'})
df_cis_daily

In [None]:
def aggregate_to_region(x_test):
    
    # add cis sub-regional predicted values to the testing values
    x_test = x_test.reset_index().merge(regions).rename(columns={'region_name':'region'}).reset_index()

    # add subregional population
    #cis_pop = ww.load_cis_to_site_lookup().groupby('CIS20CD').first()[['cis_population']]
    #x_test = x_test.set_index('CIS20CD').join(cis_pop).reset_index()

    # take weighted average
    vars_to_average = ['predictions']
    average_on = 'cis_population'
    def fn(s):
        weights = s[average_on]
        return pd.Series({ col : np.average(s[col], weights=weights) for col in vars_to_average})

    predictions_by_region = x_test.groupby(['region','date']).apply(fn).reset_index()

    # align to daily
    min_date = predictions_by_region.date.min()
    max_date = predictions_by_region.date.max()
    dates = pd.DataFrame(None, index=pd.date_range(start=min_date,end=max_date))
    predictions_by_region = dates.join(predictions_by_region.set_index('date'), how='left')
    
    # interpolate, and take rolling average
    melted_pred = predictions_by_region.interpolate().reset_index().rename(columns={'index':'date'})#.set_index(['date','region'])
    #melted_pred = predictions_by_region.interpolate().reset_index().melt(id_vars='index').rename(columns={'index':'date'}).set_index(['date','variable'])
    #melted_pred7 = predictions_by_region.interpolate().rolling(7).mean().reset_index().melt(id_vars='index').rename(columns={'index':'date'}).set_index(['date','variable']).rename(columns={'value':'pred7'})
    #melted_pred14 = predictions_by_region.interpolate().rolling(14).mean().reset_index().melt(id_vars='index').rename(columns={'index':'date'}).set_index(['date','variable']).rename(columns={'value':'pred14'})
    #melted_pred = melted_pred.join(melted_pred7).join(melted_pred14)

    return melted_pred, predictions_by_region


melted_pred, _ = aggregate_to_region(x_test)
melted_train, _ = aggregate_to_region(x_train) 
#melted_train
melted_pred

In [None]:
# load cis data
df_cis_14day = pd.read_csv(data_folder+'cis/df_cis_14day.csv', parse_dates=['date_start','date_end'], dayfirst=True)
df_cis_daily = pd.read_csv(data_folder+'cis/df_cis_daily.csv', parse_dates=['date'], dayfirst=True)

In [None]:
# plot the cis against regionally averaged wastewater concentrations
fig, axes = plt.subplots(3,3,figsize=(14,8),sharex=True, sharey=True, constrained_layout=True)
a = 0
for (train_region, train_group), (region, group), ax in zip(melted_train.groupby('region'), melted_pred.groupby('region'), axes.flatten()):
    
    # plot training
    ax.plot(train_group.date, 10**train_group.predictions, color='lightpink', label='wastewater training')
    #ax.plot(train_group.date, 10**train_group.pred7, color='red')
    
    # plot predictions
    ax.plot(group.date, 10**group.predictions, color='C0', label='wastewater testing')
    #ax.plot(group.date, 10**group.pred7, color='C0', label='wastewater modelled estimates (7-day rolling)')

    ax.set_title(region)
    
    # plot the 14 day evaluation data from CIS
    eval_cis_14day = df_cis_14day.query('region==@region').sort_values('date_end')
    ax.plot(eval_cis_14day.date_end, eval_cis_14day.percent_positive, color='C2', label='CIS (14-day)')
    ax.fill_between(eval_cis_14day.date_end, eval_cis_14day.lower_ci, eval_cis_14day.upper_ci, alpha=0.1, color='C2')
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %y'))
    
    # plot daily evaluation data from CIS
    eval_cis_daily = df_cis_daily.query('region==@region').sort_values('date')
    ax.plot(eval_cis_daily.date, eval_cis_daily.percent_positive, color='C1', label='CIS (daily)')
    ax.fill_between(eval_cis_daily.date, eval_cis_daily.lower_ci, eval_cis_daily.upper_ci, alpha=0.1, color='C1')
    ax.set_ylabel('Modelled % positive')

    if a == 0:
        ax.legend()
        a += 1
    
#*** sup-ylabel
#plt.savefig('figures/estimates_vs_cis_regional_with_training.png', dpi=300)

#### Viewing the input wastewater data against CIS

In [None]:
# apply regional averaging to concentrations
vars_to_average = ['sars_cov2_gc_l_mean']
average_on = 'catch_cis_population'

def fn(s):
    weights = s[average_on]
    return pd.Series({ col : np.average(s[col], weights=weights) for col in vars_to_average})

predictions_by_region = pd.DataFrame(index=df_total.sort_values('date').date.unique())
df_total_sars_agg = df_total.merge(regions).groupby(['date','region_name']).apply(fn).reset_index().sort_values('date')

In [None]:
# plot CIS and regionally averaged concentrations
fig, axes = plt.subplots(3,3,figsize=(14,8), sharex=True, sharey=True, constrained_layout=True)
a = 0
for (region, group), ax in zip(df_total_sars_agg.groupby('region_name'), axes.flatten()):
    

    # plot the wastewater data, aggregated to regional
    ax2 = ax.twinx()
    ax2.scatter(df_total.merge(regions).query('region_name==@region').date, df_total.merge(regions).query('region_name==@region').sars_cov2_gc_l_mean/1e3, s=1, color='lightgrey', label='Individual samples')
    
    # plot the 14 day evaluation data from CIS
    eval_cis_14day = df_cis_14day.query('region==@region').sort_values('date_end')
    ax.plot(eval_cis_14day.date_end, eval_cis_14day.percent_positive, color='C2', label='CIS (14-day)')
    ax.fill_between(eval_cis_14day.date_end, eval_cis_14day.lower_ci, eval_cis_14day.upper_ci, alpha=0.1, color='C2')
  
    # plot daily evaluation data from CIS
    eval_cis_daily = df_cis_daily.query('region==@region').sort_values('date')
    ax.plot(eval_cis_daily.date, eval_cis_daily.percent_positive, color='C1', label='CIS (daily)')
    ax.fill_between(eval_cis_daily.date, eval_cis_daily.lower_ci, eval_cis_daily.upper_ci, alpha=0.1, color='C1')    
    
    # plot the aggregated wastewater data
    ax2.plot(group.date, group.sars_cov2_gc_l_mean.rolling(7).median()/1e3, color='skyblue', label='7-day rolling conc. (gc/ml)')
    ax2.plot(group.date, group.sars_cov2_gc_l_mean.rolling(21).median()/1e3, color='C0', lw=2, label='21-day rolling conc. (gc/ml)')
    #ax2.plot(group.date, group.sars_cov2_gc_l_mean/1e3, label='Daily average conc. (gc/ml)')

    # to log
    ax.set_yscale('log')
    ax2.set_yscale('log')
    #ax2.set_ylim(top=500)
    
    # axes format
    ax.set_title(region)
    ax.set_ylabel('Modelled % positive')
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %y'))
    ax2.xaxis.set_major_formatter(mdates.DateFormatter('%b %y'))
      
    if a == 0:
        ax.legend()
        ax2.legend(loc='lower left')
        a += 1
    
#*** sup-ylabel
#plt.savefig('figures/regional_cis_ww_log_21day.png', dpi=300)

#### Regression to the mean?

In [None]:
# 
start_date_train = pd.Timestamp('2020-12-15') # start of the second wave
start_date_test = pd.Timestamp('2021-02-11') # end of sub-regional estimates

df_total = pd.read_csv(data_folder+'input_df/df_total.csv', parse_dates=['date'], dayfirst=True)
df_pred = df_total.query('date>=@start_date_train')


In [None]:
# run model on out-of-sample wastewater data
dataset = Dataset(df_pred.set_index(['CIS20CD', 'date']), vars_all, 'median_prob', input_offset=0.001)
x_train, x_test, y_train, y_test = dataset.temporal_split(start_date_test=start_date_test)
dict_pred_test = XGBRegressor().fit(x_train, y_train).predict(x_test)

# run over seen values as well
dataset = Dataset(df.set_index(['CIS20CD', 'date']), vars_all, 'median_prob', input_offset=0.001)
x, y = dataset.prepare_no_split() # what does this mean
dict_training_period = predict_kfold(XGBRegressor(), x, y, n_splits=20)

x_test['predictions'] = dict_pred_test
x_train['predictions'] = dict_training_period
x_train

In [None]:
melted_pred, _ = aggregate_to_region(x_test)
melted_train, _ = aggregate_to_region(x_train) 

In [None]:
# plot the cis against regionally averaged wastewater concentrations
fig, axes = plt.subplots(3,3,figsize=(14,8),sharex=True, sharey=True, constrained_layout=True)
a = 0
for (train_region, train_group), (region, group), ax in zip(melted_train.groupby('region'), melted_pred.groupby('region'), axes.flatten()):
    
    # plot training
    ax.plot(train_group.date, 10**train_group.predictions, color='lightpink', label='wastewater training')
    #ax.plot(train_group.date, 10**train_group.pred7, color='red')
    
    # plot predictions
    ax.plot(group.date, 10**group.predictions, color='C0', label='wastewater testing')
    #ax.plot(group.date, 10**group.pred7, color='C0', label='wastewater modelled estimates (7-day rolling)')

    ax.set_title(region)
    
    # plot the 14 day evaluation data from CIS
    eval_cis_14day = df_cis_14day.query('region==@region').sort_values('date_end')
    ax.plot(eval_cis_14day.date_end, eval_cis_14day.percent_positive, color='C2', label='CIS (14-day)')
    ax.fill_between(eval_cis_14day.date_end, eval_cis_14day.lower_ci, eval_cis_14day.upper_ci, alpha=0.1, color='C2')
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %y'))
    
    # plot daily evaluation data from CIS
    eval_cis_daily = df_cis_daily.query('region==@region').sort_values('date')
    ax.plot(eval_cis_daily.date, eval_cis_daily.percent_positive, color='C1', label='CIS (daily)')
    ax.fill_between(eval_cis_daily.date, eval_cis_daily.lower_ci, eval_cis_daily.upper_ci, alpha=0.1, color='C1')
    ax.set_ylabel('Modelled % positive')

    if a == 0:
        ax.legend()
        a += 1
    
#*** sup-ylabel
#plt.savefig('figures/estimates_vs_cis_regional_with_training.png', dpi=300)

#### Adding previous CIS data to improve the model

In [None]:
# since we expect to have regional level CIS, the available (historical) values 
# could be an additional input to the model to help predict particularly low/high prevalence!
# to explore -> Anna