In [8]:
import torch

def variation_ratio(logits: torch.Tensor) -> torch.Tensor:
    """
    Computes the variation ratio for each sample in a Bayesian model. The variation ratio is
        the proportion of predicted class labels that are not the modal class prediction. 
        Specifically, it is computed as the number of ensemble members that give a different 
        prediction than the modal prediction, divided by the total number of ensemble members.

    Args:
        logits: A tensor of shape (n_samples, n_members, n_classes) representing the logits
            of a Bayesian model, where n_samples is the number of input samples, n_members is
            the number of members in the ensemble, and n_classes is the number of output classes.

    Returns:
        A tensor of shape (n_samples,) containing the variation ratio for each sample.

    Raises:
        ValueError: If logits tensor is not 3-dimensional.

    """
    if logits.ndim != 3:
        raise ValueError(f"Input logits tensor must be 3-dimensional, got shape {logits.shape}")
    n_member = logits.size(1)

    preds_classes = logits.argmax(dim=-1)
    print(preds_classes)

    # Compute the modal prediction for each sample
    modal_preds, _ = torch.mode(preds_classes, dim=1)
    print(modal_preds)

    # Compute a binary mask indicating different predictions
    diff_mask = (preds_classes != modal_preds.unsqueeze(dim=1))

    # Compute the variation ratio
    num_diff = torch.sum(diff_mask, dim=1)
    var_ratio = num_diff.float() / n_member

    return var_ratio

torch.manual_seed(0)
logits = torch.randn(10, 5, 3)
variation_ratio(logits)


tensor([[2, 1, 2, 1, 1],
        [0, 1, 2, 0, 2],
        [0, 2, 2, 2, 2],
        [0, 2, 0, 0, 2],
        [2, 2, 1, 0, 2],
        [2, 1, 0, 2, 2],
        [1, 2, 0, 0, 1],
        [2, 2, 1, 2, 0],
        [2, 2, 1, 2, 0],
        [0, 0, 2, 1, 1]])
tensor([1, 0, 2, 0, 2, 2, 0, 2, 2, 0])


tensor([0.4000, 0.6000, 0.2000, 0.4000, 0.4000, 0.4000, 0.6000, 0.4000, 0.4000,
        0.6000])