## Flow-based sampling on the sphere

In [None]:
from math import pi as π

import torch
import matplotlib.pyplot as plt

from vonmises.distributions import VonMisesFisherDensity, MixtureDensity
from vonmises.models import RecursiveFlowS2
from vonmises.transforms import (
    MobiusMixtureTransform,
    RQSplineTransform,
    CircularRQSplineTransform,
    BSplineTransform,
)
from vonmises.utils import Trainer

In [None]:
# ?RecursiveFlowS2
# ?MobiusMixtureTransform
# ?RQSplineTransform
# ?CircularRQSplineTransform
?BSplineTransform
# ?VonMisesFisherDensity
# ?MixtureDensity
# ?Trainer

In [None]:
xy_transformer = MobiusMixtureTransform(10)
z_transformer = BSplineTransform(10)
target = MixtureDensity(
    [
        VonMisesFisherDensity(κ, μ)
        for κ, μ in zip([5, 10, 15], [(0, 0, 1), (1, 0, 0), (0, -1, 0)])
    ],
)
model = RecursiveFlowS2(
    target,
    z_transformer,
    xy_transformer,
    n_layers=1,
    batch_size=3000,
    init_lr=0.01,
    net_hidden_shape=[],
)
trainer = Trainer(3000)

In [None]:
trainer.fit(model)

(metrics,) = trainer.test(model)

In [None]:
from vonmises.utils.plot import heatmap, histogram, scatter

x, logptilde, logp = model.sample(10000)

_ = histogram(x)
_ = scatter(x, colours=logp, s=0.5)