In [None]:
import torch

def categorical_l2_project(
    atoms: torch.Tensor,
    target_z: torch.Tensor,
    target_p: torch.Tensor,
) -> torch.Tensor:
    """Projects a target distribution onto a fixed set of support points (atoms) using L2 projection.

    Args:
        atoms: 1D tensor of support points, shape `(num_atoms,)`. Must be sorted in ascending order.
        target_z: 2D tensor of target support points, shape `(batch_size, num_target_atoms)`.
        target_p: 2D tensor of probabilities for `target_z`, shape `(batch_size, num_target_atoms)`.
            Must sum to 1 over the last dimension (valid probability distribution).

    Returns:
        Projected probabilities over `atoms`, shape `(batch_size, num_atoms)`.
    """
    # Validate input shapes
    assert atoms.ndim == 1, f"atoms must be 1D, got {atoms.ndim}D"
    assert target_z.ndim == 2 and target_p.ndim == 2, \
        f"target_z and target_p must be 2D, got {target_z.ndim}D and {target_p.ndim}D"
    assert target_z.shape == target_p.shape, \
        f"target_z and target_p must have the same shape, got {target_z.shape} vs {target_p.shape}"

    num_atoms = atoms.shape[0]
    batch_size, num_target_atoms = target_z.shape

    # Find indices where target_z would be inserted into atoms (left neighbors)
    # Equivalent to JAX's jnp.searchsorted with side="left"
    b = torch.searchsorted(atoms, target_z, right=False)  # Shape: (batch_size, num_target_atoms)

    # Compute left and right neighbor indices, clamped to valid range [0, num_atoms-1]
    l = torch.clamp(b - 1, 0, num_atoms - 1)  # Left neighbors
    u = torch.clamp(b, 0, num_atoms - 1)      # Right neighbors

    # Get atom values for left and right neighbors (broadcasted to batch)
    atoms_l = atoms[l]  # Shape: (batch_size, num_target_atoms)
    atoms_u = atoms[u]  # Shape: (batch_size, num_target_atoms)

    # Calculate weights for left and right neighbors
    eps = 1e-6  # Avoid division by zero
    delta_u = target_z - atoms_l  # Distance from target to left neighbor
    delta_l = atoms_u - target_z  # Distance from target to right neighbor
    
    """This is the Wasserstein normalization term, which ensures that the weights sum to 1: F^-1(U)"""
    denominator = (atoms_u - atoms_l) + eps  # Normalizer (with epsilon)
    # Handle case where left and right neighbors are the same (atoms_l == atoms_u)
    # In this case, target_z is exactly on the atom, so assign full weight to that atom
    same_neighbor = (atoms_u == atoms_l)
    w_l = torch.where(same_neighbor, 1.0, delta_l / denominator)  # Fixed weight logic
    w_u = torch.where(same_neighbor, 0.0, delta_u / denominator)
    
    # Initialize projected probabilities with zeros
    p_proj = torch.zeros(
        (batch_size, num_atoms),
        device=target_z.device,
        dtype=target_z.dtype
    )

    # Accumulate weights into the projected distribution using scatter add
    # Add contributions from left neighbors
    p_proj.scatter_add_(dim=1, index=l, src=w_l * target_p)
    # Add contributions from right neighbors
    p_proj.scatter_add_(dim=1, index=u, src=w_u * target_p)

    return p_proj

In [None]:
#deepmind categorical projection (cramer projection)

def categorical_l2_project(
    atoms: torch.Tensor,
    target_z: torch.Tensor,
    target_p: torch.Tensor,
) -> torch.Tensor:
    """Projects a target distribution onto a fixed set of support points (atoms) using L2 projection on CDFs (Cramér projection).

    Args:
        atoms: 1D tensor of support points, shape `(num_atoms,)`. Must be sorted in ascending order.
        target_z: 2D tensor of target support points, shape `(batch_size, num_target_atoms)`.
            Each row must be sorted in ascending order.
        target_p: 2D tensor of probabilities for `target_z`, shape `(batch_size, num_target_atoms)`.
            Must sum to 1 over the last dimension (valid probability distribution).

    Returns:
        Projected probabilities over `atoms`, shape `(batch_size, num_atoms)`.
    """
    # Validate input shapes
    assert atoms.ndim == 1, f"atoms must be 1D, got {atoms.ndim}D"
    assert target_z.ndim == 2 and target_p.ndim == 2, \
        f"target_z and target_p must be 2D, got {target_z.ndim}D and {target_p.ndim}D"
    assert target_z.shape == target_p.shape, \
        f"target_z and target_p must have the same shape, got {target_z.shape} vs {target_p.shape}"

    batch_size, num_target_atoms = target_z.shape
    num_atoms = atoms.shape[0]

    # Construct helper arrays from atoms (z_q).
    d_pos = torch.roll(atoms, shifts=-1) - atoms  # atoms[i+1] - atoms[i]
    d_neg = atoms - torch.roll(atoms, shifts=1)   # atoms[i] - atoms[i-1]

    # Clip target_z to be in atoms range [min, max].
    target_z = torch.clamp(target_z, atoms[0], atoms[-1]).unsqueeze(1)  # (batch_size, 1, num_target_atoms)

    # Get the distance between atom values in support, add dims for broadcast.
    d_pos = d_pos.unsqueeze(-1)  # (num_atoms, 1)
    d_neg = d_neg.unsqueeze(-1)  # (num_atoms, 1)
    atoms_exp = atoms.unsqueeze(-1)  # (num_atoms, 1)

    # Ensure no division by zero (for duplicate atoms or boundaries).
    d_neg = torch.where(d_neg > 0, 1. / d_neg, torch.zeros_like(d_neg))
    d_pos = torch.where(d_pos > 0, 1. / d_pos, torch.zeros_like(d_pos))

    # Broadcast target_z and target_p for batch-wise computation over atoms.
    delta_qp = target_z - atoms_exp[None, :, :]  # (batch_size, num_atoms, num_target_atoms)
    d_sign = (delta_qp >= 0.0).to(target_p.dtype)

    # Compute delta_hat: signed distance normalized by directional bin width.
    delta_hat = (d_sign * delta_qp * d_pos[None, :, :]) - \
                ((1.0 - d_sign) * delta_qp * d_neg[None, :, :])

    target_p = target_p.unsqueeze(1)  # (batch_size, 1, num_target_atoms)

    # Projected probs: sum over target atoms.
    p_proj = torch.sum(torch.clamp(1.0 - delta_hat, 0.0, 1.0) * target_p, dim=-1)  # (batch_size, num_atoms)

    return p_proj

