In [1]:
import pyprob
from pyprob import Model
import numpy as np
from pyprob.distributions import Normal, Uniform
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [2]:
c_vac = 3E8

def sph_to_cart(theta, phi=0, r=1):
    """Transform spherical to cartesian coordinates."""
    x = r * torch.sin(theta) * torch.cos(phi)
    y = r * torch.sin(theta) * torch.sin(phi)
    z = r * torch.cos(theta)

    return torch.tensor([x, y, z], device=device)




class ConditioningOnTailExample(Model):
    def __init__(self, target_x, target_r):
        super().__init__(name="ConditioningOnTailExample")

        self.target_x = target_x
        self.target_r = target_r



    def photon_sphere_intersection(self, 
        photon_x, photon_p, step_size
    ):
        """
        Calculate intersection.

        Given a photon origin, a photon direction, a step size, a target location and a target radius,
        calculate whether the photon intersects the target and the intrsection point.

        Parameters:
            photon_x: float[3]
            photon_p: float[3]
            step_size: float

        Returns:
            tuple(bool, float[3])
                True and intersection position if intersected.
        """
        p_normed = photon_p  # assume normed

        a = torch.dot(p_normed, (photon_x - self.target_x))
        b = a**2 - (torch.linalg.norm(photon_x - target_x) ** 2 - self.target_r**2)
        # Distance of of the intersection point along the line
        d = -a - torch.sqrt(b)

        isected = (b >= 0) & (d > 0) & (d < step_size)

        if isected:
            return True, photon_x + d * p_normed

        else:
            return False, torch.tensor([1E8, 1E8, 1E8], device=device)
   
    def scattering_func(self, g=0.97):
        """Henyey-Greenstein scattering in one plane."""
        eta = pyprob.sample(Uniform(0, 1), name="scattering_eta")
        costheta = (
            1 / (2 * g) * (1 + g**2 - ((1 - g**2) / (1 + g * (2 * eta - 1))) ** 2)
        )
        return torch.arccos(costheta)

    def initialize_direction_isotropic(self):
        """Draw direction uniformly on a sphere."""
        
        theta = torch.arccos(pyprob.sample(Uniform(-1, 1), name="init_theta"))
        phi = pyprob.sample(Uniform(0, 2*np.pi), name="init_phi")
        direction = sph_to_cart(theta, phi, r=1)

        return direction


    def calc_new_direction(self, old_dir):
        """
        Calculate new direction after sampling a scattering angle.

        Scattering is calculated in a reference frame local
        to the photon (e_z) and then rotated back to the global coordinate system.
        """

        theta = self.scattering_func()

        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)

        phi = pyprob.sample(Uniform(0, 2 * np.pi), name="scattering_phi")
        cos_phi = torch.cos(phi)
        sin_phi = torch.sin(phi)

        px, py, pz = old_dir

        is_para_z = torch.abs(pz) == 1

        if is_para_z:
            new_dir = torch.tensor(
                [
                    sin_theta * cos_phi,
                    torch.sign(pz) * sin_theta * sin_phi,
                    torch.sign(pz) * cos_theta,
                ],
                device=device
            )
        else:
            new_dir = torch.tensor(
                 [
                    (px * cos_theta)
                    + (
                        (sin_theta * (px * pz * cos_phi - py * sin_phi))
                        / (torch.sqrt(1.0 - pz**2))
                    ),
                    (py * cos_theta)
                    + (
                        (sin_theta * (py * pz * cos_phi + px * sin_phi))
                        / (torch.sqrt(1.0 - pz**2))
                    ),
                    (pz * cos_theta) - (sin_theta * cos_phi * torch.sqrt(1.0 - pz**2)),
                ],
                device=device
            )

        # Need this for numerical stability?
        new_dir = new_dir / torch.linalg.norm(new_dir)

        return new_dir

    def scattering_length_function(self, wavelength):
        return 40.

    def ref_index_func(self, wavelength):
        return 1.32

    def step(self, photon_state):
        """Single photon step."""
        pos = photon_state["pos"]
        dir = photon_state["dir"]
        time = photon_state["time"]
        isec = photon_state["intersected"]
        stepcnt = photon_state["steps"]
        wavelength = photon_state["wavelength"]

        sca_coeff = 1 / self.scattering_length_function(wavelength)
        c_medium = (
            c_vac * 1e-9 / self.ref_index_func(wavelength)
        )  # m/ns

        eta = pyprob.sample(Uniform(0, 1), name="step_len_eta") # could just sample from expon
        step_size = -torch.log(eta) / sca_coeff

        dstep = step_size * dir
        new_pos = pos + dstep
        new_time = time + step_size / c_medium

        # Calculate intersection
        isec, isec_pos = self.photon_sphere_intersection(
            photon_x=pos,
            photon_p=dir,
            step_size=step_size,
        )

        if isec:
            new_pos = isec_pos
            new_time =  time + torch.linalg.norm(pos - isec_pos) / c_medium
            new_dir = dir
        else:
            new_dir = self.calc_new_direction(dir)
            stepcnt += 1

        new_photon_state = {
            "pos": new_pos,
            "dir": new_dir,
            "time": new_time,
            "intersected": isec,
            "steps": stepcnt,
            "wavelength": wavelength,
        }

        return new_photon_state

    def loop(self, state):

        while((state["steps"] < 10) and not state["intersected"]):
            state = self.step(state)
        return state


    def forward(self):
        state = {"dir": self.initialize_direction_isotropic(),
                 "pos": torch.tensor([0., 0., 0.], device=device),
                 "steps": 0,
                 "time": 0,
                 "intersected": False,
                 "wavelength": 450}

        state = self.loop(state)

        probs = [0,1] if state["intersected"] else [1,0]

        obs_distr = pyprob.distributions.Categorical(probs)

        pyprob.observe(obs_distr, name='obs0')
        return state



