In [None]:
"""Learning "logit" distribution for UCI data"""
import os
CODE_DIR = <Path to repo 'code' dir>
os.chdir(CODE_DIR)
import logging
import zipfile
from copy import copy, deepcopy
import urllib.request
from pathlib import Path
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Optimizer
from torch.optim.sgd import SGD
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from sklearn.model_selection import KFold
from src.dataloaders.uci import uci_base, wine, bost
from src import metrics
from src import utils
from src.ensemble import simple_regressor, ensemble
from src.distilled import gauss_logits, norm_inv_wish
from src import loss as custom_loss
import matplotlib.pyplot as plt
import tikzplotlib

# Settings
class Args():
    pass
args = Args()
args.seed = 1
args.gpu = True
args.log_dir = Path("./logs")
args.log_level = logging.WARNING
args.retrain = True
args.num_ensemble_members=1
args.num_epochs=1
args.lr = 0.01

LOGGER = logging.getLogger(__name__)
EXPERIMENT_NAME = "uci_wine"

log_file = Path("{}_{}.log".format(
    EXPERIMENT_NAME,
    datetime.now().strftime("%Y%m%d_%H%M%S")))
utils.setup_logger(log_path=Path.cwd() / args.log_dir / log_file,
                   log_level=args.log_level)

# General constructs
train_metrics = list()
test_metrics = list()

rmse = metrics.Metric(name="RMSE", function=metrics.root_mean_squared_error)
train_metrics.append(deepcopy(rmse))
test_metrics.append(rmse)

BATCH_SIZE = 32
torch.cuda.device(0)
torch.cuda.get_device_name(torch.cuda.current_device())
device = torch.device("cuda")

In [None]:
def create_ensemble(num_ensemble_members,
                    input_dim,
                    num_hidden,
                    lr,
                    ensemble_output_size):
    prob_ensemble = ensemble.Ensemble(ensemble_output_size)
    for _ in range(num_ensemble_members):
        network = simple_regressor.Model(layer_sizes=[input_dim,
                                                      num_hidden,
                                                      ensemble_output_size],
                                 device=device,
                                 variance_transform=utils.positive_linear_asymptote(1e-6),
                                 loss_function=custom_loss.gaussian_neg_log_likelihood_1d)
        network.optimizer = torch.optim.Adam(network.parameters(),
                                    lr=lr)
        prob_ensemble.add_member(network)
    return prob_ensemble

def mean_and_std_from_list(samples):
    """Calculate mean and std from np-array compatible list"""
    array = np.array(samples)
    return array.mean(), array.std()

def mean_and_std_from_metric(metric, rescale=1):
    """Calculate mean and std from np-array compatible list"""
    return metric.mean() * rescale, metric.std() * rescale

def test_distilled(distilled, x, y_true, device, scale):
    with torch.no_grad():
        num_samples = len(y_true)
        x = torch.tensor(x).float().to(device)
        z_mean, z_var = distilled.forward(x);
        mu_dist = z_mean[:, 0].reshape(y_true.shape)
        ale_dist = torch.log(1 + torch.exp(z_mean[:, 1]))
        epi_dist = z_var[:, 1]
        tot_uncert = ale_dist + epi_dist
        y_true = torch.tensor(y_true,
                              device=device,
                              dtype=torch.float).reshape((num_samples, 1, 1))

        rmse, nll, ause = common_test(y_true=y_true,
                                      mu=mu_dist,
                                      sigma_sq=tot_uncert,
                                      uncert=tot_uncert)
        rmse *= scale
        nll += np.log(scale)
        
    return rmse.item(), nll.item(), ause.item()

    
def test_ensemble(prob_ensemble, x, y_true, device, scale):
    with torch.no_grad():
        num_samples = len(y_true)
        output = prob_ensemble.predict(torch.tensor(x, device=device, dtype=torch.float))
        mu_ens, sigma_sq_ens = output[:, :, 0],  output[:, :, 1]
        mean_mu, tot_uncert = utils.gaussian_mixture_moments(mu_ens, sigma_sq_ens)
        mean_mu = mean_mu.reshape((num_samples, 1)).to(device)
        tot_uncert = tot_uncert.reshape((num_samples, 1)).to(device)
        y_true = torch.tensor(y_true,
                              device=device,
                              dtype=torch.float).reshape((num_samples, 1, 1))

        rmse, nll, ause = common_test(y_true, mean_mu, tot_uncert, tot_uncert)
        rmse *= scale
        nll += np.log(scale)
    
    return rmse.item(), nll.item(), ause.item()

def common_test(y_true, mu, sigma_sq, uncert, num_partitions=10):
    rmse = metrics.root_mean_squared_error(predictions=mu,
                                           targets=y_true)
    nll = custom_loss.gaussian_neg_log_likelihood_1d((mu, sigma_sq),
                                                    y_true)
    num_samples = len(y_true)
    ause = ause_mix = utils.ause(y_true=y_true.reshape((num_samples, 1)),
           y_pred=mu,
           uncert_meas=uncert,
           num_partitions=num_partitions)
    return rmse, nll, ause

