In [3]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
from matplotlib.ticker import NullFormatter
%matplotlib inline
import seaborn as sns
sns.set(style="ticks")
from mpl_toolkits.mplot3d import Axes3D
import geopandas
import typing
from mpl_toolkits.axes_grid1 import make_axes_locatable

import cvxpy as cp
import pandas as pd
import scipy as sp
import covidcast
import datetime
import sklearn
import sklearn.preprocessing
import copy
import pynndescent
from functools import reduce
import os
import pickle
import pymde
from sklearn.linear_model import QuantileRegressor
from sklearn.metrics import pairwise_distances
from enum import Enum
import statsmodels
from statsmodels.graphics import gofplots
import csv

import sys
sys.path.insert(1, f"{os.getcwd()}/backend")
from np_backend.dro_conformal import *
import numpy as np
np.random.seed(1)

In [5]:
print(f"numpy: {np.__version__}")
print(f"matplotlib: {matplotlib.__version__}")
print(f"seaborn: {sns.__version__}")
print(f"geopandas: {geopandas.__version__}")
print(f"cvxpy: {cp.__version__}")
print(f"pandas: {pd.__version__}")
print(f"scipy: {sp.__version__}")
print(f"sklearn: {sklearn.__version__}")
print(f"pynndescent: {pynndescent.__version__}")
print(f"pymde: {pymde.__version__}")
print(f"statsmodels: {statsmodels.__version__}")

numpy: 1.23.5
matplotlib: 3.7.0
seaborn: 0.12.2
geopandas: 0.14.0
cvxpy: 1.4.0
pandas: 1.5.3
scipy: 1.10.0
sklearn: 1.2.1
pynndescent: 0.5.10
pymde: 0.1.18
statsmodels: 0.13.5


# Settings

In [None]:
B_func = lambda n: n // 4

lam_base = 2
lam_exps = np.arange(-10, 10, dtype=float)
lams = lam_base**lam_exps
lams_orig = np.copy(lams)

In [None]:
EnsembleType = Enum("EnsembleType", 
                    ["Bagged", "Stacked", "Multitask", "PureLocal"])

ensemble_type = EnsembleType.Bagged
# ensemble_type = EnsembleType.Stacked
# ensemble_type = EnsembleType.Multitask
# ensemble_type = EnsembleType.PureLocal

if ensemble_type != EnsembleType.Multitask:
    lams = None
else:
    lams = np.copy(lams_orig)

In [None]:
want_robust_intervals = True
alpha = 0.1
kl = lambda z : -cp.entr(z)
adjust_alpha = lambda my_alpha, my_n_val: np.maximum(1. - (1. + 1./my_n_val)*(1. - my_alpha), 0.)

In [10]:
ff= "../data/jasa_10_07_2023_data/"

In [15]:
start_date = datetime.date(2021,1,21)
end_date = datetime.date(2021,9,1)

In [None]:
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 100)
pd.options.display.max_info_columns = 100
pd.options.display.max_seq_items = 100

datetime_str_f = "%m-%d-%Y"

# Load data

In [11]:
df = pd.read_pickle(ff + "df_" + str("2021-10-05") + ".pkl")

In [None]:
y_col = "indicator-combination_confirmed_7dav_incidence_prop_0_value_nyt"

In [13]:
with open(ff+'county_data_col_names.csv', 'r') as fp:
    csv_reader = csv.reader(fp, delimiter=",")
    for line in csv_reader:
        X_counties_col_names = line
        break

In [None]:
X_cols_prefix = [y_col] + line

