# Bootstrapped regression

Reproduces the bootstrapped regression experiments using neural networks from [**Bootstrapped DQN**](https://papers.nips.cc/paper/6501-deep-exploration-via-bootstrapped-dqn.pdf).

In [1]:
import numpy as np
import torch
from torch.distributions import Uniform, Normal
import pandas as pd

import altair as alt
alt.renderers.enable('notebook')

import time
from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure
from bokeh.models import Band, ColumnDataSource
output_notebook()

In [2]:
def f(x, α=4, β=13, μ=0, σ=0.03, noisy=True):
    w = 0
    if noisy:
        w = Normal(μ, σ).sample((x.shape[0],))
    return x + torch.sin(α * (x + w)) + torch.sin(β * (x + w)) + w

In [3]:
def generate_dataset():
    columns = ['x', 'y', 'series']
    data = pd.DataFrame(columns=columns)
    
    # the function that generates the data
    x = torch.linspace(-1, 2, 100)
    fx = f(x, noisy=False)
    series_name = ['f(x)' for _ in range(100)]
    
    true_data = pd.DataFrame(list(zip(x.numpy(), fx.numpy(), series_name)), columns=columns)
    data = data.append(true_data, ignore_index=False)

    # the training dataset
    x = torch.cat((Uniform(0, 0.6).sample((24,)), Uniform(0.8, 1).sample((11,))))
    t = f(x, noisy=True)
    series_name = ['obs' for _ in range(x.shape[0])]
    df = pd.DataFrame(list(zip(x.numpy(), t.numpy(), series_name)), columns=columns)
    data = data.append(df, ignore_index=True)
    
    # the test dataset
    x = torch.cat((Uniform(0, 0.6).sample((24,)), Uniform(0.9, 1.1).sample((11,))))
    t = f(x, noisy=True)
    series_name = ['test' for _ in range(x.shape[0])]
    df = pd.DataFrame(list(zip(x.numpy(), t.numpy(), series_name)), columns=columns)
    data = data.append(df, ignore_index=True)
    return data

In [4]:
data = generate_dataset()

In [215]:
true_data = data.loc[data['series'] == 'f(x)']
obs_data = data.loc[data['series'] == 'obs']
tst_data = data.loc[data['series'] == 'test']


data_plt = figure(title="Data",
                  plot_height=400,
                  plot_width=900,
                  background_fill_color='#ffffff')

fx = data_plt.line(x='x', y='y', legend='series', source=true_data,
                   color="mediumspringgreen", line_width=3, alpha=0.8)
tst = data_plt.scatter(x='x', y='y', legend='series', source=tst_data,
                       size=7, fill_color="indianred", alpha=0.8, line_color='indianred')
obs = data_plt.scatter(x='x', y='y', legend='series', source=obs_data,
                       size=10, fill_color="slateblue", alpha=0.8, line_color='slateblue')
data_plt.legend.location = "top_left"

In [216]:
show(data_plt, notebook_handle=True)

## Create a custom dataset that bootstraps the data

In [75]:
from torch.utils.data import DataLoader, Dataset

In [76]:
class BootstrappedLoader(Dataset):
    def __init__(self, features, targets, B=10):
        # Generate B datasets by sampling with replacement
        # from the original data
        self.__data_sz = N = features.shape[0]
        masks = [np.random.randint(N, size=(N,)) for _ in range(B)]
        self.__boot_dsets = [(features[m], targets[m]) for m in masks]
    
    def __getitem__(self, idx):
        # Return an item from each bootstrapped dataset
        samples = [(boot_ds[0][idx], boot_ds[1][idx]) for boot_ds in self.__boot_dsets]
        return samples
    
    def __len__(self):
        return self.__data_sz

In [77]:
def _collate(samples):
    features_batches = [torch.tensor([[el[0]] for el in batch]) for batch in zip(*samples)]
    target_batches = [torch.tensor([[el[1]] for el in batch]) for batch in zip(*samples)]
    
    return zip(features_batches, target_batches)

## Create ensemble and train it

In [166]:
from torch import nn
from torch import optim
import torch.nn.functional as F

In [197]:
B=50  # no of bootsrapped samples (no of datasets and no of models)
batch_sz = 10
epochs = 6000

train_data = data.loc[data['series'] == 'obs']
test_data = data.loc[data['series'] == 'test']

data.loc[data['series'].isin(['obs', 'test'])].sample(5)

Unnamed: 0,x,y,series
152,0.140042,1.48733,test
125,0.982207,1.029302,obs
151,0.432964,0.534979,test
129,0.901366,-0.109248,obs
144,0.112636,1.182059,test


In [198]:
ensemble = [nn.Sequential(nn.Linear(1,20, bias=True), nn.ReLU(inplace=True), nn.Linear(20,1, bias=True))
            for _ in range(B)]

optims = [optim.SGD(model.parameters(), lr=0.1) for model in ensemble]

boot_ds = BootstrappedLoader(features=train_data['x'].values, targets=train_data['y'].values, B=B)
train_loader = DataLoader(boot_ds, batch_size=batch_sz, collate_fn=_collate)

In [199]:
def train(train_loader, test_data, ensemble, optims, epochs=100, verbose=True, update_plot=False, plt_data=None):
    # do some training
    if update_plot:
        x_axis, y_axis, std_axis, std_var_axis = plt_data
    
    
    x_tst = torch.from_numpy(test_data['x'].values).unsqueeze(1).float()
    t_tst = torch.from_numpy(test_data['y'].values).unsqueeze(1).float()

    for epoch in range(epochs):
        for bidx, boot_batches in enumerate(train_loader):
            for idx, ((x, t), model, sgd) in enumerate(zip(boot_batches, ensemble, optims)):
                sgd.zero_grad()
            
                y = model(x)
                loss = F.mse_loss(y, t)
                loss.backward()
                sgd.step()

        with torch.no_grad():
            predictions = [model(x_tst) for model in ensemble]    
    
            y = torch.cat(predictions, 1).mean(1).unsqueeze(1)
            ensemble_loss = F.mse_loss(y, t_tst).item()
            
            losses = [F.mse_loss(y, t_tst).item() for y in predictions]
            std = torch.cat(predictions, 1).std(1).mean().item()
            std_var = torch.cat(predictions, 1).std(1).std().item()
        
        if verbose:
            print(f'epoch {epoch:02d} done, testing...')
            print(f'  | loss={ensemble_loss:2.3f}',
                  f'  |  mean_loss={np.mean(losses):2.3f}, std_loss={np.std(losses):2.3f}')
        
        if update_plot:
            x_axis = np.append(x_axis, [epoch])
            y_axis = np.append(y_axis, [ensemble_loss])
            std_axis = np.append(std_axis, [std])
            std_var_axis = np.append(std_var_axis, [std_var])
            update_plot(x_axis, y_axis, std_axis, std_var_axis)
    if update_plot:
        return [x_axis, y_axis, std_axis, std_var_axis]

## Visualize the training process

In [205]:
x = np.array([])
y = np.array([])
std = np.array([])
std_var = np.array([])
plt_data = [x, y, std, std_var]

train_plt = figure(title="Ensemble Loss / Mean Model Uncertainty",
                   plot_height=400,
                   plot_width=800, 
                   x_range=(0, epochs),
                   background_fill_color='whitesmoke')

std_plt = train_plt.line(x, std, color="dodgerblue", line_width=2.5, alpha=0.9, legend='test uncertainty')
std_var_plt = train_plt.line(x, std_var, color="MediumPurple", line_width=2.5, alpha=0.9, legend='var uncertainty')
loss_plt = train_plt.line(x, y, color="orangered", line_width=2.5, alpha=0.9, legend='test loss')
train_plt.legend.location = "top_left"


def update(x, y, std, std_var):
    loss_plt.data_source.data.update(x=x, y=y)
    std_plt.data_source.data.update(x=x, y=std)
    std_var_plt.data_source.data.update(x=x, y=std_var)
    push_notebook()

In [206]:
show(train_plt, notebook_handle=True)

In [202]:
plt_data = train(train_loader, test_data, ensemble, optims, epochs=epochs,
                 verbose=False, update_plot=update, plt_data=plt_data)

In [207]:
def get_predictions(ensemble, x):
    with torch.no_grad():
        predictions = [model(x) for model in ensemble]
        predictions = torch.cat(predictions, 1)
        return predictions.mean(1), predictions.std(1)

In [208]:
x_axis = torch.arange(-1, 2.01, 1/33).unsqueeze(1).float()
y_axis, std = get_predictions(ensemble, x_axis)

x_axis = x_axis.squeeze().numpy()
y_axis = y_axis.squeeze().numpy()
std = std.squeeze().numpy()

x_obs = obs_data['x'].values
y_obs = obs_data['y'].values

x_tst = tst_data['x'].values
y_tst = tst_data['y'].values

df = pd.DataFrame(data=dict(x=x_axis, lower=y_axis-std, upper=y_axis+std)).sort_values(by="x")
source = ColumnDataSource(df.reset_index())

data_plt = figure(title="Data",
                  plot_height=400,
                  plot_width=900,
                  y_range=(-1, 5),
                  x_range=(-1, 2),
                  background_fill_color='#ffffff')

fx = data_plt.line(x_axis, y_axis, color="mediumspringgreen", line_width=3, alpha=0.8, legend='model(x)')
obs = data_plt.scatter(x_obs, y_obs,
                       size=10, fill_color="slateblue", line_color="slateblue", alpha=0.8, legend='observed')
tst = data_plt.scatter(x_tst, y_tst,
                       size=7, fill_color="indianred", line_color="indianred", alpha=0.8, legend='test data')
band = Band(base='x', lower='lower', upper='upper', source=source, level='underlay',
            fill_alpha=0.3, fill_color='MediumAquamarine', line_width=0, line_color='#ffffff')
data_plt.add_layout(band)

In [209]:
show(data_plt, notebook_handle=True)