In [None]:
def train_ensemble(data,
                     num_ensemble_members,
                     num_epochs,
                     num_units,
                     n_splits,
                     learn_rate,
                     weight_decay,
                     train_metrics,
                     test_metrics,
                     batch_size):

    ens_rmses = list()
    ens_nlls = list()
    ens_auses = list()
    
    dist_rmses = list()
    dist_nlls = list()
    dist_auses = list()
    
    kf = KFold(n_splits=n_splits)
    in_dim = data.shape[1] - 1
    train_logliks, test_logliks = [], []
    train_rmses, test_rmses = [], []
    
    hidden_size = 50
    distilled_output_size = 4
    layer_sizes = [in_dim, hidden_size, hidden_size, distilled_output_size]

    for j, idx in enumerate(kf.split(data)):
        train_index, test_index = idx
        print("Fold: {}".format(j))
        for metric in train_metrics:
            metric.reset()        

        for metric in test_metrics:
            metric.reset()

        prob_ensemble = create_ensemble(num_ensemble_members=num_ensemble_members,
                                        input_dim=in_dim,
                                        num_hidden=num_units,
                                        lr=learn_rate,
                                        ensemble_output_size=2)
        prob_ensemble.add_metrics(train_metrics)
        
        distilled_model = gauss_logits.Model(
            layer_sizes=layer_sizes,
            teacher=prob_ensemble,
            variance_transform=utils.positive_linear_asymptote(0.001),
            device=device,
            learning_rate=args.lr)


        #x_train, y_train, x_test, y_test = data.create_train_val_split(0.9)
        
        x_train, y_train = data[train_index, :in_dim], data[train_index, in_dim:]
        x_test, y_test = data[test_index, :in_dim], data[test_index, in_dim:]

        x_means, x_stds = x_train.mean(axis = 0), x_train.var(axis = 0)**0.5
        y_means, y_stds = y_train.mean(axis = 0), y_train.var(axis = 0)**0.5

        x_train = (x_train - x_means) / x_stds
        y_train = (y_train - y_means) / y_stds

        x_test = (x_test - x_means) / x_stds
        y_test = (y_test - y_means) / y_stds

        data_std = y_stds[0]
        if batch_size is None:
            batch_size = x_train.shape[0]
            
        trainloader = uci_base.uci_dataloader(x_train, y_train, batch_size)
        unlabelled_loader = uci_base.uci_dataloader(x_train, y_train, 128)


        train_loss = prob_ensemble.train(train_loader=trainloader,
                            num_epochs=num_epochs)
        try:
            distilled_model.train(unlabelled_loader, 30)
        except (ValueError, RuntimeError):
            print("NaN")
            continue
            

        rmse_ens, nll_ens, ause_ens = test_ensemble(prob_ensemble=prob_ensemble,
                                                    x=x_test,
                                                    y_true=y_test,
                                                    device=device,
                                                    scale=data_std)
        rmse_dist, nll_dist, ause_dist = test_distilled(distilled_model,
                                                     x=x_test,
                                                     y_true=y_test,
                                                     device=device,
                                                     scale=data_std)
            

        ens_rmses.append(rmse_ens)
        ens_nlls.append(nll_ens)
        ens_auses.append(ause_ens)
        
        dist_rmses.append(rmse_dist)
        dist_nlls.append(nll_dist)
        dist_auses.append(ause_dist)
        print("rmse: {}\t nll: {}, ause: {}".format(rmse_ens, nll_ens, ause_ens))
        print("rmse: {}\t nll: {}, ause: {}".format(rmse_dist, nll_dist, ause_dist))

    print("Test RMSE\t = {:.3f} +/- {:.3f}".format(*mean_and_std_from_list(ens_rmses)))
    print("Test NLL\t = {:.3f} +/- {:.3f}".format(*mean_and_std_from_list(ens_nlls)))
    print("Test AUSE\t = {:.3f} +/- {:.3f}".format(*mean_and_std_from_list(ens_auses)))
    print("Test RMSE\t = {:.3f} +/- {:.3f}".format(*mean_and_std_from_list(dist_rmses)))
    print("Test NLL\t = {:.3f} +/- {:.3f}".format(*mean_and_std_from_list(dist_nlls)))
    print("Test AUSE\t = {:.3f} +/- {:.3f}".format(*mean_and_std_from_list(dist_auses)))
    
    return (ens_rmses, ens_nlls, ens_auses), (dist_rmses, dist_nlls, dist_auses)

In [None]:
wine_data = wine.WineData("data/uci/wine/winequality-red.csv")
result  = train_ensemble(data=wine_data.data,
                       num_ensemble_members=10,
                       num_epochs=40,
                       num_units=50,
                       n_splits=5,
                       learn_rate=1e-1,
                       weight_decay=0.0, #1e-1/len(data)**0.5,
                       train_metrics=train_metrics,
                       test_metrics=test_metrics,
                       batch_size=None)

