In [None]:
import math
import torch


def rotate_dist(input_dist: torch.Tensor, rotations: torch.Tensor) -> torch.Tensor:
    """
    Rotate discrete distributions defined over angles -pi to +pi by the
    specified rotations (in radians).

    Arguments:
    ----------
    input_dist : (B, 65) tensor
        Each row is a discrete distribution over 65 equally spaced angles in [-pi, pi].
    rotations : (B,) or (B, 1) tensor
        The rotation (in radians) for each distribution in the batch.

    Returns:
    --------
    rotated : (B, 65) tensor
        The input distributions rotated by the specified angles (with interpolation).
    """

    rotations = rotations.view(-1)  # Make sure rotations has shape (B,)

    B, n_bins = input_dist.shape
    assert n_bins == 65, "Expected distributions of size 65 along axis=1."

    # 1) Define the angle grid for the 65 bins
    # angles = torch.linspace(-math.pi, math.pi, steps=n_bins, device=input_dist.device)
    dtheta = 2.0 * math.pi / n_bins  # 2π / 65
    angles = (
        torch.arange(n_bins, device=input_dist.device, dtype=input_dist.dtype) * dtheta
    )

    # Bin width
    bin_width = angles[1] - angles[0]  # ~ 2*pi/64

    # 2) For each bin j in [0..64], the "rotated" angle is angles[j] - rotation.
    angles_2d = angles.unsqueeze(0)  # shape (1, 65)
    rotations_2d = rotations.unsqueeze(1)  # shape (B, 1)

    # (B, 65)
    target_angles = angles_2d - rotations_2d  # %(2*torch.pi)

    # 3) Convert these target angles to float indices in [0..64]
    float_indices = torch.round(
        (target_angles - angles[0]) / bin_width, decimals=4
    )  # shape (B, 65)

    # 4) Floor to get lower bin index, then +1 for upper bin
    idx0 = torch.floor(float_indices).long()  # can be negative
    idx1 = idx0 + 1

    # Wrap both with modulo 65
    idx0_mod = idx0 % n_bins
    idx1_mod = idx1 % n_bins

    # 5) Interpolation weights
    w1 = float_indices - idx0.float()
    w0 = 1.0 - w1

    # 6) Gather the corresponding values from input_dist
    dist_gather_0 = input_dist.gather(1, idx0_mod)
    dist_gather_1 = input_dist.gather(1, idx1_mod)

    # 7) Linear interpolation
    rotated = w0 * dist_gather_0 + w1 * dist_gather_1

    return rotated

In [None]:
# Example usage:

batch_size = 1

# Fake data: random distributions of size (B,65)
# We'll normalize them to sum to 1 across each row just for demonstration.
p = torch.rand(batch_size, 65) ** 10
p = p / p.sum(dim=1, keepdim=True)

# Random rotations in [-pi, pi]
alpha = 2 * math.pi * torch.rand(batch_size) - math.pi
alpha = alpha * 0 + math.pi

q = rotate_dist(p, alpha)

print("p shape:", p.shape)  # (8, 65)
print("alpha shape:", alpha.shape)  # (8,)
print("q shape:", q.shape)  # (8, 65)

# q is the rotated distribution
import matplotlib.pyplot as plt

plt.imshow(p)

In [None]:
plt.imshow(q)

In [None]:
torch.remainder(torch.tensor(1), torch.tensor(2))

In [None]:
# Quick test for 2π identity

B = 4
N = 65

# Make four "one-hot" distributions, each peaking in a different bin
p_test = torch.zeros(B, N)
p_test[0, 0] = 1.0  # peak in bin 0
p_test[1, 30] = 1.0
p_test[2, 64] = 1.0
p_test[3, 10] = 1.0

alpha_2pi = torch.full((B,), 2.0 * math.pi, dtype=torch.float32)

q_test = rotate_distributions(p_test, alpha_2pi)

# They should match p_test exactly (barring any tiny floating error).
print("Original p:")
print(p_test)
print("Rotated by 2π -> q:")
print(q_test)
print("Difference:")
print((q_test - p_test).abs().sum())