In [1]:
import numpy as np

from lcc.dataset import create_lcc_dataset_k1_t1_scalar

from probabilistic_classifier.dataset import create_multiclass_conditional_dataset
from probabilistic_classifier.estimate import estimate_mi_for_multiclass_classification
from probabilistic_classifier.train import train_multiclass_classifier

In [3]:
# create the basis dataset
prime = 5
data_range = 2
num_of_samples = 200000
weight = 1
dataset = create_lcc_dataset_k1_t1_scalar(prime, data_range, num_of_samples, weight)

In [4]:
dataset = np.concatenate([dataset[:, :3], dataset[:, 4:]], axis=1)

In [5]:
# create the joint and marginal datasets
x_idx, y_idx, z_idx = [0, 1], [3, 4, 5], [2]
joint_data, joint_label, all_marginal_data, all_marginal_label, marginal_y_joint_xz_data, marginal_y_joint_xz_label, marginal_x_joint_yz_data, marginal_x_joint_yz_label = create_multiclass_conditional_dataset(dataset, x_idx, y_idx, z_idx)
data, label = np.concatenate([joint_data, all_marginal_data, marginal_y_joint_xz_data, marginal_x_joint_yz_data]), np.concatenate([joint_label, all_marginal_label, marginal_y_joint_xz_label, marginal_x_joint_yz_label])
randomize_idx = np.random.permutation(np.arange(4 * num_of_samples))
data, label = data[randomize_idx], label[randomize_idx]

In [6]:
# define model parameters
num_input_features = len(x_idx) + len(y_idx) + len(z_idx)
hidden_size_arr = [256, 256]
lr = 0.001

In [7]:
num_of_outer_iteration = 20
num_of_inner_iteration = 50
batch_size = 4096

# iterate over many times
outer_running_loss = []
outer_running_loss_avg = []
ldr_estimations = []

for outer_iter in range(num_of_outer_iteration):
    print('################################################################')
    (model, inner_running_loss, inner_running_loss_avg, num_of_joint, num_of_all_marginal, 
     num_of_marginal_y_joint_xz, num_of_marginal_x_joint_yz) = train_multiclass_classifier(data, label, num_input_features, hidden_size_arr, lr, num_of_inner_iteration, batch_size, outer_iter, print_progress=True, save_avg=500)
    outer_running_loss.append(inner_running_loss)
    outer_running_loss_avg.append(inner_running_loss_avg)
    
    ## estimate cmi
    curr_ldr = estimate_mi_for_multiclass_classification(model, joint_data, num_of_joint, num_of_all_marginal, num_of_marginal_y_joint_xz, num_of_marginal_x_joint_yz)
    print('trial: {}, ldr: {}'.format(outer_iter + 1, curr_ldr.item()))
    print('################################################################\n')
    ldr_estimations.append(curr_ldr.item())
    
print('final estimations:\n\tldr: {}'.format(np.mean(ldr_estimations)))

################################################################
trial: 1, epoch, 1, iter: 1, curr loss: 1.392756462097168, avg loss: 1.392756462097168
trial: 1, epoch, 2, iter: 1, curr loss: 0.8171751499176025, avg loss: 0.8171751499176025
trial: 1, epoch, 3, iter: 1, curr loss: 0.8039337992668152, avg loss: 0.8039337992668152
trial: 1, epoch, 4, iter: 1, curr loss: 0.7981716990470886, avg loss: 0.7981716990470886
trial: 1, epoch, 5, iter: 1, curr loss: 0.8001854419708252, avg loss: 0.8001854419708252
trial: 1, epoch, 6, iter: 1, curr loss: 0.8122058510780334, avg loss: 0.8122058510780334
trial: 1, epoch, 7, iter: 1, curr loss: 0.8035977482795715, avg loss: 0.8035977482795715
trial: 1, epoch, 8, iter: 1, curr loss: 0.7968491315841675, avg loss: 0.7968491315841675
trial: 1, epoch, 9, iter: 1, curr loss: 0.8048450946807861, avg loss: 0.8048450946807861
trial: 1, epoch, 10, iter: 1, curr loss: 0.8090829253196716, avg loss: 0.8090829253196716
trial: 1, epoch, 11, iter: 1, curr loss: 0.788

KeyboardInterrupt: 