In [34]:
!pip install tensorflow==1.15 sparsemax > /dev/null

In [35]:
import torch
import tensorflow as tf
import numpy as np
from sparsemax import sparsemax

## Sample data

In [36]:
np.random.seed(42)
data = np.random.randn(10) # numpy
data_t = torch.tensor([data]) # pytorch tensor

## Tensorflow
I'll be using TF 1.5 because is the last version that includes sparsemax on the contrib module.


In [37]:
x = tf.constant(data[None, ...])
output = tf.contrib.sparsemax.sparsemax(x)
tf_result = tf.Session().run(output)
tf_result

array([[0.        , 0.        , 0.        , 0.47190852, 0.        ,
        0.        , 0.52809148, 0.        , 0.        , 0.        ]])

## Pytorch

In [38]:
# github repo: https://github.com/aced125/sparsemax

sp = sparsemax.SparsemaxFunction()


In [39]:
torch_result = sp.apply(data_t)
torch_result

tensor([[0.0000, 0.0000, 0.0000, 0.4719, 0.0000, 0.0000, 0.5281, 0.0000, 0.0000,
         0.0000]], dtype=torch.float64)

## Custom

In [40]:
class MySparsemax(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input):
    sorted, _ = torch.sort(input, dim=-1, descending=True)
    cumsum = sorted.cumsum(dim=-1)
    col_range = torch.arange(1, input.size(-1)+1)
    is_gt = (1+col_range*sorted) > cumsum
    kz = is_gt.sum(dim=-1, keepdim=True)

    row_range = torch.arange(input.size(0))[..., None]
    tau_z = (cumsum[row_range, kz-1]-1) / kz
    output = (input - tau_z).clamp(0)
    ctx.save_for_backward(output)
    return output
  
  @staticmethod
  def backward(ctx, grad_output):
    output, *_ = ctx.saved_tensors

    nonzeros = torch.ne(output, 0)
    support_size = nonzeros.sum(dim=-1, keepdim=True)
    v_hat = (grad_output * nonzeros).sum(-1, keepdim=True) / support_size

    return nonzeros * (grad_output - v_hat), None

In [41]:
my_result = MySparsemax.apply(data_t)
my_result

tensor([[0.0000, 0.0000, 0.0000, 0.4719, 0.0000, 0.0000, 0.5281, 0.0000, 0.0000,
         0.0000]], dtype=torch.float64)

## Compare my result with the other implementations

In [42]:
np.testing.assert_allclose(my_result.data.numpy(), tf_result)
np.testing.assert_allclose(my_result.data.numpy(), torch_result.data.numpy())

## Check gradient

In [43]:
torch.random.manual_seed(42)
inputs = torch.randn(2, 5, requires_grad=True).double()
torch.autograd.gradcheck(MySparsemax.apply, inputs=inputs)

True