# Choice of proposal distribution

We investigate the effect of the proposal distribution when learning an unnormalised model

In [None]:
import sys
import torch
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

sys.path.append("..")
from src.nce.cnce import CondNceCrit
from src.nce.rank import NceRankCrit

from src.noise_distr.normal import MultivariateNormal
from src.models.gaussian_model import DiagGaussianModel

from src.training.model_training import train_model, train_model_model_proposal
from src.data.normal import MultivariateNormalData
from src.training.training_utils import Mse, MvnKlDiv, no_change_stopping_condition, no_stopping
%load_ext autoreload
%autoreload 2

def mvn_curve(mu, cov, std=1, res=100):
    with torch.no_grad():
        angles = torch.linspace(0, 2*torch.pi, res)
        curve_param = torch.column_stack((torch.cos(angles), torch.sin(angles)))
        ellipsis = std * curve_param @ torch.Tensor(sqrtm(cov))
        return mu + ellipsis
    
def plot_mvn(levels, ax, label):
    ax.plot(levels[:, 0], levels[:, 1], label=label)

def plot_distrs_ideal(p_d, p_t_d, p_t_t):    
    fig, ax = plt.subplots()
    ax.set_xlim([-3, 10])
    ax.set_ylim([-3, 10])
    distrs = [
        (p_d.mu, p_d.cov, "$p_{d}}$"),
        (p_t_d.mu, p_t_d.cov(), "$q=p_d$"),
        (p_t_t.mu, p_t_t.cov(), "$q = p_{\\theta}$")
    ]
    for mu, cov, label in distrs:
        plot_mvn(mvn_curve(mu, cov), ax, label)
    ax.set_title("Comparison, optimal proposal distrs.")
    ax.legend()

def plot_distrs_adaptive(p_d, p_theta, q_phi):    
    fig, ax = plt.subplots()
    ax.set_xlim([-3, 10])
    ax.set_ylim([-3, 10])
    distrs = [
        (p_d.mu, p_d.cov, "$p_{d}}$"),
        (p_theta.mu, p_theta.cov(), "$p_{\\theta}$"),
        (q_phi.mu, q_phi.cov(), "$q_{\\varphi}$")
    ]
    for mu, cov, label in distrs:
        plot_mvn(mvn_curve(mu, cov), ax, label)
    ax.set_title("Adaptive proposal")
    ax.legend()

# Common setup

In [None]:
D, N, J = 5, 100, 10 # Dimension, Num. data samples, Num neg. samples
mu_star, cov_star = torch.ones(D,), torch.eye(D)

# Data distribution
p_d = MultivariateNormal(mu_star, cov_star)
# Model distribution
init_mu, init_cov =5.0*torch.ones(D,), 4.0*torch.eye(D)

# Optimisation
num_epochs = 2500
batch_size = N
learn_rate = 0.01*batch_size**0.5

# Metrics
kl_div = MvnKlDiv(p_d.mu, p_d.cov).metric
mse = Mse(p_d.mu).metric
metric = kl_div


# Idealistic case

Assuming that we can evaluate and sample from $p_d$ and $p_\theta$,
which is the better alternative as the proposal distribution $q$? 

In [None]:
# q = p_d
p_t_data_noise = DiagGaussianModel(init_mu.clone(), init_cov.clone())
criterion = NceRankCrit(p_t_data_noise, p_d, J)
print("Training with q = p_d")

training_data = MultivariateNormalData(mu_star, cov_star, N)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
p_d_losses, p_d_metrics = train_model(criterion,
                  metric,
                  train_loader,
                  None,
                  neg_sample_size=J,
                  num_epochs=num_epochs,
                  stopping_condition=no_change_stopping_condition,
                  lr=learn_rate)
# q = p_theta
p_t_model_noise = DiagGaussianModel(init_mu.clone(), init_cov.clone())
print("Training with q = p_theta")

training_data = MultivariateNormalData(mu_star, cov_star, N)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
p_t_losses, p_t_metrics = train_model_model_proposal(p_t_model_noise,
                           NceRankCrit,
                           metric,
                           train_loader,
                           None,
                           J,
                           num_epochs,
                           lr=learn_rate)

