# Choice of proposal distribution

We investigate the effect of the proposal distribution when learning an unnormalised model

In [None]:
import sys
from functools import partial
import torch
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

sys.path.append("..")
from src.nce.cnce import CondNceCrit
from src.nce.rank import NceRankCrit

from src.noise_distr.normal import MultivariateNormal
from src.models.gaussian_model import DiagGaussianModel

from src.training.model_training import train_model, train_model_model_proposal
from src.data.normal import MultivariateNormalData
from src.training.training_utils import Mse, MvnKlDiv, no_change_stopping_condition, no_stopping
%load_ext autoreload
%autoreload 2

def mvn_curve(mu, cov, std=1, res=100):
    with torch.no_grad():
        angles = torch.linspace(0, 2*torch.pi, res)
        curve_param = torch.column_stack((torch.cos(angles), torch.sin(angles)))
        ellipsis = std * curve_param @ torch.Tensor(sqrtm(cov))
        return mu + ellipsis
    
def plot_mvn(levels, ax, label):
    ax.plot(levels[:, 0], levels[:, 1], label=label)

def plot_distrs_ideal(p_d, p_t_d, p_t_t):    
    fig, ax = plt.subplots()
    ax.set_xlim([-3, 10])
    ax.set_ylim([-3, 10])
    distrs = [
        (p_d.mu, p_d.cov, "$p_{d}}$"),
        (p_t_d.mu, p_t_d.cov(), "$q=p_d$"),
        (p_t_t.mu, p_t_t.cov(), "$q = p_{\\theta}$")
    ]
    for mu, cov, label in distrs:
        plot_mvn(mvn_curve(mu, cov), ax, label)
    ax.set_title("Comparison, optimal proposal distrs.")
    ax.legend()

def plot_distrs_adaptive(p_d, p_theta, q_phi):    
    fig, ax = plt.subplots()
    ax.set_xlim([-3, 10])
    ax.set_ylim([-3, 10])
    distrs = [
        (p_d.mu, p_d.cov, "$p_{d}}$"),
        (p_theta.mu, p_theta.cov(), "$p_{\\theta}$"),
        (q_phi.mu, q_phi.cov(), "$q_{\\varphi}$")
    ]
    for mu, cov, label in distrs:
        plot_mvn(mvn_curve(mu, cov), ax, label)
    ax.set_title("Adaptive proposal")
    ax.legend()

# Common setup

In [None]:
D, N, J = 5, 100, 10 # Dimension, Num. data samples, Num neg. samples
mu_star, cov_star = torch.ones(D,), torch.eye(D)

# Data distribution
p_d = MultivariateNormal(mu_star, cov_star)
# Model distribution
init_mu, init_cov =5.0*torch.ones(D,), 4.0*torch.eye(D)

# Optimisation
num_epochs = 100
batch_size = N
learn_rate = 0.05*batch_size**0.5
scheduler_opts = (30, 0.9)

# Metrics
kl_div = MvnKlDiv(p_d.mu, p_d.cov).metric
mse = Mse(p_d.mu).metric
metric = kl_div


# Idealistic case

Assuming that we can evaluate and sample from $p_d$ and $p_\theta$,
which is the better alternative as the proposal distribution $q$? 

In [None]:
# q = p_theta
p_t_model_noise = DiagGaussianModel(init_mu.clone(), init_cov.clone())
print("Training with q = p_theta")

training_data = MultivariateNormalData(mu_star, cov_star, N)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
p_t_losses, p_t_metrics = train_model_model_proposal(p_t_model_noise,
                                                     NceRankCrit,
                                                     metric,
                                                     train_loader,
                                                     None,
                                                     J,
                                                     num_epochs,
                                                     # stopping_condition=no_change_stopping_condition,
                                                     stopping_condition=no_stopping,
                                                     lr=learn_rate,
                                                     scheduler_opts=scheduler_opts
                                                    )

# q = p_d
p_t_data_noise = DiagGaussianModel(init_mu.clone(), init_cov.clone())
criterion = NceRankCrit(p_t_data_noise, p_d, J)
print("Training with q = p_d")

