In [2]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn, optim
import numpy as np
import pandas as pd
import scipy
import matplotlib as mpl
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
from pyro.distributions import Categorical, Normal, Multinomial, Binomial, MultivariateNormal, Beta, constraints
from pyro.distributions.torch import Bernoulli
import pyro.infer as infer
from pyro.infer import TraceEnum_ELBO, Trace_ELBO
from pyro import poutine
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.messenger import Messenger

plt.style.use('seaborn-v0_8')


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def model(θ):
    sigma = pyro.param("sigma", lambda: torch.tensor(1.0))
    z = pyro.sample("z", dist.Normal(θ, sigma))
    o = z ** 2
    return o

In [4]:
pyro.clear_param_store()
pyro.set_rng_seed(0)
model(0)

tensor(2.3747, grad_fn=<PowBackward0>)

In [5]:
class LogMessenger(Messenger):
    
    def __init__(self, data):
        self.data = data
        self.params = []
    def __call__(self, fn):
        def wrapper(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        
        return wrapper
    
    def __enter__(self):
        self.logp = torch.tensor(0.0)
        return super().__enter__()
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.0)
        return super().__exit__(exc_type, exc_value, traceback)
    
    def _process_message(self, msg):
        print("Process message ", msg)
        # return msg
        return super()._process_message(msg)

    def _pyro_sample(self, msg):
        print(msg)
        msg["value"] = msg["fn"].sample()
        print("Pyro sample Sampled is ", msg["value"])
        value = torch.tensor(1.0)
        print("Setting value to ", value)
        msg["value"] = value
        msg["is_observed"] = True
        self.logp += msg["fn"].log_prob(value).sum()
        msg["logp"] = self.logp
        # return msg

    def _pyro_param(self, msg):
        print("Pyro param: ", msg["name"])
        # msg["value"].requires_grad_(False)
        # return msg
    
    def _postprocess_message(self, msg):
        print()
        print("Postprocess message ", msg["name"], msg["value"])
        print()
        return super()._postprocess_message(msg)

    

In [6]:
pyro.clear_param_store()
pyro.set_rng_seed(0)
with LogMessenger({"z": torch.tensor(0.5)}) as m:
    print("Model result: ", model(0))
    print("Log P: ", m.logp.clone())

Process message  {'type': 'param', 'name': 'sigma', 'fn': <bound method ParamStoreDict.get_param of <pyro.params.param_store.ParamStoreDict object at 0x7f5f3e5ee880>>, 'is_observed': False, 'args': ('sigma', <function model.<locals>.<lambda> at 0x7f604210ae50>), 'kwargs': {'constraint': Real(), 'event_dim': None}, 'value': None, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': False, 'stop': False, 'continuation': None, 'infer': {}}
Pyro param:  sigma

Postprocess message  sigma tensor(1., requires_grad=True)

Process message  {'type': 'sample', 'name': 'z', 'fn': Normal(loc: 0.0, scale: 1.0), 'is_observed': False, 'args': (), 'kwargs': {}, 'value': None, 'infer': {}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': False, 'stop': False, 'continuation': None}
{'type': 'sample', 'name': 'z', 'fn': Normal(loc: 0.0, scale: 1.0), 'is_observed': False, 'args': (), 'kwargs': {}, 'value': None, 'infer': {}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': Fals

In [7]:
class ScaleMessenger(Messenger):
    
    def __init__(self, scale):
        self.scale = scale

    def __enter__(self):
        return super().__enter__()
    
    def __exit__(self, exc_type, exc_value, traceback):
        return super().__exit__(exc_type, exc_value, traceback)
    
    def _process_message(self, msg):
        print("ScaleMessenger Process message ", msg)
        return super()._process_message(msg)

    def _pyro_sample(self, msg):
        print("ScaleMessenger pyro sample: ", msg["name"])

    def _pyro_param(self, msg):
        print("ScaleMessenger pyro param: ", msg["name"])

    def _postprocess_message(self, msg):
        print()
        print("ScaleMessenger Postprocess message ", msg["name"], msg["value"])
        print()
        return super()._postprocess_message(msg)

In [8]:
pyro.clear_param_store()
pyro.set_rng_seed(0)
with ScaleMessenger(2.0) as scale:
    with LogMessenger({"z": torch.tensor(0.5)}) as m:
        model(0.)

Process message  {'type': 'param', 'name': 'sigma', 'fn': <bound method ParamStoreDict.get_param of <pyro.params.param_store.ParamStoreDict object at 0x7f5f3e5ee880>>, 'is_observed': False, 'args': ('sigma', <function model.<locals>.<lambda> at 0x7f605b9fd310>), 'kwargs': {'constraint': Real(), 'event_dim': None}, 'value': None, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': False, 'stop': False, 'continuation': None, 'infer': {}}
Pyro param:  sigma
ScaleMessenger Process message  {'type': 'param', 'name': 'sigma', 'fn': <bound method ParamStoreDict.get_param of <pyro.params.param_store.ParamStoreDict object at 0x7f5f3e5ee880>>, 'is_observed': False, 'args': ('sigma', <function model.<locals>.<lambda> at 0x7f605b9fd310>), 'kwargs': {'constraint': Real(), 'event_dim': None}, 'value': None, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': False, 'stop': False, 'continuation': None, 'infer': {}}
ScaleMessenger pyro param:  sigma

ScaleMessenger Postprocess message 

In [9]:
def model1(θ):
    sigma = pyro.deterministic("sigma", lambda: torch.tensor(1.0))
    z = pyro.sample("z", dist.Normal(θ, sigma))
    o = pyro.deterministic("o", z**2)
    return o
pyro.clear_param_store()
pyro.set_rng_seed(0)
model(0)

tensor(2.3747, grad_fn=<PowBackward0>)

In [12]:
d = dist.Delta(torch.tensor(0.5))
s = d.sample()
s

tensor(0.5000)

In [14]:
d.log_prob(s).exp()

tensor(1.)

In [15]:
d.log_prob(s + 0.1).exp()

tensor(0.)