# Sensor fusion

In [None]:
# standard libraries
import itertools
import os
import pickle
from datetime import datetime, timedelta
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# third party
import numpy as np
import matplotlib.pyplot as plt
from pandas import date_range
from sklearn.linear_model import Lasso
from tqdm.notebook import tqdm

# first party
from data_containers import GroundTruth, LocationSeries, SensorConfig, Sensors
from fusion import covariance, fusion
from config import Config

In [None]:
class Directory:
    ROOT = '../data/'
    
    def __init__(self, gt_indicator = Config.ground_truth_indicator):
        self.gt = gt_indicator
        self.infections_root_dir = './results/ntf_tapered/'
        self.indicator_root_dir = os.path.join(Directory.ROOT, 'indicators')
        self.sensor_root_dir = os.path.join(Directory.ROOT, 'sensors')
        self.jhu_path = os.path.join(
            Directory.ROOT, f'jhu-csse_confirmed_incidence_prop/{self.gt.source}_{self.gt.signal}')
       
    def deconv_gt_file(self, as_of):
        return os.path.join(self.infections_root_dir, f'as_of_{as_of}.p')
    
    def indicator_file(self, indicator, as_of):
        return os.path.join(
            self.indicator_root_dir,
            f'{indicator.source}-{indicator.signal}_{as_of}.p')
    
    def sensor_file(self, config, as_of):
        return os.path.join(
            self.sensor_root_dir,
            f'{config.source}_{config.signal}_{as_of}.p')
    
    def maybe_load_file(self, file_name, verbose=False):
        if not os.path.isfile(file_name):
            if verbose:
                print(file_name, 'does not exist')
            return False
        
        return pickle.load(open(file_name, 'rb'))
    
    def maybe_write_file(self, data, file_name, overwrite=False, verbose=False):
        if os.path.isfile(file_name) and not overwrite:
            if verbose:
                print(file_name, 'exists')
            return False
        
        dir_name = os.path.dirname(file_name)
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)   
        
        pickle.dump(data, open(file_name, 'wb'))
        return True
    
    @staticmethod
    def exists(file_name, overwrite=False):
        if os.path.isfile(file_name) and not overwrite:
            return True
        return False
    
def conform(location_series):
    if location_series is None:
        return None
    
    if location_series.data is None or np.isnan(location_series.values).all():
        return None
    
    if isinstance(location_series.dates[0], datetime):
        location_series.data = dict(zip([d.date() for d in location_series.dates],
                                        location_series.values))
    return location_series

In [None]:
# Set-up all the method configurations.
# The 'fast_' versions use a subset of lower-latency sensors
infections_config = SensorConfig('jhu-csse', 'confirmed_incidence_prop','deconv_infections', 2)
ar3_config = SensorConfig('ar3', 'ntf_tapered_infections', 'ar3', lag=1)
simple_avg_config = SensorConfig('all', 'simple_average', 'average', 4)
fast_simple_avg_config = SensorConfig('fast_all', 'simple_average', 'average', 1)
simple_avg_no_google_aa_config = SensorConfig('all_no_google_aa', 'simple_average', 'average', 4)
fast_simple_avg_no_google_aa_config= SensorConfig('fast_all_no_google_aa', 'simple_average', 'average', 1)

simple_reg_config = SensorConfig('all', 'simple_reg', 'regression', 4)
fast_simple_reg_config = SensorConfig('fast_all', 'simple_reg', 'regression', 1)

ridge_config = SensorConfig('all', 'ridge', 'ridge', 4)
fast_ridge_config = SensorConfig('fast_all', 'ridge', 'ridge', 1)

lasso_config = SensorConfig('all', 'lasso', 'lasso', 4)
fast_lasso_config = SensorConfig('fast_all', 'lasso', 'lasso', 1)

kf_sf_config = SensorConfig('all', 'kf_sf', 'kf_sf', 4)
fast_kf_sf_config = SensorConfig('fast_all', 'kf_sf', 'kf_sf', 1)