training_data = MultivariateNormalData(mu_star, cov_star, N)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
p_d_losses, p_d_metrics = train_model(criterion,
                                      metric,
                                      train_loader,
                                      None,
                                      neg_sample_size=J,
                                      num_epochs=num_epochs,
                                      # stopping_condition=no_change_stopping_condition,
                                      stopping_condition=no_stopping,
                                      lr=learn_rate,
                                      scheduler_opts=scheduler_opts
                                     )

# Adaptive proposal

Assume that we have a learnable proposal $q_\varphi$.
We jointly learn this proposal by minimising
$$
KL(p_\theta \| q_\varphi) \propto - \mathbb{E}_{x \sim p_\theta} \log q_\varphi(x) = \mathcal{L}_\varphi
$$
with,
$$
\nabla_\varphi \mathcal{L}_\varphi \approx - \sum_{j=0}^J w(x_j) \nabla \log q_\varphi(x_j)
$$

In [None]:
from src.nce.adaptive_rank import AdaptiveRankKernel
from src.noise_distr.adaptive import AdaptiveDiagGaussianModel


def train_model_adaptive_proposal(
    p_theta,
    q_phi,
    p_criterion,
    q_criterion,
    evaluation_metric,
    train_loader,
    save_dir,
    neg_sample_size,
    num_epochs,
    stopping_condition=no_stopping,
    lr: float = 0.1,
):
    """Training loop for adaptive proposal q_phi

    Training loop for jointly learning p_tilde_theta and q_phi.
    Where we assume that we can sample and evaluate q_phi.
    """
    p_optimizer = torch.optim.SGD(p_theta.parameters(), lr=lr)
    q_optimizer = torch.optim.SGD(q_phi.parameters(), lr=lr)
    batch_metrics = []
    batch_metrics.append(evaluation_metric(p_theta))
    batch_losses = []
    for epoch in range(1, num_epochs + 1):
        # print(f"Epoch {epoch}")
        old_params = torch.nn.utils.parameters_to_vector(q_phi.parameters())
        for _, (y, idx) in enumerate(train_loader, 0):

            #with torch.no_grad():
            #    p_loss = p_criterion.crit(y, None)
            #    batch_losses.append(p_loss.item())
            # Calculate and assign gradients
            p_optimizer.zero_grad()
            p_criterion.calculate_crit_grad(y, idx)
            p_optimizer.step()

            q_optimizer.zero_grad()
            q_criterion.calculate_crit_grad(y, idx)
            q_optimizer.step()
            with torch.no_grad():
                batch_metrics.append(evaluation_metric(p_theta))
            
        if stopping_condition(
            torch.nn.utils.parameters_to_vector(q_phi.parameters()), old_params
        ):
            print("Training converged")
            break
    return torch.tensor(batch_losses), torch.tensor(batch_metrics)

p_theta = DiagGaussianModel(init_mu.clone(), init_cov.clone())
q_phi = AdaptiveDiagGaussianModel(mu_star.clone(), cov_star.clone())
p_crit = NceRankCrit(p_theta, q_phi, J)
q_crit = AdaptiveRankKernel(p_theta, q_phi, J)

training_data = MultivariateNormalData(mu_star, cov_star, N)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)


adaptive_losses, adaptive_metrics = train_model_adaptive_proposal(p_theta,
                                                        q_phi,
                                                        p_crit,
                                                        q_crit,
                                                        metric,
                                                        train_loader,
                                                        None,
                                                        neg_sample_size=J,
                                                        num_epochs=num_epochs,
                                                        stopping_condition=no_stopping,
                                                        lr=learn_rate)
# plot_distrs_adaptive(p_d, p_theta, q_phi)

In [None]:
from torch import Tensor
def nan_pad(data: Tensor, length: int):
    """Pad tensor with NaN
    
    Args:
        data: shape (N, )
        length: length >= N
    """
    assert data.dim() == 1, "Expects 1D array, with shape (N, )"
    N = data.size(0)
    print(N)
    assert length >= N, "Padding length must be larger than data length"
    padded = torch.empty((length, ))
    padded[:N] = data
    padded[N:] = torch.nan
    return padded
