In [None]:
# sphinx ignore

import sys

sys.path.append("../..")

%config Completer.use_jedi = False

In [None]:
from gpytorch.mlls import VariationalELBO
from matplotlib import pyplot as plt

from vanguard.classification import CategoricalClassification
from vanguard.classification.likelihoods import MultitaskBernoulliLikelihood, SoftmaxLikelihood
from vanguard.datasets.classification import MulticlassGaussianClassificationDataset
from vanguard.kernels import ScaledRBFKernel
from vanguard.multitask import Multitask
from vanguard.vanilla import GaussianGPController
from vanguard.variational import VariationalInference

In [None]:
NUM_CLASSES = 4

DATASET = MulticlassGaussianClassificationDataset(
    num_train_points=100, num_test_points=500, num_classes=NUM_CLASSES, covariance_scale=1
)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot()
plt.show()

In [None]:
@CategoricalClassification(num_classes=NUM_CLASSES, ignore_all=True)
@Multitask(num_tasks=NUM_CLASSES, ignore_all=True)
@VariationalInference(ignore_all=True)
class CategoricalMultitaskClassifier(GaussianGPController):
    pass

In [None]:
controller = CategoricalMultitaskClassifier(
    DATASET.train_x,
    DATASET.one_hot_train_y,
    ScaledRBFKernel,
    y_std=0,
    likelihood_class=MultitaskBernoulliLikelihood,
    marginal_log_likelihood_class=VariationalELBO,
)

In [None]:
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
controller.fit(100)
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_confusion_matrix(predictions)
plt.show()

In [None]:
NUM_LATENTS = 10


@CategoricalClassification(num_classes=NUM_CLASSES, ignore_all=True)
@Multitask(num_tasks=NUM_CLASSES, lmc_dimension=NUM_LATENTS, ignore_all=True)
@VariationalInference(ignore_all=True)
class CategoricalMultitaskClassifier(GaussianGPController):
    pass

In [None]:
controller = CategoricalMultitaskClassifier(
    DATASET.train_x,
    DATASET.one_hot_train_y,
    ScaledRBFKernel,
    y_std=0,
    likelihood_class=MultitaskBernoulliLikelihood,
    marginal_log_likelihood_class=VariationalELBO,
)

In [None]:
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
controller.fit(100)
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_confusion_matrix(predictions)
plt.show()

In [None]:
NUM_LATENTS = 10
NUM_FEATURES = 6


@CategoricalClassification(num_classes=NUM_CLASSES, ignore_all=True)
@Multitask(num_tasks=NUM_FEATURES, lmc_dimension=NUM_LATENTS, ignore_all=True)
@VariationalInference(ignore_all=True)
class CategoricalSoftmaxMultitaskClassifier(GaussianGPController):
    pass

In [None]:
controller = CategoricalSoftmaxMultitaskClassifier(
    DATASET.train_x,
    DATASET.train_y,
    ScaledRBFKernel,
    y_std=0,
    likelihood_class=SoftmaxLikelihood,
    marginal_log_likelihood_class=VariationalELBO,
)

In [None]:
controller.fit(100)
predictions, probs = controller.classify_points(DATASET.test_x)

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_prediction(predictions)
plt.show()

In [None]:
plt.figure(figsize=(8, 8))
DATASET.plot_confusion_matrix(predictions)
plt.show()