# Implementation of function log soft_top_k

In some cases, especially for CE loss, we will need rather $\log p$ instead of $p$.

*Value of $\log p$*
We have:
$$
\log p_i = g(\frac{b - r_i}{\alpha}),
$$
where:
$$
g(x) =
\begin{cases}
\log(1 - 0.5 \exp(-|x|)) & \text{for } x \geq 0, \\
-\log 2 + x & \text{for } x < 0.
\end{cases}
$$

To calculate derivatives, we need the auxiliary function:
$$
h(x) =
\frac{1}{\max(x,0.5)} - 1
$$

*Derivative with respect to $w$*
We have
   $$
   \frac{\partial u_i}{\partial w} = \frac{1}{p_i}\cdot \frac{\partial p_i}{\partial w}=\frac{q_i}{p_i}.
   $$
  Thus
  $$
    \frac{\partial u}{\partial w} =q/p= h(p)/(S\alpha). 
  $$

*Derivative of $u$ with respect to $\alpha$* 
We have:
$$
\frac{\partial u}{\partial \alpha} =
\frac{1}{p} \odot \frac{\partial p}{\partial \alpha}.
$$
Substituting:
$$
\frac{\partial u}{\partial \alpha} =
\frac{1}{p} \odot \frac{1}{\alpha}(s \odot r-\langle q,r \rangle s).
$$

Thus:
$$
\frac{\partial u}{\partial \alpha} =
\frac{1}{\alpha^2}(h(p) \odot r - \langle q, r \rangle h(p)).
$$

*Derivative of $u$ with respect to $r$*
We have
$$
Du=\mathrm{diag}\left(\frac{1}{p}\right)
$$
$$
Du=\mathrm{diag}\left(\frac{1}{p}\right)(sq^T-\mathrm{diag}(s))
=
\frac{1}{\alpha}\left(h(p) q^T-\mathrm{diag}(h(p))\right)
$$

*Computation of $v^T Du$*
We have
$$
v^T \cdot Du=\frac{1}{\alpha}(\langle v,h(p) \rangle q^T-
v^T \odot h(p)^T).
$$


In [None]:
from math import log
import time
import numpy as np
import torch

