# Evaluate fusion methods

In [None]:
# standard libraries
import os
import pickle
from datetime import datetime
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# third party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm

# first party
from config import Config
from data_containers import LocationSeries, SensorConfig, Sensors

In [None]:
class Directory:
    ROOT = '../data/'
    
    def __init__(self, gt_indicator = Config.ground_truth_indicator):
        self.gt = gt_indicator
        self.infections_root_dir = './results/ntf_tapered/'
        self.indicator_root_dir = os.path.join(Directory.ROOT, 'indicators')
        self.sensor_root_dir = os.path.join(Directory.ROOT, 'sensors')
        self.jhu_path = os.path.join(
            Directory.ROOT, f'jhu-csse_confirmed_incidence_prop/{self.gt.source}_{self.gt.signal}')
       
    def deconv_gt_file(self, as_of):
        return os.path.join(self.infections_root_dir, f'as_of_{as_of}.p')
    
    def indicator_file(self, indicator, as_of):
        return os.path.join(
            self.indicator_root_dir,
            f'{indicator.source}-{indicator.signal}_{as_of}.p')
    
    def sensor_file(self, config, as_of):
        return os.path.join(
            self.sensor_root_dir,
            f'{config.source}_{config.signal}_{as_of}.p')
    
    def maybe_load_file(self, file_name, verbose=False):
        if not os.path.isfile(file_name):
            if verbose:
                print(file_name, 'does not exist')
            return False
        
        return pickle.load(open(file_name, 'rb'))
    
    def maybe_write_file(self, data, file_name, overwrite=False, verbose=False):
        if os.path.isfile(file_name) and not overwrite:
            if verbose:
                print(file_name, 'exists')
            return False
        
        dir_name = os.path.dirname(file_name)
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)   
        
        pickle.dump(data, open(file_name, 'wb'))
        return True
    
    @staticmethod
    def exists(file_name, overwrite=False):
        if os.path.isfile(file_name) and not overwrite:
            return True
        return False
    
def conform(location_series):
    if location_series is None:
        return None
    
    if location_series.data is None or np.isnan(location_series.values).all():
        return None
    
    if isinstance(location_series.dates[0], datetime):
        location_series.data = dict(zip([d.date() for d in location_series.dates],
                                        location_series.values))
    return location_series

In [None]:
# Set-up all the method configurations.
# The 'fast_' versions use a subset of lower-latency sensors
infections_config = SensorConfig('jhu-csse', 'confirmed_incidence_prop','deconv_infections', 2)
ar3_config = SensorConfig('ar3', 'ntf_tapered_infections', 'ar3', lag=1)
simple_avg_config = SensorConfig('all', 'simple_average', 'average', 4)
fast_simple_avg_config = SensorConfig('fast_all', 'simple_average', 'average', 1)
simple_avg_no_google_aa_config = SensorConfig('all_no_google_aa', 'simple_average', 'average', 4)
fast_simple_avg_no_google_aa_config= SensorConfig('fast_all_no_google_aa', 'simple_average', 'average', 1)

simple_reg_config = SensorConfig('all', 'simple_reg', 'regression', 4)
fast_simple_reg_config = SensorConfig('fast_all', 'simple_reg', 'regression', 1)

ridge_config = SensorConfig('all', 'ridge', 'ridge', 4)
fast_ridge_config = SensorConfig('fast_all', 'ridge', 'ridge', 1)

lasso_config = SensorConfig('all', 'lasso', 'lasso', 4)
fast_lasso_config = SensorConfig('fast_all', 'lasso', 'lasso', 1)

kf_sf_config = SensorConfig('all', 'kf_sf', 'kf_sf', 4)
fast_kf_sf_config = SensorConfig('fast_all', 'kf_sf', 'kf_sf', 1)

max_eval_lag = 10
directory = Directory(infections_config)
states = sorted(Config.states - set(['dc']))
evaluation_geos = sorted(set(states) | (Config.top_counties - set(['11001'])))
get_geo_type = lambda x: 'county' if x.isnumeric() else 'state'
state_map = {}
for state in states:
    fips_code = Config.state_fips[state]
    state_map[state] = [geo for geo in evaluation_geos if geo == state or geo[:2] == fips_code] 

# testing dates are those not used in previous experiments
as_of_range = Config.as_of_range
evaluation_as_of_range = [d for d in as_of_range if d not in set(Config.every_10_as_of_range)]