padded = nan_pad(p_d_metrics, num_epochs+1)
N = p_d_metrics.size(0)


In [None]:
fig, ax = plt.subplots()
iters = torch.arange(start=0, end=p_d_metrics.size(0), step=10)
iters = torch.arange(start=0, end=2050, step=10)
ax.plot(iters, nan_pad(p_d_metrics, num_epochs+1)[iters], label="$q=p_d$")
ax.loglog(iters, nan_pad(p_t_metrics, num_epochs+1)[iters], label="$q=p_{\\theta}$")
ax.loglog(iters, adaptive_metrics[iters], label="$q=q_{\\varphi}$")
ax.legend();
# ax.set_xlim([0, 250])
ax.set_title("Choice of proposal distribution")
ax.set_xlabel("Iter. step $t$")
ax.set_ylabel("KL$(p_d || p_{\\theta})$");

In [None]:
from src.experiments.utils import process_plot_data
from src.experiments.adaptive_proposal import plot_kl_div
data = process_plot_data(torch.column_stack((p_d_metrics, p_t_metrics, adaptive_metrics)), num_epochs+1, res=1)

plot_kl_div(data[:, 0], data[:, 1], data[:, 2], data[:, 3])


In [None]:
from pathlib import Path
import numpy as np
from src.experiments.utils import generate_bounds, table_data, format_table, skip_list_item

end_iter = 6500
p_d = torch.load(Path.cwd().parent / "fig/20_runs_kl_raw_p_d.pth").numpy()[:, :end_iter]
p_t = torch.load(Path.cwd().parent / "fig/20_runs_kl_raw_p_t.pth").numpy()[:, :end_iter]
q_f = torch.load(Path.cwd().parent / "fig/20_runs_kl_raw_q_f.pth").numpy()[:, :end_iter]
num_runs = p_d.shape[0]

p_d_md, p_d_lower, p_d_upper = generate_bounds(p_d)
p_t_md, p_t_lower, p_t_upper = generate_bounds(p_t)
q_f_md, q_f_lower, q_f_upper = generate_bounds(q_f)




In [None]:
_, ax = plt.subplots()
for data, name in [(p_d, "p_d"), (p_t, "p_t"), (q_f, "q_f")]:
    tmp = table_data(*generate_bounds(data))
    it, med, low, upp = tuple(map(lambda x: skip_list_item(x, nth=30), tmp))

    ax.errorbar(it, med, [low, upp], label=f"$q={name}$")
ax.legend()
# ax.set_xlim([1, end_iter + 1])
ax.set_title("Choice of proposal distribution")
ax.set_xlabel("Iter. step $t$")
ax.set_ylabel("KL$(p_d || p_{\\theta})$")
ax.set_xscale("log")
ax.set_yscale("log")
plt.show()


In [None]:
for data, name in [(p_d, "p_d"), (p_t, "p_t"), (q_f, "q_f")]:
    with open(f"{num_runs}_runs_{name}.txt", "w") as f:
        tmp = table_data(*generate_bounds(data))
        tmp = tuple(map(lambda x: skip_list_item(x, nth=30), tmp))
        tbl = format_table(*tmp, ["t","kl","low","upp"])
        f.writelines(tbl)

In [None]:
tmp = table_data(*generate_bounds(data))
it, med, low, upp = tuple(map(lambda x: skip_list_item(x, nth=20), tmp))

In [None]:
cis_kl = np.load("../../ebms_proposals/cis_kl_store.npy")
is_kl = np.load("../../ebms_proposals/is_kl_store.npy")
is_kl_nan_filter = np.delete(is_kl, (10), axis=0)
_, ax = plt.subplots()
for data, name in [(is_kl_nan_filter, "IS"), (cis_kl, "CIS")]:
    it, med, low, upp = table_data(*generate_bounds(data))
    ax.errorbar(it, med, [low, upp], label=f"${name}$")
