In [None]:
# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# sphinx ignore

import sys

sys.path.append("../..")

In [None]:
from typing import Any, Callable, Iterable, List, Type, TypeVar, Union

import numpy as np
import torch
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.mlls import ExactMarginalLogLikelihood
from numpy.typing import ArrayLike, NDArray

from vanguard.base import GPController
from vanguard.decoratorutils import Decorator, process_args, wraps_class
from vanguard.optimise import SmartOptimiser
from vanguard.uncertainty import GaussianUncertaintyGPController

In [None]:
T = TypeVar("T")
SeedT = Union[ArrayLike, np.random.BitGenerator, None]

In [None]:
def is_py_file(file_path: str) -> bool:
    """
    Determine if a path points to a Python file.

    :param file_path: Path to query
    :return: :data:`True` if ``file_path`` has a Python extension
    """
    return str(file_path).endswith(".py")


is_py_file("foo.py"), is_py_file("bar.js")

In [None]:
# TODO: Put back in?
# # sphinx expect AttributeError
#
# is_py_file(42)

In [None]:
CallableStringT = TypeVar("CallableStringT", bound=Callable[[str, ...], Any])


def check_string(func: CallableStringT) -> CallableStringT:
    """Check that the input is a string."""

    def inner_function(*args: str) -> Any:
        for arg in args:
            if not isinstance(arg, str):
                raise TypeError("All inputs must be strings.")
        return func(*args)

    return inner_function

In [None]:
# sphinx expect TypeError


@check_string  # equivalent to: is_py_file = check_string(is_py_file)
def is_py_file(file_path: str) -> bool:
    """
    Determine if a path points to a Python file.

    :param file_path: Path to query
    :return: :data:`True` if ``file_path`` has a Python extension
    """
    return str(file_path).endswith(".py")


is_py_file("foo.py"), is_py_file("bar.js")

# TODO: Put back in?
# is_py_file(42)

In [None]:
CallableTT = TypeVar("CallableTT", bound=Callable[[T, ...], Any])


def check_type(t: Type[T]) -> Callable[[CallableTT], CallableTT]:
    """Check that the input is of a certain type."""

    def decorator(func: CallableTT) -> CallableTT:
        def inner_function(*args: T) -> Any:
            for arg in args:
                if not isinstance(arg, t):
                    raise TypeError(f"All inputs must be of type {t}.")
            return func(*args)

        return inner_function

    return decorator

In [None]:
# sphinx expect TypeError


@check_type(str)  # equivalent to: is_py_file = check_type(str)(is_py_file)
def is_py_file(file_path: str) -> bool:
    """
    Determine if a path points to a Python file.

    :param file_path: Path to query
    :return: :data:`True` if ``file_path`` has a Python extension
    """
    return str(file_path).endswith(".py")


is_py_file("foo.py"), is_py_file("bar.js")

# TODO: Put back in?
# is_py_file(42)

In [None]:
def consistent_shuffle(*arrays: NDArray[float], seed: SeedT = None) -> List[NDArray[float]]:
    """Shuffle all arrays into the same order, to maintain consistency."""
    rng = np.random.RandomState(seed=seed)
    indices = np.arange(len(arrays[0]))
    rng.shuffle(indices)

    shuffled_arrays = [array[indices] for array in arrays]
    return shuffled_arrays

In [None]:
x = np.array([1, 2, 3, 4, 5])
y = np.array([1, 4, 9, 16, 25])

In [None]:
consistent_shuffle(x, y, seed=1)

In [None]:
process_args(
    GPController.__init__,
    None,
    x,
    y,
    RBFKernel,
    mean_class=ConstantMean,
    y_std=0.1,
    likelihood_class=FixedNoiseGaussianLikelihood,
    marginal_log_likelihood_class=ExactMarginalLogLikelihood,
    optimiser_class=torch.optim.Adam,
    smart_optimiser_class=SmartOptimiser,
)

In [None]:
class ShuffleDecorator(Decorator):
    """Shuffles input data."""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(framework_class=GPController, required_decorators={}, **kwargs)

    def _decorate_class(self, cls: Type[T]) -> Type[T]:
        class InnerClass(cls):
            """An inner class."""

            def __init__(self, *args: Any, **kwargs: Any) -> None:
                all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)

                old_train_x = all_parameters_as_kwargs.pop("train_x")
                old_train_y = all_parameters_as_kwargs.pop("train_y")
                old_y_std = all_parameters_as_kwargs.pop("y_std")  # pop to avoid duplication

                if isinstance(old_y_std, (float, int)):
                    old_y_std = np.ones_like(old_train_x) * old_y_std

                new_train_x, new_train_y, new_y_std = consistent_shuffle(old_train_x, old_train_y, old_y_std)

                super().__init__(train_x=new_train_x, train_y=new_train_y, y_std=new_y_std, **all_parameters_as_kwargs)

        return InnerClass

In [None]:
ShuffledGPController = ShuffleDecorator()(GPController)


@ShuffleDecorator()
class ShuffledGPController(GPController):
    """Shuffles inputs to the controller."""

    pass

In [None]:
print(ShuffledGPController.__name__)
print(ShuffledGPController.__doc__)

