In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

![ -d sparse_soft_topk ] || git clone https://github.com/google-research/sparse_soft_topk

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import jax
import sparse_soft_topk

In [None]:
# We define the hard top-k operator
def hard_topk_mask(s, k):
  zero = jnp.zeros(s.shape)
  return jnp.where(s <= s.sort()[::-1][k], zero, jnp.ones(s.shape))

In [None]:
# Defining the input array
n_discretization = 500
x = jnp.linspace(0, 5, n_discretization)
s = (jnp.array([3 * jnp.ones(n_discretization), jnp.ones(n_discretization), -1 + x, x])).T

In [None]:
# Regularization scale
l = 6e-1
# k
k = 2
# Hard top-k
hards = jax.vmap(hard_topk_mask, in_axes=(0, None))(s, k).T 
# Sparse and smooth top-k with l2 regularization
softs_2 = sparse_soft_topk.sparse_soft_topk_mask_pav(s, k=k, l=l, p=2.0).T
# Sparse and smooth top-k with l-(4/3) regularization
softs_4 = sparse_soft_topk.sparse_soft_topk_mask_pav(s, k=k, l=l, p=4/3).T

In [None]:
# Plotting the outputs
p = 0.9
plt.figure(figsize=(4 * p, 3 * p))
plt.plot(x, hards[1] + hards[2],  "--",lw=3, label="Hard")
plt.plot(x, softs_2[1] + softs_2[2], lw=3, label="$p = 2$")
plt.plot(x, softs_4[1] + softs_4[2], lw=3, label="$p = 4/3$", alpha=0.6)


plt.xlabel(r"$s$")
plt.ylabel(r"$\mathrm{topkmask}(\theta(s))_2 + \mathrm{topkmask}(\theta(s))_3$")
plt.legend(loc="upper center")
plt.tight_layout()
plt.show()