Skip to content

Add utilities to distributions #1583

@noahfarr

Description

@noahfarr

In pytorch, the following is easily possible:

logits = ...
probs = Categorical(logits=logits)
log_prob = probs.log_prob(value)
entropy = probs.entropy()

but when I want to achieve something similar in MLX, I have to manually calculate the log_prob and entropy. Is it possible to add support for these methods as it makes working with distributions in MLX much more convenient (at least for me)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions