In [1]:
# !pip install tensorflow-probability nb_black gin-config
# !pip install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from perturbations_torch.fenchel_young import FenchelYoungLoss

def ranks(inputs, dim=-1):
    """Returns the ranks of the input values among the given axis."""
    return 1 + inputs.argsort(dim).argsort(dim).type(inputs.dtype)

x = torch.randn([3, 5]).float().unsqueeze(0)
print(x)
x.requires_grad = True
y_true = torch.arange(5).float().unsqueeze(0).repeat([x.shape[0], 1])
print(x.argsort(-1))

tensor([[[ 2.1268, -1.1092, -0.7310, -1.3338,  0.1591],
         [ 0.3222,  1.5026, -0.6778,  1.4786,  0.2953],
         [ 1.9041,  0.2432, -0.4155, -0.3035,  0.6855]]])
tensor([[[3, 1, 2, 4, 0],
         [2, 4, 0, 3, 1],
         [2, 3, 1, 4, 0]]])


In [4]:
optim = torch.optim.SGD([x], 0.01)

In [5]:
for iteration in range(200):
    optim.zero_grad()
    criterion = FenchelYoungLoss(ranks)
    loss = criterion(y_true, x).sum()
    loss.backward()
    optim.step()
    if iteration % 50 == 0:
        print(x.argsort(-1))
        print(loss.item())

tensor([[[3, 1, 2, 4, 0],
         [2, 0, 4, 1, 3],
         [2, 3, 1, 4, 0]]])
92.51268005371094
tensor([[[1, 2, 3, 0, 4],
         [0, 2, 1, 3, 4],
         [0, 1, 2, 3, 4]]])
29.038894653320312
tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])
15.000822067260742
tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])
15.000313758850098


In [33]:
x

tensor([[[-2.4375, -2.3969, -2.3266, -2.2855, -1.4416],
         [-2.0240, -1.5973, -1.5565, -0.9716, -0.9297],
         [-1.7595, -1.7227, -1.6877, -1.6510, -1.0653]]], requires_grad=True)

In [38]:
torch.any(((torch.sign(x + 2) + 1) / 2).bool(), dim=1)

tensor([[True, True, True, True, True]])

In [31]:
from perturbations_torch import perturbations


def reduce_sign_any(input_tensor, axis=-1):
    """A logical or of the signs of a tensor along an axis.

  Args:
   input_tensor: Tensor<float> of any shape.
   axis: the axis along which we want to compute a logical or of the signs of
     the values.

  Returns:
   A Tensor<float>, which as the same shape as the input tensor, but without the
    axis on which we reduced.
  """
    boolean_sign = tf.math.reduce_any(
        tf.cast((tf.sign(input_tensor) + 1) / 2.0, dtype=tf.bool), axis=axis
    )
    return tf.cast(boolean_sign, dtype=input_tensor.dtype) * 2.0 - 1.0


