In [None]:
# sphinx ignore

import sys

sys.path.append("../..")

%config Completer.use_jedi = False
random_seed = 1_989

In [None]:
import numpy as np
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ZeroMean
from matplotlib import pyplot as plt

from vanguard.classification import DirichletMulticlassClassification
from vanguard.classification.kernel import DirichletKernelMulticlassClassification
from vanguard.classification.likelihoods import (
    DirichletKernelClassifierLikelihood,
    GenericExactMarginalLogLikelihood,
)
from vanguard.datasets.classification import MulticlassGaussianClassificationDataset
from vanguard.kernels import ScaledRBFKernel
from vanguard.vanilla import GaussianGPController

In [None]:
alpha_i = 0.86858729
sigma_squared_i = np.log(1 / alpha_i + 1)
mu_i = np.log(alpha_i - sigma_squared_i / 2)

n_samples = 10_000

random_generator = np.random.Generator(np.random.PCG64(seed=random_seed))
gamma_samples = random_generator.gamma(shape=alpha_i, scale=1.0, size=n_samples)
lognormal_samples = random_generator.lognormal(mean=mu_i, sigma=np.sqrt(sigma_squared_i), size=n_samples)

plt.figure(figsize=(10, 5))
n_bins = 150
plt.hist(gamma_samples, bins=n_bins, density=True, alpha=0.6, label="gamma")
plt.hist(lognormal_samples, bins=n_bins, density=True, alpha=0.6, label="lognormal")
plt.xlim(right=8)
plt.legend()
plt.show()

In [None]:
NUM_CLASSES = 4

DATASET = MulticlassGaussianClassificationDataset(
    num_train_points=100,
    num_test_points=500,
    num_classes=NUM_CLASSES,
    covariance_scale=1,
    rng=np.random.default_rng(random_seed),
)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot()
plt.show()

In [None]:
@DirichletMulticlassClassification(num_classes=NUM_CLASSES, ignore_methods=("__init__",))
class MulticlassGaussianClassifier(GaussianGPController):
    pass

In [None]:
controller = MulticlassGaussianClassifier(
    DATASET.train_x,
    DATASET.train_y,
    ScaledRBFKernel,
    y_std=0,
    mean_class=ZeroMean,
    likelihood_class=DirichletClassificationLikelihood,
    mean_kwargs={"batch_shape": (NUM_CLASSES,)},
    kernel_kwargs={"batch_shape": (NUM_CLASSES,)},
    likelihood_kwargs={"alpha_epsilon": 0.3, "learn_additional_noise": True},
    optim_kwargs={"lr": 0.05},
    rng=np.random.default_rng(random_seed),
)

In [None]:
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
controller.fit(100)
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
@DirichletKernelMulticlassClassification(num_classes=NUM_CLASSES, ignore_methods=("__init__",))
class MulticlassGaussianClassifier(GaussianGPController):
    pass

In [None]:
controller = MulticlassGaussianClassifier(
    DATASET.train_x,
    DATASET.train_y,
    kernel_class=ScaledRBFKernel,
    y_std=0,
    mean_class=ZeroMean,
    likelihood_class=DirichletKernelClassifierLikelihood,
    likelihood_kwargs={"learn_alpha": False, "alpha": 5},
    marginal_log_likelihood_class=GenericExactMarginalLogLikelihood,
    optim_kwargs={"lr": 0.1, "early_stop_patience": 5},
    rng=np.random.default_rng(random_seed),
)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
with controller.metrics_tracker.print_metrics(every=25):
    controller.fit(100)

In [None]:
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
controller = MulticlassGaussianClassifier(
    DATASET.train_x,
    DATASET.train_y,
    kernel_class=ScaledRBFKernel,
    y_std=0,
    mean_class=ZeroMean,
    likelihood_class=DirichletKernelClassifierLikelihood,
    likelihood_kwargs={"learn_alpha": True, "alpha": 5},
    marginal_log_likelihood_class=GenericExactMarginalLogLikelihood,
    optim_kwargs={"lr": 0.1, "early_stop_patience": 5},
    rng=np.random.default_rng(random_seed),
)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
with controller.metrics_tracker.print_metrics(every=25):
    controller.fit(100)

In [None]:
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_confusion_matrix(predictions)
plt.show()