# Choice of proposal distribution

We investigate the effect of the proposal distribution when learning an unnormalised model

In [None]:
import random
from pathlib import Path
from datasets import (
    ToyDataset,
)  # (this needs to be imported before torch, because cv2 needs to be imported before torch for some reason)
from ebmdn_model_K4 import ToyNet

import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch.distributions

import math
import numpy as np
import pickle

import matplotlib

# matplotlib.use("TkAgg")
# matplotlib.use("Agg")
import matplotlib.pyplot as plt

# NOTE! change this to not overwrite all log data when you train the model:
RESULTS_DIR = Path("1dregression_1/results/mdn_k4/")
EXPERIMENT_NAME = "mdn_k4"
CD = "ebm_CD"
MDN = "ebm_MDN"

NUM_EPOCHS = 25
BATCH_SIZE = 32
LR = 0.001

NUM_SAMPLES = 32
NUM_MODELS = 15


In [None]:
import sys
from pathlib import Path
from functools import partial
import math
import pickle
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch.distributions

from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

sys.path.append("..")
sys.path.append("../../ebms_proposals/1dregression_1")
from datasets import (
    ToyDataset,
)  # (this needs to be imported before torch, because cv2 needs to be imported before torch for some reason)
from ebmdn_model_K4 import ToyNet
from src.nce.cnce import CondNceCrit
from src.nce.rank import NceRankCrit

from src.noise_distr.normal import MultivariateNormal
from src.models.gaussian_model import DiagGaussianModel

from src.training.model_training import train_model, train_model_model_proposal
from src.data.normal import MultivariateNormalData
from src.training.training_utils import Mse, MvnKlDiv, no_change_stopping_condition, no_stopping
%load_ext autoreload
%autoreload 2

# NOTE! change this to not overwrite all log data when you train the model:
RESULTS_DIR = Path("./results/")
EXPERIMENT_NAME = "mdn_k4"
CD = "ebm_CD"
MDN = "ebm_MDN"

NUM_EPOCHS = 12
BATCH_SIZE = 32
LR = 0.001

NUM_SAMPLES = 32
NUM_MODELS = 15

def mvn_curve(mu, cov, std=1, res=100):
    with torch.no_grad():
        angles = torch.linspace(0, 2*torch.pi, res)
        curve_param = torch.column_stack((torch.cos(angles), torch.sin(angles)))
        ellipsis = std * curve_param @ torch.Tensor(sqrtm(cov))
        return mu + ellipsis
    
def plot_mvn(levels, ax, label):
    ax.plot(levels[:, 0], levels[:, 1], label=label)

def plot_distrs_ideal(p_d, p_t_d, p_t_t):    
    fig, ax = plt.subplots()
    ax.set_xlim([-3, 10])
    ax.set_ylim([-3, 10])
    distrs = [
        (p_d.mu, p_d.cov, "$p_{d}}$"),
        (p_t_d.mu, p_t_d.cov(), "$q=p_d$"),
        (p_t_t.mu, p_t_t.cov(), "$q = p_{\\theta}$")
    ]
    for mu, cov, label in distrs:
        plot_mvn(mvn_curve(mu, cov), ax, label)
    ax.set_title("Comparison, optimal proposal distrs.")
    ax.legend()

def plot_distrs_adaptive(p_d, p_theta, q_phi):    
    fig, ax = plt.subplots()
    ax.set_xlim([-3, 10])
    ax.set_ylim([-3, 10])
    distrs = [
        (p_d.mu, p_d.cov, "$p_{d}}$"),
        (p_theta.mu, p_theta.cov(), "$p_{\\theta}$"),
        (q_phi.mu, q_phi.cov(), "$q_{\\varphi}$")
    ]
    for mu, cov, label in distrs:
        plot_mvn(mvn_curve(mu, cov), ax, label)
    ax.set_title("Adaptive proposal")
    ax.legend()
    
repo_dir = Path.cwd().parent
exp_dir = repo_dir / "1d_toy"
exp_dir.exists()

In [None]:
from typing import Optional

import scipy.stats

# import pickle

################################################################################
# run this once to generate the training data:
################################################################################

