In [None]:
import sys
import torch
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

sys.path.append("..")
from src.nce.cnce import CondNceCrit
from src.nce.rank import NceRankCrit

from src.noise_distr.normal import MultivariateNormal
from src.models.gaussian_model import DiagGaussianModel

from src.training.model_training import train_model
from src.data.generic import Generic
from src.training.training_utils import Mse, no_change_stopping_condition
%load_ext autoreload
%autoreload 2

In [None]:
def mvn_curve(mu, cov, std=1, res=100):
    with torch.no_grad():
        angles = torch.linspace(0, 2*torch.pi, res)
        curve_param = torch.column_stack((torch.cos(angles), torch.sin(angles)))
        ellipsis = std * curve_param @ torch.Tensor(sqrtm(cov))
        return mu + ellipsis
    
def plot_mvn(levels, ax, label):
    ax.plot(levels[:, 0], levels[:, 1], label=label)
    
def plot_distrs(p_d, p_t_d, p_t_t):    
    fig, ax = plt.subplots()
    ax.set_xlim([-3, 10])
    ax.set_ylim([-3, 10])
    plot_mvn(mvn_curve(p_d.mu, p_d.cov), ax, "$p_{d}}$")
    plot_mvn(mvn_curve(p_t_d.mu, p_t_d.cov()), ax, "$p_{\\theta}, p_n = p_d$")
    plot_mvn(mvn_curve(p_t_t.mu, p_t_t.cov()), ax, "$p_{\\theta}, p_n = p_{\\theta}$")

    ax.legend()

In [None]:
# Run experiments
D, N, J = 2, 100, 1 # Dimension, Num. data samples, Num neg. samples 
mu_star, Sigma_star = torch.ones(D,), torch.eye(D)
p_d = MultivariateNormal(mu_star, Sigma_star)
data_sample = p_d.sample((N,), None)
num_epochs = 50
batch_size = 10


In [None]:
p_t_data_noise = DiagGaussianModel(mu=5.0*torch.ones(D,), cov=torch.eye(D))



training_data = Generic(data_sample)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
criterion = NceRankCrit(p_t_data_noise, p_d, J)
error_metric = Mse(p_d).metric
err = train_model(criterion,
                  error_metric,
                  train_loader,
                  None,
                  neg_sample_size=J,
                  num_epochs=num_epochs,
                  stopping_condition=no_change_stopping_condition)

In [None]:

#running_loss_it = np.zeros(num_epochs)

p_t_model_noise = DiagGaussianModel(mu=5.0*torch.ones(D,), cov=4*torch.eye(D))
optimizer = torch.optim.SGD(p_t_model_noise.parameters(), lr=0.1)

metric = []
for epoch in range(1, num_epochs+1):
    # print(f"Epoch {epoch}")
    running_loss = 0.0
    old_params = torch.nn.utils.parameters_to_vector(p_t_model_noise.parameters())
    for i, (y, idx) in enumerate(train_loader, 0):
        #print(f"Batch {i+1}")
        #print(f"mu: {p_t_model_noise.mu}, diag cov: {torch.diag(p_t_model_noise.cov())}")
        q = MultivariateNormal(p_t_model_noise.mu.detach().clone(), p_t_model_noise.cov().clone().detach().clone())
        criterion = NceRankCrit(p_t_model_noise, q, J)
        optimizer.zero_grad()
        with torch.no_grad():
            loss = criterion.crit(y, None)
            # print(loss)
        # Calculate and assign gradients
        criterion.calculate_crit_grad(y, idx)

        # Take gradient step
        optimizer.step()

        # TODO: not sure how to do here
        #running_loss += loss.item()
    metric.append(error_metric(p_t_model_noise).detach().numpy())

    # print('[%d] evaluation metric: %.3f' %
    #       (epoch + 1, metric[epoch]))
print("Finished training")

In [None]:
x_0 = p_d.mu
x_1 = p_t_model_noise.mu
x_1

In [None]:
plot_distrs(p_d, p_t_data_noise, p_t_model_noise)

In [None]:
p_t_model_noise.mu