In [None]:
%load_ext autoreload
%autoreload 2
from copy import deepcopy
import os
os.chdir("/home/jakob/doktor/projects/EnsembleUncertainty/code")
"""Learing "logit" distribution in regression example"""
from pathlib import Path
from datetime import datetime
import logging
import numpy as np
import matplotlib.pyplot as plt
import torch

from src.dataloaders import one_dim_regression
import src.utils as utils
from src.distilled import logits_probability_distribution
from src.ensemble import ensemble
from src.ensemble import simple_regressor
import src.metrics as metrics

LOGGER = logging.getLogger(__name__)
EXPERIMENT_NAME = "regression_logits"

# Settings
class Args():
    pass
args = Args()
args.seed = 1
args.gpu = False
args.log_dir = Path("./logs")
args.log_level = logging.DEBUG
args.retrain = True

args.num_ensemble_members=10
args.num_epochs=10
args.lr = 0.01


In [None]:
def make_plots(distilled_model, data):
    test_loader = torch.utils.data.DataLoader(data,
                                              batch_size=16,
                                              shuffle=True,
                                              num_workers=0)

    predictions = np.zeros((data.n_samples, distilled_model.output_size))
    all_x = np.zeros((data.n_samples, 1))
    all_y = np.zeros((data.n_samples, 1))

    idx = 0
    for batch in test_loader:
        inputs, targets = batch

        predictions[idx * test_loader.batch_size:(idx + 1) * test_loader.batch_size, :, :] = \
            distilled_model.predict(inputs, t=None).data.numpy()

        all_x[idx * test_loader.batch_size:(idx + 1) *
              test_loader.batch_size, :] = inputs
        all_y[idx * test_loader.batch_size:(idx + 1) *
              test_loader.batch_size, :] = targets

        idx += 1

    plt.scatter(np.squeeze(all_x), np.squeeze(all_y), label="Data", marker=".")

    plt.errorbar(np.squeeze(all_x),
                 predictions[:, 0],
                 np.sqrt(predictions[:, 1]),
                 label="Distilled model predictions",
                 marker=".",
                 ls="none")

    plt.legend()
    plt.show()

In [None]:
log_file = Path("{}_{}.log".format(
    EXPERIMENT_NAME,
    datetime.now().strftime("%Y%m%d_%H%M%S")))
utils.setup_logger(log_path=Path.cwd() / args.log_dir / log_file,
                   log_level=args.log_level)
LOGGER.info("Args: {}".format(args))
device = utils.torch_settings(args.seed, args.gpu)
LOGGER.info("Creating dataloader")
data = one_dim_regression.SyntheticRegressionData(
    store_file=Path("data/2d_gaussian_1000"))

input_size = 1
hidden_size = 5
ensemble_output_size = 2
train_loader = torch.utils.data.DataLoader(data,
                                           batch_size=6,
                                           shuffle=True,
                                           num_workers=0)

prob_ensemble = ensemble.Ensemble(ensemble_output_size)
for _ in range(args.num_ensemble_members):
    model = simple_regressor.SimpleRegressor(input_size,
                                             hidden_size,
                                             hidden_size,
                                             ensemble_output_size,
                                             device=device,
                                             learning_rate=args.lr)
    prob_ensemble.add_member(model)
squared_error_metric = metrics.Metric(name="Squared error",
                                      function=metrics.squared_error)
prob_ensemble.add_metrics([squared_error_metric])
prob_ensemble.train(train_loader, args.num_epochs)

In [None]:
start = -5
end = 6
step = 0.1
x_length = int((end - start) / step)
x = torch.arange(start=start,
                 end=end,
                 step=step).reshape((x_length, 1)).float()
output = prob_ensemble.predict(x)
x = x.detach().numpy()[:,0]

In [None]:
def plot_uncert(data, x, mean_mu, ale, epi, ax):
    inputs = data[:, :-1]
    targets = data[:, -1]
    ax.scatter(inputs, targets)
    ax.plot(x, mean_mu, "r-", label="$\mu_{avg}(x)$")
    ax.plot(x, 0.2*ale, "g-", label="$E_w[\sigma_w^2(x)]$")
    ax.plot(x, epi, "b-", label="var$_w(\mu_w(x))$")
    plt.legend(prop={'size': 40})
    plt.show()

plt.rcParams['figure.figsize'] = [30, 30]