In [16]:
num_wks = ((end_date - start_date).days // 7)
t = start_date
one_wk = datetime.timedelta(weeks=1)

Xy_tot = pd.read_pickle(ff + "Xy_tot_" + str("2021-10-05") + ".pkl")

# Define helper functions

In [None]:
def apply_logit(y, a=1e-2):
    y_out = np.copy(y)
    for idx in range(y_out.shape[0]):
        x = y_out[idx]
        y_out[idx] = np.log((x+a)/(1-x+a))
    return y_out

def apply_inverse_logit(y, a=1e-2):
    y_out = np.copy(y)
    for idx in range(y_out.shape[0]):
        x = y_out[idx]
        # y_out[idx] = 1/(1+np.exp(-x)) - a
        y_out[idx] = (np.exp(x) * (1 + a) - a) / (1 + np.exp(x))
    return y_out

def compute_scores(yhat, y, yprev=None,
                   want_apply_inverse_logit=True, a=0): # a=1e-2
    num = np.abs(yhat - y)        
    # den = np.abs(yprev - y)
    den = 1 if yprev is None else np.maximum(np.abs(yprev - y), a)
    return num/den

def pick_out_complement(v, all_idxes, some_idxes):
    some_idxes_complement = [idx for idx in all_idxes if idx not in some_idxes]
    
    if len(v.shape) == 1:
        return v[some_idxes_complement]
    else:
        return v[some_idxes_complement, :]
    
add_suffix_to_X_cols_prefix = lambda suffix : [X_col_prefix + suffix for X_col_prefix in X_cols_prefix]

def get_train_val_test_data(t, t_idx, one_wk, num_wks, datetime_str_f,
                            Xy_tot, X_cols_prefix, y_col):
    
    t_str = t.strftime(datetime_str_f)
    
    tplus1_str = (t+one_wk).strftime(datetime_str_f)
    tplus2_str = (t+2*one_wk).strftime(datetime_str_f)
    tplus3_str = (t+3*one_wk).strftime(datetime_str_f)

    print("Processing (t_idx = " + str(t_idx) + ") ...")
    print("\tTrain set range: [" + t_str + ", " + tplus1_str + ")")
    print("\tVal. set range: [" + tplus1_str + ", " + tplus2_str + ")")
    print("\tTest set range: [" + tplus2_str + ", " + tplus3_str + ")")
    print()

    if t_idx == 0:
        X_train = Xy_tot[X_cols_prefix].to_numpy()
        X_train[:,0] = apply_logit(X_train[:,0])
    else:
        X_train = Xy_tot[add_suffix_to_X_cols_prefix("_" + t_str)].to_numpy()
        X_train[:,0] = apply_logit(X_train[:,0])
    y_train = Xy_tot[y_col + "_" + tplus1_str].to_numpy()

    X_val = Xy_tot[add_suffix_to_X_cols_prefix("_" + tplus1_str)].to_numpy()
    X_val[:,0] = apply_logit(X_val[:,0])
    y_val = Xy_tot[y_col + "_" + tplus2_str].to_numpy()
    n_val = y_val.shape[0]

    X_test = Xy_tot[add_suffix_to_X_cols_prefix("_" + tplus2_str)].to_numpy()
    X_test[:,0] = apply_logit(X_test[:,0])
    y_test = Xy_tot[y_col + "_" + tplus3_str].to_numpy()
    n_test = y_test.shape[0]

    y_train_h = apply_logit(y_train)
    y_val_h = apply_logit(y_val)
    y_test_h = apply_logit(y_test)

    return X_train, y_train, y_train_h, \
            X_val, y_val, y_val_h, n_val, \
            X_test, y_test, y_test_h, n_test

def compute_len_cvg(hi, lo, y_test, n_test):
    my_len = np.mean(hi-lo)
    my_cvg = np.array([1. if ((lo[idx] <= y_test[idx]) & (hi[idx] >= y_test[idx])) else 0. 
                       for idx in range(n_test)])
    return my_len, my_cvg

def gaussian_kernel_matrix(X,Y, sigma):
    if X.ndim == 1:
        X = X[:,None]
    if Y.ndim == 1:
        Y = Y[:,None]
    pw_dist = - 2.0 * X.dot(
        Y.T) + (X**2).sum(
        axis=1, keepdims=True
    ) + (Y**2).sum(
        axis=1, keepdims=True
    ).T
    return np.exp(- pw_dist / sigma)

def kl_M1_estimator(K_YY, K_YX, lambda_n):
    n, p= K_YX.shape
    alpha = cp.Variable(p)

    kl_obj = cp.Minimize(
        1.0 / (2 * lambda_n) * (
            cp.quad_form(alpha, K_YY) - 2 * cp.sum(cp.multiply(alpha,K_YX.dot(np.ones(n)/n)))
        ) - 1.0 / p * cp.sum(cp.log(p * alpha))
    )
    kl_prob = cp.Problem(kl_obj, constraints=[alpha>=0])
    kl_prob.solve(solver=cp.SCS)

    return -1.0 / p * np.sum(np.log(p*alpha.value)), alpha.value

def kl_bregman_div(x,y):
    return x*np.log(x/y) - x + y

def kl_M2_estimator(K_XX, K_XY, lambda_n):
    n, p= K_XY.shape
    alpha = cp.Variable(n)

    kl_obj = cp.Minimize(
        1.0 / (2 * lambda_n) * (
            cp.quad_form(alpha, K_XX) - 2 * cp.sum(cp.multiply(alpha,K_XY.dot(np.ones(p)/p)))
        ) + cp.sum(cp.kl_div(alpha, np.ones(n)/n))
    )
    kl_prob = cp.Problem(kl_obj, constraints=[alpha>=0])
    kl_prob.solve(solver=cp.SCS)

    return kl_bregman_div(alpha.value, np.ones(n)/n).sum(), alpha.value

def run_wainwright_to_est_shift(X1, X2, lambda_n, gaus_sigma=1.0):
    K_XX = gaussian_kernel_matrix(X1,X1, sigma=gaus_sigma)
    K_YY = gaussian_kernel_matrix(X2,X2, sigma=gaus_sigma)
    K_XY = gaussian_kernel_matrix(X1,X2, sigma=gaus_sigma)

    kl1, alpha_1 = kl_M1_estimator(K_YY,K_XY.T, lambda_n)
    kl2, alpha_2 = kl_M2_estimator(K_XX, K_XY, lambda_n)

    return kl1, kl2

def fit_and_eval_ensemble_model_on_test_data(X_train, y_train_h, hard_idxes,
                                             X_val, y_val_h,
                                             X_test, y_test_h,
                                             datetime_str_f, t_idx_start, start_date, one_wk, num_wks, Xy_tot, y_col,
                                             ensemble_type, lams=None):
    if hard_idxes:
        if ensemble_type == EnsembleType.Bagged:
            mixed_model_obj = BaggedOrStackedModel()
        elif ensemble_type == EnsembleType.Stacked:
            mixed_model_obj = BaggedOrStackedModel()
            mixed_model_obj.ensemble_type = EnsembleType.Stacked
        elif ensemble_type == EnsembleType.Multitask:
            mixed_model_obj = MultitaskModel(regularization=lams)
            mixed_model_obj.ensemble_type = EnsembleType.Multitask
        elif ensemble_type == EnsembleType.PureLocal:
            mixed_model_obj = PureLocalStrategy()
            mixed_model_obj.X_cols_idxes_for_dist_calc = [X_cols_prefix.index("pclat10"),
                                                          X_cols_prefix.index("pclon10")]
            mixed_model_obj.hard_idxes = hard_idxes
            
        mixed_model_obj.fit(y_train_h, X_train, idxes=hard_idxes, y_val_h=y_val_h, X_val=X_val)
        mixture_weights = mixed_model_obj.coeffs_hat

        yhat_test = mixed_model_obj.predict(X_test)
        AE = compute_scores(apply_inverse_logit(yhat_test), apply_inverse_logit(y_test_h),
                            apply_inverse_logit(y_val_h), True)
        
    else:
        model_obj = Model()
        model_obj.fit(y_train_h, X_train)
        mixture_weights = None
        
        yhat_test = model_obj.predict(X_test)
        AE = compute_scores(apply_inverse_logit(yhat_test), apply_inverse_logit(y_test_h),
                            apply_inverse_logit(y_val_h), True)

    AEs = [None]*num_wks
    
    return None, None, \
           AE, AEs, \
           mixture_weights

def compute_robust_pred_ints_stats(hist, alg_idx, num_wks):
    my_cvg = np.mean(np.concatenate([hist[t_idx].cvgs[alg_idx,:] for t_idx in range(num_wks-3)]))
    my_len = np.mean(np.concatenate([hist[t_idx].lens[alg_idx,:] for t_idx in range(num_wks-3)]))
    
    my_his = [hist[t_idx].his[alg_idx,:] for t_idx in range(num_wks-3)]
    my_los = [np.maximum(hist[t_idx].los[alg_idx,:], 0) for t_idx in range(num_wks-3)]
    
    true = [hist[t_idx].true[alg_idx,:] for t_idx in range(num_wks-3)]
    
    return my_cvg, my_len, my_his, my_los, true

def compute_robust_pred_ints(scores_val, alpha, n_val, yhat_test, y_train_h, y_val_h, y_test_h):
    y_train_h_sig = apply_inverse_logit(y_train_h)
    n_train = y_train_h_sig.shape[0]
    
    y_val_h_sig = apply_inverse_logit(y_val_h)
    n_val = y_val_h_sig.shape[0]
    
    yhat_test_sig = apply_inverse_logit(yhat_test)
    y_test_h_sig = apply_inverse_logit(y_test_h)
    n_test = y_test_h_sig.shape[0]
    
    his = np.inf*np.ones((5, n_test))
    los = np.inf*np.ones((5, n_test))
    true = np.inf*np.ones((5, n_test))
    cvgs = np.inf*np.ones((5, n_test))
    lens = np.inf*np.ones((5, n_test))
    preds = np.inf*np.ones((5, n_test))
    
    # Run standard conformal:
    q_std = np.quantile(scores_val, 1. - adjust_alpha(alpha, n_val))
    hi_std = yhat_test_sig + q_std
    lo_std = yhat_test_sig - q_std
    my_len_std, my_cvg_std = compute_len_cvg(hi_std, lo_std, y_test_h_sig, n_test)
    
    conformal_alg_idx = 0
    his[conformal_alg_idx, :] = hi_std
    los[conformal_alg_idx, :] = lo_std
    true[conformal_alg_idx, :] = y_test_h_sig
    cvgs[conformal_alg_idx, :] = my_cvg_std
    lens[conformal_alg_idx, :] = my_len_std
    preds[conformal_alg_idx, :] = yhat_test_sig
    
    # Just cut alpha by half:
    q_half = np.quantile(scores_val, 1. - adjust_alpha(alpha/2., n_val))
    hi_half = yhat_test_sig + q_half
    lo_half = yhat_test_sig - q_half
    my_len_half, my_cvg_half = compute_len_cvg(hi_half, lo_half, y_test_h_sig, n_test)
    
    conformal_alg_idx = 1
    his[conformal_alg_idx, :] = hi_half
    los[conformal_alg_idx, :] = lo_half
    true[conformal_alg_idx, :] = y_test_h_sig
    cvgs[conformal_alg_idx, :] = my_cvg_half
    lens[conformal_alg_idx, :] = my_len_half
    preds[conformal_alg_idx, :] = yhat_test_sig
    
    # Run Wainwright's stuff w/ k11:
    rho_k11, rho_k12 = run_wainwright_to_est_shift(y_train_h_sig.reshape(-1,1),
                                                   y_val_h_sig.reshape(-1,1),
                                                   lambda_n=1.0/np.minimum(n_train, n_val))
    q_wain_k11, _ = dro_conformal_quantile_procedure_cvx(scores_val, kl, adjust_alpha(
        alpha, n_val), rho_k11, want_bisection=True, verbose=False, solver=cp.SCS)
    hi_wain_k11 = yhat_test_sig + q_wain_k11
    lo_wain_k11 = yhat_test_sig - q_wain_k11
    my_len_wain_k11, my_cvg_wain_k11 = compute_len_cvg(hi_wain_k11, lo_wain_k11, y_test_h_sig, n_test)
    
    conformal_alg_idx = 2
    his[conformal_alg_idx, :] = hi_wain_k11
    los[conformal_alg_idx, :] = lo_wain_k11
    true[conformal_alg_idx, :] = y_test_h_sig
    cvgs[conformal_alg_idx, :] = my_cvg_wain_k11
    lens[conformal_alg_idx, :] = my_len_wain_k11
    preds[conformal_alg_idx, :] = yhat_test_sig
    
    # Run Wainwright's stuff w/ k12:
    q_wain_k12, _ = dro_conformal_quantile_procedure_cvx(scores_val, kl, adjust_alpha(
        alpha, n_val), rho_k12, want_bisection=True, verbose=False, solver=cp.SCS)
    hi_wain_k12 = yhat_test_sig + q_wain_k12
    lo_wain_k12 = yhat_test_sig - q_wain_k12
    my_len_wain_k12, my_cvg_wain_k12 = compute_len_cvg(hi_wain_k12, lo_wain_k12, y_test_h_sig, n_test)
    
    conformal_alg_idx = 3
    his[conformal_alg_idx, :] = hi_wain_k12
    los[conformal_alg_idx, :] = lo_wain_k12
    true[conformal_alg_idx, :] = y_test_h_sig
    cvgs[conformal_alg_idx, :] = my_cvg_wain_k12
    lens[conformal_alg_idx, :] = my_len_wain_k12
    preds[conformal_alg_idx, :] = yhat_test_sig
    
#     # Use Alg. 2 from the robust_cv paper:
#     q_alg2 = learnable_direction_quantile(y_val_h.reshape(-1,1), scores_val, np.arange(y_val_h.shape[0]))
#     hi_alg2 = yhat_test_sig + q_alg2
#     lo_alg2 = yhat_test_sig - q_alg2
#     my_len_alg2, my_cvg_alg2 = compute_len_cvg(hi_alg2, lo_alg2, y_test_h_sig, n_test)
    
#     conformal_alg_idx = 4
#     his[conformal_alg_idx, :] = hi_alg2
#     los[conformal_alg_idx, :] = lo_alg2
#     true[conformal_alg_idx, :] = y_test_h_sig
#     cvgs[conformal_alg_idx, :] = my_cvg_alg2
#     lens[conformal_alg_idx, :] = my_len_alg2
#     preds[conformal_alg_idx, :] = yhat_test_sig
    
    return his, los, true, cvgs, lens, preds

    
def make_df(idxes, intensities, fips,
            wk_idx, start_date, one_wk):
    # intensities = apply_inverse_logit(intensities)
    selected = np.copy(intensities)
    if idxes is not None:
        for idx in range(len(selected)):
            if idx not in idxes:
                selected[idx] = np.nan # np.nan # 0 # np.nan
    
    num_rows = fips.shape[0]
    time_value = start_date + wk_idx*one_wk
    geo_type = "county"
    data_source = "indicator-combination"
    signal = "confirmed_7dav_incidence_prop"
    null = np.nan*np.ones(num_rows)
    
    df_columns = ["geo_value", "time_value",
                  "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size",
                  "value",
                  "stderr", "sample_size",
                  "geo_type", "data_source", "signal"]

    df_data = {df_columns[0]:fips, df_columns[1]:time_value,
               df_columns[2]:null, df_columns[3]:null, df_columns[4]:null, df_columns[5]:null, df_columns[6]:null,
               df_columns[7]:selected,
               df_columns[8]:null, df_columns[9]:null,
               df_columns[10]:geo_type, df_columns[11]:data_source, df_columns[12]:signal}    

    df = pd.DataFrame(df_data, columns=df_columns)

    for idx in df["geo_value"].index:
        cur_fips = int(df.loc[idx, "geo_value"])
        cur_fips_str = str(cur_fips).zfill(5)
        df.loc[idx, "geo_value"] = cur_fips_str
            
    return df, df_data, df_columns

def plot(data, time_value=None, combine_megacounties=True, plot_type="choropleth",
         vmax_mean=1, vmax_std=0, cbar=True, for_res_stmt=False, my_title=None, **kwargs):
    
    data_source, signal, geo_type = covidcast.plotting._detect_metadata(data)  # pylint: disable=W0212
    meta = covidcast.plotting._signal_metadata(data_source, signal, geo_type)  # pylint: disable=W0212
    # use most recent date in data if none provided
    day_to_plot = time_value if time_value else max(data.time_value)
    day_data = data.loc[data.time_value == pd.to_datetime(day_to_plot), :]
    
    kwargs["vmax"] = kwargs.get("vmax", vmax_mean + 3*vmax_std)
    
    kwargs["figsize"] = kwargs.get("figsize", (12, 6))
    
    fig, ax = covidcast.plotting._plot_background_states(kwargs["figsize"])
    if plot_type == "choropleth":
        if for_res_stmt:
            ax.annotate("Michigan", xy=(0.59, 0.68), xytext=(0.65, 0.94), xycoords='figure fraction', textcoords='figure fraction',
                        fontsize=res_stmt_font_size, arrowprops=dict(facecolor='black', width=1.))
        _plot_choro(ax, day_data, combine_megacounties, "vertical", cbar, for_res_stmt, **kwargs)
        
    return fig, ax

def _plot_choro(ax: matplotlib.axes.Axes,
                data: geopandas.gpd.GeoDataFrame,
                combine_megacounties: bool,
                orientation: bool,
                cbar: bool,
                for_res_stmt: bool,
                **kwargs: typing.Any) -> None:
    """Generate a choropleth map on a given Figure/Axes from a GeoDataFrame.
    :param ax: Matplotlib axes to plot on.
    :param data: GeoDataFrame with information to plot.
    :param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``.
    :return: Matplotlib axes with the plot added.
    """
    kwargs["vmin"] = kwargs.get("vmin", 0)
    kwargs["cmap"] = kwargs.get("cmap", "YlOrRd")
    data_w_geo = covidcast.plotting.get_geo_df(data, combine_megacounties=combine_megacounties)
    for shape in covidcast.plotting._project_and_transform(data_w_geo):
        if not shape.empty:
            shape.plot(column="value", ax=ax, **kwargs)
    sm = plt.cm.ScalarMappable(cmap=kwargs["cmap"],
                               norm=plt.Normalize(vmin=kwargs["vmin"], vmax=kwargs["vmax"]))
    # this is to remove the set_array error that occurs on some platforms
    sm._A = []  # pylint: disable=W0212

    if cbar:
        divider = make_axes_locatable(ax)
        my_size = "5.%" if for_res_stmt else "3.%" 
        cax = divider.append_axes("right", size=my_size, pad=0.1)
        my_cbar = plt.colorbar(sm, ticks=np.linspace(kwargs["vmin"], kwargs["vmax"], 8), ax=ax,
                     orientation=orientation, anchor=(0.5, 1.8), pad=0.1, format="%.2f",
                     cax=cax)
        if for_res_stmt:
            my_cbar.ax.tick_params(labelsize=res_stmt_font_size)
        
def make_boxplot(cvgs, lens, algs, alpha, mode, ff, fontsize):    
    if mode == "Coverage":
        obj_to_plot = np.vstack(cvgs).T
    else:
        obj_to_plot = np.vstack(lens).T
        
    fig, ax = plt.subplots(figsize=(12,5))
    ax.boxplot(obj_to_plot,
               labels=algs,
               showmeans=True,
               showfliers=True)
    ax.tick_params(axis='both', labelsize=fontsize)
    ax.tick_params(axis="x") # , rotation=45
    ax.set_title(mode, fontsize=fontsize+2)
    
    if mode == "Coverage":
        ax.axhline((1. - alpha), c='r', linestyle="-", linewidth=1)
    fig.savefig(ff + "covid_rob_pred_ints_boxplot_{}_{}.pdf".format(mode, datetime.date.today()),
                bbox_inches="tight")

In [None]:
class Alg:
    def __init__(self):
        self.sel_idxes = None
        self.all_sel_idxes = None

        self.y_train = None
        self.y_val = None
        self.y_test = None

        self.AE = None
        self.AEs = None
        self.MAE = None
        
        self.props = {}

    def identify_weirdest_points(self, X, weirdness, budget,
                                 remove_idxes=None, round_idx=0, delta=0.01, knns=None):
        raise Exception("Not implemented.")

    def identify_groups_of_weirdest_points(self, X, weirdness, budget,
                                           just_identify_the_single_weirdest_group=False, delta=0.01, knns=None):
        n = X.shape[0]

        sel_idxes = self.identify_weirdest_points(X, weirdness, budget,
                                                  delta=delta, knns=knns)
        
        all_sel_idxes = [copy.deepcopy(sel_idxes)]
        if just_identify_the_single_weirdest_group == False:
            remove_idxes = copy.deepcopy(sel_idxes)
            num_weird_rounds = n//budget
            for weird_round_idx in range(1, num_weird_rounds):
                sel_idxes_cur = self.identify_weirdest_points(X, weirdness, budget,
                                                              remove_idxes=remove_idxes, round_idx=weird_round_idx,
                                                              delta=delta, knns=knns)
                all_sel_idxes += [copy.deepcopy(sel_idxes_cur)]
                remove_idxes += copy.deepcopy(sel_idxes_cur)

            if num_weird_rounds*budget < n:
                all_sel_idxes += [[idx for idx in range(n) if idx not in remove_idxes]]
        return all_sel_idxes
    
class Balls(Alg):
    def __init__(self, use_penalty=True):
        super().__init__()
        self.use_penalty = use_penalty

    def identify_weirdest_points(self, X, weirdness, budget, remove_idxes=None, round_idx=0,
                                 delta=0.01, knns=None):
        n = X.shape[0]
        p = X.shape[1]
        X = np.copy(X)

        if remove_idxes is not None:
            keep_idxes = list(set(range(n)) - set(remove_idxes))
            X = X[keep_idxes, :]
            n = X.shape[0]

        if knns is None:
            print("\tEnumerating balls (using pynndescent) ...")
            index_obj = pynndescent.NNDescent(X, n_neighbors=budget)
            knns = index_obj.neighbor_graph[0]

        my_std = np.std(weirdness)
        avg_weirdness = np.inf*np.ones((n, budget))
        for cur_budget_idx, cur_budget in enumerate(range(budget)):
            rewards = np.sum(norm_ranks[knns][:,0:cur_budget+1], axis=1)*(1./np.sqrt(cur_budget+1))
            pen = my_std * np.sqrt(p * np.log(n/(cur_budget+1.) + 1) + np.log(1./delta)) if self.use_penalty \
                else 0.

            avg_weirdness[:, cur_budget_idx] = rewards - pen
        cstar, rstar = np.unravel_index(np.argmax(avg_weirdness, axis=None), avg_weirdness.shape)
        sel_idxes = knns[cstar,0:rstar+1].tolist()

        if remove_idxes is not None:
            sel_idxes = [keep_idxes[sel_idx] for sel_idx in sel_idxes]
        return sel_idxes
    
class Naive(Alg):
    def __init__(self):
        super().__init__()

    def identify_weirdest_points(self, X, weirdness, budget, remove_idxes=None, round_idx=0,
                                 delta=None, knns=None):
        if round_idx == 0:       
            sel_idxes = np.argsort(weirdness)[-budget:]
        else:
            sel_idxes = np.argsort(weirdness)[-(round_idx+1)*budget:-(round_idx)*budget]
        return list(sel_idxes)
    
class History:
    def __init__(self):
        self.fips = None
        
        self.date = None
        
        self.test_hardness = None
        
        self.balls = Balls()
        self.balls_no_penalty = Balls(use_penalty=False)
        
        self.naive = Naive()
        self.random = Alg()
        self.actual = Alg()
        self.all = Alg()
        self.half = Alg()

        self.lens = None
        self.cvgs = None
        self.his = None
        self.los = None
        self.true = None
        self.preds = None
        
class Model:
    def __init__(self):
        self.coeffs_hat = None
        self.training_MAE = None
        self.props = {}
    
    def fit(self, y_h, X, idxes=None, y_val_h=None, X_val=None):
        qr_obj = QuantileRegressor(quantile=0.5, solver="highs-ds", alpha=0)
        if idxes is None:
            qr_obj.fit(X, y_h)            
        else:
            qr_obj.fit(X[idxes,:], y_h[idxes])
        self.coeffs_hat = np.hstack((qr_obj.coef_, qr_obj.intercept_))
                    
    def predict(self, X, want_props=True):
        yhat = X @ self.coeffs_hat[:-1] + self.coeffs_hat[-1]
        return yhat
    
class BaggedOrStackedModel(Model):
    def __init__(self, identifier=None, regularization=None):
        super().__init__()
        self.ensemble_type = EnsembleType.Bagged
        self.model_objs = None
        self.coeffs_hat = None
    
    def fit(self, y_h, X, idxes=None, y_val_h=None, X_val=None):
        model_preds_val = np.inf*np.ones((X_val.shape[0], len(idxes))) \
            if self.ensemble_type == EnsembleType.Stacked else None
        
        self.model_objs = []
        for hard_region_idx in range(len(idxes)):
            print("\tFitting model to hard region " + str(hard_region_idx) + " ...")    
            model_obj = Model()
            if self.ensemble_type == EnsembleType.Stacked:
                model_obj.fit(y_h, X, idxes=idxes[hard_region_idx])
                model_preds_val[:, hard_region_idx] = model_obj.predict(X_val)
            else:
                model_obj.fit(np.hstack((y_h, y_val_h)), np.vstack((X, X_val)),
                              idxes=idxes[hard_region_idx])
            self.model_objs += [model_obj]
                
        print("\tFitting ensembled model to validation data ...") 
        if self.ensemble_type == EnsembleType.Stacked:
            w = cp.Variable(model_preds_val.shape[1])
            objf = cp.sum(cp.abs(y_val_h \
                                         - model_preds_val @ w))
            prob = cp.Problem(cp.Minimize(objf), [w >= 0, cp.sum(w) == 1])
            prob.solve(solver=cp.SCS) # cp.SCS
            self.coeffs_hat = w.value            
        else:
            self.coeffs_hat = (1/len(idxes)) * np.ones(len(idxes))
        
    def predict(self, X, want_props=True):
        yhat = 0
        for idx in range(len(self.model_objs)):
            yhat_idx = X @ self.model_objs[idx].coeffs_hat[:-1] + self.model_objs[idx].coeffs_hat[-1]
            yhat += self.coeffs_hat[idx] * yhat_idx
        return yhat        
    
class MultitaskModel(Model):
    def __init__(self, identifier=None, regularization=None):
        super().__init__()
        self.ensemble_type = EnsembleType.Multitask
        self.model_objs = None
        self.coeffs_hat = None
        self.lams = regularization
    
    def fit(self, y_h, X, idxes=None, y_val_h=None, X_val=None):
        self.model_objs = []
        for hard_region_idx in range(len(idxes)):        
            print("\tFitting 'child' model to hard region " + str(hard_region_idx) + " ...")
            model_obj = Model()
            model_obj.fit(y_h, X, idxes=idxes[hard_region_idx])
            self.model_objs += [model_obj]
    
        print("\tFitting 'parent' model to hard region (and tuning regularization strength) ...")
        coeffs = cp.Variable(X.shape[1]+1) # +1 for bias
        X_pad = np.hstack([X, np.ones((X.shape[0], 1))])
        objf = cp.sum(cp.abs(y_h - X_pad @ coeffs))

        reg = 0
        for model_obj in self.model_objs:
            reg += cp.norm(coeffs - model_obj.coeffs_hat)
        lam_param = cp.Parameter(nonneg=True)
        prob = cp.Problem(cp.Minimize(objf + lam_param*reg))

        coeffs_hats = [None]*len(self.lams)
        errs = np.inf*np.ones(len(self.lams))
        for lam_idx, lam in enumerate(lams):
            lam_param.value = lam
            prob.solve(solver=cp.SCS, warm_start=True)
            self.coeffs_hat = np.copy(coeffs.value)
            coeffs_hats[lam_idx] = np.copy(self.coeffs_hat)
            
            yhat_val = self.predict(X_val)
            errs[lam_idx] = np.median(compute_scores(apply_inverse_logit(yhat_val), apply_inverse_logit(y_val_h),
                                                     apply_inverse_logit(y_h), True))
        best_lam_idx = np.argmin(errs)
        print("\tPicked lam idx " + str(best_lam_idx) + " ...")
        self.coeffs_hat = coeffs_hats[best_lam_idx]
        
class PureLocalStrategy(Model):
    def __init__(self, identifier=None, regularization=None):
        super().__init__()
        self.ensemble_type = EnsembleType.PureLocal
        self.model_objs = None
        self.coeffs_hat = None
        self.hard_idxes = None
        self.X_val = None
        self.X_cols_idxes_for_dist_calc = None
        
    def fit(self, y_h, X, idxes=None, y_val_h=None, X_val=None):
        self.model_objs = []
        for hard_region_idx in range(len(idxes)):        
            model_obj = Model()
            model_obj.fit(y_h, X, idxes=idxes[hard_region_idx])
            self.model_objs += [model_obj]
        self.X_val = X_val
        
    def predict(self, X, want_props=True):
        yhat = np.nan*np.ones(X.shape[0])
        num_local_models = len(self.model_objs)
        test_pt_2_cluster_dists = np.nan*np.ones((X.shape[0], num_local_models))
        for model_obj_idx, model_obj in enumerate(self.model_objs):
            model_obj_hard_idxes = self.hard_idxes[model_obj_idx]
            test_pt_2_cluster_dists[:, model_obj_idx] = np.min(
                pairwise_distances(X[:, self.X_cols_idxes_for_dist_calc],
                                   self.X_val[:, self.X_cols_idxes_for_dist_calc])[:,model_obj_hard_idxes],
                axis=1)
        closest_model_obj_idxes = np.argmin(test_pt_2_cluster_dists, axis=1)
        
        for row_idx in range(X.shape[0]):
            x = X[row_idx, :]
            yhat[row_idx] = self.model_objs[closest_model_obj_idxes[row_idx]].predict(x)        
        return yhat        

In [None]:
hist = [History() for i in range(num_wks)]
for t_idx in range(num_wks):
    t = start_date + t_idx*one_wk
    t_str = t.strftime(datetime_str_f)
    fips_test = Xy_tot["geo_value"].to_numpy()
    X_train, y_train, y_train_h, X_val, y_val, y_val_h, n_val, X_test, y_test, y_test_h, n_test \
        = get_train_val_test_data(t, t_idx, one_wk, num_wks, datetime_str_f, Xy_tot, X_cols_prefix, y_col)

    n_test_range = range(n_test)
    if t_idx == 0:
        B = B_func(n_test)
        print("\tFixing budget = " + str(B))

        print("\tEnumerating balls (once, using pynndescent) ...")
        index_obj = pynndescent.NNDescent(Xy_tot[["pclon10", "pclat10"]].to_numpy(), n_neighbors=B) 
        knns = index_obj.neighbor_graph[0]            

    # Fit model(s).
    print("\tFitting model ...")
    model_obj = Model()
    model_obj.fit(y_train_h, X_train)

    # Compute (normalized) ranks.
    print("\tComputing (normalized) ranks ...")
    yhat_val = model_obj.predict(X_val)
    scores_val = compute_scores(apply_inverse_logit(yhat_val), apply_inverse_logit(y_val_h),
                                apply_inverse_logit(y_train_h), True, a=1e-2)

    yhat_test = model_obj.predict(X_test)
    scores_test = compute_scores(apply_inverse_logit(yhat_test), apply_inverse_logit(y_test_h),
                                 apply_inverse_logit(y_train_h), True, a=1e-2)
    if want_robust_intervals:
        print("\tComputing robust intervals ...")
        scores_val_unscaled = compute_scores(apply_inverse_logit(yhat_val), apply_inverse_logit(y_val_h),
                                             None, True, a=1e-2)
        his, los, true, cvgs, lens, preds = compute_robust_pred_ints(
            scores_val_unscaled, alpha, n_val, yhat_test, y_train_h, y_val_h, y_test_h)
        hist[t_idx].his = his
        hist[t_idx].los = los
        hist[t_idx].true = true
        hist[t_idx].cvgs = cvgs
        hist[t_idx].lens = lens
        hist[t_idx].preds = preds

    norm_ranks = [sp.stats.rankdata(np.append(scores_val, scores_test[idx]))[-1] / (n_val+1) \
                      for idx in range(n_test)]
    norm_ranks = np.array(norm_ranks)

    # Identify weird (test) points (two different ways):
    print("\tIdentifying weird points ...")
    # 1a,b) Use balls (w/ and w/o penalty) to pick weird points.
    all_sel_idxes = hist[t_idx].balls.identify_groups_of_weirdest_points(X_test, norm_ranks, B,
                                                                         just_identify_the_single_weirdest_group=True,
                                                                         knns=knns)

    all_sel_idxes_balls_no_penalty = hist[t_idx].balls_no_penalty.identify_groups_of_weirdest_points(X_test, norm_ranks, B,
                                                                         just_identify_the_single_weirdest_group=True,
                                                                         knns=knns)

    # 2) Just pick the weirdest points (i.e., a "naive strategy").
    all_sel_idxes_baseline = hist[t_idx].naive.identify_groups_of_weirdest_points(X_test, norm_ranks, B,
                                                                                  just_identify_the_single_weirdest_group=True)

    # Update loop variables.
    print("\tSaving epoch ...\n")
    hist[t_idx].fips = fips_test
    
    hist[t_idx].date = t
    hist[t_idx+1].date = t+one_wk
    hist[t_idx+2].date = t+2*one_wk
    hist[t_idx+3].date = t+3*one_wk
    
    hist[t_idx+3].test_hardness = norm_ranks
    
    if t_idx == 0:
        hist[t_idx].actual = Xy_tot[y_col + "_original"].to_numpy()
    else:
        hist[t_idx].actual = Xy_tot[y_col + "_original" + "_" + t_str].to_numpy()
    hist[t_idx+1].actual = y_train
    hist[t_idx+2].actual = y_val
    hist[t_idx+3].actual = y_test
    
    hist[t_idx+3].balls.sel_idxes = all_sel_idxes[0]
    hist[t_idx+3].balls.all_sel_idxes = all_sel_idxes
    
    hist[t_idx+3].balls_no_penalty.sel_idxes = all_sel_idxes_balls_no_penalty[0]
    hist[t_idx+3].balls_no_penalty.all_sel_idxes = all_sel_idxes_balls_no_penalty
    
    hist[t_idx+3].naive.sel_idxes = all_sel_idxes_baseline[0]
    hist[t_idx+3].naive.all_sel_idxes = all_sel_idxes_baseline
    
    if (t_idx+1 + 3) >= num_wks:
        hist[t_idx+1 + 0].fips = fips_test
        hist[t_idx+1 + 1].fips = fips_test
        hist[t_idx+1 + 2].fips = fips_test
        break
        
print("All done.")

# Make plots

In [None]:
my_intensities = []
for wk_idx in range(num_wks-3):
    _, _, _, _, true = compute_robust_pred_ints_stats(hist, 0, num_wks)
    my_intensities += [true[wk_idx]]
    
vmax_mean_frac = np.mean(np.concatenate(my_intensities))
vmax_std_frac = np.std(np.concatenate(my_intensities))

In [None]:
special_alg_idx = 0
num_algs = 5

In [None]:
for alg_idx in range(num_algs):
    my_cvg, my_len, my_his, my_los, true = compute_robust_pred_ints_stats(hist, alg_idx, num_wks)
    print(str(alg_idx) + ":")
    print(my_cvg, my_len)
    
    if alg_idx == special_alg_idx:
        for hi_lo_true in ["his", "los", "true"]:
            for wk_idx in range(num_wks-3):
                if hi_lo_true == "his":
                    intensities = my_his[wk_idx]
                elif hi_lo_true == "los":
                    intensities = my_los[wk_idx]
                else:
                    intensities = true[wk_idx]
                all_idxes = list(range(len(hist[wk_idx].actual)))                    
                fips = hist[wk_idx].fips
                cur_df, _, _ = make_df(all_idxes, intensities, fips,
                                 wk_idx, start_date, one_wk)

                fn = "covid_rob_pred_ints_{}_{}_alg_idx_{}".format(
                    (start_date + wk_idx*one_wk).strftime(datetime_str_f), hi_lo_true, alg_idx)
                fig, _ = plot(cur_df, vmax_mean=vmax_mean_frac, vmax_std=vmax_std_frac, cbar=False)
                fig.savefig(ff + fn + ".pdf", bbox_inches="tight")

                fig, _ = plot(cur_df, vmax_mean=vmax_mean_frac, vmax_std=vmax_std_frac, cbar=True)
                fig.savefig(ff + fn + "_cbar.pdf", bbox_inches="tight")

In [None]:
algs = ["SC",
        r"SC-$\alpha/2$",
        "KL-M2",
        "KL-R"]

skip_alg_idx = 2

In [None]:
fontsize = 16
linewidth = 2
linestyles = ["-",
              ":",
              "-",
              "-",
              "-",
              "-",
              "-"]
colors = ["mediumvioletred",
          "mediumvioletred",
          "blue",
          "deepskyblue",
          "chartreuse",
          "gold"]
algs = ["Algorithm 2",
        "Algorithm 2, unpenalized",
        "Hardest points",
        "Uniformly at random",
        "Pure global"]
    
vmax_mean = df[y_col + "_original"].mean(skipna=True)
vmax_std = df[y_col + "_original"].std(skipna=True)

res_stmt_font_size = 30

In [None]:
my_cvgs = []
my_lens = []
for alg_idx in range(num_algs):
    if alg_idx == skip_alg_idx:
        continue
    
    my_cvgs += [[np.mean(hist[t_idx].cvgs[alg_idx,:]) for t_idx in range(num_wks-3)]]
    my_lens += [[np.mean(hist[t_idx].lens[alg_idx,:]) for t_idx in range(num_wks-3)]]

make_boxplot(my_cvgs, my_lens, algs, alpha=alpha, mode="Coverage", ff=ff, fontsize=fontsize)
make_boxplot(my_cvgs, my_lens, algs, alpha=alpha, mode="Length", ff=ff, fontsize=fontsize)

In [None]:
for wk_idx in range(3,num_wks):
    all_idxes = list(range(len(hist[wk_idx].actual)))
    intensities = hist[wk_idx].actual # test_hardness
    fips = hist[wk_idx].fips
    cur_df, _, _ = make_df(all_idxes, intensities, fips,
                     wk_idx, start_date, one_wk)
    
    fp = ff + "covid_regions_true_{}.pdf".format(
        (start_date + wk_idx*one_wk).strftime(datetime_str_f))
    fig, _ = plot(cur_df, vmax_mean=vmax_mean, vmax_std=vmax_std, cbar=False)
    fig.savefig(fp, bbox_inches="tight")
    
    fp = ff + "covid_regions_true_{}_cbar.pdf".format(
        (start_date + wk_idx*one_wk).strftime(datetime_str_f))
    fig, _ = plot(cur_df, vmax_mean=vmax_mean, vmax_std=vmax_std, cbar=True)
    fig.savefig(fp, bbox_inches="tight")