# Revisiting the idea of obtaining MVaR derivatives via component-wise mapping.

First, we need to re-implement the approximate derivatives. Then, we can compare
it with the finite-difference estimates.

In [107]:
from typing import List, Optional

import torch
from botorch.acquisition.multi_objective.multi_output_risk_measures import MVaR
from torch import Tensor


class DiffMVaR(MVaR):

    def make_diffable(self, prepared_samples: Tensor, mvar_set: List[Tensor]) -> List[Tensor]:
        r"""An experimental approach for obtaining the gradient of the MVaR via
        component-wise mapping to original samples.

        Args:
            prepared_samples: A `(sample_shape * batch_shape * q) x n_w x m`-dim tensor of
                posterior samples. The q-batches should be ordered so that each
                `n_w` block of samples correspond to the same input.
            mvar_set: A `(sample_shape * batch_shape * q)` list of `k x m`-dim tensor
                of MVaR values, where `k` is varies depending on the particular batch.

        Returns:
            The same `mvar_set` with entries mapped to inputs to produce gradients.
        """
        for batch_idx in range(prepared_samples.shape[0]):
            base_samples = prepared_samples[batch_idx]
            mvars = mvar_set[batch_idx]
            equal_check = mvars.unsqueeze(-2) == base_samples
            new_mvars_list = []
            for check in equal_check:
                p1 = base_samples[check[:, 0], 0].mean()
                p2 = base_samples[check[:, 1], 1].mean()
                new_mvars_list.append(torch.stack([p1, p2]))
            mvar_set[batch_idx] = torch.stack(new_mvars_list)
        return mvar_set


    def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
        r"""Calculate the MVaR corresponding to the given samples.

        Args:
            samples: A `sample_shape x batch_shape x (q * n_w) x m`-dim tensor of
                posterior samples. The q-batches should be ordered so that each
                `n_w` block of samples correspond to the same input.
            X: A `batch_shape x q x d`-dim tensor of inputs. Ignored.

        Returns:
            A `sample_shape x batch_shape x q x m`-dim tensor of MVaR values,
            if `self.expectation=True`.
            Otherwise, this returns a `sample_shape x batch_shape x (q * k') x m`-dim
            tensor, where `k'` is the maximum `k` across all batches that is returned
            by `get_mvar_set_...`. Each `(q * k') x m` corresponds to the `k` MVaR
            values for each `q` batch of `n_w` inputs, padded up to `k'` by repeating
            the last element. If `self.pad_to_n_w`, we set `k' = self.n_w`, producing
            a deterministic return shape.
        """
        batch_shape, m = samples.shape[:-2], samples.shape[-1]
        prepared_samples = self._prepare_samples(samples)
        # This is -1 x n_w x m.
        prepared_samples = prepared_samples.reshape(-1, *prepared_samples.shape[-2:])
        # Get the mvar set using the appropriate method based on device, m & n_w.
        # NOTE: The `n_w <= 64` part is based on testing on a 24 core CPU.
        # `get_mvar_set_gpu` heavily relies on parallelized batch computations and
        # may scale worse on CPUs with fewer cores.
        # Using `no_grad` here since `MVaR` is not differentiable.
        with torch.no_grad():
            if (
                samples.device == torch.device("cpu")
                and m == 2
                and prepared_samples.shape[-2] <= 64
            ):
                mvar_set = self.get_mvar_set_cpu(prepared_samples)
            else:
                mvar_set = self.get_mvar_set_gpu(prepared_samples)
        if samples.requires_grad:
            mvar_set = self.make_diffable(prepared_samples, mvar_set)
        # Set the `pad_size` to either `self.n_w` or the size of the largest MVaR set.
        pad_size = self.n_w if self.pad_to_n_w else max([_.shape[0] for _ in mvar_set])
        padded_mvar_list = []
        for mvar_ in mvar_set:
            if self.expectation:
                padded_mvar_list.append(mvar_.mean(dim=0))
            else:
                # Repeat the last entry to make `mvar_set` `n_w x m`.
                repeats_needed = pad_size - mvar_.shape[0]
                padded_mvar_list.append(
                    torch.cat([mvar_, mvar_[-1].expand(repeats_needed, m)], dim=0)
                )
        mvars = torch.stack(padded_mvar_list, dim=0)
        return mvars.view(*batch_shape, -1, m)


def func(X: Tensor, n_w: int = 5, seed: int = 0) -> Tensor:
    torch.manual_seed(seed)
    perturbed_X = X.unsqueeze(-2) + torch.rand(n_w, X.shape[-1])
    return perturbed_X.pow(2)

X = torch.ones(1, 1, 2, requires_grad=True)
Y = func(X)
mvar = DiffMVaR(n_w=5, alpha=0.6)
mvar_vals = mvar(Y)
grad = torch.autograd.grad(mvar_vals.sum(), X)
print(mvar_vals)
print(grad)

tensor([[[[2.1189, 2.6644],
          [1.7094, 2.6702]]]], grad_fn=<ViewBackward>)
(tensor([[[5.5261, 6.5328]]]),)


In [108]:
torch.seed()
X = torch.rand(3, 2, 2, requires_grad=True)
eps = 1e-4
X_eps = X + eps
Y = func(X)
mvar_Y = mvar(Y)
Y_eps = func(X_eps)
mvar_Y_eps = mvar(Y_eps)

grad = torch.autograd.grad(mvar_Y.sum(), X)
grad_fd = (mvar_Y_eps.sum() - mvar_Y.sum()) / eps
# print(mvar_Y)
print(grad[0].sum(), grad_fd)
print(grad_fd - grad[0].sum())


tensor(50.7030) tensor(50.6973, grad_fn=<DivBackward0>)
tensor(-0.0056, grad_fn=<SubBackward0>)


In [109]:
from botorch.test_functions.multi_objective import DH3

X = torch.rand(3, 2, 5, 3, requires_grad=True)
func = DH3(dim=3)
eps = 1e-4
X_eps = X + eps
Y = func(X)
mvar_Y = mvar(Y)
Y_eps = func(X_eps)
mvar_Y_eps = mvar(Y_eps)

grad = torch.autograd.grad(mvar_Y.sum(), X)
grad_fd = (mvar_Y_eps.sum() - mvar_Y.sum()) / eps
# print(mvar_Y)
print(grad[0].sum(), grad_fd)
print(grad_fd - grad[0].sum())


tensor(1483.1537) tensor(1485.5957, grad_fn=<DivBackward0>)
tensor(2.4420, grad_fn=<SubBackward0>)