## Load sensors

In [None]:
def load_sensors(as_of_range, sensor_configs, directory):
    sensors = Sensors()
    for as_of in tqdm(as_of_range):
        for sensor, config in sensor_configs.items():
            sensor_file = directory.sensor_file(config, as_of)
            sensor_data = directory.maybe_load_file(sensor_file)
            if not sensor_data:
                continue
            for k in sensor_data.keys():
                vals = conform(sensor_data[k])
                if vals is None:
                    continue
                sensors.add_data(as_of, sensor, sensor_data[k].geo_value, vals)
    return sensors

indicator_sensors = {
    'fb_cliic': Config.fb_cliic,
    'dv_cli': Config.dv_cli,
    'google_aa': Config.google_aa,
    'chng_cli': Config.chng_cli,
    'chng_covid': Config.chng_covid,
    'ar3': ar3_config,
}

fusion_sensors = {
    'simple_avg': simple_avg_config,
    'fast_simple_avg': fast_simple_avg_config,
    'simple_avg_no_google_aa': simple_avg_no_google_aa_config,
    'fast_simple_avg_no_google_aa': fast_simple_avg_no_google_aa_config,
    'simple_reg': simple_reg_config,
    'fast_simple_reg': fast_simple_reg_config,
    'ridge': ridge_config,   
    'fast_ridge': fast_ridge_config,
    'lasso': lasso_config,
    'fast_lasso': fast_lasso_config,
    'kf_sf': kf_sf_config,
    'fast_kf_sf': fast_kf_sf_config,
}

indicator_sensor_data = load_sensors(as_of_range, indicator_sensors, directory)
fusion_sensor_data = load_sensors(as_of_range, fusion_sensors, directory)

In [None]:
# replace kf_sf predictions in states where there are no subcounties
for state in states:
    if len(state_map[state]) == 1:
        present_dates = fusion_sensor_data.data.keys()
        for as_of in present_dates:
            data = fusion_sensor_data.data[as_of]
            if 'kf_sf' in data:
                fusion_sensor_data.data[as_of]['kf_sf'][state] = data['ridge'][state]
            if 'fast_kf_sf' in data:
                fusion_sensor_data.data[as_of]['fast_kf_sf'][state] = data['fast_ridge'][state]
        print(state)

In [None]:
# add in natural trend filtering estimates
for as_of in tqdm(as_of_range):
    infections_file = directory.deconv_gt_file(as_of)
    data = directory.maybe_load_file(infections_file)
    assert data is not None, as_of

    for k in data.keys():
        vals = conform(data[k])
        if vals is None:
            continue
        indicator_sensor_data.add_data(as_of, 'ntf_tapered', k, vals)
        fusion_sensor_data.add_data(as_of, 'ntf_tapered', k, vals)
        
indicator_sensors['ntf_tapered'] = infections_config
fusion_sensors['ntf_tapered'] = infections_config

## Compute errors

In [None]:
truth = pickle.load(open('../data/tf_ground_truths.p', 'rb'))
df = lambda ls: pd.DataFrame.from_dict(ls.data, orient='index')

In [None]:
for as_of in tqdm(as_of_range):
    output = []
    for geo in evaluation_geos:
        x = df(truth[geo])
        for sensor, config in indicator_sensors.items():
            x_hat = indicator_sensor_data.get_data(as_of, sensor, geo)
            if x_hat is None:
                continue
            x_hat = df(x_hat)
            err = np.abs((x - x_hat).dropna()).reset_index()
            err.columns = ['dates', 'abs_err']
            err['as_of'] = as_of
            err['geo'] = geo
            err['sensor'] = sensor
            err['lag'] = (pd.to_datetime(as_of) - err.dates).dt.days 
            err = err[err.lag.le(max_eval_lag)]
            output.append(err)
    output = pd.concat(output, ignore_index=True)
    pickle.dump(output, open(f'./results/errors/indicator_{as_of}.p', 'wb'))