In [None]:
class ShuffleDecorator(Decorator):
    """Shuffles input data."""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(framework_class=GPController, required_decorators={}, **kwargs)

    def _decorate_class(self, cls: Type[T]) -> Type[T]:
        @wraps_class(cls)
        class InnerClass(cls):
            """An inner class."""

            def __init__(self, *args: Any, **kwargs: Any) -> None:
                all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)

                old_train_x = all_parameters_as_kwargs.pop("train_x")
                old_train_y = all_parameters_as_kwargs.pop("train_y")
                old_y_std = all_parameters_as_kwargs.pop("y_std")  # pop to avoid duplication

                if isinstance(old_y_std, (float, int)):
                    old_y_std = np.ones_like(old_train_x) * old_y_std

                new_train_x, new_train_y, new_y_std = consistent_shuffle(old_train_x, old_train_y, old_y_std)

                super().__init__(train_x=new_train_x, train_y=new_train_y, y_std=new_y_std, **all_parameters_as_kwargs)

        return InnerClass

In [None]:
@ShuffleDecorator()
class ShuffledGPController(GPController):
    """Shuffles inputs to the controller."""

    pass


print(ShuffledGPController.__name__)
print(ShuffledGPController.__doc__)

In [None]:
class ShuffleDecorator(Decorator):
    """Shuffles input data."""

    def __init__(self, seed: SeedT = None, **kwargs: Any) -> None:
        super().__init__(framework_class=GPController, required_decorators={}, **kwargs)
        self.seed = seed

    def _decorate_class(self, cls: Type[T]) -> Type[T]:
        seed = self.seed

        @wraps_class(cls)
        class InnerClass(cls):
            """An inner class."""

            def __init__(self, *args: Any, **kwargs: Any) -> None:
                all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)

                old_train_x = all_parameters_as_kwargs.pop("train_x")
                old_train_y = all_parameters_as_kwargs.pop("train_y")
                old_y_std = all_parameters_as_kwargs.pop("y_std")  # pop to avoid duplication

                if isinstance(old_y_std, (float, int)):
                    old_y_std = np.ones_like(old_train_x) * old_y_std

                new_train_x, new_train_y, new_y_std = consistent_shuffle(old_train_x, old_train_y, old_y_std, seed=seed)

                super().__init__(train_x=new_train_x, train_y=new_train_y, y_std=new_y_std, **all_parameters_as_kwargs)

        return InnerClass

In [None]:
@ShuffleDecorator()
class ShuffledGaussianUncertaintyGPController(GaussianUncertaintyGPController):
    """Shuffles inputs to the controller."""

    pass

In [None]:
@ShuffleDecorator(
    ignore_methods={
        "predict_at_point",
        "_get_additive_grad_noise",
        "_noise_transform",
        "_append_constant_to_infinite_generator",
    }
)
class ShuffledGaussianUncertaintyGPController(GaussianUncertaintyGPController):  # noqa: F811
    """Shuffles inputs to the controller."""

    pass

In [None]:
class ShuffleDecorator(Decorator):
    """Shuffles input data."""

    def __init__(self, seed: SeedT = None, additional_params_to_shuffle: Iterable[str] = (), **kwargs: Any) -> None:
        if additional_params_to_shuffle:
            kwargs["ignore_methods"] = set(kwargs["ignore_methods"]) | {"__init__"}

        super().__init__(framework_class=GPController, required_decorators={}, **kwargs)

        self.seed = seed
        self.params_to_shuffle = set.union({"train_x", "train_y", "y_std"}, set(additional_params_to_shuffle))

    def _decorate_class(self, cls: Type[T]) -> Type[T]:
        seed = self.seed
        params_to_shuffle = self.params_to_shuffle

        @wraps_class(cls)
        class InnerClass(cls):
            """An inner class."""

            def __init__(self, *args: Any, **kwargs: Any) -> None:
                all_parameters_as_kwargs = process_args(super().__init__, *args, **kwargs)

                array_for_reference = all_parameters_as_kwargs["train_x"]

                pre_shuffled_args = [all_parameters_as_kwargs.pop(param) for param in params_to_shuffle]
                pre_shuffled_args_as_arrays = [
                    np.ones_like(array_for_reference) * arg if isinstance(arg, (float, int)) else arg
                    for arg in pre_shuffled_args
                ]
                shuffled_args = consistent_shuffle(*pre_shuffled_args_as_arrays, seed=seed)

                shuffled_params_as_kwargs = dict(zip(params_to_shuffle, shuffled_args))

                super().__init__(**shuffled_params_as_kwargs, **all_parameters_as_kwargs)

        return InnerClass

In [None]:
ignore_methods = {
    "_get_posterior_over_fuzzy_point_in_eval_mode",
    "__init__",
    "_sgd_round",
    "_process_x_std",
    "_set_requires_grad",
    "predict_at_point",
    "_get_additive_grad_noise",
    "_noise_transform",
    "_append_constant_to_infinite_generator",
}


@ShuffleDecorator(seed=1, additional_params_to_shuffle={"train_x_std"}, ignore_methods=ignore_methods)
class ShuffledGaussianUncertaintyGPController(GaussianUncertaintyGPController):  # noqa: F811
    """Shuffles inputs to the controller."""

    pass

In [None]:
train_x = np.array([1, 2, 3, 4, 5])
train_x_std = np.array([0.01, 0.02, 0.03, 0.04, 0.05])
train_y = np.array([1, 4, 9, 16, 25])
y_std = np.array([0.02, 0.04, 0.06, 0.08, 0.1])

In [None]:
controller = ShuffledGaussianUncertaintyGPController(
    train_x,
    train_x_std,
    train_y,
    y_std,
    kernel_class=RBFKernel,
    mean_class=ConstantMean,
    likelihood_class=FixedNoiseGaussianLikelihood,
    marginal_log_likelihood_class=ExactMarginalLogLikelihood,
    optimiser_class=torch.optim.Adam,
)

In [None]:
print(controller.train_x.T)
print(controller.train_x_std.T)
print(controller.train_y.T)
print(controller._y_variance.T)