In [21]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm
from tqdm import tqdm
import sys
import matplotlib as mpl

#insert path
sys.path.insert(0, '../methods/')

sys.modules.pop('generate_syn_data', None)
from generate_syn_data import *

sys.modules.pop('ARWQE', None)
from ARWQE import *

sys.modules.pop('plots', None)
from plots import *

In [44]:
# ELEC2 data set
# downloaded from https://www.kaggle.com/yashsharan/the-elec2-dataset
data = pd.read_csv('electricity-normalized.csv')
col_names = data.columns
data = data.to_numpy()

# remove the first stretch of time where 'transfer' does not vary
data = data[17760:]

# set up variables for the task (predicting 'transfer')
covariate_col = ['nswprice', 'nswdemand', 'vicprice', 'vicdemand']
response_col = 'transfer'
# keep data points for 9:00am - 12:00pm
keep_rows = np.where((data[:,2]>data[17,2])&(data[:,2]<data[24,2]))[0]

X = data[keep_rows][:,np.where([t in covariate_col for t in col_names])[0]]
Y = data[keep_rows][:,np.where(col_names==response_col)[0]].flatten()
X = X.astype('float64')
Y = Y.astype('float64')

N = X.shape[0]; p = X.shape[1]; alpha = 0.1
train_lag = 100 # start predicting after train_lag many observations

# randomly permuted data 
perm = np.random.choice(N,N,replace=False)
X_perm = X[perm]
Y_perm = Y[perm]

# weights and tags (parameters for new methods)
rho = 0.99; rho_LS = 0.99

In [53]:
alpha =0.1; delta = 0.1; gamma=1

#NOTE!!: change num periods later
num_periods = X.shape[0]

np.random.seed(6)

B_arr = np.ones(num_periods)
B_arr_starts = np.arange(num_periods)
B_arr_ends = np.cumsum(B_arr) - 1

#split and use odd indices for training
inds_odd = np.arange(1,int(np.ceil(num_periods/2)*2),2)
inds_even = np.arange(0,int(np.floor(num_periods/2)*2),2)

X_tr = X[inds_odd]; Y_tr = Y[inds_odd]
X_val = X[inds_even]; Y_val = Y[inds_even] 

B_arr_tr = B_arr[inds_odd].astype(int); B_arr_val = B_arr[inds_even].astype(int)
B_tr_starts = B_arr_starts[inds_odd].astype(int); B_val_starts = B_arr_starts[inds_even].astype(int)
B_tr_ends = B_arr_ends[inds_odd].astype(int); B_val_ends = B_arr_ends[inds_even].astype(int)
print(B_tr_ends)

[   1    3    5 ... 3439 3441 3443]


In [54]:
X.shape, Y.shape

((3444, 4), (3444,))

In [77]:
num_trials = 1
fixed_windows = [1, 16, 256, 4096]
cdf_dict = create_empty_dict(fixed_windows, num_trials)
interval_dict = create_empty_dict(fixed_windows, num_trials)

seeds = np.arange(num_trials) + 2024

for (trial, seed) in enumerate(seeds):
    
    np.random.seed(seed)

    k_hat_all = []

    for t in tqdm(range(len(B_arr_tr)-1)):
        
        X_tr_t = X[:B_tr_ends[t]+1]; Y_tr_t = Y[:B_tr_ends[t]+1]
        X_val_t = X[:B_val_ends[t]+1]; Y_val_t = Y[:B_val_ends[t]+1]

        reg, S_t = fit_and_get_scores(X_tr_t, Y_tr_t, X_val_t, Y_val_t)

        B_val_t = B_arr_val[:t+1]

        khat, qt_khat, qtk_all = ARWQE(S_t, B_val_t, alpha, delta, gamma)
        k_hat_all.append(khat)

        #compute prediction set on the next training point
        X_test = X[B_tr_ends[t]+1]; Y_test = Y[B_tr_ends[t]+1]
        y_hat = reg.predict(X_test.reshape(1,-1))
        coverage_ARW = check_coverage(y_hat, qt_khat, Y_test)
        
        cdf_dict['ARW'][trial].append(coverage_ARW[0])
        interval_dict['ARW'][trial].append(qt_khat)

        #baseline: take quantile of fixed k
        for ik, k in enumerate(fixed_windows):
            log2k = int(np.log2(k))
            qtk = qtk_all[min(log2k, len(qtk_all)-1)]
            coverage_k = check_coverage(y_hat, qtk, Y_test)
            cdf_dict[f'Val_{k}'][trial].append(coverage_k[0])
            interval_dict[f'Val_{k}'][trial].append(qtk)

    #plot k_hat_all
    #plt.plot(k_hat_all)

100%|██████████| 1721/1721 [00:03<00:00, 561.52it/s]


In [78]:
arw_cov = np.array(cdf_dict['ARW'][0])
print(arw_cov.mean())

0.8413712957582801


In [79]:
#process results in cdf_dict
methods = ['ARW'] + [f'Val_{k}' for k in fixed_windows]
for method in methods:
    print(method, np.array(cdf_dict[method][0]).mean())

ARW 0.8413712957582801
Val_1 0.8762347472399767
Val_16 0.8773968622893666
Val_256 0.8617083091226031
Val_4096 0.8413712957582801