def generate_data(data_dir: Optional[Path]):
    x = np.random.uniform(low=-3.0, high=3.0, size=(2000,))
    x = x.astype(np.float32)

    y = []
    for x_value in x:
        if x_value < 0:
            component = np.random.randint(low=1, high=6)  # (1, 2, 3, 4, 5 with 0.5 prob)

            if component in [1, 2, 3, 4]:
                mu_value = np.sin(x_value)
                sigma_value = 0.15 * (1.0 / (1 + 1))
            elif component == 5:
                mu_value = -np.sin(x_value)
                sigma_value = 0.15 * (1.0 / (1 + 1))

            y_value = np.random.normal(mu_value, sigma_value)
        else:
            y_value = np.random.lognormal(0.0, 0.25) - 1.0

        y.append(y_value)
    y = np.array(y, dtype=np.float32)
    if data_dir is not None:
        with open(data_dir / "x.pkl", "wb") as file:
            pickle.dump(x, file)
        with open(data_dir / "y.pkl", "wb") as file:
            pickle.dump(y, file)
    return x, y
        
def generate_scores(data_dir: Optional[Path]):
    num_samples = 2048
    x = np.linspace(-3.0, 3.0, num_samples, dtype=np.float32)
    y_samples = np.linspace(-3.0, 3.0, num_samples)  # (shape: (num_samples, ))
    x_values_2_scores = {}
    for x_value in x:
        if x_value < 0:
            scores = 0.8 * scipy.stats.norm.pdf(
                y_samples, np.sin(x_value), 0.15 * (1.0 / (1 + 1))
            ) + 0.2 * scipy.stats.norm.pdf(
                y_samples, -np.sin(x_value), 0.15 * (1.0 / (1 + 1))
            )
        else:
            scores = scipy.stats.lognorm.pdf(y_samples + 1.0, 0.25)

        x_values_2_scores[x_value] = scores
    if data_dir is not None:
        with open(data_dir / "gt_x_values_2_scores.pkl", "wb") as file:
            pickle.dump(x_values_2_scores, file)
    return x, x_values_2_scores


def save_checkpoint(network, model_id, model_idx, epoch):
    # save the model weights to disk:
    checkpoint_path = (
        network.checkpoints_dir / f"model_{model_id}_{model_idx}_epoch_{epoch+1}.pth"
    )
    torch.save(network.state_dict(), checkpoint_path)


In [None]:
from src.data.generic import Generic
x, y = generate_data(None)
x, y = torch.tensor(x), torch.tensor(y)
x_range, scores = generate_scores(None)
plt.plot(x, y, 'k.')
# plt.plot(x_range, scores, 'k.')

training_data = Generic(torch.column_stack((x,y)))

# Common setup

In [None]:
from torch.optim.lr_scheduler import LinearLR
NUM_EPOCHS = 50

def get_train_loader(batch_size):
    train_dataset = ToyDataset()

    num_train_batches = int(len(train_dataset) / batch_size)
    print("num_train_batches:", num_train_batches)

    return torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=batch_size, shuffle=True
    )

def save_checkpoint(network, model_id, model_idx, epoch):
    # save the model weights to disk:
    checkpoint_path = (
        network.checkpoints_dir / f"model_{model_id}_{model_idx}_epoch_{epoch+1}.pth"
    )
    torch.save(network.state_dict(), checkpoint_path)

