In [23]:
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.distributions.testing.fakes import NonreparameterizedBeta
from pyro.infer import SVI, TraceGraph_ELBO
import pyro.distributions as dist
import sys

**モデル**

ベルヌーイ分布について考える。

$$
p(x|f) = {\rm Bern}(x|f)
$$

事前分布にはBeta分布を用いる.

$$
p(f)= {\rm Beta}(f| \alpha_0, \beta_0)
$$

**ガイド**

事後分布は

$$
p(f|{\bf X})= {\rm Beta}(f| \alpha_q, \beta_q)
$$

In [30]:
def param_abs_error(name, target):
    return torch.sum(torch.abs(target - pyro.param(name))).item()

class BernoulliBetaExample:
    def __init__(self, data, max_steps):
        self.max_steps = max_steps
        # モデルの事前分布: ベータ分布のパラメータの初期値
        self.alpha0 = 10.0
        self.beta0 = 10.0

        # サイコロのデータ
        self.data = data
        self.n_data = self.data.size(0)
        
        # ガイド(事後分布): ベータ分布のパラメータの初期値
        self.alpha_q_0 = 15.0
        self.beta_q_0 = 15.0

        # 真の事後分布のパラメータを計算する
        self.alpha_n = self.data.sum() + self.alpha0
        self.beta_n = - self.data.sum() + torch.tensor(self.beta0 + self.n_data)

    def model(self, use_decaying_avg_baseline):
        f = pyro.sample("latent_fairness", dist.Beta(self.alpha0, self.beta0))
        with pyro.plate("data_plate"):
            pyro.sample("obs", dist.Bernoulli(f), obs=self.data)

    def guide(self, use_decaying_avg_baseline):
        alpha_q = pyro.param("alpha_q", torch.tensor(self.alpha_q_0), constraint=constraints.positive)
        beta_q = pyro.param("beta_q", torch.tensor(self.beta_q_0), constraint=constraints.positive)

        baseline_dict = {
            'use_decaying_avg_baseline': use_decaying_avg_baseline,
            'baseline_beta': 0.90
        }
        
        pyro.sample("latent_fairness", 
                    NonreparameterizedBeta(alpha_q, beta_q),
                    infer=dict(baseline=baseline_dict))

    def do_inference(self, use_decaying_avg_baseline, tolerance=0.80):
        pyro.clear_param_store()
        
        # optimizerと推論手法
        optimizer = Adam({"lr": .0005, "betas": (0.93, 0.999)})
        svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())

        print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline)

        # Maxstepまで学習
        for k in range(self.max_steps):
            svi.step(use_decaying_avg_baseline)
            if k % 100 == 0:
                print('.', end='')
                sys.stdout.flush()

            # compute the distance to the parameters of the true posterior
            alpha_error = param_abs_error("alpha_q", self.alpha_n)
            beta_error = param_abs_error("beta_q", self.beta_n)

            # stop inference early if we're close to the true posterior
            if alpha_error < tolerance and beta_error < tolerance:
                break

            print("\nDid %d steps of inference." % k)
            print(("Final absolute errors for the two variational parameters " + "were %.4f & %.4f") % (alpha_error, beta_error))


In [31]:
data = torch.zeros(10)
data[:6] = torch.ones(6)

In [32]:
bbe = BernoulliBetaExample(data, max_steps=10000)
bbe.do_inference(use_decaying_avg_baseline=True)
bbe.do_inference(use_decaying_avg_baseline=False)

Doing inference with use_decaying_avg_baseline=True
.
Did 0 steps of inference.
Final absolute errors for the two variational parameters were 1.0075 & 0.9925

Did 1 steps of inference.
Final absolute errors for the two variational parameters were 1.0138 & 0.9976

Did 2 steps of inference.
Final absolute errors for the two variational parameters were 1.0202 & 1.0028

Did 3 steps of inference.
Final absolute errors for the two variational parameters were 1.0189 & 1.0013

Did 4 steps of inference.
Final absolute errors for the two variational parameters were 1.0174 & 0.9994

Did 5 steps of inference.
Final absolute errors for the two variational parameters were 1.0162 & 0.9976

Did 6 steps of inference.
Final absolute errors for the two variational parameters were 1.0157 & 0.9965

Did 7 steps of inference.
Final absolute errors for the two variational parameters were 1.0140 & 0.9942

Did 8 steps of inference.
Final absolute errors for the two variational parameters were 1.0130 & 0.9926