num_backcast = 15
directory = Directory(infections_config)
states = sorted(Config.states - set(['dc']))
geos = sorted((set(states)| Config.megacounties | Config.top_counties) - set(['11000', '11001']))
evaluation_geos = sorted((set(states) | Config.top_counties) - set(['11001']))
get_geo_type = lambda x: 'county' if x.isnumeric() else 'state'
state_map = {}
state_type_map = {}
for state in states:
    fips_code = Config.state_fips[state]
    state_map[state] = [geo for geo in geos if geo == state or geo[:2] == fips_code] 
    state_type_map[state] = [(geo, get_geo_type(geo)) for geo in state_map[state]]

as_of_range = Config.as_of_range

## Load infections

In [None]:
infections_data = GroundTruth()
for as_of in tqdm(as_of_range):
    infections_file = directory.deconv_gt_file(as_of)
    data = directory.maybe_load_file(infections_file)
    assert data is not None, as_of

    for k in data.keys():
        vals = conform(data[k])
        if vals is None:
            continue
        infections_data.add_data(as_of, k, vals)

## Load sensors

In [None]:
def load_sensors(as_of_range, sensor_configs, directory):
    sensors = Sensors()
    for as_of in tqdm(as_of_range):
        for sensor, config in sensor_configs.items():
            sensor_file = directory.sensor_file(config, as_of)
            sensor_data = directory.maybe_load_file(sensor_file)
            if not sensor_data:
                continue
            for k in sensor_data.keys():
                vals = conform(sensor_data[k])
                if vals is None:
                    continue
                sensors.add_data(as_of, sensor, sensor_data[k].geo_value, vals)
    return sensors

full_sensors = {
    'fb_cliic': Config.fb_cliic,
    'dv_cli': Config.dv_cli,
    'google_aa': Config.google_aa,
    'chng_cli': Config.chng_cli,
    'chng_covid': Config.chng_covid,
    'ar3': ar3_config,
}

fast_sensors = {
    'fb_cliic': Config.fb_cliic,
    'google_aa': Config.google_aa,
    'ar3': ar3_config,
}

full_sensor_data = load_sensors(as_of_range, full_sensors, directory)
fast_sensor_data = load_sensors(as_of_range, fast_sensors, directory)

## Simple average

In [None]:
def gen_simple_average(output_config, as_of_range, sensor_dict, sensor_data, directory, overwrite):
    for as_of in tqdm(as_of_range):
        output_file = directory.sensor_file(output_config, as_of)
        last_date = as_of - timedelta(output_config.lag)
        full_dates = [d.date() for d in date_range(
            as_of - timedelta(num_backcast), last_date)]

        if directory.exists(output_file, overwrite):
            print(output_file, 'exists')
            continue

        output = {}
        for geo in evaluation_geos:            
            geo_output = []
            for sensor in sensor_dict.keys():
                signal = sensor_data.get_data(as_of, sensor, geo)
                if signal is None:
                    continue
                try:
                    vals = signal.get_data_range(full_dates[0], full_dates[-1])
                    geo_output.append(vals)
                except Exception as e:
                    print(e, geo, as_of, sensor, 'failed')
                    continue
            output[geo] = LocationSeries(
                geo, get_geo_type(geo), dict(zip(full_dates, np.nanmean(geo_output, axis=0)))
            )
            
        directory.maybe_write_file(output, output_file, overwrite)

In [None]:
gen_simple_average(simple_avg_config, as_of_range, full_sensors, full_sensor_data, directory, False)

In [None]:
gen_simple_average(fast_simple_avg_config, as_of_range, fast_sensors, fast_sensor_data, directory, False)

Run without GOOGLE-AA.

In [None]:
full_sensors_no_google_aa = dict((k,v) for k, v in full_sensors.items() if k != "google_aa")
gen_simple_average(simple_avg_no_google_aa_config, as_of_range, full_sensors_no_google_aa, 
                   full_sensor_data, directory, False)

In [None]:
fast_sensors_no_google_aa = dict((k,v) for k, v in fast_sensors.items() if k != "google_aa")
gen_simple_average(fast_simple_avg_no_google_aa_config, 
                   as_of_range, fast_sensors_no_google_aa,
                   fast_sensor_data, directory, False)

