In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from model.dropout_vi import Dropout_VI
from testing import tester,display_stats,avg_stats,save_stats
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from tqdm import tqdm as pbar
from torch.distributions.normal import Normal
from trainers.dropout_vi_trainer import Dropout_VI_trainer
from utils.data_utils import data_helper,uci_helper



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
def train_and_test(data,input_dim,num_units=128,batch_size = 512,qr_reg = False, 
              K=0.0,epochs = 64 ,lr =1e-2, n_splits=5,spacing = 64):
    
    """
          does k-fold splitting w
    
    """
    

    kf = KFold(n_splits = n_splits)
    output_dim = 1
    dataset_stats = []
    iso_nll_inf_count = 0
    max_fx_up = -1
    for i,idx in enumerate(kf.split(data)):
        train_index, test_index = idx
        helper       = data_helper(data,train_index,test_index,input_dim)
        train_loader = helper.train_loader 
        test_loader  = helper.test_loader
        
        y_mean = helper.y_mean
        y_std  = helper.y_std
        network      = Dropout_VI(input_dim =input_dim , output_dim=1, num_units=num_units,drop_prob=0.5)
        
        optimizer    = torch.optim.Adam(network.parameters(), lr=lr)
        trainer      = Dropout_VI_trainer(network = network,input_dim =input_dim,
                                          batch_size=batch_size,optimizer = optimizer,
                                          device = device,mean=y_mean,std=y_std,
                                          qr_reg = qr_reg,K = K,spacing = spacing)
        
        
        time = trainer.train(train_loader,epochs)
        iso,delta,t_cdf,iso_time  = trainer.fit_isotonic(train_loader) 
        test_util                 = tester(network = network,delta=delta,iso=iso,
                                      t_cdf=t_cdf,mean=y_mean,std=y_std,
                                      qr_reg =qr_reg,K=K)
        
        
        current_split_stats,fx_up,count       = test_util.test(test_loader)
        current_fx_up = torch.max(fx_up)
        
        iso_nll_inf_count = max(count,iso_nll_inf_count)
        max_fx_up = max(current_fx_up.item(),max_fx_up)
        current_split_stats.append(time)
        current_split_stats.append(iso_time)
        dataset_stats.append(current_split_stats)
        
        
        
    
    dataset_stats = avg_stats(dataset_stats)
    return dataset_stats,max_fx_up,iso_nll_inf_count
        
        
    

In [4]:
def multiple_runs(name,qr=False,K=0.0,times=10,spacing=64):
    #setting hyperparameters
    
    
    data = uci_helper(name)
    epochs       = 64
    units        = 128
    learn_rate   = 1e-2
    input_dim    = data.shape[1]-1
    batch_size   = 512

    overall_stats = []
    max_iso_nll_inf_count = -1
    max_fx_up  = -1
    for i in pbar(range(times)):
        dataset_stats,fx_up,count = train_and_test(data,input_dim,units,batch_size,epochs=epochs,
                                       lr=learn_rate,qr_reg = qr,K=K,spacing = spacing)
        
        overall_stats.append(dataset_stats)
        max_iso_nll_inf_count = max(max_iso_nll_inf_count,count)
        max_fx_up = max(max_fx_up,fx_up)
    
    save_stats(overall_stats,name,qr,K)
    print("iso nll count : {} , maximum likelihood :{}".format(max_iso_nll_inf_count , max_fx_up))
    
    #return dataset_stats,count


In [5]:
multiple_runs("airfoil",qr=False)

100%|██████████| 5/5 [00:40<00:00,  8.01s/it]

calib         : 12.46 -+ 1.55
iso_calib     : 14.73 -+ 0.61
rmse          : 3.63 -+ 0.07
iso_rmse      : 3.59 -+ 0.04
nll           : 2.69 -+ 0.01
iso_nll       : -1.15 -+ 0.52
time          : 1.31 -+ 0.04
iso_time      : 0.07 -+ 0.00
iso nll count : 28 , maximum likelihood :4512.7685546875





In [6]:
multiple_runs("airfoil",qr=True,K=1.0)

100%|██████████| 5/5 [00:50<00:00, 10.07s/it]

calib         : 9.39 -+ 1.81
iso_calib     : 9.51 -+ 1.49
rmse          : 3.94 -+ 0.07
iso_rmse      : 3.92 -+ 0.05
nll           : 2.80 -+ 0.03
iso_nll       : -1.90 -+ 1.61
time          : 1.85 -+ 0.06
iso_time      : 0.08 -+ 0.01
iso nll count : 23 , maximum likelihood :42114.05859375



