In [None]:
%matplotlib inline
from dpm.train import train
from dpm.visualize import (
    plot_model, plot_models, plot_stats, 
    plot_hist, plot_loss_function
)
from dpm.distributions import Normal
from dpm.divergences import forward_kl, reverse_kl, js_divergence
from dpm.mixture_models import MixtureModel

### Forward KL

In [None]:
# Normal Examples:
p_model = Normal(-7.3, 3.2)
q_model = Normal(0., 1.)

plot_models(p_model, q_model)

In [None]:
# Train
stats = train(p_model, q_model, forward_kl, epochs=4000)
plot_stats(stats, goals=[p_model.loc.item(), p_model.scale.item()])

In [None]:
# Plot Result
plot_models(p_model, q_model)

### Forward KL Bimodal

In [None]:
# plot a bimodal example
bimodal = MixtureModel([Normal(-2., 1.4), Normal(2., 1.4)], [0.5, 0.5])
plot_model(bimodal)

In [None]:
# Show dists
p_model = MixtureModel([Normal(-7., 1.4), Normal(7., 1.4)], [0.5, 0.5])
q_model = Normal(0., 1.)
plot_models(p_model, q_model)

In [None]:
# Train
stats = train(p_model, q_model, forward_kl, epochs=4000)
plot_stats(stats)

In [None]:
# Plot Result
plot_models(p_model, q_model)

In [None]:
# plot loss dynamics
plot_loss_function(foward_kl, p_model=MixtureModel([Normal(-7., 1.4), Normal(7., 1.4)], [0.5, 0.5]))

### Reverse KL Bimodal

In [None]:
# Show dists
p_model = MixtureModel([Normal(-7., 1.4), Normal(7., 1.4)], [0.5, 0.5])
q_model = Normal(0., 1.)
plot_models(p_model, q_model)

In [None]:
# Train (Mode Collapse)
p_model = MixtureModel([Normal(-7., 1.4), Normal(7., 1.4)], [0.5, 0.5])
q_model = Normal(0., 1.)
stats = train(p_model, q_model, forward_kl, epochs=4000)
plot_stats(stats)

In [None]:
# Plot Result
plot_models(p_model, q_model)

In [None]:
# Train 2 (Move to other side)
p_model = MixtureModel([Normal(-7., 1.4), Normal(7., 1.4)], [0.5, 0.5])
q_model = Normal(0., 1.)
stats = train(p_model, q_model, forward_kl, epochs=4000)
plot_stats(stats)

In [None]:
# Plot Result
plot_models(p_model, q_model)

In [None]:
# plot loss dynamics
plot_loss_function(reverse_kl, p_model=MixtureModel([Normal(-7., 1.4), Normal(7., 1.4)], [0.5, 0.5]))

### Jensen-Shannon Divergence

In [None]:
# Show dists
p_model = MixtureModel([Normal(-7., 1.4), Normal(7., 1.4)], [0.5, 0.5])
q_model = Normal(0., 1.)
plot_models(p_model, q_model)

In [None]:
# Train
stats = train(p_model, q_model, forward_kl, epochs=4000)
plot_stats(stats)

In [None]:
# Plot Result
plot_models(p_model, q_model)

In [None]:
# plot loss dynamics
plot_loss_function(js_divergence, p_model=MixtureModel([Normal(-7., 1.4), Normal(7., 1.4)], [0.5, 0.5]))

### F-Divergence?