In [1]:
import pickle as pkl
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os

sns.set_style('white')
palette = ["#1b4079", "#C6DDF0", "#048A81", "#B9E28C", "#8C2155", "#AF7595", "#E6480F", "#FA9500"]
sns.set(palette = palette, font_scale=2.0, style="white", rc={"lines.linewidth": 4.0})

In [2]:
stn_names, stn_lonlat, full_data = pkl.load(open("./wind_data.p", 'rb'))

In [4]:
def ECDF(sample_pxs, true_px):    
    return (torch.sum(sample_pxs < true_px, 0)/sample_pxs.shape[0])
 
def Calibration(pcts, percentile=0.95):
    in_band = np.where((pcts < percentile))[0].shape[0]
    return in_band/pcts.shape[0]

def GetCalibration(model, ema=True, k=100, theta=0.0, horizon=np.arange(75,100), 
                   logger=[], exp=True):
    
    ntime = full_data[0].shape[0]
    ntrain = 400
    n_test_times = 20
    ntest = 100
    test_idxs = torch.arange(ntrain, ntime-ntest, 
                             int((ntime-ntest-ntrain)/n_test_times))
    
    stns = list(stn_names.keys())
#     pcts = torch.zeros(len(stns), len(test_idxs), horizon.shape[0])    
    pcts = torch.tensor([])
    for stn_save_idx, stn_idx in enumerate(stns):
        for test_save_idx, test_idx in enumerate(test_idxs):

            fpath = "./saved-outputs/stn" + str(stn_idx) + "/"
            fname = model + "_"
            if model == 'volt':
                if ema:
                    fname += "ema" + str(k) + "_"
                fname += "theta" + str(theta) + "_"
            if model == 'matern':
                if ema:
                    fname += 'ewma' + str(k) + "_"
            fname += str(test_idx.item()) + ".pt"
            
#             print(fpath + fname)
            if os.path.exists(fpath + fname): 
                if model == 'volt':
                    preds = torch.load(fpath + fname)[0]
                elif model == 'matern':
                    preds = torch.load(fpath + fname).cpu()
                else:
                    preds = torch.load(fpath + fname)
                
                preds = preds[:, horizon]
                test_y = torch.tensor(full_data[stn_idx][test_idx:]) + 1
                
#                 return preds, test_y
                if exp:
                    preds = preds.exp()
                pcts = torch.cat((pcts, ECDF(preds, test_y[horizon])))
            
    pcts = pcts.flatten().numpy()
    percentiles = np.linspace(0.05, 0.95, 19)
    for pct in percentiles:
        clb = Calibration(pcts, pct)
        logger.append([clb, np.round(pct, 2), model, theta, ema, k])
        
    return logger

In [45]:
logger = []
logger = GetCalibration('volt', ema=False, k=200, theta=0.025,
              exp=True, logger=logger)
logger = GetCalibration('lstm', ema=False, theta=0.0,
              exp=True, logger=logger)
logger = GetCalibration('matern', ema=False, theta=0.0,
              exp=True, logger=logger)
logger = GetCalibration('volt', ema=True, k=200, theta=0.025,
              exp=True, logger=logger)
logger = GetCalibration('matern', ema=True, k=200, theta=0.0,
              exp=True, logger=logger)

In [46]:
df = pd.DataFrame(logger)
df.columns = ['Calibration', 'Percentile', 'Type', 'theta', 'ema', 'k']

In [47]:
df.to_pickle("wind_calib_df.pkl")

## Theta Sensitivity

In [40]:
logger = []
logger = GetCalibration('volt', theta=0.0,
              exp=True, logger=logger, ema=False)
logger = GetCalibration('volt', theta=0.01,
              exp=True, logger=logger, ema=False)
logger = GetCalibration('volt', theta=0.025,
              exp=True, logger=logger, ema=False)
logger = GetCalibration('volt', theta=0.05,
              exp=True, logger=logger, ema=False)


In [41]:
df = pd.DataFrame(logger)
df.columns = ['Calibration', 'Percentile', 'Type', 'theta', 'ema', 'k']

In [43]:
df.to_pickle("theta_calib_df.pkl")

## EMA Calibration

In [5]:
logger = []
logger = GetCalibration('volt', ema=True, theta=0.025, k=50,
              exp=True, logger=logger)
logger = GetCalibration('volt', ema=True, theta=0.025, k=100,
              exp=True, logger=logger)
logger = GetCalibration('volt', ema=True, theta=0.025, k=200,
              exp=True, logger=logger)
logger = GetCalibration('volt', ema=True, theta=0.025, k=400,
              exp=True, logger=logger)
logger = GetCalibration('volt', ema=False, theta=0.025, k=400,
              exp=True, logger=logger)

In [6]:
df = pd.DataFrame(logger)
df.columns = ['Calibration', 'Percentile', 'Type', 'theta', 'ema', 'k']

In [7]:
df.to_pickle("./ema_calibration.pkl")