mu = output[:, :, 0]
var = output[:, :, 1]
ale, epi = metrics.uncertainty_separation_parametric(mu, var)
mean_mu = torch.mean(mu, dim=1).detach().numpy()
ale = ale.detach().numpy()
epi = epi.detach().numpy()
_, ax = plt.subplots()
plot_uncert(data.get_full_data(), x, mean_mu, ale, epi, ax)


In [None]:

def plot_reg_with_pred(data, x, pred_mean, pred_var, ax):
    inputs = data[:, :-1]
    targets = data[:, -1]
    ax.scatter(inputs, targets)
    ax.plot(x, means, 'r-')
    every_nth = 10
    ax.errorbar(x, means,
                np.sqrt(var),
                errorevery=every_nth,
                color="r")
    plt.show()

plt.rcParams['figure.figsize'] = [30, 30]
member = 9
means = output.detach().numpy()[:, member, 0]
var = output.detach().numpy()[:, member, 1]
_, ax = plt.subplots()
plot_reg_with_pred(data.get_full_data(), x, means, var, ax)


In [None]:
unlabelled_data = one_dim_regression.SyntheticRegressionData(
    store_file=Path("data/2d_gaussian_1000"), train=False)
unlabelled_loader = torch.utils.data.DataLoader(unlabelled_data,
                                           batch_size=6,
                                           shuffle=True,
                                           num_workers=0)

distilled_output_size = ensemble_output_size * 2
distilled_model = logits_probability_distribution.LogitsProbabilityDistribution(
    input_size,
    hidden_size,
    hidden_size,
    distilled_output_size,
    teacher=prob_ensemble,
    device=device,
    learning_rate=args.lr)
distilled_model.train(unlabelled_loader, 100)

#distilled_model._train_epoch(unlabelled_loader)

In [None]:
data = one_dim_regression.SyntheticRegressionData(
    store_file=Path("data/2d_gaussian_1000"))

start = -6
end = 4
step = 0.1
x_length = int((end - start) / step)
x = torch.arange(start=start,
                 end=end,
                 step=step, requires_grad=False).reshape((x_length, 1)).float()

z_mean, z_var = distilled_model.forward(x);
x = x.detach().numpy()[:,0]

mu = z_mean[:, 0].detach().numpy()
sigma_sq = torch.exp(z_mean[:, 1]).detach().numpy()

epi_var = z_var[:, 0].detach().numpy()

def plot_reg_with_pred(data, x, pred_mean, pred_var, ax):
    inputs = data[:, :-1]
    targets = data[:, -1]
    ax.scatter(inputs, targets)
    ax.plot(x, pred_mean, 'r-')
    every_nth = 10
    ax.errorbar(x, pred_mean,
                np.sqrt(pred_var),
                errorevery=every_nth,
                color="r")
    plt.show()

def plot_uncert(ax, data, x, mean_mu=None, ale=None, epi=None):
    inputs = data[:, :-1]
    targets = data[:, -1]
    ax.scatter(inputs, targets)
    if mean_mu is not None:
        ax.plot(x, mean_mu, "r-", label="$\mu_{avg}(x)$")
    if ale is not None:
        ax.plot(x, ale, "g-", label="$E_w[\sigma_w^2(x)]$")
    if epi is not None:
        ax.plot(x, epi, "b-", label="var$_w(\mu_w(x))$")
    plt.legend(prop={'size': 40})
    plt.show()
    
plt.rcParams['figure.figsize'] = [30, 30]
_, ax = plt.subplots()
#plot_reg_with_pred(data.get_full_data(), x, mu, sigma_sq, ax)

plot_uncert(ax, data.get_full_data(), x, mean_mu=mu, ale=0.1*sigma_sq, epi=epi_var)


In [None]:
start = -4
end = 4
step = 0.01
x_length = int((end - start) / step)
x = torch.arange(start=start,
                 end=end,
                 step=step, requires_grad=False).reshape((x_length, 1)).float()
ensemble_preds = distilled_model._generate_teacher_predictions(x)
ensemble_preds.shape

plt.plot(x, ensemble_preds[:, :, 0].detach().numpy())


In [None]:
distilled_model.train(unlabelled_loader, 300)


In [None]:
data.train

In [None]:
cov_mat = torch.Size([1])
cov_mat = torch.eye(torch.Size([1]))
diff = torch.ones(torch.Size([2, 1]))


torch.matmul(diff, diff.T)
diff.shape