In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import tikzplotlib

sys.path.append("..")
from src.nce.cd_cnce import CdCnceCrit
from src.nce.cd_mh_cnce import CdMHCnceCrit
from src.nce.per_cnce import PersistentCondNceCrit
from src.nce.per_cnce_batch import PersistentCondNceCritBatch
from src.nce.per_cnce_batch import PersistentCondNceCritBatch
from src.nce.per_mh_cnce_batch import PersistentMHCnceCritBatch

from src.noise_distr.conditional_normal import ConditionalMultivariateNormal

from src.models.ring_model.ring_model import RingModel, RingModelNCE, unnorm_ring_model_log_pdf
from src.data.ring_model_dataset import RingModelDataset

from src.training.model_training import train_model
from src.training.training_utils import PrecisionErrorMetric, no_stopping, remove_file

from src.experiments.ring_model_exp_utils import generate_true_params, initialise_params
from src.experiments.noise_distr_utils import get_nce_noise_distr_par, get_cnce_noise_distr_par

%load_ext autoreload
%autoreload 2

## EXPERIMENTS

In [None]:
# Data specs
num_dims = 5

# Experiments specs
num_samples = 1000
num_neg_samples = [5, 10]
reps = 100

# Training specs
batch_size = 20
num_epochs = 50
base_lr = 0.01
lr = base_lr * batch_size ** 0.5
lr_factor = 0.1

In [None]:
# Configurations to consider in experiments 

config_cnce = {
    "criterion": CdCnceCrit,
    "label": "cd_cnce",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": 1,
    "calc_acc_prob": True
}

config_cd_mh = {
    "criterion": CdMHCnceCrit,
    "label": "cd_mh",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": 1,
    "calc_acc_prob": True
}

config_pers_cnce = {
    "criterion": PersistentCondNceCrit,
    "label": "pers_cnce",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": None,
}

config_pers_cnce_batch = {
    "criterion": PersistentCondNceCritBatch,
    "label": "pers_cnce",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": None,
    "calc_acc_prob": True
}

config_pers_mh_cnce_batch = {
    "criterion": PersistentMHCnceCritBatch,
    "label": "pers_mh_cnce",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": None,
    "calc_acc_prob": True
}

configs = [config_cnce, config_pers_cnce_batch, config_cd_mh, config_pers_mh_cnce_batch]
labels = ['CNCE', 'P-CNCE', 'MH-CNCE', 'P-MH-CNCE']

In [None]:
def get_simple_cnce_noise_distr_par(y):
    epsilon = torch.std(y, dim=-1).mean()

    return torch.eye(y.shape[-1]) * epsilon ** 2

In [None]:
# Run experiments

# Data saved over reps
error_res = np.zeros((len(num_neg_samples), len(configs), int(np.ceil(num_samples / batch_size) * num_epochs), reps))
acc_prob_res = np.zeros((len(num_neg_samples), len(configs), 2, int(np.ceil(num_samples / batch_size) * num_epochs), reps))

for i, J in enumerate(num_neg_samples):
    for rep in range(reps):

        # Get data 
        mu, precision, _ = generate_true_params()
        error_metric = PrecisionErrorMetric(true_precision=precision).metric            

        training_data = RingModelDataset(sample_size=num_samples, num_dims=num_dims, mu=mu, precision=precision, 
                                         root_dir="res/datasets/ring_data_size_" + str(num_samples) + "_nn_" + str(J) + "_rep_" + str(rep))
        train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)

        # Initialise           
        _, log_precision_init, log_z_init = initialise_params()

        # Get noise distr. params
        p_m = RingModel(mu=mu, log_precision=log_precision_init.clone())
        cov_noise_cnce = get_simple_cnce_noise_distr_par(training_data.get_full_data()) #get_cnce_noise_distr_par(training_data.get_full_data(), J, p_m)

        for j, config in enumerate(configs):

            # Remove old acc. prob.
            remove_file("res/" + config["label"] + "_num_neg_" + str(J) + "_cd_cnce_acc_prob.npy")
            remove_file("res/" + config["label"] + "_num_neg_" + str(J) + "_cd_mh_acc_prob.npy")

            # Make sure that these are "reinitialised"
            p_m, p_n, criterion = None, None, None

            if config["estimate_part_fn"]:
                p_m = RingModelNCE(mu=mu, log_precision=log_precision_init.clone(), log_part_fn=log_z_init.clone())
            else:
                p_m = RingModel(mu=mu, log_precision=log_precision_init.clone())

            p_n = ConditionalMultivariateNormal(cov=cov_noise_cnce)
            
            if config["mcmc_steps"] is not None: 
                criterion = config["criterion"](p_m, p_n, J, config["mcmc_steps"], save_acc_prob=config["calc_acc_prob"])
            else:
                criterion = config["criterion"](p_m, p_n, J, save_acc_prob=config["calc_acc_prob"])

            save_dir_pre = "res/" + config["label"] + "_num_neg_" + str(num_neg_samples[i])
            _ = train_model(criterion, error_metric, train_loader, save_dir_pre + "_error", num_epochs=num_epochs,
                            stopping_condition=no_stopping, lr=lr, decaying_lr=True, lr_factor=lr_factor)

            # Fetch data that has been saved
            error_res[i, j, :, rep] = np.load(save_dir_pre + "_error.npy")
            if config["calc_acc_prob"]:
                acc_prob_res[i, j, 0, :, rep] = np.load(save_dir_pre + "_" + "cd_cnce" + "_acc_prob.npy").mean(axis=0)
                acc_prob_res[i, j, 1, :, rep] = np.load(save_dir_pre + "_" + "cd_mh" + "_acc_prob.npy").mean(axis=0)