# Adaptive proposal

Assume that we have a learnable proposal $q_\varphi$.
We jointly learn this proposal by minimising
$$
KL(p_\theta \| q_\varphi) \propto - \mathbb{E}_{x \sim p_\theta} \log q_\varphi(x) = \mathcal{L}_\varphi
$$
with,
$$
\nabla_\varphi \mathcal{L}_\varphi \approx - \sum_{j=0}^J w(x_j) \nabla \log q_\varphi(x_j)
$$

In [None]:
from src.nce.adaptive_rank import AdaptiveRankKernel
from src.noise_distr.adaptive import AdaptiveDiagGaussianModel


def train_model_adaptive_proposal(
    p_theta,
    q_phi,
    p_criterion,
    q_criterion,
    evaluation_metric,
    train_loader,
    save_dir,
    neg_sample_size,
    num_epochs,
    stopping_condition=no_stopping,
    lr: float = 0.1,
):
    """Training loop for adaptive proposal q_phi

    Training loop for jointly learning p_tilde_theta and q_phi.
    Where we assume that we can sample and evaluate q_phi.
    """
    p_optimizer = torch.optim.SGD(p_theta.parameters(), lr=lr)
    q_optimizer = torch.optim.SGD(q_phi.parameters(), lr=lr)
    batch_metrics = []
    batch_metrics.append(evaluation_metric(p_theta))
    batch_losses = []
    for epoch in range(1, num_epochs + 1):
        # print(f"Epoch {epoch}")
        old_params = torch.nn.utils.parameters_to_vector(q_phi.parameters())
        for _, (y, idx) in enumerate(train_loader, 0):

            #with torch.no_grad():
            #    p_loss = p_criterion.crit(y, None)
            #    batch_losses.append(p_loss.item())
            # Calculate and assign gradients
            p_optimizer.zero_grad()
            p_criterion.calculate_crit_grad(y, idx)
            p_optimizer.step()

            q_optimizer.zero_grad()
            q_criterion.calculate_crit_grad(y, idx)
            q_optimizer.step()
            with torch.no_grad():
                batch_metrics.append(evaluation_metric(p_theta))
            
        if stopping_condition(
            torch.nn.utils.parameters_to_vector(q_phi.parameters()), old_params
        ):
            print("Training converged")
            break
    return torch.tensor(batch_losses), torch.tensor(batch_metrics)

p_theta = DiagGaussianModel(init_mu.clone(), init_cov.clone())
q_phi = AdaptiveDiagGaussianModel(mu_star.clone(), cov_star.clone())
p_crit = NceRankCrit(p_theta, q_phi, J)
q_crit = AdaptiveRankKernel(p_theta, q_phi, J)

training_data = MultivariateNormalData(mu_star, cov_star, N)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)


adaptive_losses, adaptive_metrics = train_model_adaptive_proposal(p_theta,
                                                        q_phi,
                                                        p_crit,
                                                        q_crit,
                                                        metric,
                                                        train_loader,
                                                        None,
                                                        neg_sample_size=J,
                                                        num_epochs=num_epochs,
                                                        stopping_condition=no_stopping,
                                                        lr=learn_rate)
# plot_distrs_adaptive(p_d, p_theta, q_phi)

In [None]:
fig, ax = plt.subplots()
iters = torch.arange(start=0, end=p_d_metrics.size(0), step=10)
iters = torch.arange(start=0, end=2050, step=10)
ax.plot(iters, p_d_metrics[iters], label="$q=p_d$")
#ax.plot(iters, p_t_metrics[iters], label="$q=p_{\\theta}$")
#ax.plot(iters, adaptive_metrics[iters], label="$q=q_{\\varphi}$")
ax.legend();
# ax.set_xlim([0, 250])
ax.set_title("Choice of proposal distribution")
ax.set_xlabel("Iter. step $t$")
ax.set_ylabel("KL$(p_d || p_{\\theta})$");