In [None]:
class LogSoftTopK(torch.autograd.Function):

    @staticmethod
    def _solve(s, t, a, b, e):
        z = torch.abs(e) + torch.sqrt(e**2 + a * b * torch.exp(s - t))
        ab = torch.where(e > 0, a, b)

        return torch.where(
            e > 0, t + torch.log(z) - torch.log(ab), s - torch.log(z) + torch.log(ab)
        )

    @staticmethod
    def forward(ctx, r, k, alpha, descending=False):
        # Sprawdzenie wymiarów
        assert r.shape[0] == k.shape[0], "k must have same batch size as r"

        batch_size, num_dim = r.shape
        x = torch.empty_like(r, requires_grad=False)

        def finding_b():
            scaled = torch.sort(r, dim=1)[0]
            scaled.div_(alpha)

            eB = torch.logcumsumexp(scaled, dim=1)
            eB.sub_(scaled).exp_()

            torch.neg(scaled, out=x)
            eA = torch.flip(x, dims=(1,))
            torch.logcumsumexp(eA, dim=1, out=x)
            idx = torch.arange(start=num_dim - 1, end=-1, step=-1, device=x.device)
            torch.index_select(x, 1, idx, out=eA)
            eA.add_(scaled).exp_()

            row = torch.arange(1, 2 * num_dim + 1, 2, device=r.device)

            torch.add(torch.add(eA, eB, alpha=-1, out=x), row.view(1, -1), out=x)

            w = (k if descending else num_dim - k).unsqueeze(1)
            i = torch.searchsorted(x, 2 * w)
            m = torch.clamp(i - 1, 0, num_dim - 1)
            n = torch.clamp(i, 0, num_dim - 1)

            b = LogSoftTopK._solve(
                scaled.gather(1, m),
                scaled.gather(1, n),
                torch.where(i < num_dim, eA.gather(1, n), 0),
                torch.where(i > 0, eB.gather(1, m), 0),
                w - i,
            )
            return b

        b = finding_b()

        sign = -1 if descending else 1

        torch.div(r, alpha * sign, out=x)
        x.sub_(sign * b)

        sign_x = x > 0
        qx = torch.relu(x).neg_().exp_().mul_(-0.5).add_(1)

        ctx.save_for_backward(x, qx, r)
        ctx.alpha = alpha
        ctx.sign = sign

        log_p = torch.where(sign_x, torch.log(qx), x.sub(log(2)))
        return log_p

    # @staticmethod
    # def backward(ctx, grad_output):
    #     x, qx, r = ctx.saved_tensors
    #     alpha = ctx.alpha
    #     sign = ctx.sign

    #     w = 1 / qx - 1
    #     wgrad = w * grad_output
    #     wsum = wgrad.sum(dim=1, keepdim=True)

    #     q = torch.softmax(-torch.abs(x), dim=1)
    #     R = 0.5 * torch.exp(-torch.abs(x)).sum(dim=1)

    #     grad_k = abs(sign) / R * wsum.squeeze(1)
    #     grad_r = -sign / alpha * (wsum * q - wgrad)
    #     grad_alpha = -1 /alpha *(grad_r * r).sum()

    #     return grad_r, grad_k, grad_alpha, None

    @staticmethod
    def backward(ctx, grad_output):
        x, qx, r = ctx.saved_tensors
        alpha = ctx.alpha
        sign = ctx.sign

        x.abs_().neg_()
        grad_r = torch.softmax(x, dim=1)
        x.exp_()
        grad_k = torch.sum(x, dim=1).mul_(0.5)

        qx.reciprocal_().sub_(1)
        qx.mul_(grad_output)  # wgrad

        wsum = qx.sum(dim=1, keepdim=True)

        # Gradients
        grad_k.reciprocal_().mul_(wsum.squeeze(1)).mul_(abs(sign))
        grad_r.mul_(wsum).sub_(qx).mul_(-sign / alpha)

        x.copy_(r).mul_(grad_r)
        grad_alpha = torch.sum(x).div_(-alpha)

        return grad_r, grad_k, grad_alpha, None


def log_soft_top_k(r, k, alpha, descending=False):
    return LogSoftTopK.apply(r, k, alpha, descending)

## Test

In [None]:
def numerical_vjp(x, k, alpha, descending, v, h=1e-5):
    grad_approx = torch.zeros_like(x)
    for i in range(x.numel()):
        e = torch.zeros_like(x).view(-1)
        e[i] = h  # Perturb one dimension at a time
        e = e.view_as(x)  # Reshape back to original shape

        grad_approx.view(-1)[i] = torch.dot(
            v.flatten(),
            (
                log_soft_top_k(x + e, k, alpha, descending)
                - log_soft_top_k(x - e, k, alpha, descending)
            ).flatten(),
        ) / (2 * h)
    return grad_approx


def check_value(x, v, text):
    assert x.shape == v.shape, f"Shape mismatch: {x.shape} vs {v.shape}"

    def fun():
        if isinstance(x, torch.Tensor):
            return torch.allclose, torch.linalg.norm
        else:
            return np.allclose, np.linalg.norm

    function, dist = fun()
    check = None
    for tol_exp in range(-15, 0):
        if function(x, v, rtol=1e-05, atol=10**tol_exp):
            check = f"Error within atol=1e{tol_exp}"
            break
    if check:
        print(f"✅ - {text} ({check})")
    else:
        print(f"❌ - {text} [dist: {dist(x - v):.4f}]")
        print(f"Expected: {v}")
        print(f"Got: {x}")


def print_time_stats(times, name):
    if not times:
        return
    avg = sum(times) / len(times)
    min_t = min(times)
    max_t = max(times)
    print(f"\n{name} time stats (seconds):")
    print(f"\033[0;1;35m  Average: {avg:.4f}\033[0m")
    print(f"  Min:     {min_t:.4f}")
    print(f"  Max:     {max_t:.4f}")
    print(f"  All times: {[f'{t:.4f}' for t in times]}")

In [None]:
torch.set_default_dtype(torch.float64)

# ==============  Parameters  =================
use_gpu = False
use_gpu = True

descending = False
descending = True

h = 1e-5

bs = 3
n = 500
# =============================================

device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
print(f"{device=}\n")

factory_kwargs = {"device": device, "requires_grad": True}

for i in range(3):
    alpha = torch.tensor(np.random.rand(), **factory_kwargs)

    r = torch.randn(bs, n, **factory_kwargs)
    k = torch.tensor(np.random.rand(bs) * n, **{**factory_kwargs, "dtype": r.dtype})

    print(f"bs={bs}, n={n}, alpha={alpha.item()}")
    assert (
        alpha.dtype == k.dtype == r.dtype
    ), f"You have different types of tensors: {alpha.dtype=}, {k.dtype=}, {r.dtype=}"

    # For Backward computation
    v = torch.randn_like(r)

    # Forward pass
    start_forward = time.perf_counter()
    prob = log_soft_top_k(r, k, alpha, descending)
    torch.cuda.synchronize() if device.type == "cuda" else None
    forward_time = time.perf_counter() - start_forward
    print(f"\033[0;32mForward pass time: {forward_time:.4g} s\033[0m")

    # Test sum
    test_sum = torch.logsumexp(prob, dim=-1).exp()
    check_value(test_sum, k, "test sum")

    # ======================================================
    print("=" * 10, "Gradients", "=" * 10, sep="   ")

    # Backward pass
    start_backward = time.perf_counter()
    r.grad = None  # Clear gradients
    k.grad = None
    alpha.grad = None
    prob.backward(v)
    torch.cuda.synchronize() if device.type == "cuda" else None
    backward_time = time.perf_counter() - start_backward
    print(f"\033[0;34mBackward pass time: {backward_time:.4g} s\033[0m")
    print(f"\033[0;33mTotal time: {forward_time + backward_time:.4g} s\033[0m")

    # try:
    #     torch.autograd.gradcheck(
    #         lambda r: log_soft_top_k(r, k, alpha, descending),
    #         r,
    #         eps=1e-6,
    #         atol=1e-5,
    #         rtol=1e-3
    #     )
    #     print("✅ r gradient passed")
    # except Exception as e:
    #     print(f"❌ r gradient failed: {str(e)}")

    # try:
    #     torch.autograd.gradcheck(
    #         lambda k: log_soft_top_k(r, k, alpha, descending).sum(),
    #         k,
    #         eps=1e-6,
    #         atol=1e-5,
    #         rtol=1e-3
    #     )
    #     print("✅ k gradient passed")
    # except Exception as e:
    #     print(f"❌ k gradient failed: {str(e)}")

    numerical_derivative = numerical_vjp(r, k, alpha, descending, v, h)
    check_value(r.grad, numerical_derivative, "grad r")

    numerical_k_grad = (
        torch.mul(
            v,
            log_soft_top_k(r, k + h, alpha, descending)
            - log_soft_top_k(r, k - h, alpha, descending),
        )
        / (2 * h)
    ).sum(1)
    check_value(k.grad, numerical_k_grad, "grad k")

    numerical_alpha_grad = torch.mul(
        v,
        log_soft_top_k(r, k, alpha + h, descending)
        - log_soft_top_k(r, k, alpha - h, descending),
    ) / (2 * h)
    check_value(alpha.grad, numerical_alpha_grad.sum(), "grad alpha")
    print()