In [1]:
import math
from typing import Callable, Generic, TypeVar

import torch
from torch.distributions import Normal, Uniform



# Custom PPL

In [2]:
T = TypeVar("T")


class ProbCtx:
    """Probabilistic context: keeps track of a probabilistic execution (samples, weight, etc.)"""

    def __init__(self, trace: torch.Tensor) -> None:
        self.idx = 0
        """Index/address of the next sample variable"""
        self.samples = trace.clone().detach()
        """Sampled values so far in the trace"""
        self.samples.requires_grad_(True)
        """The given sample vector"""
        self.is_cont: torch.Tensor = torch.ones(self.samples.shape, dtype=torch.bool)
        """Whether the sampled value is continuous.

        A sample is discontinuous if a branch in the program depends on it."""
        self.log_weight = torch.tensor(0.0, requires_grad=True)
        """Logarithm of the weight.

        The weight by the pdf for sample()s as well as score()s"""
        self.log_score = torch.tensor(0.0, requires_grad=True)
        """Logarithm of the score.

        The score is only multiplied for score()"""

        """ The log weight of the trace given """
        self.sample_logps = torch.zeros(trace.size())
        """Records the log probability of each sample."""

    def constrain(
        self,
        sample,
        geq: float = None,
        lt: float = None,
    ) -> None:
        """Constrains the sample to be >= geq and < lt.

        This is necessary for random variables whose support isn't all reals.

        Args:
            sample: the sample to be constrained
            geq (float, optional): The lower bound. Defaults to None.
            lt (float, optional): The upper bound. Defaults to None.
        """
        if lt is not None:
            if sample >= lt:
                self.score_log(torch.tensor(-math.inf))
        if geq is not None:
            if sample <= geq:
                self.score_log(torch.tensor(-math.inf))

    def sample(
        self,
        dist: torch.distributions.Distribution,
        is_cont: bool,
    ) -> torch.Tensor:
        """Samples from the given distribution.

        If the distribution has support not on all reals, this needs to be followed by suitable constrain() calls.

        Args:
            dist (torch.distributions.Distribution): the distribution to sample from
            is_cont (bool): whether or not the weight function is continuous in this variable

        Returns:
            the sample
        """
        samples = self.sample_n(1, dist, is_cont)
        return samples[0]

    def sample_n(
        self,
        n: int,
        dist: torch.distributions.Distribution,
        is_cont: bool,
    ) -> torch.Tensor:
        """Samples n times from the given distribution.

        Args:
            n (int): the number of samples
            dist (torch.distributions.Distribution): the distribution to sample from
            is_cont (bool): whether or not the weight function is continuous in this variable

        Returns:
            the samples
        """
        needed = self.idx + n - len(self.samples)
        if needed > 0:
            values = dist.sample((needed,))
            values.requires_grad_(True)
            self.samples = torch.cat((self.samples, values))
            self.sample_logps = torch.cat([self.sample_logps, torch.zeros((needed,))])
            self.is_cont = torch.cat(
                (self.is_cont, torch.ones(needed, dtype=torch.bool))
            )
        for i in range(self.idx, self.idx + n):
            if math.isnan(self.samples[i]):
                self.samples[i] = dist.sample(())
        values = self.samples[self.idx : self.idx + n]
        self.sample_logps[self.idx : self.idx + n] = dist.log_prob(values)
        self.is_cont[self.idx : self.idx + n] = torch.tensor(is_cont).repeat(n)
        self.log_weight = self.log_weight + torch.sum(dist.log_prob(values))
        self.idx += n
        return values

    def score(self, weight: torch.Tensor) -> None:
        """Multiplies the current trace by the given weight.

        Args:
            weight (torch.Tensor): the weight.
        """
        assert torch.is_tensor(weight), "weight is not a tensor"
        self.score_log(torch.log(weight))

    def score_log(self, log_weight: torch.Tensor) -> None:
        assert torch.is_tensor(log_weight), "weight is not a tensor"
        self.log_weight = self.log_weight + log_weight
        self.log_score = self.log_score + log_weight

    def observe(
        self,
        obs: torch.Tensor,
        dist: torch.distributions.Distribution,
    ) -> None:
        self.score_log(dist.log_prob(obs))


class ProbRun(Generic[T]):
    """Result of a probabilistic run"""

    def __init__(self, ctx: ProbCtx, value: T) -> None:
        """Creates a probabilistic run result.

        Undocumented fields are the same as for ProbCtx

        Args:
            ctx (ProbCtx): the probabilistic context used for the program.
            value (T): the return value of the probabilistic program.
        """
        self._gradU: torch.Tensor = None
        """Caches the gradient."""
        self.log_weight = ctx.log_weight
        self.log_score = ctx.log_score
        self.samples = ctx.samples
        self.len = ctx.idx
        """Number of sample statements encountered, i.e. length of the trace."""
        self.is_cont = ctx.is_cont
        self.value = value
        """Returned value of the probabilistic program."""
        self.sample_logps = ctx.sample_logps

    def gradU(self) -> torch.Tensor:
        if self._gradU is not None:
            return self._gradU
        U = -self.log_weight
        (self._gradU,) = torch.autograd.grad(U, self.samples, allow_unused=True)
        if self._gradU is None:
            self._gradU = torch.zeros(self.samples.shape)
        return self._gradU

    def used_samples(self) -> torch.Tensor:
        return self.samples[: self.len]


