In [None]:
# neff_adaptive_dropout.py
import torch
import torch.nn as nn

class NeffDropout(nn.Module):
    """
    Neff-adaptive dropout.
    Keeps top floor(beta * Neff) activations along `dim`, zeros the rest,
    then rescales to preserve the per-sample L1 mass (optional).

    Args:
      beta: float >= 0, multiplicative factor on Neff.
      dim: int, feature dimension (default: last).
      mode: 'topk' (deterministic kWTA) or 'random' (Gumbel-Top-k by p_i).
      keep_l1: if True, rescales masked output so sum|x| along `dim` is preserved.
      eps: numerical epsilon to avoid 0/0.
    """
    def __init__(self, beta: float = 1.0, dim: int = -1,
                 mode: str = "topk", keep_l1: bool = False, eps: float = 1e-12):
        super().__init__()
        assert mode in ("topk", "random")
        self.beta = beta
        self.dim = dim
        self.mode = mode
        self.keep_l1 = keep_l1
        self.eps = eps
        self.last_mask = None  # for debugging/inspection

    def _move_last(self, x):
        dim = self.dim if self.dim >= 0 else x.dim() + self.dim
        perm = [d for d in range(x.dim()) if d != dim] + [dim]
        inv = [0] * x.dim()
        for i, p in enumerate(perm):
            inv[p] = i
        return x.permute(*perm), perm, inv

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return x
        # move target dim last for simpler vectorized ops
        x_moved, perm, inv = self._move_last(x)
        B = x_moved.shape[:-1]
        L = x_moved.shape[-1]
        y = x_moved.reshape(-1, L)          # [B*, L]
        w = y.abs()
        wsum = w.sum(dim=1, keepdim=True)
        p = w / (wsum + self.eps)           # probabilities on simplex
        neff = 1.0 / (p.pow(2).sum(dim=1, keepdim=True))   # [B*,1]
        k = torch.floor(self.beta * neff).to(torch.long)
        k = k.clamp_(min=1, max=L)          # at least 1, at most all

        # Build mask (True=keep)
        if self.mode == "topk":
            _, idx = torch.sort(p, dim=1, descending=True)  # ranks per row
            range_row = torch.arange(L, device=y.device).view(1, -1).expand(y.size(0), -1)
            # broadcast k across columns
            keep = range_row < k
            mask = torch.zeros_like(y, dtype=torch.bool)
            mask.scatter_(1, idx, keep)
        else:
            # random, weighted by p_i via Gumbel-Top-k
            g = -torch.log(-torch.log(torch.rand_like(y) + self.eps) + self.eps)
            logits = torch.log(p + self.eps) + g
            _, idx = torch.sort(logits, dim=1, descending=True)
            range_row = torch.arange(L, device=y.device).view(1, -1).expand(y.size(0), -1)
            keep = range_row < k
            mask = torch.zeros_like(y, dtype=torch.bool)
            mask.scatter_(1, idx, keep)

        y_masked = y * mask
        if self.keep_l1:
            pre = wsum                        # sum|x| before
            post = y_masked.abs().sum(dim=1, keepdim=True)
            scale = pre / (post + self.eps)
            y_masked = y_masked * scale

        self.last_mask = mask.reshape(*B, L).permute(*inv).contiguous()
        out = y_masked.reshape(*B, L).permute(*inv).contiguous()
        return out