# Save res
np.save("res/final_param_error_cnce_exp_w_pers_mh_num_samples_" + str(num_samples), error_res)
np.save("res/final_acc_prob_cnce_exp_w_pers_mh_num_samples_" + str(num_samples), acc_prob_res)

In [None]:
# Visualise results

#error_res = np.load("res/final_param_error_cnce_exp_w_pers_mh_num_samples_" + str(num_samples) + ".npy")
#acc_prob_res = np.load("res/final_acc_prob_cnce_exp_w_pers_mh_num_samples_" + str(num_samples) + ".npy")

error_res = np.load("res/final_param_error_cnce_exp_w_pers_mh_num_samples_" + str(num_samples) + "bs_20.npy")
acc_prob_res = np.load("res/final_acc_prob_cnce_exp_w_pers_mh_num_samples_" + str(num_samples) + "bs_20.npy")

colors = [[31/255,68/255,156/255], [240/255,80/255,57/255], [124/255,161/255,204/255], [238/255,186/255,180/255]]
assert len(colors) >= len(configs), "Need one colour for each method"

def get_statistics(data):
    median = np.median(data, axis=-1)
    upper_quartile = np.max(data, axis=-1)
    lower_quartile = np.min(data, axis=-1)
    
    return median, upper_quartile, lower_quartile
        
def skip_list_item(lst: list, nth: int):
    """Skip every nth element in list"""
    return list(
        map(
            lambda val: val[1],
            filter(lambda idx_row: (idx_row[0] % nth == 0), enumerate(lst)),
        )
    )
    
plot_all = False
num_its = error_res.shape[-2]
x = np.arange(num_its) + 1


if num_samples == 1000:
    x = skip_list_item(x, 5)
    

for i, J in enumerate(num_neg_samples):
    
    if plot_all:
        fig, ax = plt.subplots(1, 4, figsize=(16, 5))
    else:
        fig, ax = plt.subplots(1, 3, figsize=(16, 5))
        
    for j, config in enumerate(configs):
        
        if plot_all:
            for k in range(reps):
                if num_samples == 1000:
                    data = skip_list_item(error_res[i, j, :, k], 5)
                else:
                    data = error_res[i, j, :, k]
                    
                if k == 0:
                    ax[j].plot(x, data, color=colors[j], label=labels[j])
                else:
                    ax[j].plot(x, data, color=colors[j])
                    
            ax[j].legend()
            ax[j].set_ylim([-1, 100])        
            ax[j].set_xlabel("Iter.")
            ax[j].set_ylabel("Avg. Sq. Error")
            

        else:
            ax[0].set_yscale('log')
            
            if num_samples == 1000:
                err_data = skip_list_item(error_res[i, j, :, :], 5)
                cnce_data = skip_list_item(acc_prob_res[i, j, 0, :, :], 5)
                mh_cnce_data = skip_list_item(acc_prob_res[i, j, 1, :, :], 5)

            else:
                err_data = error_res[i, j, :, :]
                cnce_data = acc_prob_res[i, j, 0, :, :]
                mh_cnce_data = acc_prob_res[i, j, 1, :, :]

                
            err_median, err_upper_quartile, err_lower_quartile = get_statistics(err_data)

            ax[0].plot(x, err_median, color=colors[j], label=labels[j])
            ax[0].plot(x, err_upper_quartile, '--', color=colors[j])

            cnce_median, cnce_upper_quartile, cnce_lower_quartile = get_statistics(cnce_data)
            mh_cnce_median, mh_cnce_upper_quartile, mh_cnce_lower_quartile = get_statistics(mh_cnce_data)
        
            if j == 0 or j == 1:
                ax[1].plot(x, cnce_median, color=colors[j], label=labels[j])
                ax[1].plot(x, mh_cnce_median, color=colors[j+2], label=labels[j+2])
            else:
                ax[2].plot(x, cnce_median, color=colors[j-2], label=labels[j-2])
                ax[2].plot(x, mh_cnce_median, color=colors[j], label=colors[j-2])
                

    if plot_all:
         tikzplotlib.save("res/cnce_acc_prob_res_num_samples_" + str(num_samples) + "_num_neg_samples_" + str(J) + "_all.tex")
    else:
        ax[0].set_xlabel("Iter.")
        ax[0].set_ylabel("Sq. Error")
        
        ax[1].set_xlabel("Iter.")
        ax[1].set_ylabel("Acc. Prob., (P-)CNCE")
        ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, 1.12), ncol=4, fancybox=True, shadow=True)

        ax[2].set_xlabel("Iter.")
        ax[2].set_ylabel("Acc. Prob., (P-)MH-CNCE")
        
        tikzplotlib.save("res/cnce_acc_prob_res_num_samples_" + str(num_samples) + "_num_neg_samples_" + str(J) + ".tex")

    plt.show()