Code for illustrating Sparsemax
========

This code is intended to exemplify and help understanding the sparsemax operations, as defined in the [paper](https://arxiv.org/abs/1602.02068) "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification" by André Martins and Ramón Fernandez.

**Proposition 1** defines $k(z)$ as:

$k(z) = max(k \in \{1, 2, ..., K\} | 1 + kz_{(k)} > \sum_{j \leq k} z_{(j)})$

where $z \in \mathbb{R}^K$ and $\{z_{(1)}, z_{(2)}, ..., z_{(K)}\}$ is a list of the ordered elements in z.

Let's see it in action.

In [57]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as pl

In [34]:
K = 10
z = np.random.uniform(-5, 5, K)
z_sorted = sorted(z, reverse=True)

print(z)
print(z_sorted)

[ 2.44828998 -0.8733298   0.06289572  1.62920888  3.22873505 -2.08432584
 -1.01466484  3.8326466  -3.37218107  4.65226222]
[4.652262215877265, 3.832646599074419, 3.2287350512536506, 2.4482899814934553, 1.6292088802445157, 0.06289571927698834, -0.8733298042918101, -1.0146648437284456, -2.084325841272495, -3.3721810696340593]


This is the first part of the inequality inside the *max*:

In [35]:
# range starting from 1
k_range = np.arange(1, len(z) + 1)
1 + k_range * z_sorted

array([  5.65226222,   8.6652932 ,  10.68620515,  10.79315993,
         9.1460444 ,   1.37737432,  -5.11330863,  -7.11731875,
       -17.75893257, -32.7218107 ])

And this is the second one, the cumulative sum:

In [37]:
np.cumsum(z_sorted)

array([ 4.65226222,  8.48490881, 11.71364387, 14.16193385, 15.79114273,
       15.85403845, 14.98070864, 13.9660438 , 11.88171796,  8.50953689])

The result

In [38]:
inds = 1 + k_range * z_sorted > np.cumsum(z_sorted)
print(inds)

[ True  True False False False False False False False False]


Putting it all together:

In [41]:
def k(z):
    z_sorted = sorted(z, reverse=True)
    
    # range starting from 1
    k_range = np.arange(1, len(z) + 1)
    inds = 1 + k_range * z_sorted > np.cumsum(z_sorted)
    
    # we know the last element in k_range[inds] is the max 
    # because k_range is already ordered
    max_ind = k_range[inds][-1]
    return max_ind

Let's test it for some distributions

In [80]:
distributions = [np.random.normal(0, 1, K),
                 np.random.normal(0, 10, K),
                 np.random.gamma(1, size=K),
                 np.random.gamma(10, size=K)]
for dist in distributions:
    print(k(dist))

2
1
2
2


Now for the $\tau(z)$ function:

$\tau(z) = \frac{(\sum_{j \leq k(z)} z_{(j)}) - 1}{k(z)}$

In [76]:
def tau(z):
    k_value = k(z)
    
    # we are repeating this call here; an optimized version should avoid it
    z_sorted = np.sort(z)[::-1]
    
    # recall that k starts from 1
    sum_minus_1 = z_sorted[:k_value].sum() - 1
    return sum_minus_1 / k_value

In [84]:
for dist in distributions:
    print('z:', dist)
    print('tau(z):', tau(dist))
    print()

z: [ 1.939607    1.1502079  -0.88095171 -0.67453557 -0.52282631 -0.85172683
 -0.4600404  -1.32632804 -2.35361093  0.20246752]
tau(z): 1.0449074461701429

z: [ 12.54312349  -7.53803208   2.52981906   1.25311422  -3.85766725
  15.67246015  -0.06798676  -4.20733613   4.387767   -12.74671914]
tau(z): 14.672460153704167

z: [4.35197807 0.16903238 0.4451037  0.24514015 0.22523818 3.37137648
 1.34018159 0.08429251 2.16652902 0.20441061]
tau(z): 3.3616772751717097

z: [10.96969989  6.39395072  9.9947205  11.68669685  4.44256289 12.30690624
  7.58271533  7.6362838   6.86791408  7.6610175 ]
tau(z): 11.496801543924576



### Effectively computing sparsemax

Now let's call our just implemented $\tau(z)$ to compute the sparsemax of some distributions and compare it with their softmax.

In [85]:
def softmax(z):
    exps = np.exp(z)
    return exps / exps.sum()

def sparsemax(z):
    return np.maximum(z - tau(z), 0)

In [87]:
for dist in distributions:
    print('Softmax: ', softmax(dist))
    print('Sparsemax: ', sparsemax(dist))
    print()

Softmax:  [0.48730979 0.22129595 0.02903034 0.03568596 0.0415321  0.02989126
 0.04422334 0.01859634 0.00665709 0.08577783]
Sparsemax:  [0.89469955 0.10530045 0.         0.         0.         0.
 0.         0.         0.         0.        ]

Softmax:  [4.19126249e-02 7.96544068e-11 1.87768182e-06 5.23788710e-07
 3.15916236e-09 9.58072791e-01 1.39768489e-07 2.22696134e-09
 1.20370077e-05 4.35617944e-13]
Sparsemax:  [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

Softmax:  [0.61181362 0.0093323  0.01229943 0.01007029 0.00987186 0.22948234
 0.03010319 0.00857407 0.06878451 0.00966837]
Sparsemax:  [0.9903008 0.        0.        0.        0.        0.0096992 0.
 0.        0.        0.       ]

Softmax:  [1.35718116e-01 1.39770400e-03 5.11928909e-02 2.77987743e-01
 1.98581252e-04 5.16867424e-01 4.58869425e-03 4.84120650e-03
 2.24519921e-03 4.96244056e-03]
Sparsemax:  [0.         0.         0.         0.18989531 0.         0.81010469
 0.         0.         0.         0.        ]