D

.
Did 100 steps of inference.
Final absolute errors for the two variational parameters were 0.8909 & 0.8513

Did 101 steps of inference.
Final absolute errors for the two variational parameters were 0.8892 & 0.8495

Did 102 steps of inference.
Final absolute errors for the two variational parameters were 0.8875 & 0.8479

Did 103 steps of inference.
Final absolute errors for the two variational parameters were 0.8860 & 0.8464

Did 104 steps of inference.
Final absolute errors for the two variational parameters were 0.8838 & 0.8436

Did 105 steps of inference.
Final absolute errors for the two variational parameters were 0.8817 & 0.8409

Did 106 steps of inference.
Final absolute errors for the two variational parameters were 0.8797 & 0.8383

Did 107 steps of inference.
Final absolute errors for the two variational parameters were 0.8778 & 0.8359

Did 108 steps of inference.
Final absolute errors for the two variational parameters were 0.8758 & 0.8335

Did 109 steps of inference.
Final a


Did 31 steps of inference.
Final absolute errors for the two variational parameters were 1.0410 & 1.0114

Did 32 steps of inference.
Final absolute errors for the two variational parameters were 1.0418 & 1.0116

Did 33 steps of inference.
Final absolute errors for the two variational parameters were 1.0425 & 1.0115

Did 34 steps of inference.
Final absolute errors for the two variational parameters were 1.0430 & 1.0113

Did 35 steps of inference.
Final absolute errors for the two variational parameters were 1.0433 & 1.0106

Did 36 steps of inference.
Final absolute errors for the two variational parameters were 1.0422 & 1.0089

Did 37 steps of inference.
Final absolute errors for the two variational parameters were 1.0424 & 1.0089

Did 38 steps of inference.
Final absolute errors for the two variational parameters were 1.0422 & 1.0086

Did 39 steps of inference.
Final absolute errors for the two variational parameters were 1.0428 & 1.0090

Did 40 steps of inference.
Final absolute err


Did 132 steps of inference.
Final absolute errors for the two variational parameters were 0.9297 & 0.8851

Did 133 steps of inference.
Final absolute errors for the two variational parameters were 0.9293 & 0.8842

Did 134 steps of inference.
Final absolute errors for the two variational parameters were 0.9294 & 0.8840

Did 135 steps of inference.
Final absolute errors for the two variational parameters were 0.9294 & 0.8837

Did 136 steps of inference.
Final absolute errors for the two variational parameters were 0.9301 & 0.8842

Did 137 steps of inference.
Final absolute errors for the two variational parameters were 0.9307 & 0.8846

Did 138 steps of inference.
Final absolute errors for the two variational parameters were 0.9313 & 0.8849

Did 139 steps of inference.
Final absolute errors for the two variational parameters were 0.9321 & 0.8853

Did 140 steps of inference.
Final absolute errors for the two variational parameters were 0.9330 & 0.8858

Did 141 steps of inference.
Final ab


Did 223 steps of inference.
Final absolute errors for the two variational parameters were 0.9076 & 0.8381

Did 224 steps of inference.
Final absolute errors for the two variational parameters were 0.9094 & 0.8399

Did 225 steps of inference.
Final absolute errors for the two variational parameters were 0.9108 & 0.8410

Did 226 steps of inference.
Final absolute errors for the two variational parameters were 0.9124 & 0.8424

Did 227 steps of inference.
Final absolute errors for the two variational parameters were 0.9138 & 0.8435

Did 228 steps of inference.
Final absolute errors for the two variational parameters were 0.9158 & 0.8455

Did 229 steps of inference.
Final absolute errors for the two variational parameters were 0.9171 & 0.8467

Did 230 steps of inference.
Final absolute errors for the two variational parameters were 0.9188 & 0.8486

Did 231 steps of inference.
Final absolute errors for the two variational parameters were 0.9203 & 0.8501

Did 232 steps of inference.
Final ab

.
Did 300 steps of inference.
Final absolute errors for the two variational parameters were 0.9601 & 0.9115

Did 301 steps of inference.
Final absolute errors for the two variational parameters were 0.9602 & 0.9119

Did 302 steps of inference.
Final absolute errors for the two variational parameters were 0.9602 & 0.9121

Did 303 steps of inference.
Final absolute errors for the two variational parameters were 0.9605 & 0.9125