## Simple regression

In [None]:
def gen_simple_regression(output_config, as_of_range, sensor_dict, 
                          infections_data, sensor_data, directory, overwrite):
    def _regression(y, X):
        return np.linalg.inv(X.T @ X) @ X.T @ y
        
    p = len(sensor_dict)
    d = Config.max_delay_days
    col_order = sorted(sensor_dict.keys())
    for as_of in tqdm(as_of_range):
        output_file = directory.sensor_file(output_config, as_of)
        last_infection_date = as_of - timedelta(directory.gt.lag)
        last_date = as_of - timedelta(output_config.lag)
        last_training_date = min(last_infection_date, last_date)
        training_dates = [d.date() for d in date_range(as_of - timedelta(2*d), last_training_date)]
        n_train = len(training_dates)
        n_test = (last_date - last_infection_date).days
        
        if directory.exists(output_file, overwrite):
            print(output_file, 'exists')
            continue

        output = {}
        for geo in evaluation_geos:            
            response_series = infections_data.get_data(as_of, geo)
            response = response_series.get_data_range(training_dates[0], training_dates[-1])
            covariates = np.full((n_train, p), np.nan)
            
            test_covariates, test_dates = None, None
            if n_test >= 0:
                test_covariates = np.full((n_test, p), np.nan)
                test_dates = [d.date() for d in date_range(
                    last_date - timedelta(n_test), last_date)][1:]
            
            for j, col in enumerate(col_order):
                covariate = sensor_data.get_data(as_of, col, geo)
                if covariate is None: continue
                try:
                    covariates[:, j] = covariate.get_data_range(training_dates[0], training_dates[-1])
                    if test_covariates is not None:
                        test_covariates[:, j] = covariate.get_data_range(test_dates[0], test_dates[-1])
                except Exception as e:
                    print(e, geo, as_of, col, 'failed')
                    continue
            
            nan_cols = np.all(np.isnan(covariates), axis=0)
            if test_covariates is not None:
                nan_cols = np.logical_or(nan_cols, np.any(np.isnan(test_covariates), axis=0))
            covariates = np.c_[np.ones(covariates.shape[0]), covariates[:, ~nan_cols]]
            try:
                beta = _regression(response, covariates)
            except np.linalg.LinAlgError as e:
                if str(e) != "Singular matrix":
                    raise
                else:
                    print(geo, as_of, 'Singular matrix')
                    continue
                    
            est = (covariates @ beta).flatten()
            dates = training_dates
            if test_covariates is not None and test_dates is not None:
                est = np.r_[est, 
                            (np.c_[np.ones(test_covariates.shape[0]), 
                                   test_covariates[:, ~nan_cols]] @ beta).flatten()]
                dates = np.r_[dates, test_dates]
    
            est = est[-num_backcast:]
            dates = dates[-num_backcast:]
            output[geo] = LocationSeries(geo, get_geo_type(geo), dict(zip(dates, est)))   
        directory.maybe_write_file(output, output_file, overwrite)

In [None]:
gen_simple_regression(
    simple_reg_config, as_of_range, full_sensors, infections_data, 
    full_sensor_data, directory, False
)

In [None]:
gen_simple_regression(
    fast_simple_reg_config, as_of_range, fast_sensors, infections_data, 
    fast_sensor_data, directory, False
)

## Ridge and Lasso

In [None]:
def _lasso_regression(y, X, lam):
    mod = Lasso(alpha=lam, fit_intercept=False)
    mod.fit(X, np.array(y).reshape(-1,))
    return mod.coef_

def _ridge_regression(y, X, lam):
    return np.linalg.inv(X.T @ X + lam*np.eye(X.shape[1])) @ X.T @ y

def mean_impute(A, col_means=None):
    miss_inds = np.where(np.isnan(A))
    if col_means is not None:
        A[miss_inds] = np.take(col_means, miss_inds[1])
    else:
        A[miss_inds] = np.take(np.nanmean(A, axis=0), miss_inds[1])
    return A