In [4]:
target_x = torch.tensor([0., 0., 10.], device=device)
target_r = 0.5
model = ConditioningOnTailExample(target_x, target_r)

In [5]:
model.learn_inference_network(
    num_traces=10000,
    observe_embeddings={'obs0': {'dim': 64, 'depth': 3}},
    inference_network=pyprob.InferenceNetwork.LSTM
)



Creating new inference network...
Observable obs0: reshape not specified, using shape torch.Size([]).
Observable obs0: using embedding dim torch.Size([64]).
Observable obs0: observe embedding not specified, using the default FEEDFORWARD.
Observable obs0: using embedding depth 3.
Observe embedding dimension: 64
Train. time | Epoch| Trace     | Init. loss| Min. loss | Curr. loss| T.since min | Learn.rate| Traces/sec
New layers, address: 20__forward__initialize_direction_isotropic__?__Un..., distribution: Uniform
New layers, address: 48__forward__initialize_direction_isotropic__phi__..., distribution: Uniform
New layers, address: 96__forward__loop__step__eta__Uniform__1, distribution: Uniform
New layers, address: 16__forward__loop__step__calc_new_direction__scatt..., distribution: Uniform
New layers, address: 50__forward__loop__step__calc_new_direction__phi__..., distribution: Uniform
New layers, address: 96__forward__loop__step__eta__Uniform__2, distribution: Uniform
New layers, address:

In [6]:
condition  = {'obs0': 1}

prior = model.prior(
    num_traces=2000,
)
posterior = model.posterior(
    num_traces=2000,
    inference_engine=pyprob.InferenceEngine.IMPORTANCE_SAMPLING_WITH_INFERENCE_NETWORK,
    observe=condition
)

Time spent  | Time remain.| Progress             | Trace     | ESS    | Traces/sec
0d:00:01:21 | 0d:00:00:00 | #################### | 2000/2000 | 2000.00 | 24.67       
Time spent  | Time remain.| Progress             | Trace     | ESS    | Traces/sec
0d:00:14:23 | 0d:00:00:00 | #################### | 2000/2000 |   3.00 | 2.32        


In [7]:
prior_samples = [prior.sample() for x in range(1000)]
post_samples  = [posterior.sample() for x in range(1000)]



In [15]:
post_samples[15].variables

[Variable(name:init_theta, observable:True, observed:False, tagged:False, control:True, address:20__forward__initialize_direction_isotropic__?__Uniform__1, distribution:Uniform(low=-1.0, high=1.0), value:tensor(0.9995), log_importance_weight:0.001259922981262207, log_prob:tensor(-0.6931)),
 Variable(name:init_phi, observable:True, observed:False, tagged:False, control:True, address:48__forward__initialize_direction_isotropic__phi__Uniform__1, distribution:Uniform(low=0.0, high=6.2831854820251465), value:tensor(0.4760), log_importance_weight:0.0015921592712402344, log_prob:tensor(-1.8379)),
 Variable(name:step_len_eta, observable:True, observed:False, tagged:False, control:True, address:96__forward__loop__step__eta__Uniform__1, distribution:Uniform(low=0.0, high=1.0), value:tensor(0.3114), log_importance_weight:-0.0006515979766845703, log_prob:tensor(0.)),
 Variable(name:obs0, observable:True, observed:True, tagged:False, control:False, address:94__forward__?__Categorical(len_probs:2)__

In [12]:
posterior

Empirical(items:2000, weighted:True)