## NLL

In [34]:
def GetNLL(model, ema=True, k=100, theta=0.0, horizon=np.arange(75,100), 
                   logger=[], exp=True):
    
    ntime = full_data[0].shape[0]
    ntrain = 400
    n_test_times = 20
    ntest = 100
    test_idxs = torch.arange(ntrain, ntime-ntest, 
                             int((ntime-ntest-ntrain)/n_test_times))
    
    stns = list(stn_names.keys())
    nlls = torch.tensor([])
    for stn_save_idx, stn_idx in enumerate(stns):
        for test_save_idx, test_idx in enumerate(test_idxs):

            fpath = "./saved-outputs/stn" + str(stn_idx) + "/"
            fname = model + "_"
            if model == 'volt':
                if ema:
                    fname += "ema" + str(k) + "_"
                fname += "theta" + str(theta) + "_"
            if model=='matern':
                if ema:
                    fname += 'ewma' + str(k) + '_'
            if model=='sm':
                if ema:
                    fname += "ewma" + str(k) + "_"
                else:
                    fname += 'constant200' + "_"
            fname += str(test_idx.item()) + ".pt"
            if os.path.exists(fpath + fname): 
                if model == 'volt':
                    preds = torch.load(fpath + fname)[0]
                elif model == 'matern':
                    preds = torch.load(fpath + fname).cpu()
                else:
                    preds = torch.load(fpath + fname)
                
                preds = preds[:, horizon]
                test_y = torch.tensor(full_data[stn_idx][test_idx:]) + 1
                test_y = test_y[horizon]

                if exp:
                    preds = preds.exp()
                
                ## clear out broken data ##
                keepers = torch.where(test_y > 0)[0]
                preds = preds[:, keepers]
                test_y = test_y[keepers]
                
#                 try:
                curr = torch.distributions.Normal(preds.mean(0), preds.std(0)).log_prob(test_y)
                if curr.mean().abs() < 500: ## ignore clearly broken cases
                    nlls = torch.cat((curr, nlls))
#                 except:
#                     pass
            
    if nlls.numel() > 0:
        logger.append([-nlls.sum().item(), -nlls.mean().item(), nlls.std().item(), model, ema, k])
        
    return logger

In [16]:
logger = []
logger = GetNLL('volt', ema=False, k=200, theta=0.025,
              exp=True, logger=logger)
logger = GetNLL('lstm', theta=0.0,
              exp=True, logger=logger)
logger = GetNLL('matern', theta=0.0, ema=False,
              exp=True, logger=logger)
logger = GetNLL('volt', ema=True, k=200, theta=0.025,
              exp=True, logger=logger)
logger = GetNLL('matern', ema=True, k=200, theta=0.0,
              exp=True, logger=logger)

./saved-outputs/stn0/volt_theta0.025_400.pt
./saved-outputs/stn0/volt_theta0.025_5631.pt
./saved-outputs/stn0/volt_theta0.025_10862.pt
./saved-outputs/stn0/volt_theta0.025_16093.pt
./saved-outputs/stn0/volt_theta0.025_21324.pt
./saved-outputs/stn0/volt_theta0.025_26555.pt
./saved-outputs/stn0/volt_theta0.025_31786.pt
./saved-outputs/stn0/volt_theta0.025_37017.pt
./saved-outputs/stn0/volt_theta0.025_42248.pt
./saved-outputs/stn0/volt_theta0.025_47479.pt
./saved-outputs/stn0/volt_theta0.025_52710.pt
./saved-outputs/stn0/volt_theta0.025_57941.pt
./saved-outputs/stn0/volt_theta0.025_63172.pt
./saved-outputs/stn0/volt_theta0.025_68403.pt
./saved-outputs/stn0/volt_theta0.025_73634.pt
./saved-outputs/stn0/volt_theta0.025_78865.pt
./saved-outputs/stn0/volt_theta0.025_84096.pt
./saved-outputs/stn0/volt_theta0.025_89327.pt
./saved-outputs/stn0/volt_theta0.025_94558.pt
./saved-outputs/stn0/volt_theta0.025_99789.pt
./saved-outputs/stn1/volt_theta0.025_400.pt
./saved-outputs/stn1/volt_theta0.025_56

In [35]:
# logger = GetNLL('matern', theta=0.0, ema=False,
#               exp=True, logger=[])

logger = GetNLL('sm', theta=0.0, ema=True,
              exp=True, logger=[], k=200)
logger = GetNLL('sm', theta=0.0, ema=False,
              exp=True, logger=logger, k=200)

In [36]:
df = pd.DataFrame(logger)
df.columns = ["NLL", "Mean_NLL", "Std_NLL", "kernel", "ema", "k"]

In [37]:
df

Unnamed: 0,NLL,Mean_NLL,Std_NLL,kernel,ema,k
0,2532691.0,110.069138,163.039618,sm,True,200
1,1899801.0,70.144774,134.305314,sm,False,200
