## Flow-based sampling on the circle

In [None]:
from math import pi as π

import torch
import matplotlib.pyplot as plt

from vonmises.distributions import VonMisesFisherDensity, MixtureDensity
from vonmises.flows import CircularFlow
from vonmises.transforms import (
    MobiusMixtureTransform,
    CircularRQSplineTransform,
)
from vonmises.model import FlowBasedModel
from vonmises.utils import get_trainer

In [None]:
#?CircularFlow
# ?MobiusMixtureTransform
# ?CircularRQSplineTransform
# ?VonMisesFisherDensity
# ?MixtureDensity
# ?Trainer

In [None]:
# transformer = MobiusMixtureTransform(10)
transformer = CircularRQSplineTransform(10)
target = MixtureDensity(
    [
        VonMisesFisherDensity(κ, μ)
        for κ, μ in zip([5, 10, 15], [(0, 1), (1, 0), (-1, 1)])
    ],
)
flow = CircularFlow(
    transformer,
    n_layers=1,
)
model = FlowBasedModel(flow, target, batch_size=2000)
trainer = get_trainer(2000)

In [None]:
trainer.fit(model)

(metrics,) = trainer.test(model)

In [None]:
from vonmises.viz import CircularFlowVisualiser

vis = CircularFlowVisualiser(model)

vis.histogram(72)

vis.scatter()
vis.scatter_2()
vis.weights()
vis.forces()

#x, log_model_density, log_target_density = model.sample(int(1e6))

#circular_histogram(x)
#model.visualise()
#model.visualise(polar=True)


In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(121)
ax2 = fig.add_subplot(122, projection="polar")
