In [4]:
import torch
import numpy as np
from torch.distributions import Normal, Uniform
import sys

In [13]:
def f():
    z1 = Normal(0,1).sample()
    z2 = R1(z1)
    z3 = Uniform(0,2).sample()
    z4 = R2(z2,z3)
    return z1,z2,z3,z4

# problem with tail recursion as the function continuies to evaluate R1 beyond
# the allowed recursion depth. To mitigate this we must add some additional decorators to deal with this
# and hack the python internals. Else, use functional programming. 
class TailRecurseException:
  def __init__(self, args, kwargs):
    self.args = args
    self.kwargs = kwargs

def tail_call_optimized(g):
  """
  This function decorates a function with tail call
  optimization. It does this by throwing an exception
  if it is it's own grandparent, and catching such
  exceptions to fake the tail call optimization.
  
  This function fails if the decorated
  function recurses in a non-tail context.
  """
  def func(*args, **kwargs):
    f = sys._getframe()
    if f.f_back and f.f_back.f_back \
        and f.f_back.f_back.f_code == f.f_code:
      raise TailRecurseException(args, kwargs)
    else:
      while 1:
        try:
          return g(*args, **kwargs)
        except TailRecurseException, e:
          args = e.args
          kwargs = e.kwargs
  func.__doc__ = g.__doc__
  return func


@tail_call_optimized
def R1(z1):
    temp = Normal(z1,1).sample()
    if temp.data > 0:
        return temp
    else:
        return R1(z1)

def R2(z2,z3):
    ''' Need to ensure that we don't return a 
    value generated from a deterministic process, i.e if return temp=0, that
    has not been generated from the while loop - although with R1 defined as it is 
    this will never be the case'''
    temp = 0
    while temp < z2:
        temp = Normal(z3,1).sample()
    return temp

n_samples = 10
samples_1 = torch.zeros(n_samples,1)
samples_2 = torch.zeros(n_samples,1)

for i in range(n_samples):
    z1,z2,z3,z4 = f()
    samples_1[i] = z2
    samples_2[i] = z4
    
print(samples_1)
print(samples_2)

SyntaxError: invalid syntax (<ipython-input-13-7b5284b49774>, line 32)

In [11]:
sys.getrecursionlimit()

3000