In [5]:
import numpy as np

from probabilistic_classifier.dataset import create_multiclass_conditional_dataset
from probabilistic_classifier.estimate import estimate_mi_for_multiclass_classification
from probabilistic_classifier.train import train_multiclass_classifier

In [6]:
# create the basis dataset
mean = [0, 0, 0]
cov = [[1, 0.8, 0.5],
       [0.8, 1, 0],
       [0.5, 0, 1]]
num_of_samples = 800000
dataset = np.random.multivariate_normal(mean=mean, cov=cov, size=num_of_samples)

In [7]:
# create the joint and marginal datasets
x_idx, y_idx, z_idx = [0], [1], [2]
joint_data, joint_label, all_marginal_data, all_marginal_label, marginal_y_joint_xz_data, marginal_y_joint_xz_label, marginal_x_joint_yz_data, marginal_x_joint_yz_label = create_multiclass_conditional_dataset(dataset, x_idx, y_idx, z_idx)
data, label = np.concatenate([joint_data, all_marginal_data, marginal_y_joint_xz_data, marginal_x_joint_yz_data]), np.concatenate([joint_label, all_marginal_label, marginal_y_joint_xz_label, marginal_x_joint_yz_label])
randomize_idx = np.random.permutation(np.arange(4 * num_of_samples))
data, label = data[randomize_idx], label[randomize_idx]

In [8]:
# define model parameters
num_input_features = len(x_idx) + len(y_idx) + len(z_idx)
hidden_size_arr = [256, 256, 256]
lr = 0.001

In [9]:
num_of_outer_iteration = 20
num_of_inner_iteration = 50
batch_size = 4096

# iterate over many times
outer_running_loss = []
outer_running_loss_avg = []
ldr_estimations = []

for outer_iter in range(num_of_outer_iteration):
    print('################################################################')
    (model, inner_running_loss, inner_running_loss_avg, num_of_joint, num_of_all_marginal, 
     num_of_marginal_y_joint_xz, num_of_marginal_x_joint_yz) = train_multiclass_classifier(data, label, num_input_features, hidden_size_arr, lr, num_of_inner_iteration, batch_size, outer_iter, print_progress=True, save_avg=500)
    outer_running_loss.append(inner_running_loss)
    outer_running_loss_avg.append(inner_running_loss_avg)
    
    ## estimate cmi
    curr_ldr = estimate_mi_for_multiclass_classification(model, joint_data, num_of_joint, num_of_all_marginal, num_of_marginal_y_joint_xz, num_of_marginal_x_joint_yz)
    print('trial: {}, ldr: {}'.format(outer_iter + 1, curr_ldr.item()))
    print('################################################################\n')
    ldr_estimations.append(curr_ldr.item())
    
print('final estimations:\n\tldr: {}'.format(np.mean(ldr_estimations)))

################################################################
trial: 1, epoch, 1, iter: 1, curr loss: 1.3872872591018677, avg loss: 1.3872872591018677
trial: 1, epoch, 1, iter: 500, curr loss: 1.1742324829101562, avg loss: 1.1799034812450409
trial: 1, epoch, 2, iter: 1, curr loss: 1.1682443618774414, avg loss: 1.1682443618774414
trial: 1, epoch, 2, iter: 500, curr loss: 1.181355357170105, avg loss: 1.175142041683197
trial: 1, epoch, 3, iter: 1, curr loss: 1.1830161809921265, avg loss: 1.1830161809921265
trial: 1, epoch, 3, iter: 500, curr loss: 1.1591696739196777, avg loss: 1.1745931642055512
trial: 1, epoch, 4, iter: 1, curr loss: 1.1728137731552124, avg loss: 1.1728137731552124
trial: 1, epoch, 4, iter: 500, curr loss: 1.1772903203964233, avg loss: 1.174594832420349
trial: 1, epoch, 5, iter: 1, curr loss: 1.1663143634796143, avg loss: 1.1663143634796143
trial: 1, epoch, 5, iter: 500, curr loss: 1.1787867546081543, avg loss: 1.1742442948818206
trial: 1, epoch, 6, iter: 1, curr loss