In [18]:
import torch
import torch.nn.functional as F

def soft_topk(s: torch.Tensor, k: int, tau1: float = 0.01, tau2: float = 0.07):
    # softmax 归1
    s = F.softmax(s/1, dim=-1)

    # 差值矩阵
    diff = torch.unsqueeze(s, -1) - torch.unsqueeze(s, -2)

    # sigma
    sigma = torch.sigmoid(-diff / tau1)

    # row sum
    row_sum = sigma.sum(dim=-1) - 0.5

    # r_tilde
    r_tilde = 1.0 + row_sum

    # gating
    eps = 6
    a = torch.sigmoid((k - r_tilde) / tau2 +eps)
    a = a / a.sum(-1, keepdim=True)
    print("a:", a)
    # final output
    a = a * s
    return a
z = torch.tensor([0.2, 0.3])  # 输入张量
k = 1

softz = soft_topk(z, k)          # 软 topk

# softz= softz/softz.sum(-1,keepdim=True)

z=F.softmax(z, dim=-1)
topkz,_ = torch.topk(z, k)    # 正常的 topk
# topkz= topkz/topkz.sum(-1,keepdim=True)  



print("topk:", topkz)
print("soft topk:", softz)


    

a: tensor([2.7810e-04, 9.9972e-01])
topk: tensor([0.5250])
soft topk: tensor([1.3210e-04, 5.2483e-01])


In [2]:
import torch
import torch.nn.functional as F

def pick_taus_from_eps(s_logits: torch.Tensor, k: int, eps_hard: float,
                       p: float = 0.7,  # 过渡带定义：p->1-p
                       tau1_clip=(1e-4, 1.0)):
    """
    eps_hard: 你希望 rank 轴上的过渡带宽度（越小越硬）
    p=0.95 表示用 0.95->0.05 的宽度；p=0.99 更硬
    """
    # 常数：logit(p)-logit(1-p)
    C = torch.log(torch.tensor(p/(1-p))) * 2  # = logit(p) - logit(1-p)

    # tau2 由 eps_hard 严格控制
    tau2 = float(eps_hard / float(C))

    # 估计 tau1：用第k与第k+1名的典型 gap
    s = F.softmax(s_logits, dim=-1)
    topv, _ = torch.topk(s, k + 1, dim=-1, largest=True, sorted=True)
    gap = (topv[..., k-1] - topv[..., k]).detach()  # kth - (k+1)th
    gap_med = torch.median(gap)

    tau1 = float((gap_med / C).clamp(min=tau1_clip[0], max=tau1_clip[1]))
    return tau1, tau2

tau1,tau2= pick_taus_from_eps(z, k=1, eps_hard=0.6)
print(tau1,tau2)

0.014737431891262531 0.35406675582413005
