
For the sparsity constraint, we use a k-sparse constraint: **only the k largest activations in h are
retained**, while **the rest are set to zero** (Makhzani et al., 2013; Gao et al., 2024). This approach avoids
issues such as shrinkage, where L1 regularisation can cause feature activations to be systematically
lower than their true values, potentially leading to suboptimal representations shrinkage, (Wright
et al., 2024; Rajamanoharan et al., 2024).

```
def k_sparse(self, x):
    # 实现k-sparse约束
    topk, indices = torch.topk(x, self.k, dim=1)
    mask = torch.zeros_like(x).scatter_(1, indices, 1)
    return x * mask
```

In [7]:
import torch
torch.manual_seed(42)

<torch._C.Generator at 0x7eee165d9610>

In [17]:
x = torch.tensor([[0.1, 0.5, 0.3, 0.8, 0.2, 0.4], 
                  [0.2, 0.9, 0.3, 0.1, 0.7, 0.4]], 
                 requires_grad=True
)

In [18]:
topk_values, topk_indices = torch.topk(x, k=2, dim=1)

In [19]:
topk_values

tensor([[0.8000, 0.5000],
        [0.9000, 0.7000]], grad_fn=<TopkBackward0>)

In [14]:
topk_indices

tensor([[3, 1],
        [1, 4]])

In [6]:
torch.zeros_like(x).scatter_(1, topk_indices, 1)

tensor([[0., 1., 0., 1., 0., 0.],
        [0., 1., 0., 0., 1., 0.]])

In [18]:
mask = torch.zeros_like(x).scatter_(1, topk_indices, 1)

In [20]:
x

tensor([[0.1000, 0.5000, 0.3000, 0.8000, 0.2000, 0.4000],
        [0.2000, 0.9000, 0.3000, 0.1000, 0.7000, 0.4000]])

In [19]:
x * mask

tensor([[0.0000, 0.5000, 0.0000, 0.8000, 0.0000, 0.0000],
        [0.0000, 0.9000, 0.0000, 0.0000, 0.7000, 0.0000]])