ax.legend()
# ax.set_xlim([1, end_iter + 1])
ax.set_title("Approximate KL div.")
ax.set_xlabel("Epoch")
ax.set_ylabel("KL$(p_d || p_{\\theta})$")
plt.show()


In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from src.experiments.utils import generate_bounds, table_data, format_table, skip_list_item

runs, epoch = 5, 20
res_dir = Path(f"../../ebms_proposals/1dregression_1/results/{runs}_runs_{epoch}_epochs")
cis_kl = np.load(res_dir / "cis_kl_eval.npy")
is_kl = np.load(res_dir / "is_kl_eval.npy")
pcis_kl = np.load(res_dir.parent / "5_runs_20_epochs" / "pcis_kl_eval.npy")


cis_q_nll = np.load(res_dir / "cis_q_nll.npy")
is_q_nll = np.load(res_dir / "is_q_nll.npy")
pcis_q_nll = np.load(res_dir.parent / "5_runs_20_epochs" / "pcis_q_nll.npy")

# cis_kl = np.delete(cis_kl, (15), axis=0)
_, (ax_kl, ax_nll) = plt.subplots(1, 2)
vals = [
    # (pcis_kl, "P-CIS"),
    (cis_kl, "CIS"),
    (is_kl, "IS")
]
for data, name in vals:
    it, med, low, upp = table_data(*generate_bounds(data))
    ax_kl.errorbar(it, med, [low, upp], label=f"${name}$")
    # mean, std = data.mean(axis=0), data.std(axis=0)
    # it = np.arange(1, len(mean)+1)
    # ax_kl.errorbar(it, mean, std, label=f"${name}$")
ax_kl.legend()
ax_kl.set_title("Approximate KL div.")
ax_kl.set_xlabel("Epoch")
ax_kl.set_ylabel("KL$(p_d || p_{\\theta})$")

vals = [
    # (pcis_q_nll, "P-CIS"),
    (cis_q_nll, "CIS"),
    (is_q_nll, "IS")
] 
for data, name in vals:
    it, med, low, upp = table_data(*generate_bounds(data))
    ax_nll.errorbar(it, med, [low, upp], label=f"${name}$")
    # mean, std = data.mean(axis=0), data.std(axis=0)
    # it = np.arange(1, len(mean)+1)
    # ax_nll.errorbar(it, mean, std, label=f"${name}$")
ax_nll.legend()
ax_nll.set_title("NLL: q")
ax_nll.set_xlabel("Epoch")
ax_nll.set_ylabel("NLL q")

plt.show()


In [None]:
np.argwhere(np.isnan(cis_kl))

In [None]:
it, med, low, upp = table_data(*generate_bounds(data))
upp.shape
name

In [None]:
import torch
def get_y_p(ys, xs, y_pers_dict):
    y_ps = torch.empty(ys.size())
    for ind, x in enumerate(xs):
        x = x.item()
        tmp = y_pers_dict.get(x)
        if tmp is None:
            y_pers_dict[x] = ys[ind].item()
            y_ps[ind] = ys[ind]
        else:
            y_ps[ind] = tmp
    return y_ps

def update_y_pers(xs, y_p_J, w_tilde, y_pers_dict):
    distr_ = torch.distributions.categorical.Categorical(logits=w_tilde)
    smpl = distr_.sample()
    for ind, x in enumerate(xs):
        p_ind = smpl[ind]
        y_pers_dict[xs[ind].item()] = y_p_J[ind, p_ind].item()



B, J = 3, 4
xs, ys = torch.arange(0, B).reshape((B, 1)), torch.ones(B, 1)

y_pers_dict = dict()
y_p = get_y_p(ys, xs, y_pers_dict)
print(y_pers_dict)

y_p_J = torch.ones(B, J) * torch.arange(0, J)
w_tilde = torch.ones(B, J)
update_y_pers(xs, y_p_J, w_tilde, y_pers_dict)
y_pers_dict

In [None]:
smpl