In [None]:
for as_of in tqdm(as_of_range):
    out_file = f'./results/errors/fusion_{as_of}.p'
    output = []
    for geo in evaluation_geos:
        x = df(truth[geo])
        for sensor, config in fusion_sensors.items():
            x_hat = fusion_sensor_data.get_data(as_of, sensor, geo)
            if x_hat is None:
                continue
            x_hat = df(x_hat)
            err = np.abs((x - x_hat).dropna()).reset_index()
            err.columns = ['dates', 'abs_err']
            err['as_of'] = as_of
            err['geo'] = geo
            err['sensor'] = sensor
            err['lag'] = (pd.to_datetime(as_of) - err.dates).dt.days 
            err = err[err.lag.le(max_eval_lag)]
            output.append(err)
    output = pd.concat(output, ignore_index=True)
    pickle.dump(output, open(out_file, 'wb'))

## Plotting maps

In [None]:
legend_map = {
    'ntf_tapered': 'NTF (tapered)',
    'simple_avg': 'Simple average (+claims)',
    'fast_simple_avg': 'Simple average',
    'simple_avg_no_google_aa': 'Simple average (+claims) (-google-aa)',
    'fast_simple_avg_no_google_aa': 'Simple average (-google-aa)',
    'simple_reg': 'Simple regression (+claims)',
    'fast_simple_reg': 'Simple regression',
    'ridge': 'Ridge (+claims)',
    'fast_ridge': 'Ridge',
    'lasso': 'Lasso (+claims)',
    'fast_lasso': 'Lasso',
    'kf_sf': 'KF-SF (+claims)',
    'fast_kf_sf': 'KF-SF',
    'ar3': 'AR(3)', 
    'dv_cli': 'DV-CLI',
    'chng_cli': 'CHNG-CLI',
    'chng_covid': 'CHNG-COVID',
    'fb_cliic': 'CTIS-CLIIC',
    'google_aa': 'Google-AA',
}

color_map = {
    'NTF (tapered)': '#949494',
    'Simple average': '#0173B2',
    'Simple regression': '#DE8F05',
    'Ridge': '#1f7a34',
    'Lasso': '#56B4E9',
    'KF-SF': '#CC78BC',
    'AR(3)': '#CA9161',
    'DV-CLI': '#029E73',
    'CHNG-CLI': 'brown',
    'CHNG-COVID': 'tab:purple',
    'CTIS-CLIIC': '#FBAFE4',
    'Google-AA': '#ECE133',
}

marker_map = {
    'NTF (tapered)': 'd',
    'Simple average': 'o',
    'Simple regression': 'o',
    'Ridge': 'o',
    'Lasso': 'o',
    'KF-SF': 'o',
    'AR(3)': 'd',
    'DV-CLI': 'd',
    'CHNG-CLI': 'd',
    'CHNG-COVID': 'd',
    'CTIS-CLIIC': 'd',
    'Google-AA': 'd',
}

In [None]:
sensor_latency = {
    1: ['CTIS-CLIIC', 'Google-AA', 'AR(3)',
        'Simple average', 'Simple regression',
        'Ridge', 'Lasso', 'KF-SF']
}
sensor_latency[2] = sensor_latency[1] + ['NTF (tapered)']
sensor_latency[3] = sensor_latency[2]
sensor_latency[4] = sensor_latency[2] + ['CHNG-CLI', 'CHNG-COVID', 'DV-CLI']
for lag in range(5, max_eval_lag + 1):
    sensor_latency[lag] = sensor_latency[4]

In [None]:
sensor_err_df = []
fusion_err_df = []
for as_of in tqdm(evaluation_as_of_range):
    sensor_err = pickle.load(open(f'./results/errors/indicator_{as_of}.p', 'rb'))
    fusion_err = pickle.load(open(f'./results/errors/fusion_{as_of}.p', 'rb'))
    sensor_err_df.append(sensor_err)
    fusion_err_df.append(fusion_err)
    
sensor_err_df = pd.concat(sensor_err_df, ignore_index=True)
fusion_err_df = pd.concat(fusion_err_df, ignore_index=True)

In [None]:
sensor_err_df = sensor_err_df.replace(legend_map)
fusion_err_df = fusion_err_df.replace(legend_map)
all_err_df = pd.concat([sensor_err_df, fusion_err_df], ignore_index=True)

In [None]:
all_err_df = all_err_df[all_err_df.lag.le(max_eval_lag)]

In [None]:
full_index = pd.MultiIndex.from_product([evaluation_geos, evaluation_as_of_range], names=['geo', 'as_of'])

## Sensor availability

In [None]:
plot_sensors = [
    'NTF (tapered)',  
    'AR(3)',
    'CTIS-CLIIC',
    'CHNG-CLI',
    'CHNG-COVID',
    'DV-CLI', 
    'Google-AA',
    'Simple average',
]

