In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Dirichlet
from torch.distributions import Independent

In [30]:
random_x = torch.randn((30,20))
#random_x = 0.1*torch.ones((30,20))

In [31]:
head = Dirichlet(random_x)

In [32]:
out = head.rsample()

In [33]:
out.shape

torch.Size([30, 20])

In [34]:
out.sum(dim=-1)

tensor([1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 2.3510e-37, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 2.3510e-37, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00])

In [35]:
all(list((torch.abs(out.sum(dim=-1)-1) < 1.e-4).numpy()))

False

## エラーが出るようなパラメータ―を求める 

以下の例はgoogle colabではエラーがでるが，python3.7環境ではエラーが出ない．

In [7]:
counter_lim = 1000
counter = 0
while True:
    try:
        random_x = torch.randn((30,20))
        head = Dirichlet(random_x)
        out = head.rsample()
    except:
        print(random_x)
        
    counter += 1
    if counter > counter_lim:
        break

## サンプリング値がおかしい場合

In [40]:
counter_lim = 1000
counter = 0
while True:
    random_x = torch.randn((1,20))
    head = Dirichlet(random_x)
    out = head.rsample()
    if not all(list((torch.abs(out.sum(dim=-1)-1) < 1.e-4).numpy())):
        print(random_x)
        
    counter += 1
    if counter > counter_lim:
        break

tensor([[ 1.9957e-01,  3.2629e+00,  4.0906e-01, -8.4434e-01, -8.1881e-01,
          1.1297e+00, -3.5708e-01, -6.6563e-01,  5.9348e-01,  1.6351e-01,
         -2.3143e+00,  1.4304e+00, -9.5094e-01,  1.4831e-01,  4.4692e-01,
         -1.2896e-03, -5.2810e-02,  2.3991e-01,  1.4483e+00,  1.0790e+00]])
tensor([[-0.0215, -0.2960, -0.8306,  0.8585, -1.4803,  0.8773, -0.4882,  0.0924,
         -1.2277, -1.4266, -1.6086,  0.6579, -0.1315, -1.9777, -0.8895,  1.0256,
         -0.0031, -0.9086,  0.8029,  1.0828]])
tensor([[ 1.3403e+00,  1.4186e+00,  1.0830e+00,  6.2232e-01, -6.9022e-01,
         -5.8756e-01,  1.9541e+00,  8.2538e-01,  1.8332e-01,  4.5170e-01,
         -1.0230e+00,  5.3085e-01, -1.4168e-01, -9.4577e-04,  1.1987e+00,
         -5.4771e-01, -7.5666e-01,  1.8050e+00, -2.6698e-01,  2.7226e+00]])
tensor([[-4.2373e-01,  1.9527e+00, -2.3924e+00,  3.8037e-01, -1.3410e+00,
         -7.6841e-01, -4.4676e-01,  2.8070e-01, -5.9809e-04,  8.3449e-01,
         -1.2968e+00,  1.9828e-01, -3.6699e-01,

dirichlet分布のalphaの値は非負でなければならないのでabsをすればよい

In [41]:
counter_lim = 1000
counter = 0
while True:
    random_x = torch.abs(torch.randn((1,20)))
    head = Dirichlet(random_x)
    out = head.rsample()
    if not all(list((torch.abs(out.sum(dim=-1)-1) < 1.e-4).numpy())):
        print(random_x)
        
    counter += 1
    if counter > counter_lim:
        break

0のときは大丈夫

In [44]:
random_x = torch.zeros((30,20))
head = Dirichlet(random_x)
out = head.rsample()
out.sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000])

## モデルを利用する場合

In [7]:
class DirichletHead(nn.Module):
    def __init__(self):
        super(DirichletHead, self).__init__()
        
    def forward(self, alpha):
        return torch.distributions.Dirichlet(alpha)

In [8]:
class Policy(nn.Module):
    def __init__(self, in_channels=3, out_number=20):
        super(Policy, self).__init__()
        self.out_number = out_number
        
        self.conv1 = nn.Conv2d(in_channels, 12, kernel_size=5, padding=2, stride=2)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(12, self.out_number, kernel_size=5, padding=2, stride=(2,3))
        self.bn2 = nn.BatchNorm2d(self.out_number)
        self.conv3 = nn.Conv2d(self.out_number, self.out_number, kernel_size=(4,3), padding=(2,1), stride=(2,3))
        self.bn3 = nn.BatchNorm2d(self.out_number)

        self.avgpool = nn.AvgPool2d(kernel_size=3)
        self.head = DirichletHead()
        
        
    def forward(self, x):
        #from IPython.core.debugger import Pdb; Pdb().set_trace()
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = self.avgpool(x)
        x = torch.reshape(x, (-1, self.out_number))
        #x = F.softmax(x, dim=-1)
        x = F.relu(x)
        x = torch.clamp(x, 0.01, 10)
        out = self.head(x)
        return out

In [9]:
random_x = torch.randn((30, 3, 20, 50))
policy = Policy(3, 20)

In [10]:
out = policy(random_x)
print(out.batch_shape, out.event_shape)
print(out.rsample(torch.Size([])).shape)

torch.Size([30]) torch.Size([20])
torch.Size([30, 20])
