In [None]:
%load_ext autoreload
%autoreload 2
from copy import deepcopy
import os
os.chdir("/home/jakob/doktor/projects/EnsembleUncertainty/code")
"""Learing "logit" distribution in regression example"""
from pathlib import Path
from datetime import datetime
import logging
import numpy as np
import matplotlib.pyplot as plt
import torch

from src.dataloaders import gaussian_sinus, one_dim_regression
import src.utils as utils
from src.distilled import logits_probability_distribution
from src.ensemble import ensemble
from src.ensemble import sep_regressor, simple_regressor
import src.metrics as metrics

LOGGER = logging.getLogger(__name__)
EXPERIMENT_NAME = "regression_logits"

# Settings
class Args():
    pass
args = Args()
args.seed = 1
args.gpu = False
args.log_dir = Path("./logs")
args.log_level = logging.INFO
args.retrain = True

args.num_ensemble_members=10
args.num_epochs=100
args.lr = 0.01


In [None]:
def make_plots(distilled_model, data):
    test_loader = torch.utils.data.DataLoader(data,
                                              batch_size=16,
                                              shuffle=True,
                                              num_workers=0)

    predictions = np.zeros((data.n_samples, distilled_model.output_size))
    all_x = np.zeros((data.n_samples, 1))
    all_y = np.zeros((data.n_samples, 1))

    idx = 0
    for batch in test_loader:
        inputs, targets = batch

        predictions[idx * test_loader.batch_size:(idx + 1) * test_loader.batch_size, :, :] = \
            distilled_model.predict(inputs, t=None).data.numpy()

        all_x[idx * test_loader.batch_size:(idx + 1) *
              test_loader.batch_size, :] = inputs
        all_y[idx * test_loader.batch_size:(idx + 1) *
              test_loader.batch_size, :] = targets

        idx += 1

    plt.scatter(np.squeeze(all_x), np.squeeze(all_y), label="Data", marker=".")

    plt.errorbar(np.squeeze(all_x),
                 predictions[:, 0],
                 np.sqrt(predictions[:, 1]),
                 label="Distilled model predictions",
                 marker=".",
                 ls="none")

    plt.legend()
    plt.show()

In [None]:
log_file = Path("{}_{}.log".format(
    EXPERIMENT_NAME,
    datetime.now().strftime("%Y%m%d_%H%M%S")))
utils.setup_logger(log_path=Path.cwd() / args.log_dir / log_file,
                   log_level=args.log_level)
LOGGER.info("Args: {}".format(args))
device = utils.torch_settings(args.seed, args.gpu)
LOGGER.info("Creating dataloader")
data = gaussian_sinus.GaussianSinus(
    store_file=Path("none"))

input_size = 1
layer_sizes = [1, 10, 10, 1]
ensemble_output_size = layer_sizes[-1] * 2
args.num_ensemble_members = 2
args.num_epochs=25
args.lr = 0.001
args.log_level = logging.INFO
train_loader = torch.utils.data.DataLoader(data,
                                           batch_size=32,
                                           shuffle=True,
                                           num_workers=0)

prob_ensemble = ensemble.Ensemble(ensemble_output_size)
for _ in range(args.num_ensemble_members):
    model = sep_regressor.SepRegressor(layer_sizes,
                                       device=device,
                                       learning_rate=args.lr)
    model.switch_active_network("mu")
    prob_ensemble.add_member(model)
squared_error_metric = metrics.Metric(name="Squared error",
                                      function=metrics.mean_squared_error)
prob_ensemble.add_metrics([squared_error_metric])
prob_ensemble.train(train_loader, args.num_epochs)

In [None]:
for model in prob_ensemble.members:
    model.switch_active_network("sigma_sq")
prob_ensemble.train(train_loader, args.num_epochs)

In [None]:
start = -5
end = 5
step = 0.1
x_length = int((end - start) / step)
x = torch.arange(start=start,
                 end=end,
                 step=step).reshape((x_length, 1)).float()
output = prob_ensemble.predict(x)
x = x.detach().numpy()[:,0]

def plot_uncert(ax, data, x, mean_mu=None, ale=None, epi=None):
    inputs = data[:, :-1]
    targets = data[:, -1]
    ax.scatter(inputs, targets)
    lower_x_bound = np.array([-3, -3])
    upper_x_bound = np.array([3, 3])
    y_bound = np.array([-2, 2])
    ax.plot(lower_x_bound, y_bound, "b--")
    ax.plot(upper_x_bound, y_bound, "b--")
    if mean_mu is not None:
        ax.plot(x, mean_mu, "r-", label="$\mu_{avg}(x)$")
    if ale is not None:
        every_nth = 3
        ax.errorbar(x, mean_mu,
                    np.sqrt(ale),
                    errorevery=every_nth,
                    color="r",
                    label="$E_w[\sigma_w^2(x)]$")
        #ax.plot(x, ale, "g-", label="$E_w[\sigma_w^2(x)]$")
    if epi is not None:
        ax.fill_between(x, mean_mu + np.sqrt(epi), mean_mu - np.sqrt(epi),
                        facecolor = "blue", alpha=0.5, label="var$_w(\mu_w(x))$")
        #ax.plot(x, np.sqrt(100*epi))
    plt.legend(prop={'size': 40})
    plt.show()
    
mu = output[:, :, 0]
var = output[:, :, 1]
ale, epi = metrics.uncertainty_separation_parametric(mu, var)
mean_mu = torch.mean(mu, dim=1).detach().numpy()
ale = ale.detach().numpy()
epi = epi.detach().numpy()
#plot_uncert(data.get_full_data(), x, mean_mu, 10*ale, epi, ax)

plt.rcParams['figure.figsize'] = [30, 30]
_, ax = plt.subplots()

plot_uncert(ax, data.get_full_data(), x, mean_mu=mean_mu, ale=1*ale, epi=1*epi)


In [None]:

def plot_reg_with_pred(data, x, pred_mean, pred_var, ax):
    inputs = data[:, :-1]
    targets = data[:, -1]
    ax.scatter(inputs, targets)
    ax.plot(x, means, 'r-')
    every_nth = 10
    ax.errorbar(x, means,
                np.sqrt(var),
                errorevery=every_nth,
                color="r")
    plt.show()

plt.rcParams['figure.figsize'] = [30, 30]
member = 1
means = output.detach().numpy()[:, member, 0]
var = output.detach().numpy()[:, member, 1]
_, ax = plt.subplots()
plot_reg_with_pred(data.get_full_data(), x, means, var, ax)


In [None]:
#Create distilled!
hidden_size = 10
layer_sizes = [input_size, hidden_size, hidden_size, distilled_output_size]
distilled_output_size = ensemble_output_size * 2
distilled_model = logits_probability_distribution.LogitsProbabilityDistribution(
    layer_sizes=layer_sizes,
    teacher=prob_ensemble,
    device=device,
    learning_rate=args.lr)

In [None]:
# Retrain!
lower = -5
upper = 5
unlabelled_data = gaussian_sinus.GaussianSinus(
    store_file=Path("None"), train=False, range_=(lower, upper))
unlabelled_loader = torch.utils.data.DataLoader(unlabelled_data,
                                           batch_size=6,
                                           shuffle=True,
                                           num_workers=0)


distilled_model.train(unlabelled_loader, 50)


In [None]:
ens_output

In [None]:
def plot_reg_with_pred(data, x, pred_mean, pred_var, ax):
    inputs = data[:, :-1]
    targets = data[:, -1]
    ax.scatter(inputs, targets)
    ax.plot(x, pred_mean, 'r-')
    every_nth = 10
    ax.errorbar(x, pred_mean,
                np.sqrt(pred_var),
                errorevery=every_nth,
                color="r")
    plt.show()