def train_model_cd_obj(train_loader, model_id, model_idx):
    network = ToyNet(f"{model_id}_{model_idx}", project_dir=RESULTS_DIR).cuda()
    p_optimizer = torch.optim.SGD(network.parameters(), lr=0.1)
    q_optimizer = torch.optim.SGD(network.parameters(), lr=0.1)

    epoch_losses_train = torch.empty((NUM_EPOCHS,))
    for epoch in range(NUM_EPOCHS):

        network.train()  # (set in training mode, this affects BatchNorm and dropout)
        batch_losses = []
        for _, (xs, ys) in enumerate(train_loader):
            xs = xs.cuda().unsqueeze(1)  # (shape: (batch_size, 1))
            ys = ys.cuda().unsqueeze(1)  # (shape: (batch_size, 1))

            x_features = network.feature_net(xs)  # (shape: (batch_size, hidden_dim))
            means, log_sigma2s, weights = network.noise_net(
                x_features.detach()
            )  # (all have shape: (batch_size, K))
            sigmas = torch.exp(log_sigma2s / 2.0)  # (shape: (batch_size, K))
            # print("Sigmas", sigmas)
            q_distr = torch.distributions.normal.Normal(loc=means, scale=sigmas)
            q_ys_K = torch.exp(
                q_distr.log_prob(torch.transpose(ys, 1, 0).unsqueeze(2))
            )  # (shape: (1, batch_size, K))
            q_ys = torch.sum(
                weights.unsqueeze(0) * q_ys_K, dim=2
            )  # (shape: (1, batch_size))
            q_ys = q_ys.squeeze(0)  # (shape: (batch_size))

            y_samples_K = q_distr.sample(
                sample_shape=torch.Size([NUM_SAMPLES])
            )  # (shape: (num_samples, batch_size, K))
            inds = torch.multinomial(
                weights, num_samples=NUM_SAMPLES, replacement=True
            ).unsqueeze(
                2
            )  # (shape: (batch_size, num_samples, 1))
            inds = torch.transpose(inds, 1, 0)  # (shape: (num_samples, batch_size, 1))
            y_samples = y_samples_K.gather(2, inds).squeeze(
                2
            )  # (shape: (num_samples, batch_size))
            y_samples = y_samples.detach()
            q_y_samples_K = torch.exp(
                q_distr.log_prob(y_samples.unsqueeze(2))
            )  # (shape: (num_samples, batch_size, K))
            q_y_samples = torch.sum(
                weights.unsqueeze(0) * q_y_samples_K, dim=2
            )  # (shape: (num_samples, batch_size))
            y_samples = torch.transpose(
                y_samples, 1, 0
            )  # (shape: (batch_size, num_samples))
            q_y_samples = torch.transpose(
                q_y_samples, 1, 0
            )  # (shape: (batch_size, num_samples))

            scores_gt = network.predictor_net(
                x_features, ys
            )  # (shape: (batch_size, 1))
            scores_gt = scores_gt.squeeze(1)  # (shape: (batch_size))

            scores_samples = network.predictor_net(
                x_features, y_samples
            )  # (shape: (batch_size, num_samples))

            # EBM loss
            f_samples = scores_samples
            p_N_samples = q_y_samples.detach()
            f_0 = scores_gt
            p_N_0 = q_ys.detach()
            exp_vals_0 = f_0 - torch.log(p_N_0 + 0.0)
            exp_vals_samples = f_samples - torch.log(p_N_samples + 0.0)
            exp_vals = torch.cat([exp_vals_0.unsqueeze(1), exp_vals_samples], dim=1)
            loss_ebm_nce = -torch.mean(exp_vals_0 - torch.logsumexp(exp_vals, dim=1))

            # Prop loss
            # Compute weights with detached tensors
            p_tilde_0 = torch.exp(scores_gt.detach())
            p_tilde_1_J = torch.exp(scores_samples.detach())
            ps = torch.column_stack((p_tilde_0, p_tilde_1_J))
            qs_no_grad = torch.column_stack((p_N_0, p_N_samples))
            w_tilde = ps / qs_no_grad
            w_norm = w_tilde / w_tilde.sum(axis=1).unsqueeze(-1)
            # Assemble log q(x_0:J)
            log_qs = torch.log(torch.column_stack((q_ys, q_y_samples)))

            # print(f"p_0: {p_tilde_0.shape}, p_1:J: {p_tilde_1_J.shape}")
            # print(f"q_0: {p_N_0.shape}, q_1:J: {p_N_samples.shape}")
            # print(f"q_ys: {q_ys.shape}, q_y_samples: {q_y_samples.shape}")
            # print(f"Z_cis: {Z_cis.shape}")
            loss_mdn_kl = -(w_norm * log_qs).sum(axis=1).mean()

            # log_Z = torch.logsumexp(
            #     scores_samples.detach() - torch.log(q_y_samples), dim=1
            # ) - math.log(
            #     NUM_SAMPLES
            # )  # (shape: (batch_size))
            # loss_mdn_kl = torch.mean(log_Z)

            loss_mdn_nll = torch.mean(-torch.log(q_ys))

            loss = loss_ebm_nce + loss_mdn_kl

            # loss_value = loss.data.cpu().numpy()
            batch_losses.append(loss.data.cpu().item())

            ########################################################################
            # optimization step:
            ########################################################################
            optimizer.zero_grad()  # (reset gradients)
            loss.backward()  # (compute gradients)
            optimizer.step()  # (perform optimization step)
        epoch_loss = torch.tensor(batch_losses).mean().item()
        epoch_losses_train[epoch] = epoch_loss
        print(f"Train loss: {epoch_loss}")
        save_checkpoint(network, model_id, model_idx, epoch)

    # save_loss("cd", network.model_dir, epoch_losses_train)
    return epoch_losses_train


    # save_loss("mdn", network.model_dir, epoch_losses_train)
    return epoch_losses_train
idx = 0
epoch_losses_mdn = train_model_cd_obj(get_train_loader(BATCH_SIZE), CD, idx)

In [None]:
from shutil import copyfile
src_dir = Path.cwd().parent.parent / "ebms_proposals/1dregression_1/results/mdn_k4/training_logs/"
dst_dir = Path.cwd() / "losses"

