In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

sys.path.append("..")
from src.nce.cd_cnce import CdCnceCrit
from src.nce.cd_mh import CdMHCrit

from src.noise_distr.conditional_normal import ConditionalMultivariateNormal

from src.models.ring_model.ring_model import RingModel, RingModelNCE, unnorm_ring_model_log_pdf
from src.data.ring_model_dataset import RingModelDataset

from src.training.model_training import train_model
from src.training.training_utils import PrecisionErrorMetric, no_stopping, remove_file

from src.experiments.ring_model_exp_utils import generate_true_params, initialise_params
from src.experiments.noise_distr_utils import get_nce_noise_distr_par, get_cnce_noise_distr_par

%load_ext autoreload
%autoreload 2

## EXPERIMENTS

In [None]:
# Data specs
num_dims = 2

# Experiments specs
num_samples = 200
num_neg_samples = [5, 10]
reps = 100

# Training specs
batch_size = 20
num_epochs = 20

In [None]:
# Configurations to consider in experiments 

config_conditional_multi = {
    "criterion": CdCnceCrit,
    "label": "cd_cnce",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": 1,
}

config_cd_mh = {
    "criterion": CdMHCrit,
    "label": "cd_mh",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": 1,
}

configs = [config_conditional_multi, 
           config_cd_mh]

In [None]:
# Run experiments

# Data saved over reps
error_res = np.zeros((len(num_neg_samples), len(configs), int(np.ceil(num_samples / batch_size) * num_epochs), reps))
acc_prob_res = np.zeros((len(num_neg_samples), len(configs), len(configs), int(np.ceil(num_samples / batch_size) * num_epochs), reps))

for i, J in enumerate(num_neg_samples):
    for rep in range(reps):
        
        # Get data 
        mu, precision, _ = generate_true_params()
        error_metric = PrecisionErrorMetric(true_precision=precision).metric            

        training_data = RingModelDataset(sample_size=num_samples, num_dims=num_dims, mu=mu.numpy(), precision=precision.numpy(), 
                                         root_dir="res/datasets/ring_data_size_" + str(num_samples) + "_nn_" + str(J) + "_rep_" + str(rep))
        train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)

        # Initialise           
        _, log_precision_init, log_z_init = initialise_params()

        # Get noise distr. params
        p_m = RingModel(mu=mu, log_precision=log_precision_init.clone())
        cov_noise_cnce = get_cnce_noise_distr_par(training_data.get_full_data(), J, p_m)

        for j, config in enumerate(configs):

            # Remove old acc. prob.
            remove_file("res/" + config["label"] + "_num_neg_" + str(J) + "_cd_cnce_acc_prob.npy")
            remove_file("res/" + config["label"] + "_num_neg_" + str(J) + "_cd_mh_acc_prob.npy")

            # Make sure that these are "reinitialised"
            p_m, p_n, criterion = None, None, None

            if config["estimate_part_fn"]:
                p_m = RingModelNCE(mu=mu, log_precision=log_precision_init.clone(), log_part_fn=log_z_init.clone())
            else:
                p_m = RingModel(mu=mu, log_precision=log_precision_init.clone())

            p_n = ConditionalMultivariateNormal(cov=cov_noise_cnce)

            criterion = config["criterion"](p_m, p_n, J, config["mcmc_steps"], save_acc_prob=True)
            
            save_dir_pre = "res/" + config["label"] + "_num_neg_" + str(num_neg_samples[i])
            _ = train_model(criterion, error_metric, train_loader, save_dir_pre + "_error", num_epochs=num_epochs,
                            stopping_condition=no_stopping)
            
            # Fetch data that has been saved
            error_res[i, j, :, rep] = np.load(save_dir_pre + "_error.npy")
            acc_prob_res[i, j, 0, :, rep] = np.load(save_dir_pre + "_" + configs[0]["label"] + "_acc_prob.npy").mean(axis=0)
            acc_prob_res[i, j, 1, :, rep] = np.load(save_dir_pre + "_" + configs[1]["label"] + "_acc_prob.npy").mean(axis=0)

# Save res
np.save("res/final_param_error_cnce_acceptance_prob_exp", error_res)
np.save("res/final_acc_prob_cnce_acceptance_prob_exp", acc_prob_res)

In [None]:
# Visualise results
error_res = np.load("res/final_param_error_cnce_acceptance_prob_exp.npy")
acc_prob_res = np.load("res/final_acc_prob_cnce_acceptance_prob_exp.npy")

fig, ax = plt.subplots(len(num_neg_samples), 3, figsize=(16, 15))
colors = ['C0', 'C1']
assert len(colors) == len(configs), "Need one colour for each method"

log_num_samples = np.log(np.array(num_samples))

for i, axis in enumerate(num_neg_samples):

    for j, config in enumerate(configs):
        ax[i, 0].plot(np.log(error_res[i, j, :, :]).mean(axis=-1), color=colors[j], label=config["label"])
        ax[i, 0].set_xlabel("Iter.")
        ax[i, 0].set_ylabel("Log(SE)")

        for k, config_2 in enumerate(configs):
            ax[i, j + 1].plot(acc_prob_res[i, j, k, :, :].mean(axis=-1), color=colors[k], label=config_2["label"])
            ax[i, j + 1].set_xlabel("Iter.")
            ax[i, j + 1].set_ylabel("Acc. prob., " + config["label"])

    ax[i, 0].legend()            
    ax[i, 1].set_title("Num neg. samples: {}".format(num_neg_samples[i]))    

plt.show()