def response_matrix(as_of, dates, input_geos, response_data):
    n, p = len(dates), len(input_geos)
    matrix = np.full((n, p), np.nan)
    for i, date in enumerate(dates):
        for j, geo in enumerate(input_geos):
            matrix[i, j] = response_data.get_val(as_of, date, geo)
    return matrix
    
def covariate_matrix(as_of, dates, input_pairs, covariate_data,
                     return_nan_cols=True):
    n, p = len(dates), len(input_pairs)
    matrix = np.full((n, p), np.nan)
    for i, date in enumerate(dates):
        for j, (covariate, geo) in enumerate(input_pairs):
            matrix[i, j] = covariate_data.get_val(as_of, date, covariate, geo)   
    if return_nan_cols:
        nan_cols = np.all(np.isnan(matrix), axis=0)
        return matrix, nan_cols
    return matrix

def gen_regularized_regression(
    output_config, as_of_range, fit_func, sensor_dict, 
    infections_data, sensor_data,
    cv_grid, cv_folds, directory, overwrite
):
    
    d = Config.max_delay_days
    input_sensors = sorted(sensor_dict.keys())
    for as_of in tqdm(as_of_range[::-1]):
        output_file = directory.sensor_file(output_config, as_of)
        if directory.exists(output_file, overwrite):
            print(output_file, 'exists')
            continue
            
        # find last observed data according to given lags
        last_infection_date = as_of - timedelta(directory.gt.lag)
        last_date = as_of - timedelta(output_config.lag)
        last_training_date = min(last_infection_date, last_date)
        
        # use 2d training window, determine if out-of-sample estimates can be made (n_test)
        training_dates = [d.date() for d in date_range(as_of - timedelta(2*d), last_training_date)]
        n_train = len(training_dates)
        n_test = (last_date - last_infection_date).days
        
        output = {}
        # train one model per state
        for state in states:
            input_locations = state_map[state]
            input_pairs = np.array(list(itertools.product(input_sensors, input_locations)))
            for geo in input_locations:
                
                # skip any geographies we are not interested in (megacounties)
                if geo not in evaluation_geos:
                    continue
                
                # pull training data, and find which inputs are available on this date
                response_series = infections_data.get_data(as_of, geo)
                response = response_series.get_data_range(training_dates[0], training_dates[-1])
                covariates, nan_cols = covariate_matrix(as_of, training_dates, input_pairs, sensor_data)
                test_covariates, test_dates = None, None
                if n_test >= 0:
                    test_dates = [d.date() for d in date_range(
                        last_date - timedelta(n_test), last_date)][1:]
                    test_covariates, test_nan_cols = covariate_matrix(
                        as_of, test_dates, input_pairs, sensor_data
                    )
                    nan_cols = np.logical_or(nan_cols, test_nan_cols)
                covariates = np.c_[np.ones(covariates.shape[0]), covariates[:, ~nan_cols]]
                covariates = mean_impute(covariates)
                
                # perform cross-validation to find tuning parameters
                cv_errors = np.full((len(cv_grid), cv_folds, num_backcast), np.inf)
                for j, fold in enumerate(range(1, cv_folds + 1)):
                    for i, lam in enumerate(cv_grid):
                        n = covariates.shape[0] - fold
                        if n_test >= 0: # produce a 1-ahead prediction
                            n -= 1
                            cv_test_covariates = covariates[n]
                            cv_test_response = response[n]

                        cv_response = response[:n]
                        cv_covariates = covariates[:n]
                        try:
                            beta = fit_func(cv_response, cv_covariates, lam)

                            cv_predictions = (cv_covariates @ beta).flatten()
                            assert len(cv_predictions) == n
                            if n_test >= 0:
                                cv_predictions = np.r_[cv_predictions,
                                                       cv_test_covariates @ beta]
                                cv_response = np.r_[cv_response, cv_test_response]

                            cv_errors[i, j, :] = np.abs(cv_predictions - cv_response)[-num_backcast:]
                        except np.linalg.LinAlgError as e:
                            print(geo, as_of)
                            continue
                    
                # take argmin of cross-validation errors for each lag
                cv_errors = np.mean(cv_errors, axis=1)
                cv_lam = np.full((num_backcast), np.nan)
                for k in range(num_backcast):
                    cv_lam[k] = cv_grid[np.argmin(cv_errors[:, k])]
                unique_lams = list(set(cv_lam))
                
                # perform final fits for each value of the tuning parameter set
                out_values = np.full((num_backcast,), np.nan)
                for lam in unique_lams:
                    beta = fit_func(response, covariates, lam)
                    est = (covariates @ beta).flatten()
                    dates = training_dates
                    if test_covariates is not None and test_dates is not None:
                        est = np.r_[est, 
                                    (np.c_[np.ones(test_covariates.shape[0]), 
                                           test_covariates[:, ~nan_cols]] @ beta).flatten()]
                        dates = np.r_[dates, test_dates]
                        
                    est = est[-num_backcast:]
                    dates = dates[-num_backcast:]
                    for k in range(num_backcast):
                        if cv_lam[k] == lam:
                            out_values[k] = est[k]
                            
                assert np.any(~np.isnan(out_values)), out_values
                output[geo] = LocationSeries(geo, get_geo_type(geo), dict(zip(dates, out_values)))
        directory.maybe_write_file(output, output_file, overwrite)