def get_availability(all_err_df, full_index, sensor_list):
    lag_sensor_availability = {}
    available_index = full_index.copy()
    for sensor in sensor_list:
        tmp = all_err_df[all_err_df.sensor.eq(sensor)]
        tmp = tmp.set_index(['geo', 'as_of'])
        available_index = available_index.intersection(tmp.index)
        
    return available_index

def get_availability_by_lag(lag, all_err_df, full_index, sensor_list, verbose=True):
    lag_sensor_availability = {}
    available_index = full_index.copy()
    for sensor in sensor_list:
        tmp = all_err_df[all_err_df.sensor.eq(sensor) & all_err_df.lag.eq(lag)]
        tmp = tmp.set_index(['geo', 'as_of'])
        sensor_available_index = full_index.intersection(tmp.index)
        available_index = available_index.intersection(tmp.index)
        lag_sensor_availability[sensor] = len(sensor_available_index) / len(full_index)
        if verbose:
            print(lag,
                f'{sensor:20s}{100*len(sensor_available_index)/len(full_index):3.3f}',
                f'{100*len(available_index)/len(full_index):3.3f}')
        
    return available_index, lag_sensor_availability

sensor_availability = {}
total_availability = {}
availability = {}
total_availability_no_google_aa = {}
availablity_no_google_aa = {}
print("\t\t\tSensor\tTotal")
for lag in range(1, max_eval_lag + 1):
    print(lag)
    sensor_list = [s for s in plot_sensors if s in sensor_latency[lag]]
    available_index, lag_sensor_availability = get_availability_by_lag(lag, all_err_df, full_index, sensor_list)
    sensor_availability[lag] = lag_sensor_availability
    total_availability[lag] = len(available_index) / len(full_index)
    availability[lag] = available_index
    print('--- no google aa ---')
    sensor_list = [s for s in sensor_list if s != 'Google-AA']
    available_index, lag_sensor_availability = get_availability_by_lag(lag, all_err_df, full_index, sensor_list)
    total_availability_no_google_aa[lag] = len(available_index) / len(full_index)
    availablity_no_google_aa[lag] = available_index

In [None]:
tmp = pd.DataFrame.from_dict(sensor_availability)
tmp = pd.concat([
    tmp.reindex(plot_sensors),
    pd.DataFrame({
        'lag': total_availability.keys(), 
         'Intersection': total_availability.values()}).set_index('lag').T,
      pd.DataFrame({
        'lag': total_availability_no_google_aa.keys(), 
         'Intersection without\nGoogle-AA': total_availability_no_google_aa.values()}).set_index('lag').T,
])

plt.figure(figsize=(12, 4))
sns.heatmap(tmp, annot=True, cmap="BuPu", linewidths=.25, vmin=0,
            cbar_kws={'label': 'Proportion', 'shrink': 0.8})
plt.xlabel('Days back from nowcast time')
plt.title('Sensor availability during estimation period\n')
plt.tight_layout()
plt.savefig('./figures/availability.pdf', bbox_inches='tight')
plt.show()

## MAE line plot and rank plots

In [None]:
label_legend_map = {
    'Simple average (+claims)': 'Simple average',
    'Simple average (+claims) (-google-aa)': 'Simple average',
    'Simple average (-google-aa)': 'Simple average',
    'Simple regression (+claims)': 'Simple regression',
    'Ridge (+claims)': 'Ridge',
    'Lasso (+claims)': 'Lasso',
    'KF-SF (+claims)': 'KF-SF',
}
strip_claims = lambda s: s.replace(' (+claims)', '')
strip_google_aa = lambda s: s.replace(' (-google-aa)', '')

def plot(err_df, full_index, sensor_list, line_lags, rank_lags, out_name):
    available_index = get_availability(err_df, full_index, sensor_list)
    tmp = err_df[err_df.lag.isin(line_lags) & err_df.sensor.isin(sensor_list)].set_index(['geo', 'as_of'])
    plot_df = tmp.loc[available_index].reset_index()
    plot_df.replace(label_legend_map, inplace=True)
    stripped_hue_order = [strip_google_aa(strip_claims(s)) for s in version['sensor_list']]
    
    # create lineplot
    plt.figure(figsize=(5, 5))
    sns.lineplot(data=plot_df,
        x='lag',
        y='abs_err',
        hue='sensor',
        hue_order=stripped_hue_order,
        marker='o',
        markers=marker_map,
        dashes=False, 
        palette=color_map,
        n_boot=500,
        err_kws={'alpha':0.1}
    )
    plt.legend(fontsize=10)
    plt.xlabel('Days back from nowcast time')
    plt.ylabel('Mean absolute error')
    plt.xticks(line_lags[::2])
    plt.tight_layout()
    plt.savefig(f'./figures/lineplot_{out_name}.pdf')
    plt.show()
    
    # create rank plot
    plot_df = []
    for lag in rank_lags:
        available_index, _ = get_availability_by_lag(lag, err_df, full_index, sensor_list, verbose=False)
        tmp = err_df[err_df.lag.eq(lag) & err_df.sensor.isin(sensor_list)].set_index(['geo', 'as_of'])
        plot_df.append(tmp.loc[available_index].reset_index())
    plot_df = pd.concat(plot_df, ignore_index=True)
    plot_df.replace(label_legend_map, inplace=True)
    
    sensor_rank_df = plot_df.groupby(['dates', 'geo', 'sensor', 'lag']).abs_err.mean().reset_index()
    sensor_rank_df['rank'] = sensor_rank_df.groupby(['dates', 'geo', 'lag']).abs_err.rank('dense', ascending=True)
    plt.figure(figsize=(6, 6))
    g = sns.histplot(
        data=sensor_rank_df,
        x='rank',
        hue='sensor',
        hue_order=stripped_hue_order,
        element='bars',
        discrete=True,
        multiple='fill',
        palette=color_map,
        legend=True)

    g.legend(handles=g.legend_.legendHandles, labels=[t.get_text() for t in g.legend_.texts],
              title=None,
              bbox_to_anchor=(1, 1), 
             loc='upper left', fontsize=10)
    g.set_xticks(np.arange(1, len(sensor_list)+1))
    g.set_xticklabels(np.arange(1, len(sensor_list)+1))
    plt.ylabel('Proportion')
    plt.xlabel('Rank')
    plt.savefig(f'./figures/rankplot_{out_name}.pdf', bbox_inches='tight')
    plt.show()

versions = {
    'claims_no_google_aa': {
        'sensor_list': ['NTF (tapered)', 'AR(3)', 'CTIS-CLIIC',
                       'CHNG-CLI','CHNG-COVID','DV-CLI', 'Simple average (+claims) (-google-aa)'],
        'line_lags': np.arange(1, max_eval_lag+1),
        'rank_lags': [4, 5, 6, 7, 8],
    },
    'claims': {
        'sensor_list': ['NTF (tapered)', 'AR(3)', 'CTIS-CLIIC', 'CHNG-CLI',
               'CHNG-COVID', 'DV-CLI','Google-AA', 'Simple average (+claims)'],
        'line_lags': np.arange(1, max_eval_lag+1),
        'rank_lags': [4, 5, 6, 7, 8],
    },
    'no_claims_no_google_aa': {
        'sensor_list': ['NTF (tapered)', 'AR(3)', 'CTIS-CLIIC', 
                               'Simple average (-google-aa)'],
        'line_lags': np.arange(1, max_eval_lag+1),
        'rank_lags': [2, 3, 4, 5, 6],
    },
    'no_claims': {
        'sensor_list': ['NTF (tapered)', 'AR(3)','CTIS-CLIIC', 
                            'Google-AA', 'Simple average'],
        'line_lags': np.arange(1, max_eval_lag+1),
        'rank_lags': [2, 3, 4, 5, 6],
    },
}
    
full_index = pd.MultiIndex.from_product([evaluation_geos, evaluation_as_of_range], names=['geo', 'as_of'])
for version_key, version in versions.items():
    print(version_key)
    err_df = all_err_df[all_err_df.sensor.isin(version['sensor_list'])]
    plot(err_df, full_index, version['sensor_list'], version['line_lags'], version['rank_lags'], version_key)

## Fusion box plot and rank plots

In [None]:
def map_state_type(x):
    if x < 5:
        return 'Small'
    if x < 15:
        return 'Medium'
    return 'Large'

map_fips_to_state = dict((v,k) for k,v in Config.state_fips.items())

