# Choice of proposal distribution

We investigate the effect of the proposal distribution when learning an unnormalised model

In [None]:
import sys
import torch
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

sys.path.append("..")
from src.nce.cnce import CondNceCrit
from src.nce.rank import NceRankCrit

from src.noise_distr.normal import MultivariateNormal
from src.models.gaussian_model import DiagGaussianModel

from src.training.model_training import train_model, train_model_model_proposal
from src.data.normal import MultivariateNormalData
from src.training.training_utils import Mse, MvnKlDiv, no_change_stopping_condition, no_stopping
%load_ext autoreload
%autoreload 2

def mvn_curve(mu, cov, std=1, res=100):
    with torch.no_grad():
        angles = torch.linspace(0, 2*torch.pi, res)
        curve_param = torch.column_stack((torch.cos(angles), torch.sin(angles)))
        ellipsis = std * curve_param @ torch.Tensor(sqrtm(cov))
        return mu + ellipsis
    
def plot_mvn(levels, ax, label):
    ax.plot(levels[:, 0], levels[:, 1], label=label)
    
def plot_distrs(p_d, p_t_d, p_t_t):    
    fig, ax = plt.subplots()
    ax.set_xlim([-3, 10])
    ax.set_ylim([-3, 10])
    plot_mvn(mvn_curve(p_d.mu, p_d.cov), ax, "$p_{d}}$")
    plot_mvn(mvn_curve(p_t_d.mu, p_t_d.cov()), ax, "$p_{\\theta}, p_n = p_d$")
    plot_mvn(mvn_curve(p_t_t.mu, p_t_t.cov()), ax, "$p_{\\theta}, p_n = p_{\\theta}$")

    ax.legend()

# Idealistic case

Assume that we can evaluate and sample from $p_d$ and $p_\theta$,
which is the better alternative as the proposal distribution $q$? 

In [None]:
D, N, J = 2, 100, 10 # Dimension, Num. data samples, Num neg. samples
# Data model
mu_star, cov_star = torch.ones(D,), torch.eye(D)
p_d = MultivariateNormal(mu_star, cov_star)

init_mu, init_cov =5.0*torch.ones(D,), 4*torch.eye(D)

num_epochs = 200
batch_size = 20
learn_rate = 0.01*batch_size**0.5

training_data = MultivariateNormalData(mu_star, cov_star, N)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)

# Metrics
kl_div = MvnKlDiv(p_d.mu, p_d.cov).metric
mse = Mse(p_d.mu).metric
metric = kl_div

# q = p_d
p_t_data_noise = DiagGaussianModel(init_mu.clone(), init_cov.clone())
criterion = NceRankCrit(p_t_data_noise, p_d, J)
p_d_losses, p_d_metrics = train_model(criterion,
                  metric,
                  train_loader,
                  None,
                  neg_sample_size=J,
                  num_epochs=num_epochs,
                  stopping_condition=no_change_stopping_condition,
                  lr=learn_rate)

# q = p_theta
p_t_model_noise = DiagGaussianModel(init_mu.clone(), init_cov.clone())
p_t_losses, p_t_metrics = train_model_model_proposal(p_t_model_noise,
                           NceRankCrit,
                           metric,
                           train_loader,
                           None,
                           J,
                           num_epochs,
                           lr=learn_rate)

fig, ax = plt.subplots()
ax.plot(torch.arange(p_d_metrics.size(0)), p_d_metrics, label="$q=p_d$")
ax.plot(torch.arange(p_t_metrics.size(0)), p_t_metrics, label="$q=p_{\\theta}$")
ax.legend();
ax.set_title("Choice of proposal distribution")
ax.set_xlabel("Iter. step $t$")
ax.set_ylabel("KL$(p_d || p_{\\theta})$");

# Adaptive proposal

Assume that we have a learnable proposal $q_\varphi$.
We jointly learn this proposal by minimising
$$
KL(p_\theta \| q_\varphi) \propto - \mathbb{E}_{x \sim p_\theta} \log q_\varphi(x) = \mathcal{L}_\varphi
$$
with,
$$
\nabla_\varphi \mathcal{L}_\varphi \approx \sum_{j=0}^J w(x_j) \log \nabla q_\varphi(x_j)
$$

In [None]:
def train_model_adaptive_proposal(
    p_theta,
    q_phi,
    crit_constructor,
    evaluation_metric,
    train_loader,
    save_dir,
    neg_sample_size: int = 10,
    num_epochs: int = 100,
    stopping_condition=no_stopping,
    lr: float = 0.1,
):
    """Training loop for adaptive proposal q_phi

    Training loop for jointly learning p_tilde_theta and q_phi.
    Where we assume that we can sample and evaluate q_phi.
    """
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    batch_metrics = []
    batch_metrics.append(evaluation_metric(model))
    batch_losses = []
    for epoch in range(1, num_epochs + 1):
        # print(f"Epoch {epoch}")
        old_params = torch.nn.utils.parameters_to_vector(model.parameters())
        for _, (y, idx) in enumerate(train_loader, 0):
            q = MultivariateNormal(
                model.mu.detach().clone(), model.cov().clone().detach().clone()
            )
            criterion = crit_constructor(model, q, neg_sample_size)
            optimizer.zero_grad()
            with torch.no_grad():
                loss = criterion.crit(y, None)
                batch_losses.append(loss.item())
                # print(loss)
            # Calculate and assign gradients
            criterion.calculate_crit_grad(y, idx)

            # Take gradient step
            optimizer.step()

            # running_loss += loss.item()
            batch_metrics.append(evaluation_metric(model))
        if stopping_condition(
            torch.nn.utils.parameters_to_vector(model.parameters()), old_params
        ):
            print("Training converged")
            break
    return torch.tensor(batch_losses), torch.tensor(batch_metrics)