In [None]:
import torch
import torch.nn.functional as F

def categorical_q_learning(
    q_atoms_tm1: torch.Tensor,
    q_logits_tm1: torch.Tensor,
    a_tm1: int,
    r_t: torch.Tensor,
    discount_t: torch.Tensor,
    q_atoms_t: torch.Tensor,
    q_logits_t: torch.Tensor,
    stop_target_gradients: bool = True,
) -> torch.Tensor:
    """Implements Q-learning for categorical Q distributions.

    See "A Distributional Perspective on Reinforcement Learning", by
    Bellemere, Dabney and Munos (https://arxiv.org/pdf/1707.06887.pdf).

    Args:
        q_atoms_tm1: atoms of Q distribution at time t-1, shape (num_atoms,).
        q_logits_tm1: logits of Q distribution at time t-1, shape (num_actions, num_atoms).
        a_tm1: action index at time t-1.
        r_t: reward at time t, scalar tensor.
        discount_t: discount at time t, scalar tensor.
        q_atoms_t: atoms of Q distribution at time t, shape (num_atoms,).
        q_logits_t: logits of Q distribution at time t, shape (num_actions, num_atoms).
        stop_target_gradients: bool indicating whether to apply stop gradient to targets.

    Returns:
        Categorical Q-learning loss (temporal difference error).
    """
    # Input validation (matching chex assertions)
    assert q_atoms_tm1.ndim == 1, f"q_atoms_tm1 must be 1D, got {q_atoms_tm1.ndim}D"
    assert q_logits_tm1.ndim == 2, f"q_logits_tm1 must be 2D, got {q_logits_tm1.ndim}D"
    assert isinstance(a_tm1, int), f"a_tm1 must be int, got {type(a_tm1)}"
    assert r_t.ndim == 0, f"r_t must be scalar, got {r_t.ndim}D"
    assert discount_t.ndim == 0, f"discount_t must be scalar, got {discount_t.ndim}D"
    assert q_atoms_t.ndim == 1, f"q_atoms_t must be 1D, got {q_atoms_t.ndim}D"
    assert q_logits_t.ndim == 2, f"q_logits_t must be 2D, got {q_logits_t.ndim}D"

    assert q_atoms_tm1.dtype.is_floating_point, "q_atoms_tm1 must be float"
    assert q_logits_tm1.dtype.is_floating_point, "q_logits_tm1 must be float"
    assert r_t.dtype.is_floating_point, "r_t must be float"
    assert discount_t.dtype.is_floating_point, "discount_t must be float"
    assert q_atoms_t.dtype.is_floating_point, "q_atoms_t must be float"
    assert q_logits_t.dtype.is_floating_point, "q_logits_t must be float"

    # Scale and shift time-t distribution atoms by discount and reward
    target_z = r_t + discount_t * q_atoms_t

    # Convert logits to distribution and find greedy action in state s_t
    q_t_probs = F.softmax(q_logits_t, dim=-1)  # Shape: (num_actions, num_atoms)
    # Compute mean Q-value for each action: sum(probs * atoms) over atoms
    q_t_mean = torch.sum(q_t_probs * q_atoms_t.unsqueeze(0), dim=1)  # Shape: (num_actions,)
    pi_t = torch.argmax(q_t_mean)  # Greedy action index

    # Get distribution for greedy action
    p_target_z = q_t_probs[pi_t]  # Shape: (num_atoms,)

    # Project using L2 projection (Cramer distance)
    target = categorical_l2_project(target_z, p_target_z, q_atoms_tm1)  # Use PyTorch version of this function

    # Stop gradient flow to targets if required
    if stop_target_gradients:
        target = target.detach()

    # Compute cross-entropy loss between target distribution and logits of taken action
    logit_qa_tm1 = q_logits_tm1[a_tm1]  # Logits for action a_tm1: (num_atoms,)
    # Cross-entropy between target (probabilities) and logits (after softmax)
    loss = -torch.sum(target * F.log_softmax(logit_qa_tm1, dim=0))

    return loss