In [2]:
import sys
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import math

sys.path.append("..")
from src.nce.cond_cd_cnce import CondCdCnceCrit
from src.nce.cond_per_cnce import CondPersistentCnceCrit


from src.noise_distr.conditional_bernoulli import ConditionalMultivariateBernoulli
from src.noise_distr.conditional_bernoulli import SpatialConditionalMultivariateBernoulli

from src.models.rbm.conditional_rbm import CondRbm
from src.data.mnist_w_labels import MnistDatasetWLabs

from src.training.model_training import train_model
from src.training.training_utils import no_stopping

from src.experiments.mnist_exp_utils import initialise_cond_rbm_params

%load_ext autoreload
%autoreload 2

## EXPERIMENTS

In [3]:
# Data specs
img_dim = 28
num_dims = img_dim**2
num_classes = 10

# Model specs
num_hidden = 200

# Training specs
num_neg_samples = 5
lr = 0.1
batch_size = 64
num_epochs = 50

In [4]:
# Check accuracy of model
def rbm_acc(rbm, data_loader, k=100):
    acc = 0
    for i, (y, idx) in enumerate(data_loader, 0):
        _, y_pred = rbm.sample(y, k=k)
        acc += (y_pred == y).type(torch.float).mean(dim=-1).sum()
        
    return acc / len(data_loader.dataset)

def placeholder_metric(model):
    return model.weights.mean()

In [None]:
# Configurations to consider in experiments 

config_cnce = {
    "criterion": CondCdCnceCrit,
    "label": "cd_cnce",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": 1,
}

config_per_cnce = {
    "criterion": CondPersistentCnceCrit,
    "label": "pers_cd_cnce",
    "estimate_part_fn": False,
    "conditional_noise_distr": True,
    "mcmc_steps": None,
}



configs = [config_cnce]


In [None]:
# Get data 
num_classes = 5
training_data = MnistDatasetWLabs(train=True, root_dir="../src/data/datasets/", num_classes=num_classes)
train_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
                             
test_data = MnistDatasetWLabs(train=False, root_dir="../src/data/datasets/", num_classes=num_classes)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

In [None]:
# Initialise model       
weights, vis_bias, hidden_bias, class_weights = initialise_cond_rbm_params(num_visible=num_dims, num_hidden=num_hidden, num_conditional=num_classes)

In [None]:
def get_cond_bern_params(y, eps=1e-1):
    num_dims = y.shape[-1]
    p_0, p_1 = torch.zeros((num_dims,)), torch.zeros((num_dims,))
    for i in range(num_dims):
                
        if y[y[:, i] < 0.5, i].size()[0] == 0:
            p_0[i] = 0.5
        else:
            p_0[i] = y[y[:, i] < 0.5, i].mean()
        
        if y[y[:, i] >= 0.5, i].size()[0] == 0:
            p_1[i] = 0.5
        else:
            p_1[i] = y[y[:, i] >= 0.5, i].mean()

    p_0[p_0 < eps] = eps
    p_1[p_1 > 1 - eps] = 1 - eps
    return p_0, p_1
        
p_0, p_1 = get_cond_bern_params(training_data.y)

In [None]:
# Test noise distr. params

fig, ax = plt.subplots(2, 1)

num_samples = 16
y_true = training_data.y[:num_samples, :].clone()
y_true[y_true >= 0.5] = 1.0
y_true[y_true < 0.5] = 0.0
y_sample = (1 - y_true) * torch.distributions.Bernoulli(p_0).sample((1,)) + y_true * torch.distributions.Bernoulli(p_1).sample((1,))


ax[0].imshow(np.transpose(torchvision.utils.make_grid(y_true.reshape(-1, 1, 28, 28), nrow=4).numpy(), (1, 2, 0)))
ax[1].imshow(np.transpose(torchvision.utils.make_grid(y_sample.reshape(-1, 1, 28, 28), nrow=4).numpy(), (1, 2, 0)))


In [None]:
# Run experiments

for config in configs:
    
    # Make sure that these are "reinitialised"
    p_m, p_n, criterion = None, None, None

    p_m = CondRbm(weights=weights.clone(), vis_bias=vis_bias.clone(), hidden_bias=hidden_bias.clone(), class_weights=class_weights.clone())

    
    if config["conditional_noise_distr"]:
        p_n = ConditionalMultivariateBernoulli(p_0, p_1) # SpatialConditionalMultivariateBernoulli(p_0, p_1)
    else:
        p_n = None
   
    if config["mcmc_steps"] is not None:
        criterion = config["criterion"](p_m, p_n, num_neg_samples, config["mcmc_steps"])
    else:
        criterion = config["criterion"](p_m, p_n, num_neg_samples)

    save_dir = None
    _ = train_model(criterion, placeholder_metric, train_loader, save_dir, num_epochs=num_epochs,
                    decaying_lr=True, weight_decay=1e-3, stopping_condition=no_stopping, Adam=False)
    
    # Save model
    torch.save(p_m.state_dict(), "res/params_rbm_" + config["label"])
    
    # Check test accuracy of model
    #acc = rbm_acc(p_m, test_loader)
    #print("Model accuracy: {}".format(acc))
    


In [None]:
# Visualise results
configs = [config_cnce, config_per_cnce]

num_samples = 8
fig, ax = plt.subplots(len(configs), 2)

if ax.ndim == 1:
    ax = ax.reshape(1, -1)

p_m = CondRbm(weights=weights.clone(), vis_bias=vis_bias.clone(), hidden_bias=hidden_bias.clone(), class_weights=class_weights.clone()
)

for i, config in enumerate(configs):
    p_m.load_state_dict(torch.load("res/params_rbm_" + config["label"]))
    
    #y_true = training_data.y[:num_samples, :] #  torch.distributions.Bernoulli(0.5).sample((num_samples, num_dims)) #
    #y_true_pred_prob, y_true_pred = p_m.sample(y_true.clone(), k=10000)
    
    y_noise = torch.distributions.Bernoulli(0.5).sample((num_samples, num_hidden)) 
    x = torch.distributions.OneHotCategorical(torch.tensor([1 / num_classes] * num_classes)).sample((num_samples,))
    print("Sampled classes: {}".format(x.argmax(dim=-1).tolist()))
    y_pred_prob, y_pred = p_m.sample_from_hidden(y_noise, x, k=100) 
    
    ax[i, 0].imshow(np.transpose(torchvision.utils.make_grid(y_pred_prob.reshape(-1, 1, 28, 28), nrow=4).numpy(), (1, 2, 0)))
    ax[i, 0].set_title(config["label"] + " predicted probabilities")
    
    ax[i, 1].imshow(np.transpose(torchvision.utils.make_grid(y_pred.reshape(-1, 1, 28, 28), nrow=4).numpy(), (1, 2, 0)))
    ax[i, 1].set_title(config["label"] + " samples")
    
plt.show()

In [None]:
# TODO: SGD verkar fungera bättre för CNCE