In [None]:
# sphinx ignore

import sys

sys.path.append("../..")

%load_ext autoreload
%autoreload 2

In [None]:
# sphinx ignore

import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="gpytorch.utils.linear_cg", lineno=234)
warnings.filterwarnings("ignore", category=UserWarning, module="gpytorch.utils.linear_cg", lineno=266)

In [None]:
import time
from contextlib import contextmanager

import matplotlib.pyplot as plt
from gpytorch.kernels import MaternKernel, ScaleKernel

from vanguard.datasets.bike import BikeDataset
from vanguard.distribute import Distributed, aggregators, partitioners
from vanguard.vanilla import GaussianGPController
from vanguard.warps import SetWarp, warpfunctions

In [None]:
DATASET = BikeDataset(n_samples=5000, training_proportion=0.9, noise_scale=0.01)

In [None]:
plt.figure(figsize=(10, 5))
DATASET.plot_y()
plt.show()

In [None]:
class ScaledMaternKernel(ScaleKernel):
    """A scaled Matern kernel."""

    def __init__(self):
        super().__init__(MaternKernel(nu=1.5, ard_num_dims=DATASET.train_x.shape[1]))

In [None]:
gp = GaussianGPController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=DATASET.train_y_std,
    likelihood_kwargs={"learn_additional_noise": True},
)

In [None]:
@contextmanager
def timer():
    start = time.time()
    yield
    end = time.time()
    print(f"Time taken: {end-start:.3f}s")

In [None]:
with timer():
    loss = gp.fit(n_sgd_iters=100)

In [None]:
posterior = gp.posterior_over_point(DATASET.test_x)

plt.figure(figsize=(10, 5))
DATASET.plot_prediction(*posterior.confidence_interval())
plt.show()

In [None]:
warp = warpfunctions.AffineWarpFunction(a=3, b=-1) @ warpfunctions.BoxCoxWarpFunction(0.2)


@SetWarp(warp, ignore_all=True)
class WarpedGPController(GaussianGPController):
    pass

In [None]:
gp = WarpedGPController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=DATASET.train_y_std,
    likelihood_kwargs={"learn_additional_noise": True},
)

In [None]:
with timer():
    loss = gp.fit(n_sgd_iters=100)

In [None]:
posterior = gp.posterior_over_point(DATASET.test_x)

plt.figure(figsize=(10, 5))
DATASET.plot_prediction(*posterior.confidence_interval())
plt.show()

In [None]:
partitioner = partitioners.KMeansPartitioner(DATASET.train_x, n_experts=5)
partition = partitioner.create_partition()

In [None]:
plt.figure(figsize=(8, 8))
partitioner.plot_partition(partition)
plt.show()

In [None]:
N_EXPERTS = 4

In [None]:
@Distributed(
    n_experts=N_EXPERTS,
    subset_fraction=1 / N_EXPERTS,
    aggregator_class=aggregators.XGRBCMAggregator,
    partitioner_class=partitioners.KMeansPartitioner,
    ignore_methods=("__init__",),
)
class DistributedGPController(GaussianGPController):
    pass

In [None]:
gp = DistributedGPController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=DATASET.train_y_std[0],
    likelihood_kwargs={"learn_additional_noise": True},
)

In [None]:
with timer():
    loss = gp.fit(n_sgd_iters=100)

In [None]:
posterior = gp.posterior_over_point(DATASET.test_x)

plt.figure(figsize=(10, 5))
DATASET.plot_prediction(*posterior.confidence_interval())
plt.show()

In [None]:
warp = warpfunctions.AffineWarpFunction(a=3, b=-1) @ warpfunctions.BoxCoxWarpFunction(0.2)


@Distributed(
    n_experts=N_EXPERTS,
    subset_fraction=1 / N_EXPERTS,
    aggregator_class=aggregators.XGRBCMAggregator,
    partitioner_class=partitioners.KMeansPartitioner,
    ignore_all=True,
)
@SetWarp(warp, ignore_all=True)
class WarpedDistributedGPController(GaussianGPController):
    pass

In [None]:
gp = WarpedDistributedGPController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=DATASET.train_y_std[0],
    likelihood_kwargs={"learn_additional_noise": True},
)

In [None]:
with timer():
    loss = gp.fit(n_sgd_iters=100)

In [None]:
posterior = gp.posterior_over_point(DATASET.test_x)

plt.figure(figsize=(10, 5))
DATASET.plot_prediction(*posterior.confidence_interval())
plt.show()