In [None]:
import torch

def batch_omp(X, D, sparsity):
    """
    Batched Orthogonal Matching Pursuit.

    Args:
        X (torch.Tensor): Input signals of shape (B, M).
        D (torch.Tensor): Dictionary of shape (M, N), where each column is an atom of dimension M.
        sparsity (int): Number of atoms to select.

    Returns:
        support: (B, sparsity) LongTensor with indices of selected atoms.
        coeffs: (B, sparsity) Tensor with the corresponding coefficients.
        Y_hat: (B, M) Reconstructed signals from the sparse codes.
    """
    B, M = X.shape
    # D is (M, N) as per your specification, so we transpose to get (N, M)
    Dt = D.t()          # Now Dt is (N, M)
    N = Dt.shape[0]     # N atoms

    # Compute initial projections: (B, M) x (M, N) => (B, N)
    projections = X.matmul(Dt)

    # Containers for support indices and coefficients
    support = torch.zeros((B, sparsity), dtype=torch.long, device=X.device)
    coeffs  = torch.zeros((B, sparsity), device=X.device)
    residual = X.clone()

    # Main OMP loop (simplified version)
    for k in range(sparsity):
        # Select the best atom for each sample
        idx = torch.argmax(torch.abs(projections), dim=1)  # (B,)
        support[:, k] = idx

        # Extract the chosen atoms: using Dt (shape (N, M)) to pick rows
        Dk = Dt[idx]  # (B, M)

        # Compute coefficient for the newly selected atom
        alpha = torch.sum(residual * Dk, dim=1) / torch.sum(Dk * Dk, dim=1)
        coeffs[:, k] = alpha

        # Update residual: subtract contribution of the new atom
        recon_new = Dk * alpha.unsqueeze(1)
        residual -= recon_new

        # Update projections for next iteration
        projections = residual.matmul(Dt)

    # Final reconstruction: sum over contributions from all selected atoms
    # Gather the selected atoms: shape should be (B, sparsity, M)
    A_final = Dt[support]        # Using advanced indexing (each row in Dt corresponds to an atom)
    # Reconstruct each signal: (B, sparsity) * (B, sparsity, M) summed over sparsity dimension
    Y_hat = torch.sum(coeffs.unsqueeze(2) * A_final, dim=1)

    return support, coeffs, Y_hat

# Example usage:
B, M, N, sparsity = 32, 128, 512, 10
X = torch.randn(B, M, device='cuda')
D = torch.randn(M, N, device='cuda')  # Dictionary shape (M, N)
support, coeffs, Y_hat = batch_omp(X, D, sparsity)