def plot_uncert(ax, data, x, mean_mu=None, ale=None, epi=None):
    inputs = data[:, :-1]
    targets = data[:, -1]
    ax.scatter(inputs, targets)
    lower_x_bound = np.array([-3, -3])
    upper_x_bound = np.array([3, 3])
    y_bound = np.array([-2, 2])
    ax.plot(lower_x_bound, y_bound, "b--")
    ax.plot(upper_x_bound, y_bound, "b--")
    if mean_mu is not None:
        ax.plot(x, mean_mu, "r-", label="$\mu_{avg}(x)$")
    if ale is not None:
        every_nth = 3
        ax.errorbar(x, mean_mu,
                    np.sqrt(ale),
                    errorevery=every_nth,
                    color="r",
                    label="$E_w[\sigma_w^2(x)]$")
    if epi is not None:
        ax.fill_between(x, mean_mu + np.sqrt(epi), mean_mu - np.sqrt(epi),
                        facecolor = "blue", alpha=0.5, label="var$_w(\mu_w(x))$")
        #ax.plot(x, np.sqrt(100*epi))
    plt.legend(prop={'size': 40})
    plt.show()
    
def compare_ale(ax, x, ale_ens, ale_dist):
    mu = np.sin(x)
    sigma = 0.15 * 1 / (1 + np.exp(-x))
    ax.plot(x, sigma, label="true")
    ax.plot(x, ale_ens, label="Ens")
    ax.plot(x, ale_dist, label="Dist")
    ax.legend(prop={'size': 40})
    plt.show()


start = -5
end = 5
step = 0.25
x_length = int((end - start) / step)
x = torch.arange(start=start,
                 end=end,
                 step=step, requires_grad=False).reshape((x_length, 1)).float()

ens_output = prob_ensemble.predict(x)
mu_ens = ens_output[:, :, 0]
var_ens = ens_output[:, :, 1]
ale_ens, epi_ens = metrics.uncertainty_separation_parametric(mu_ens, var_ens)
ale_ens = ale_ens.detach().numpy()
epi_ens = epi_ens.detach().numpy()
z_mean, z_var = distilled_model.forward(x);
z_mean = z_mean.detach().numpy()
z_var = z_var.detach().numpy()
x = x.detach().numpy()[:,0]

mu = z_mean[:, 0]
ale_dist = np.log( 1 + np.exp(z_mean[:, 1]))
epi_dist = z_var[:, 1]


plt.rcParams['figure.figsize'] = [30, 30]
_, ax = plt.subplots()
#plot_reg_with_pred(data.get_full_data(), x, mu, sigma_sq, ax)

#compare_ale(ax, x, epi_ens, epi_dist)
#plot_data = gaussian_sinus.GaussianSinus(
#    store_file=Path("None"), train=False, range_=(-3, 3))
plot_uncert(ax, unlabelled_data.get_full_data(), x, mean_mu=mu, ale=1*ale_dist, epi=1*epi_dist)


In [None]:
from src.loss import gaussian_neg_log_likelihood
x_test = torch.tensor((3.0)).reshape((1,1))
ens_output = prob_ensemble.get_logits(x_test)
print("Ens", ens_output.shape)
print("Ens mean", ens_output.mean(1))
z_mean, z_var = distilled_model.forward(x_test)
mu = z_mean[0, :].reshape((1,2))
cov = z_var[0, :].reshape((1,2))
print("mu", mu)
print("cov", cov)
with torch.no_grad():
    loss = gaussian_neg_log_likelihood((mu, cov), ens_output)
    print(loss)

In [None]:
start = -4
end = 4
step = 0.01
x_length = int((end - start) / step)
x = torch.arange(start=start,
                 end=end,
                 step=step, requires_grad=False).reshape((x_length, 1)).float()
ensemble_preds = distilled_model._generate_teacher_predictions(x)
ensemble_preds.shape

plt.plot(x, ensemble_preds[:, :, 0].detach().numpy())


In [None]:
cov_mat = torch.Size([1])
cov_mat = torch.eye(torch.Size([1]))
diff = torch.ones(torch.Size([2, 1]))


torch.matmul(diff, diff.T)
diff.shape