In [9]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

sys.path.append("..")
from src.nce.binary import NceBinaryCrit
from src.nce.cnce import CondNceCrit

from src.noise_distr.normal import Normal
from src.noise_distr.conditional_normal import ConditionalNormal

from src.models.ring_model.ring_model import RingModel, RingModelNCE
from src.data.ring_model_dataset import RingModelDataset

from src.training.model_training import train_model
from src.training.model_training import EuclideanPrecisionMetric

In [None]:
# Data specs
num_dims = 5
mu = 3
precision = 2

# Experiments specs
num_samples = [20, 50, 100, 200]
num_neg_samples = [2, 5, 10]
reps = 5

# Criteria
criteria = [NceBinaryCrit, CondNceCrit]
crit_labels = ["binary", "conditional"]

# Training specs
error_metric(EuclideanPrecisionMetric(true_precision=precision))
batch_size = 32
num_epochs = 10

In [None]:
# Run experiments

error_res = np.zeros((len(num_samples), len(num_neg_samples), len(criteria), reps))

# Sorry för alla loopar
for i in num_samples:
    
    for j in num_neg_samples:
        
        for rep in range(reps):
                        
            training_data = RingModelDataset(sample_size=i, num_dims=num_dims, mu=mu, precision=precision, 
                                             root_dir="data/datasets/ring_data_size_" + str(i) + "_nn_" + str(j) + "_rep_" + str(rep))
            train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)

            log_precision_init = np.random.normal()
            
            for k, (crit, lab) in enumerate(zip(criteria, crit_labels)):
            
                if isintance(crit, NceBinaryCrit):
                    log_part_fun_init = np.random.normal()
                    unnorm_distr = RingModelNCE(mu=mu, log_precision=log_precision_init, log_part_fun=log_part_fun_init)
                else:
                    unnorm_distr = RingModel(mu=mu, log_precision=log_precision_init)
                    
                if isintance(crit, CondNCECrit):
                    noise_distr = ConditionalNormal(mu=0, sigma_sq=sigma_noise)
                else:
                    noise_distr = Normal(sigma_sq=sigma_noise)
                    
                criterion = crit(unnorm_distr, noise_distr)
                
                save_dir = "res/param_error_" + lab + "_samples_" str(i) + "num_neg_" + str(j) + "_rep_" + str(rep)
                error_res[i, j, k, rep] = train_model(criterion, error_metric, train_loader, save_dir, neg_sample_size=j, num_epochs=num_epochs)

np.save("res/final_param_error_ring_model_all", error_res)

In [None]:
# For visualising results
def plot_res(x, error, label, col, ax)
    plt.plot(x, np.mean(error, axis=-1), color=col, linewidth=1.0, marker='o', label=label)

    plt.fill_between(x, np.min(error, axis=-1), np.max(error, axis=-1), alpha=0.1, color=col)

In [None]:
# Visualise results
fig, ax = plt.subplots((len(num_neg_samples), 1))
colors = ['b', 'r']

log_num_samples = np.log(np.array(num_samples))
for j, (axis, col) in enumerate(zip(ax.reshape(-1), colors):
    for k, (crit, lab) in enumerate(zip(criteria, crit_labels)):
        plot_res(log_num_samples, error_res[:, j, k, :], lab, col, axis)
        
plt.show()