class PerturbationsTest(parameterized.TestCase, tf.test.TestCase):
    """Testing the perturbations module."""

    def setUp(self):
        super(PerturbationsTest, self).setUp()
        tf.random.set_seed(0)

    @parameterized.parameters([perturbations._GUMBEL, perturbations._NORMAL])
    def test_sample_noise_with_gradients(self, noise):
        shape = (3, 2, 4)
        samples, gradients = perturbations.sample_noise_with_gradients(
            noise, shape
        )
        self.assertAllEqual(samples.shape, shape)
        self.assertAllEqual(gradients.shape, shape)

    def test_sample_noise_with_gradients_raise(self):
        with self.assertRaises(ValueError):
            _, _ = perturbations.sample_noise_with_gradients(
                "unknown", (3, 2, 4)
            )

    @parameterized.parameters([1e-3, 1e-2, 1e-1])
    def test_perturbed_reduce_sign_any(self, sigma):
        input_tensor = tf.constant([[-0.3, -1.2, 1.6], [-0.4, -2.4, -1.0]])
        soft_reduce_any = perturbations.perturbed(reduce_sign_any, sigma=sigma)
        output_tensor = soft_reduce_any(input_tensor, axis=-1)
        self.assertAllClose(output_tensor, [1.0, -1.0])

    def test_perturbed_reduce_sign_any_gradients(self):
        # We choose a point where the gradient should be above noise, that is
        # to say the distance to 0 along one direction is about sigma.
        sigma = 0.1
        input_tensor = tf.constant(
            [[-0.6, -1.2, 0.5 * sigma], [-2 * sigma, -2.4, -1.0]]
        )
        soft_reduce_any = perturbations.perturbed(reduce_sign_any, sigma=sigma)
        with tf.GradientTape() as tape:
            tape.watch(input_tensor)
            output_tensor = soft_reduce_any(input_tensor)
        gradient = tape.gradient(output_tensor, input_tensor)
        # The two values that could change the soft logical or should be the one
        # with real positive impact on the final values.
        self.assertAllGreater(gradient[0, 2], 0.0)
        self.assertAllGreater(gradient[1, 0], 0.0)
        # The value that is more on the fence should bring more gradient than any
        # other one.
        self.assertAllLessEqual(gradient, gradient[0, 2].numpy())

    def test_unbatched_rank_one_raise(self):
        with self.assertRaises(ValueError):
            input_tensor = tf.constant([-0.6, -0.5, 0.5])
            dim = len(input_tensor)
            n = 10000000

            argmax = lambda t: tf.one_hot(tf.argmax(t, 1), dim)
            soft_argmax = perturbations.perturbed(
                argmax, sigma=0.5, num_samples=n
            )
            _ = soft_argmax(input_tensor)

    def test_perturbed_argmax_gradients_without_minibatch(self):
        input_tensor = tf.constant([-0.6, -0.5, 0.5])
        dim = len(input_tensor)
        eps = 1e-2
        n = 10000000

        argmax = lambda t: tf.one_hot(tf.argmax(t, 1), dim)
        soft_argmax = perturbations.perturbed(
            argmax, sigma=0.5, num_samples=n, batched=False
        )
        norm_argmax = lambda t: tf.reduce_sum(tf.square(soft_argmax(t)))

        w = tf.random.normal(input_tensor.shape)
        w /= tf.linalg.norm(w)
        var = tf.Variable(input_tensor)
        with tf.GradientTape() as tape:
            value = norm_argmax(var)

        grad = tape.gradient(value, var)
        grad = tf.reshape(grad, input_tensor.shape)

        value_minus = norm_argmax(input_tensor - eps * w)
        value_plus = norm_argmax(input_tensor + eps * w)

        lhs = tf.reduce_sum(w * grad)
        rhs = (value_plus - value_minus) * 1.0 / (2 * eps)
        self.assertAllLess(tf.abs(lhs - rhs), 0.05)

    def test_perturbed_argmax_gradients_with_minibatch(self):
        input_tensor = tf.constant([[-0.6, -0.7, 0.5], [0.9, -0.6, -0.5]])
        dim = len(input_tensor)
        eps = 1e-2
        n = 10000000

        argmax = lambda t: tf.one_hot(tf.argmax(t, -1), dim)
        soft_argmax = perturbations.perturbed(argmax, sigma=2.5, num_samples=n)
        norm_argmax = lambda t: tf.reduce_sum(tf.square(soft_argmax(t)))

        w = tf.random.normal(input_tensor.shape)
        w /= tf.linalg.norm(w)
        var = tf.Variable(input_tensor)
        with tf.GradientTape() as tape:
            value = norm_argmax(var)

        grad = tape.gradient(value, var)
        grad = tf.reshape(grad, input_tensor.shape)

        value_minus = norm_argmax(input_tensor - eps * w)
        value_plus = norm_argmax(input_tensor + eps * w)

        lhs = tf.reduce_sum(w * grad)
        rhs = (value_plus - value_minus) * 1.0 / (2 * eps)
        self.assertAllLess(tf.abs(lhs - rhs), 0.05)


if __name__ == "__main__":
    tf.enable_v2_behavior()
    tf.test.main()
