# Choice of proposal distribution

We investigate the effect of the proposal distribution when learning an unnormalised model

In [None]:
import sys
from functools import partial
import torch
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

sys.path.append("..")
from src.nce.rank import NceRankCrit

from src.data.normal import MultivariateNormal
from src.models.ebm import Ebm
from src.noise_distr.adaptive import AdaptiveMdn
from src.nce.adaptive_rank import AdaptiveRankKernel

from src.training.model_training import train_model, train_model_model_proposal
from src.data.normal import MultivariateNormalData
from src.training.training_utils import Mse, MvnKlDiv, no_change_stopping_condition, no_stopping
from src.experiments.utils import mvn_curve, plot_mvn, plot_distrs_adaptive
%load_ext autoreload
%autoreload 2

# Common setup

In [None]:
D, N, J = 2, 100, 10 # Dimension, Num. data samples, Num neg. samples
mu_star, cov_star = torch.ones(D,), torch.eye(D)

# Data distribution
p_d = MultivariateNormal(mu_star, cov_star)
# Model distribution
# init_mu, init_cov =5.0*torch.ones(D,), 4.0*torch.eye(D)

# Optimisation
num_epochs = 100
batch_size = 16
learn_rate = 0.05*batch_size**0.5
scheduler_opts = (30, 0.9)

# Metrics
kl_div = MvnKlDiv(p_d.mu, p_d.cov).metric
mse = Mse(p_d.mu).metric
metric = kl_div

# Adaptive proposal

Assume that we have a learnable proposal $q_\varphi$.
We jointly learn this proposal by minimising
$$
KL(p_\theta \| q_\varphi) \propto - \mathbb{E}_{x \sim p_\theta} \log q_\varphi(x) = \mathcal{L}_\varphi
$$
with,
$$
\nabla_\varphi \mathcal{L}_\varphi \approx - \sum_{j=0}^J w(x_j) \nabla \log q_\varphi(x_j)
$$

In [None]:
q_phi = AdaptiveMdn(num_components=4, input_dim=2)

x = torch.randn(batch_size, D)
means, sigmas, weights = q_phi.predict_params(x)
inds = q_phi.sample(torch.Size((x.size(0), J)), x)
for ind in inds:
    print(ind)

In [None]:
def train_model_adaptive_proposal(
    p_theta,
    q_phi,
    p_criterion,
    q_criterion,
    evaluation_metric,
    train_loader,
    save_dir,
    neg_sample_size,
    num_epochs,
    stopping_condition=no_stopping,
    lr: float = 0.1,
):
    """Training loop for adaptive proposal q_phi

    Training loop for jointly learning p_tilde_theta and q_phi.
    Where we assume that we can sample and evaluate q_phi.
    """
    p_optimizer = torch.optim.SGD(p_theta.parameters(), lr=lr)
    q_optimizer = torch.optim.SGD(q_phi.parameters(), lr=lr)
    batch_metrics = []
    batch_metrics.append(evaluation_metric(p_theta))
    batch_losses = []
    for epoch in range(1, num_epochs + 1):
        # print(f"Epoch {epoch}")
        old_params = torch.nn.utils.parameters_to_vector(q_phi.parameters())
        for _, (y, idx) in enumerate(train_loader, 0):

            #with torch.no_grad():
            #    p_loss = p_criterion.crit(y, None)
            #    batch_losses.append(p_loss.item())
            # Calculate and assign gradients
            p_optimizer.zero_grad()
            p_criterion.calculate_crit_grad(y, idx)
            p_optimizer.step()

            q_optimizer.zero_grad()
            q_criterion.calculate_crit_grad(y, idx)
            q_optimizer.step()
            with torch.no_grad():
                batch_metrics.append(evaluation_metric(p_theta))
            
        if stopping_condition(
            torch.nn.utils.parameters_to_vector(q_phi.parameters()), old_params
        ):
            print("Training converged")
            break
    return torch.tensor(batch_losses), torch.tensor(batch_metrics)

input_dim = mu_star.size(0)
p_theta = Ebm(input_dim = input_dim)
q_phi = AdaptiveMdn(num_components=4, input_dim)
p_crit = NceRankCrit(p_theta, q_phi, J)
q_crit = AdaptiveRankKernel(p_theta, q_phi, J)

training_data = MultivariateNormalData(mu_star, cov_star, N)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)


adaptive_losses, adaptive_metrics = train_model_adaptive_proposal(p_theta,
                                                        q_phi,
                                                        p_crit,
                                                        q_crit,
                                                        metric,
                                                        train_loader,
                                                        None,
                                                        neg_sample_size=J,
                                                        num_epochs=num_epochs,
                                                        stopping_condition=no_stopping,
                                                        lr=learn_rate)
# plot_distrs_adaptive(p_d, p_theta, q_phi)

In [None]:
fig, ax = plt.subplots()
iters = torch.arange(start=0, end=p_d_metrics.size(0), step=10)
iters = torch.arange(start=0, end=2050, step=10)
ax.plot(iters, nan_pad(p_d_metrics, num_epochs+1)[iters], label="$q=p_d$")
ax.loglog(iters, nan_pad(p_t_metrics, num_epochs+1)[iters], label="$q=p_{\\theta}$")
ax.loglog(iters, adaptive_metrics[iters], label="$q=q_{\\varphi}$")
ax.legend();
# ax.set_xlim([0, 250])
ax.set_title("Choice of proposal distribution")
ax.set_xlabel("Iter. step $t$")
ax.set_ylabel("KL$(p_d || p_{\\theta})$");

In [None]:
from src.experiments.utils import process_plot_data
from src.experiments.adaptive_proposal import plot_kl_div
data = process_plot_data(torch.column_stack((p_d_metrics, p_t_metrics, adaptive_metrics)), num_epochs+1, res=1)

plot_kl_div(data[:, 0], data[:, 1], data[:, 2], data[:, 3])