In [None]:
x_train, y_train, _, _ = wine_data.create_train_val_split(1)

x_means, x_stds = x_train.mean(axis = 0), x_train.var(axis = 0)**0.5
y_means, y_stds = y_train.mean(axis = 0), y_train.var(axis = 0)**0.5

x_train = (x_train - x_means) / x_stds
y_train = (y_train - y_means) / y_stds

x_tensor = torch.tensor(x_train).float().to(device)
ens_output = prob_ensemble.predict(x_tensor)
mu_ens = ens_output[:, :, 0]
var_ens = ens_output[:, :, 1]
mean_mu_ens = torch.mean(mu_ens, dim=1).reshape(y_train.shape).cpu().detach().numpy()

ale_ens, epi_ens = metrics.uncertainty_separation_parametric(mu_ens, var_ens)
ale_ens = ale_ens.detach().numpy()
epi_ens = epi_ens.detach().numpy()

z_mean, z_var = distilled_model.forward(x_tensor);
z_mean = z_mean.cpu().detach()
z_var = z_var.cpu().detach().numpy()
mu_dist = z_mean[:, 0].reshape(y_train.shape).numpy()
ale_dist = torch.log(1 + torch.exp(z_mean[:, 1])).numpy()
epi_dist = z_var[:, 1]

dist_spread = z_var.sum(1)
window_size = 20
uncert_ens = ale_ens + epi_ens
fig, (ax_ens, ax_dist) = plt.subplots(1, 2, sharey=True)


uncert_dist = ale_dist + epi_dist

num_partitions = 10

utils.plot_sparsification_error(ax_ens,
                 y_true=y_train,
                 y_pred=mean_mu_ens,
                 uncert_meas=uncert_ens,
                 num_partitions=num_partitions,
                 label="Ensemble")

utils.plot_sparsification_error(ax_dist,
                 y_true=y_train,
                 y_pred=mu_dist,
                 uncert_meas=uncert_dist,
                 num_partitions=num_partitions,
                 label="Distilled")

ax_ens.set_ylabel("$SE$")

In [None]:
hidden_size = 50
distilled_output_size = 4
layer_sizes = [wine_data.input_dim, hidden_size, hidden_size, distilled_output_size]
distilled_model = logits_probability_distribution.LogitsProbabilityDistribution(
    layer_sizes=layer_sizes,
    teacher=prob_ensemble,
    variance_transform=utils.positive_linear_asymptote(),
    device=device,
    learning_rate=args.lr)

unlabelled_loader = uci_base.uci_dataloader(x_train, y_train, 128)
test_loader = uci_base.uci_dataloader(x_test, y_test, len(y_test))

distilled_model.train(unlabelled_loader, 30)

x_train, y_train, x_test, y_test = wine_data.create_train_val_split(0.9)
x_means, x_stds = x_train.mean(axis = 0), x_train.var(axis = 0)**0.5
y_means, y_stds = y_train.mean(axis = 0), y_train.var(axis = 0)**0.5

x_train = (x_train - x_means) / x_stds
y_train = (y_train - y_means) / y_stds

x_test = (x_test - x_means) / x_stds
y_test = (y_test - y_means) / y_stds

unlabelled_loader = uci_base.uci_dataloader(x_train, y_train, 128)
test_loader = uci_base.uci_dataloader(x_test, y_test, len(y_test))

distilled_model.train(unlabelled_loader, 30)


In [None]:
x_train, y_train, x_test, y_test = wine_data.create_train_val_split(0.9)
x_means, x_stds = x_train.mean(axis = 0), x_train.var(axis = 0)**0.5
y_means, y_stds = y_train.mean(axis = 0), y_train.var(axis = 0)**0.5

x_train = (x_train - x_means) / x_stds
y_train = (y_train - y_means) / y_stds

x_test = (x_test - x_means) / x_stds
y_test = (y_test - y_means) / y_stds
test_nlls = list()
test_rmses = list()
with torch.no_grad():
    x_test_tensor = torch.tensor(x_test, device=device, dtype=torch.float)
    y_test_tensor = torch.tensor(y_test,
                                     device=device,
                                     dtype=torch.float).reshape((len(y_test), 1, 1))
    
    mean_dist, var_dist = distilled_model.forward(x_test_tensor)
    mu_dist = mean_dist[:, 0].unsqueeze(1)
    sigma_sq_dist = distilled_model.variance_transform(mean_dist[:, 0].unsqueeze(1))
    test_rmse = metrics.root_mean_squared_error(predictions=mu_dist,
                                    targets=y_test_tensor) * data_std
    test_nll = custom_loss.gaussian_neg_log_likelihood_1d((mu_dist, sigma_sq_dist),
                                                    y_test_tensor) + np.log(data_std)

test_nlls.append(test_nll.item())
test_rmses.append(test_rmse.item())

print(test_rmse)
print(test_nll)