for i in range(15):
    src = src_dir / f"model_ebm_CD_{i}/epoch_losses_train.pkl"
    copyfile(src, dst_dir / f"cd_{i}.pkl")
    src = src_dir / f"model_ebm_MDN_{i}/epoch_losses_train.pkl"
    copyfile(src, dst_dir / f"mdn_{i}.pkl")


In [None]:
import torch
from pathlib import Path
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
matplotlib.use("TkAgg")

dir_ = Path.cwd().parent.parent / "ebms_proposals/1dregression_1/results"

runs, num_epochs = 5, 30
cis_p_loss = np.load(dir_ / f"{runs}_runs_{num_epochs}_epochs/cis_p_loss.npy")
cis_q_loss = np.load(dir_ / f"{runs}_runs_{num_epochs}_epochs/cis_q_loss.npy")
cis_q_nll = np.load(dir_ / f"{runs}_runs_{num_epochs}_epochs/cis_q_nll.npy")
cis_p_nll = np.load(dir_ / f"{runs}_runs_{num_epochs}_epochs/cis_p_nll.npy")

is_p_loss = np.load(dir_ / f"{runs}_runs_{num_epochs}_epochs/is_p_loss.npy")
is_q_loss = np.load(dir_ / f"{runs}_runs_{num_epochs}_epochs/is_q_loss.npy")
is_q_nll = np.load(dir_ / f"{runs}_runs_{num_epochs}_epochs/is_q_nll.npy")
is_p_nll = np.load(dir_ / f"{runs}_runs_{num_epochs}_epochs/is_p_nll.npy")

#is_ = np.load(dir_ / f"{num_epochs}_epochs/is_losses.npy")

epochs = np.arange(1, num_epochs+1)
is_p_loss_mean, is_p_loss_std = np.mean(is_p_loss, axis=0), np.std(is_p_loss, axis=0)
is_q_loss_mean, is_q_loss_std = np.mean(is_q_loss, axis=0), np.std(is_q_loss, axis=0)
is_q_nll_mean, is_q_nll_std = np.mean(is_q_nll, axis=0), np.std(is_q_nll, axis=0)
is_p_nll_mean, is_p_nll_std = np.mean(is_p_nll, axis=0), np.std(is_p_nll, axis=0)


cis_p_loss_mean, cis_p_loss_std = np.mean(cis_p_loss, axis=0), np.std(cis_p_loss, axis=0)
cis_q_loss_mean, cis_q_loss_std = np.mean(cis_q_loss, axis=0), np.std(cis_q_loss, axis=0)
cis_q_nll_mean, cis_q_nll_std = np.mean(cis_q_nll, axis=0), np.std(cis_q_nll, axis=0)
cis_p_nll_mean, cis_p_nll_std = np.mean(cis_p_nll, axis=0), np.std(cis_p_nll, axis=0)

_, (ax_loss, ax_nll) = plt.subplots(1, 2)
ax_loss.errorbar(epochs, is_p_loss_mean, is_p_loss_std / np.sqrt(num_epochs), label="Loss $p_{IS}$")
ax_loss.errorbar(epochs, is_q_loss_mean, is_q_loss_std / np.sqrt(num_epochs), label="Loss $q_{IS}$")
ax_loss.errorbar(epochs, cis_p_loss_mean, cis_p_loss_std / np.sqrt(num_epochs), label="Loss $p_{CIS}$")
ax_loss.errorbar(epochs, cis_q_loss_mean, cis_q_loss_std / np.sqrt(num_epochs), label="Loss $q_{CIS}$")
ax_loss.set_ylabel("Loss")
ax_loss.set_xlabel("Epoch")
ax_loss.legend()

ax_nll.errorbar(epochs, is_q_nll_mean, is_q_nll_std / np.sqrt(num_epochs), label="NLL $q_{IS}$")
ax_nll.errorbar(epochs, cis_q_nll_mean, cis_q_nll_std / np.sqrt(num_epochs), label="NLL $q_{CIS}$")
#ax_nll.errorbar(epochs, is_p_nll_mean, is_p_nll_std / np.sqrt(num_epochs), label="NLL $p_{IS}$")
#ax_nll.errorbar(epochs, cis_p_nll_mean, cis_p_nll_std / np.sqrt(num_epochs), label="NLL $p_{CIS}$")
ax_nll.set_ylabel("NLL")
ax_nll.set_xlabel("Epoch")
ax_nll.legend()
plt.show()

In [None]:
is_q_nll