Did 304 steps of inference.
Final absolute errors for the two variational parameters were 0.9609 & 0.9130

Did 305 steps of inference.
Final absolute errors for the two variational parameters were 0.9613 & 0.9132

Did 306 steps of inference.
Final absolute errors for the two variational parameters were 0.9622 & 0.9141

Did 307 steps of inference.
Final absolute errors for the two variational parameters were 0.9626 & 0.9145

Did 308 steps of inference.
Final absolute errors for the two variational parameters were 0.9634 & 0.9152

Did 309 steps of inference.
Final a

.
Did 400 steps of inference.
Final absolute errors for the two variational parameters were 0.9664 & 0.9230

Did 401 steps of inference.
Final absolute errors for the two variational parameters were 0.9661 & 0.9228

Did 402 steps of inference.
Final absolute errors for the two variational parameters were 0.9661 & 0.9228

Did 403 steps of inference.
Final absolute errors for the two variational parameters were 0.9656 & 0.9223

Did 404 steps of inference.
Final absolute errors for the two variational parameters were 0.9643 & 0.9211

Did 405 steps of inference.
Final absolute errors for the two variational parameters were 0.9633 & 0.9201

Did 406 steps of inference.
Final absolute errors for the two variational parameters were 0.9621 & 0.9188

Did 407 steps of inference.
Final absolute errors for the two variational parameters were 0.9609 & 0.9174

Did 408 steps of inference.
Final absolute errors for the two variational parameters were 0.9592 & 0.9156

Did 409 steps of inference.
Final a


Did 499 steps of inference.
Final absolute errors for the two variational parameters were 0.9523 & 0.8927
.
Did 500 steps of inference.
Final absolute errors for the two variational parameters were 0.9513 & 0.8919

Did 501 steps of inference.
Final absolute errors for the two variational parameters were 0.9506 & 0.8912

Did 502 steps of inference.
Final absolute errors for the two variational parameters were 0.9491 & 0.8897

Did 503 steps of inference.
Final absolute errors for the two variational parameters were 0.9475 & 0.8882

Did 504 steps of inference.
Final absolute errors for the two variational parameters were 0.9467 & 0.8876

Did 505 steps of inference.
Final absolute errors for the two variational parameters were 0.9465 & 0.8875

Did 506 steps of inference.
Final absolute errors for the two variational parameters were 0.9466 & 0.8878

Did 507 steps of inference.
Final absolute errors for the two variational parameters were 0.9457 & 0.8871

Did 508 steps of inference.
Final a


Did 577 steps of inference.
Final absolute errors for the two variational parameters were 0.8971 & 0.8446

Did 578 steps of inference.
Final absolute errors for the two variational parameters were 0.8976 & 0.8452

Did 579 steps of inference.
Final absolute errors for the two variational parameters were 0.8980 & 0.8455

Did 580 steps of inference.
Final absolute errors for the two variational parameters were 0.8985 & 0.8459

Did 581 steps of inference.
Final absolute errors for the two variational parameters were 0.8972 & 0.8450

Did 582 steps of inference.
Final absolute errors for the two variational parameters were 0.8958 & 0.8439

Did 583 steps of inference.
Final absolute errors for the two variational parameters were 0.8946 & 0.8429

Did 584 steps of inference.
Final absolute errors for the two variational parameters were 0.8938 & 0.8423

Did 585 steps of inference.
Final absolute errors for the two variational parameters were 0.8936 & 0.8421

Did 586 steps of inference.
Final ab


Did 674 steps of inference.
Final absolute errors for the two variational parameters were 0.9564 & 0.9409

Did 675 steps of inference.
Final absolute errors for the two variational parameters were 0.9548 & 0.9394

Did 676 steps of inference.
Final absolute errors for the two variational parameters were 0.9528 & 0.9374

Did 677 steps of inference.
Final absolute errors for the two variational parameters were 0.9503 & 0.9349

Did 678 steps of inference.
Final absolute errors for the two variational parameters were 0.9481 & 0.9327

Did 679 steps of inference.
Final absolute errors for the two variational parameters were 0.9464 & 0.9308

Did 680 steps of inference.
Final absolute errors for the two variational parameters were 0.9445 & 0.9288

Did 681 steps of inference.
Final absolute errors for the two variational parameters were 0.9433 & 0.9274

Did 682 steps of inference.
Final absolute errors for the two variational parameters were 0.9413 & 0.9254

Did 683 steps of inference.
Final ab