def plot_fusion(err_df, full_index, sensor_list, boxen_lags, rank_lags, out_name):
    available_index = get_availability(err_df, full_index, sensor_list)
    tmp = err_df[err_df.lag.isin(boxen_lags) & err_df.sensor.isin(sensor_list)].set_index(['geo', 'as_of'])
    plot_df = tmp.loc[available_index].reset_index()
    plot_df.replace(label_legend_map, inplace=True)
    
    # boxen plot
    plot_df['state'] = plot_df['geo'].apply(lambda x: map_fips_to_state[x[:2]] if get_geo_type(x) == 'county' else x)
    plot_df['n_geos_in_state'] = plot_df.state.apply(lambda x: len(state_map[x]))
    plot_df_states = plot_df.groupby(['lag', 'sensor', 'state', 'n_geos_in_state']).abs_err.mean().reset_index()
    plot_df_states['state_size'] = plot_df_states.n_geos_in_state.apply(map_state_type)
    
    plt.figure(figsize=(10, 5))
    sns.set_palette("colorblind")
    g = sns.catplot(
        data=plot_df_states,
        x='lag',
        y='abs_err',
        hue='sensor',
        col='state_size', 
        kind='boxen',
        hue_order=[
            'Simple average', 
            'Simple regression',
            'Lasso', 
            'Ridge',
            'KF-SF'],
        col_order=['Small', 'Medium', 'Large'],
        palette=color_map,
        showfliers=False,
        legend=False,
        saturation=1,
    )
    g.set_titles("{col_name} states")
    plt.legend(bbox_to_anchor=(.58, .975), loc='upper left', fontsize=10)
    g.set_xlabels('Days back from nowcast time')
    g.set_ylabels('Mean absolute error')
    plt.tight_layout()
    plt.savefig(f'./figures/boxenplot_{out_name}.pdf')
    plt.show()

    # rank plot
    # for each lag, figure out the intersection of dates
    plot_df = []
    for lag in rank_lags:
        available_index, _ = get_availability_by_lag(lag, err_df, full_index, sensor_list, verbose=False)
        tmp = err_df[err_df.lag.eq(lag) & err_df.sensor.isin(sensor_list)].set_index(['geo', 'as_of'])
        plot_df.append(tmp.loc[available_index].reset_index())
    plot_df = pd.concat(plot_df, ignore_index=True)
    plot_df.replace(label_legend_map, inplace=True)
    ensemble_rank_df = plot_df.groupby(['dates', 'geo', 'sensor', 'lag']).abs_err.mean().reset_index()
    ensemble_rank_df['rank'] = ensemble_rank_df.groupby(['dates', 'geo', 'lag']).abs_err.rank('dense', ascending=True)

    plt.figure(figsize=(6, 6))
    g = sns.histplot(
        data=ensemble_rank_df,
        x='rank',
        hue='sensor',
        hue_order=[strip_claims(s) for s in version['sensor_list']],
        element='bars',
        discrete=True,
        multiple='fill',
        palette=color_map,
        legend=True)

    g.legend(handles=g.legend_.legendHandles, labels=[t.get_text() for t in g.legend_.texts],
              title=None,
              bbox_to_anchor=(1, 1), 
             loc='upper left', fontsize=10)
    plt.ylabel('Proportion')
    plt.xlabel('Rank')
    plt.savefig(f'./figures/rankplot_fusion_{out_name}.pdf', bbox_inches='tight')
    plt.show()
    
versions = {
    'claims': {
        'sensor_list': ['Simple average (+claims)', 'Simple regression (+claims)', 
                        'Ridge (+claims)', 'Lasso (+claims)', 'KF-SF (+claims)'],
        'boxen_lags': np.arange(1, max_eval_lag+1),
        'rank_lags': [4, 5, 6, 7, 8],
    },
    'no_claims': {
        'sensor_list': ['Simple average', 'Simple regression', 'Ridge', 'Lasso', 'KF-SF'],
        'boxen_lags': np.arange(1, max_eval_lag+1),
        'rank_lags': [1, 2, 3, 4, 5],
    },  
}

full_index = pd.MultiIndex.from_product([evaluation_geos, evaluation_as_of_range], names=['geo', 'as_of'])
for version_key, version in versions.items():
    print(version_key)
    err_df = all_err_df[all_err_df.sensor.isin(version['sensor_list'])]
    plot_fusion(err_df, full_index, version['sensor_list'],
                version['boxen_lags'], version['rank_lags'], version_key)