In [1]:
from typing import Optional

import torch
from torch.distributions.categorical import Categorical
from torch import einsum
from einops import  reduce


class CategoricalMasked(Categorical):
    def __init__(self, logits: torch.Tensor, mask: Optional[torch.Tensor] = None):
        self.mask = mask
        self.batch, self.nb_action = logits.size()
        if mask is None:
            super(CategoricalMasked, self).__init__(logits=logits)
        else:
            self.mask_value = torch.tensor(
                torch.finfo(logits.dtype).min, dtype=logits.dtype
            )
            logits = torch.where(self.mask, logits, self.mask_value)
            super(CategoricalMasked, self).__init__(logits=logits)

    def entropy(self):
        if self.mask is None:
            return super().entropy()
        # Elementwise multiplication
        p_log_p = einsum("ij,ij->ij", self.logits, self.probs)
        # Compute the entropy with possible action only
        p_log_p = torch.where(
            self.mask,
            p_log_p,
            torch.tensor(0, dtype=p_log_p.dtype, device=p_log_p.device),
        )
        return -reduce(p_log_p, "b a -> b", "sum", b=self.batch, a=self.nb_action)

In [14]:
a = Categorical
print(dir(a))

b = a(torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]))
print(b.entropy())
print(b.logits)
print(b.probs)

['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_extended_shape', '_get_checked_instance', '_new', '_validate_args', '_validate_sample', 'arg_constraints', 'batch_shape', 'cdf', 'entropy', 'enumerate_support', 'event_shape', 'expand', 'has_enumerate_support', 'has_rsample', 'icdf', 'log_prob', 'logits', 'mean', 'mode', 'param_shape', 'perplexity', 'probs', 'rsample', 'sample', 'sample_n', 'set_default_validate_args', 'stddev', 'support', 'variance']
tensor([1.0114, 1.0114])
tensor([[-1.7918, -1.0986, -0.6931],
        [-1.7918, -1.0986, -0.6931]])
tensor([[0.1667, 0.3333, 0.5000],
        [0.1667, 0.3333, 0.5000]])


In [11]:
class CategoricalMasking(Categorical):
    def __init__(self, logits: torch.Tensor, mask: Optional[torch.Tensor]):
        self.mask = mask
        self.batch, self.nb_action = logits.size()
        self.mask_value = torch.tensor(
            torch.finfo(logits.dtype).min, dtype=logits.dtype
        )
        logits = torch.where(self.mask, logits, self.mask_value)
        super(CategoricalMasking, self).__init__(logits=logits)

    def entropy(self):
        # Elementwise multiplication
        p_log_p = einsum("ij,ij->ij", self.logits, self.probs)
        # Compute the entropy with possible action only
        p_log_p = torch.where(
            self.mask,
            p_log_p,
            torch.tensor(0, dtype=p_log_p.dtype, device=p_log_p.device),
        )
        return -reduce(p_log_p, "b a -> b", "sum", b=self.batch, a=self.nb_action)


logits = torch.randn(2, 100, requires_grad=True)
mask = torch.zeros(2, 100, dtype=torch.bool)
mask[:, 0] = True
mask[:, 1] = True
mask[:, 2] = True

dist = CategoricalMasking(logits = logits, mask = mask)

print(dist.probs)

tensor([[0.1594, 0.0198, 0.8208, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0430, 0.0455, 0.9115, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0

In [5]:
logits_or_qvalues = torch.randn((2, 10000), requires_grad=True) # batch size, nb action
print(logits_or_qvalues) 


mask = torch.zeros((2, 10000), dtype=torch.bool) # batch size, nb action
mask[0][2] = True
mask[1][0] = True
mask[1][1] = True
print(mask) # False -> mask action 


tensor([[ 1.2366e-01, -7.6685e-01,  2.4841e-01,  ...,  8.3988e-01,
          2.2024e+00,  6.5887e-01],
        [-5.1119e-01, -1.0158e+00, -1.6483e-03,  ...,  6.6559e-01,
         -5.9711e-02,  5.7094e-01]], requires_grad=True)
tensor([[False, False,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]])


In [6]:
head = CategoricalMasked(logits=logits_or_qvalues)
print(head.probs) # Impossible action are not masked
# tensor([[0.0447, 0.8119, 0.1434], There remain 3 actions available
#         [0.2745, 0.6353, 0.0902]]) There remain 3 actions available

head_masked = CategoricalMasked(logits=logits_or_qvalues, mask=mask)
print(head_masked.probs) # Impossible action are  masked
# tensor([[0.0000, 0.0000, 1.0000], There remain 1 actions available
#         [0.3017, 0.6983, 0.0000]]) There remain 2 actions available

print(head.entropy())
# tensor([0.5867, 0.8601])

print(head_masked.entropy())
# tensor([-0.0000, 0.6123])

tensor([[6.7861e-05, 2.7853e-05, 7.6876e-05,  ..., 1.3889e-04, 5.4248e-04,
         1.1589e-04],
        [3.6736e-05, 2.2179e-05, 6.1149e-05,  ..., 1.1917e-04, 5.7699e-05,
         1.0841e-04]], grad_fn=<SoftmaxBackward0>)
tensor([[0.0000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.6235, 0.3765, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<SoftmaxBackward0>)
tensor([8.7067, 8.7181], grad_fn=<NegBackward0>)
tensor([-0.0000, 0.6623], grad_fn=<NegBackward0>)


In [15]:
7*44*44

13552