In [2]:
import torch
from torch.nn import functional as F
from torch.distributions import Categorical
from torch import Tensor

In [3]:
def sample_t(data: Tensor, n_steps=None) -> Tensor:
    if n_steps == 0 or n_steps is None:
        t = torch.rand(data.size(0), device=data.device).unsqueeze(-1)
    else:
        t = (
            torch.randint(
                0, n_steps, (data.size(0),), device=data.device
            ).unsqueeze(-1)
            / n_steps
        )
    t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data)
    return t

In [4]:
batch_size, seq_len, K = 2, 3, 4
beta_1 = 20.4054 / K

In [5]:
ground_truth = torch.randint(0, K, (batch_size, seq_len))
mock_logits = torch.randn(batch_size, seq_len, K)

In [6]:
t = sample_t(ground_truth)

In [7]:
t

tensor([[0.9959, 0.9959, 0.9959],
        [0.0521, 0.0521, 0.0521]])

In [8]:
cat = Categorical(logits=mock_logits)
cat_probs = cat.probs

In [9]:
softmax_probs = torch.softmax(mock_logits, dim=-1)

In [10]:
assert torch.allclose(cat_probs, softmax_probs)

In [11]:
target = F.one_hot(ground_truth, num_classes=K).float()

In [12]:
kl = K * (target - softmax_probs).square().sum(-1)

In [13]:
kl.sum()

tensor(18.5951)

In [14]:
torch.mean(t * beta_1 * kl)

tensor(9.1797)

In [15]:
t_batch, _ = torch.max(t, dim=-1)

In [16]:
torch.mean(
    torch.sum((target - softmax_probs) ** 2, dim=(-2, -1)) * K * beta_1 * t_batch / seq_len
)

tensor(9.1797)

In [18]:
result = torch.sum(K * beta_1 * t_batch * torch.sum((target - softmax_probs) ** 2) / (batch_size**2))
result = result / seq_len
print(result)

tensor(8.2842)
