In [1]:
import torch
import torch.nn.functional as F

In [52]:
import torch

def smoothen(
    labels: Int[torch.Tensor, "N L"],
    n_quantiles: int,
    sigma: float,
    threshold: int
) -> Float[torch.Tensor, "N L ?"]:

    device = labels.device

    N, L = labels.size()

    range_tensor = torch.arange(0, n_quantiles+threshold, device=device).float()
    
    # expand and reshape to match the batch and sequence dimensions
    range_tensor = range_tensor.unsqueeze(0).unsqueeze(0).expand(N, L, n_quantiles+threshold)
    labels_expanded = labels.float().unsqueeze(-1)
    
    # create gaussian distribution for each label in the sequence
    gaussian = torch.exp(-0.5 * ((range_tensor - labels_expanded) ** 2) / sigma**2)
    gaussian /= gaussian.sum(dim=-1, keepdim=True)
    
    # one-hot encoding for labels at or below the threshold
    one_hot = torch.zeros_like(gaussian).scatter_(-1, labels.unsqueeze(-1), 1.0)
    
    # determine which labels are above the threshold
    is_above_threshold = labels > threshold
    
    # prevent gaussian bleeding for labels above the threshold
    start_bleed = torch.zeros_like(labels, dtype=torch.float32) + threshold + 1
    start_positions = torch.where(is_above_threshold, start_bleed, labels.float())
    prevent_bleed_mask = range_tensor >= start_positions.unsqueeze(-1)
    
    # re-normalize
    gaussian_masked = gaussian * prevent_bleed_mask.float()
    gaussian_masked /= gaussian_masked.sum(dim=-1, keepdim=True)
    
    # combine using the condition
    return torch.where(is_above_threshold.unsqueeze(-1), gaussian_masked, one_hot)

# Example usage
n_quantiles = 5  # Assuming 5 quantile bins
labels = torch.tensor([[0, 2, 3, 6], [1, 2, 4, 3]])  # Example labels for a sequence on GPU
sigma = 1.0  # Smoothing parameter
threshold = 3  # Only smooth labels above this threshold

smoothed_labels = smoothen(labels, n_quantiles, sigma, threshold)
print(smoothed_labels)


tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0576, 0.2583, 0.4258, 0.2583]],

        [[0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.5705, 0.3460, 0.0772, 0.0063],
         [0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