def run_prob_prog(program: Callable[[ProbCtx], T], trace: torch.Tensor) -> ProbRun[T]:
    """Runs the given probabilistic program on the given trace.

    Args:
        program (Callable[[ProbCtx], T]): the probabilistic program.
        trace (torch.Tensor): the trace to replay.

    Returns:
        ProbRun: the result of the probabilistic run.
    """
    tensor_trace = trace
    while True:
        ctx = ProbCtx(tensor_trace)
        ret = None
        try:
            ret = program(ctx)
        except Exception as e:
            if ctx.log_score.item() > -math.inf or ctx.log_weight.item() > -math.inf:
                print("Exception in code with nonzero weight!")
                raise e
            else:
                print("Info: exception in branch with zero weight")
        if ctx.idx > len(tensor_trace):
            tensor_trace = ctx.samples
            continue
        return ProbRun(ctx, ret)


In [3]:
def walk_model(ctx: ProbCtx) -> float:
    """Random walk model.

    Mak et al. (2020): Densities of almost-surely terminating probabilistic programs are differentiable almost everywhere.
    """
    distance = torch.tensor(0.0, requires_grad=True)
    start = ctx.sample(Uniform(0, 3), is_cont=False)
    position = start
    while position > 0 and distance < 10:
        step = ctx.sample(Uniform(-1, 1), is_cont=False)
        distance = distance + torch.abs(step)
        position = position + step
    ctx.observe(distance, Normal(1.1, 0.1))
    return start.item()


In [4]:
trace = torch.tensor([])
ctx = ProbCtx(trace)
walk_model(ctx)
dir(ctx)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'constrain',
 'idx',
 'is_cont',
 'log_score',
 'log_weight',
 'observe',
 'sample',
 'sample_logps',
 'sample_n',
 'samples',
 'score',
 'score_log']

In [5]:
ctx.samples

tensor([ 1.6704,  0.7230, -0.3760, -0.7057, -0.3816,  0.5948,  0.6755, -0.2594,
         0.0581, -0.6915,  0.7120, -0.2331, -0.4883,  0.4422, -0.8988, -0.2762,
         0.2867, -0.7014, -0.7489], grad_fn=<CatBackward0>)

# Pyro

In [6]:
import pyro
import torch

def pyro_walk_model():
    start = pyro.sample("start", pyro.distributions.Uniform(0, 3))
    t = 0
    position = start
    distance = torch.tensor(0.0)
    while position > 0 and position < 10:
        step = pyro.sample(f"step_{t}", pyro.distributions.Uniform(-1, 1))
        distance = distance + torch.abs(step)
        position = position + step
        t = t + 1
    pyro.sample("obs", pyro.distributions.Normal(1.1, 0.1), obs=distance)
    return start.item()


In [28]:
pyro_trace = pyro.poutine.trace(pyro_walk_model).get_trace()
pyro_trace.nodes

OrderedDict([('_INPUT',
              {'name': '_INPUT', 'type': 'args', 'args': (), 'kwargs': {}}),
             ('start',
              {'type': 'sample',
               'name': 'start',
               'fn': Uniform(low: 0.0, high: 3.0),
               'is_observed': False,
               'args': (),
               'kwargs': {},
               'value': tensor(1.0567),
               'infer': {},
               'scale': 1.0,
               'mask': None,
               'cond_indep_stack': (),
               'done': True,
               'stop': False,
               'continuation': None}),
             ('step_0',
              {'type': 'sample',
               'name': 'step_0',
               'fn': Uniform(low: -1.0, high: 1.0),
               'is_observed': False,
               'args': (),
               'kwargs': {},
               'value': tensor(-0.0357),
               'infer': {},
               'scale': 1.0,
               'mask': None,
               'cond_indep_stack': (),
   

In [29]:
pyro_trace.nodes['step_3']['fn'].log_prob(pyro_trace.nodes['step_3']['value'])

tensor(-0.6931)

In [34]:
pyro_trace.nodes['step_0']['value'] = torch.tensor(-0.99)
pyro_trace.nodes['step_1']['value'] = torch.tensor(-0.99)
pyro_trace = pyro.poutine.trace(pyro.poutine.replay(pyro_walk_model, trace=pyro_trace)).get_trace()

In [35]:
pyro_trace.nodes

OrderedDict([('_INPUT',
              {'name': '_INPUT', 'type': 'args', 'args': (), 'kwargs': {}}),
             ('start',
              {'type': 'sample',
               'name': 'start',
               'fn': Uniform(low: 0.0, high: 3.0),
               'is_observed': False,
               'args': (),
               'kwargs': {},
               'value': tensor(1.0567),
               'infer': {},
               'scale': 1.0,
               'mask': None,
               'cond_indep_stack': (),
               'done': True,
               'stop': False,
               'continuation': None}),
             ('step_0',
              {'type': 'sample',
               'name': 'step_0',
               'fn': Uniform(low: -1.0, high: 1.0),
               'is_observed': False,
               'args': (),
               'kwargs': {},
               'value': tensor(-0.9900),
               'infer': {},
               'scale': 1.0,
               'mask': None,
               'cond_indep_stack': (),
   