In [None]:
cv_grid = np.r_[np.logspace(1e-4, 1, 10) - 1, np.logspace(1.05, 2, 5)]
cv_folds = 7
plt.plot(cv_grid, marker=".")
plt.show()

In [None]:
gen_regularized_regression(
    ridge_config, as_of_range, _ridge_regression,
    full_sensors, infections_data,
    full_sensor_data, cv_grid, cv_folds, directory, False
)

In [None]:
gen_regularized_regression(
    fast_ridge_config, as_of_range, _ridge_regression,
    fast_sensors, infections_data,
    fast_sensor_data, cv_grid, cv_folds, directory, False
)

In [None]:
gen_regularized_regression(
    lasso_config, as_of_range, _lasso_regression,
    full_sensors, infections_data,
    full_sensor_data, cv_grid, cv_folds, directory, False
)

In [None]:
gen_regularized_regression(
    fast_lasso_config, as_of_range, _lasso_regression,
    fast_sensors, infections_data,
    fast_sensor_data, cv_grid, cv_folds, directory, False
)

## Kalman filter fusion

In [None]:
# Load population dataframe
pop_df = pickle.load(open("./fusion/top_250_pops", "rb"))
pop_df.set_index("fips", inplace=True)

In [None]:
def generate_statespace(state_id, input_location_types, atom_list, pop_df=pop_df):
    """Specific to a state-only heirarchy."""
    
    # list of all locations: state, county
    all_location_types = [(state_id, 'state')]
    for loc in atom_list:
        all_location_types.append((loc, 'county'))
    
    def get_weight_row(location, location_type, atoms):
        total_population = 0
        atom_populations = []

        if location_type == 'county':
            for atom in atoms:
                if atom == location:
                    population = pop_df.loc[atom].population
                else:
                    population = 0
                total_population += population
                atom_populations.append(population)
                
        elif location_type == 'state':
            for atom in atoms:
                population = pop_df.loc[atom].population
                total_population += population
                atom_populations.append(population)
            
        else:
            raise Exception("get_weight_row: invalid location_type passed")

        # sanity check
        if total_population == 0:
            raise Exception(('location has no constituent atoms', location))

        get_fraction = lambda pop: pop / total_population
        return list(map(get_fraction, atom_populations))

    def get_weight_matrix(location_types, atoms):
        """Construct weight matrix."""
        get_row = lambda loc: get_weight_row(loc[0], loc[1], atoms)
        return np.array(list(map(get_row, location_types)))

    H0 = get_weight_matrix(input_location_types, atom_list)
    W0 = get_weight_matrix(all_location_types, atom_list)
    
    # get H and W from H0 and W0
    H, W, output_idx = fusion.determine_statespace(H0, W0)
    output_locations = [all_location_types[i] for i in output_idx]

    return H, W, output_locations


def response_matrix(as_of, dates, input_geos, response_data):
    n, p = len(dates), len(input_geos)
    matrix = np.full((n, p), np.nan)
    for i, date in enumerate(dates):
        for j, geo in enumerate(input_geos):
            matrix[i, j] = response_data.get_val(as_of, date, geo)
    return matrix
    
def covariate_matrix(as_of, dates, input_pairs, covariate_data,
                     return_nan_cols=True):
    n, p = len(dates), len(input_pairs)
    matrix = np.full((n, p), np.nan)
    for i, date in enumerate(dates):
        for j, (covariate, geo) in enumerate(input_pairs):
            matrix[i, j] = covariate_data.get_val(as_of, date, covariate, geo)   
    if return_nan_cols:
        nan_cols = np.all(np.isnan(matrix), axis=0)
        return matrix, nan_cols
    return matrix

def gen_kf_sf(
    output_config, as_of_range, sensor_dict, 
    infections_data, sensor_data,
    cv_grid, cv_folds, directory, overwrite
):
    
    d = Config.max_delay_days
    input_sensors = sorted(sensor_dict.keys())
    for as_of in tqdm(as_of_range[8:]):
        output_file = directory.sensor_file(output_config, as_of)
        if directory.exists(output_file, overwrite):
            print(output_file, 'exists')
            continue
            
        # find last observed data according to given lags
        last_infection_date = as_of - timedelta(directory.gt.lag)
        last_date = as_of - timedelta(output_config.lag)
        last_training_date = min(last_infection_date, last_date)
        
        # use 2d training window, determine if out-of-sample estimates can be made (n_test)
        training_dates = [d.date() for d in date_range(as_of - timedelta(2*d), last_training_date)]
        n_train = len(training_dates)
        n_test = (last_date - last_infection_date).days

        # train one model per state
        output = {}
        for state in states:
            input_locations = state_map[state]
            atom_list = [l for l in input_locations if l != state]
                
            # find sensors available at time-date pairs
            input_pairs = np.array(list(itertools.product(input_sensors, input_locations)))
            covariates, nan_cols = covariate_matrix(as_of, training_dates, input_pairs, sensor_data)
            test_covariates, test_dates = None, None
            if n_test >= 0:
                test_dates = [d.date() for d in date_range(
                    last_date - timedelta(n_test), last_date)][1:]
                test_covariates, test_nan_cols = covariate_matrix(
                    as_of, test_dates, input_pairs, sensor_data
                )
                nan_cols = np.logical_or(nan_cols, test_nan_cols)

            input_pairs = input_pairs[~nan_cols]
            input_geos = [l for s, l in input_pairs]
            input_response = response_matrix(as_of, training_dates, input_geos, infections_data)
            col_means = np.nanmean(covariates[:, ~nan_cols], axis=0)
            covariates = mean_impute(covariates[:, ~nan_cols])
            p = len(input_pairs)
            
            # determine measurement maps
            H, W, output_locations = generate_statespace(
                state, [(l, get_geo_type(l)) for s, l in input_pairs], atom_list)
            output_geos = [l[0] for l in output_locations]
            output_response = response_matrix(as_of, training_dates, output_geos, infections_data)

            # perform cross-validation to find regularization parameter
            cv_errors = np.full((len(cv_grid), cv_folds, num_backcast), np.inf)
            for j, fold in enumerate(range(1, cv_folds + 1)):
                n = covariates.shape[0] - fold
                cv_as_of = as_of - timedelta(fold)
                cv_training_dates = training_dates[:n]
                
                cv_covariates, cv_nan_cols = covariate_matrix(
                    cv_as_of, cv_training_dates, input_pairs, sensor_data)
                cv_input_response = response_matrix(
                    cv_as_of, cv_training_dates, input_geos, infections_data)
                cv_output_response = response_matrix(
                    as_of, cv_training_dates, output_geos, infections_data)
                
                cv_col_means = col_means
                if not np.any(cv_nan_cols):
                    cv_col_means = np.nanmean(cv_covariates, axis=0)
                cv_covariates = mean_impute(cv_covariates, cv_col_means)
                
                cv_test_covariates, cv_test_output_response = None, None
                if n_test >= 0: # produce a 1-ahead prediction                    
                    cv_test_covariates, _ = covariate_matrix(
                        cv_as_of, [training_dates[n]], input_pairs, sensor_data)
                    cv_test_covariates = mean_impute(cv_test_covariates, cv_col_means)
                    cv_test_output_response = response_matrix(
                        as_of, [training_dates[n]], output_geos, infections_data)
                    cv_output_response = np.r_[
                        cv_output_response, cv_test_output_response]
                    
                cv_noise = cv_input_response - cv_covariates
                a, b = covariance.nancov(cv_noise)
                R = a / b
                for i, lam in enumerate(cv_grid):
                    R_tilde = (1 - lam) * R + lam * np.eye(p)
                    try:
                        RiH = np.dot(np.linalg.inv(R_tilde), H)
                        P = np.linalg.inv(np.dot(H.T, RiH))
                        B = np.dot(P, RiH.T)
                        cv_predictions = np.dot(W, np.dot(B, cv_covariates.T)).T
                        assert cv_predictions.shape[0] == n
                
                        if n_test >= 0:
                            cv_predictions = np.r_[
                                cv_predictions,
                                np.dot(W, np.dot(B, cv_test_covariates.T)).T.reshape(1, -1)]
                        
                        cv_errors[i, j, :] = np.mean(np.abs(cv_predictions - cv_output_response), axis=1)[-num_backcast:]
                    except np.linalg.LinAlgError as e:
                        print(geo, as_of, lam, e)
                        print(np.sum(np.isnan(cv_covariates)))
                        print(np.sum(np.isnan(cv_input_response)))
                        print(np.sum(np.isnan(cv_noise)))
                        print(R)
                        print(R_tilde)
                        continue

            # take argmin of cross-validation errors for each lag
            cv_errors = np.mean(cv_errors, axis=1)
            cv_lam = np.full((num_backcast), np.nan)
            for k in range(num_backcast):
                cv_lam[k] = cv_grid[np.argmin(cv_errors[:, k])]
            unique_lams = list(set(cv_lam))
           
            # perform final fits for each value of the tuning parameter set
            out_values = np.full((len(output_locations), num_backcast), np.nan)
            noise = input_response - covariates
            a, b = covariance.nancov(noise)
            R = a / b
            for lam in unique_lams:
                R_tilde = (1 - lam) * R + lam * np.eye(p)
                RiH = np.linalg.inv(R_tilde) @ H
                P = np.linalg.inv(H.T @ RiH)
                B = P @ RiH.T
                est = np.dot(W, B @ covariates.T)
                dates = training_dates
                if test_covariates is not None and test_dates is not None:
                    est = np.c_[est, 
                        np.dot(W, np.dot(B, test_covariates[:, ~nan_cols].T))
                    ]
                    dates = np.r_[dates, test_dates]
                
                assert est.shape[1] == len(dates)
                est = est[:, -num_backcast:]
                dates = dates[-num_backcast:]
                for k in range(num_backcast):
                    if cv_lam[k] == lam:
                        out_values[:, k] = est[:, k]
            assert np.any(~np.isnan(out_values)), out_values
                
            # store output in dictionary
            for i, (geo, geotype) in enumerate(output_locations):
                # skip any geographies we are not interested in (megacounties)
                if geo not in evaluation_geos:
                    continue
                output[geo] = LocationSeries(geo, geotype, dict(zip(dates, out_values[i, :])))
        directory.maybe_write_file(output, output_file, overwrite)

In [None]:
kf_sf_cv_grid = np.r_[1/(1+np.exp(-np.linspace(-10, 10, 30))), 1]
cv_folds = 7

plt.plot(kf_sf_cv_grid, marker=".")

In [None]:
gen_kf_sf(
    kf_sf_config, as_of_range,
    full_sensors, infections_data,
    full_sensor_data, kf_sf_cv_grid, cv_folds, directory, False
)

In [None]:
gen_kf_sf(
    fast_kf_sf_config, as_of_range,
    fast_sensors, infections_data,
    fast_sensor_data, kf_sf_cv_grid, cv_folds, directory, False
)