In [1]:
import numpy as np

from lcc.dataset import create_lcc_dataset_k1_t1_scalar

from probabilistic_classifier.dataset import create_joint_marginal_dataset
from probabilistic_classifier.estimate import estimate_mi_for_binary_classification
from probabilistic_classifier.train import train_binary_classifier_v2

In [2]:
# create the basis dataset
prime = 5
data_range = 2
num_of_samples = 800000
weight = 1
dataset = create_lcc_dataset_k1_t1_scalar(prime, data_range, num_of_samples, weight)

In [3]:
dataset = np.concatenate([dataset[:, :3], dataset[:, 4:]], axis=1)

In [4]:
# create the joint and marginal datasets
x_idx, y_idx, z_idx = [0, 1], [3, 4, 5], [2]
yz_idx = [2, 3, 4, 5]
joint_data, joint_label, marginal_data, marginal_label = create_joint_marginal_dataset(dataset, x_idx, yz_idx)
data, label = np.concatenate([joint_data, marginal_data]), np.concatenate([joint_label, marginal_label])
randomize_idx = np.random.permutation(np.arange(2 * num_of_samples))
data, label = data[randomize_idx], label[randomize_idx]

In [5]:
# define model parameters
num_input_features = len(x_idx) + len(yz_idx)
hidden_size_arr = [256, 256, 256]
lr = 0.001

In [6]:
num_of_outer_iteration = 20
num_of_inner_iteration = 50
batch_size = 4096

# iterate over many times
outer_running_loss = []
outer_running_loss_avg = []
ldr_estimations = []
dv_estimations = []
nwj_estimations = []

for outer_iter in range(num_of_outer_iteration):
    print('################################################################')
    model, inner_running_loss, inner_running_loss_avg, num_of_joint, num_of_marginal = train_binary_classifier_v2(data, label, num_input_features, hidden_size_arr, lr, num_of_inner_iteration, batch_size, outer_iter, save_avg=200, print_progress=True)
    outer_running_loss.append(inner_running_loss)
    outer_running_loss_avg.append(inner_running_loss_avg)
    
    ## estimate cmi
    curr_ldr, curr_dv, curr_nwj = estimate_mi_for_binary_classification(model, joint_data, num_of_joint, marginal_data, num_of_marginal)
    print('trial: {}, ldr: {}, dv: {}, nwj: {}'.format(outer_iter + 1, curr_ldr.item(), curr_dv.item(), curr_nwj.item()))
    print('################################################################\n')
    ldr_estimations.append(curr_ldr.item())
    dv_estimations.append(curr_dv.item())
    nwj_estimations.append(curr_nwj.item())
    
print('final estimations:\n\tldr: {}\n\tdv: {}\n\tnwj: {}'.format(np.mean(ldr_estimations), np.mean(dv_estimations), np.mean(nwj_estimations)))

################################################################
trial: 1, epoch, 1, iter: 1, curr loss: 0.6964414715766907, avg loss: 0.6964414715766907
trial: 1, epoch, 2, iter: 1, curr loss: 0.4438747763633728, avg loss: 0.4438747763633728
trial: 1, epoch, 3, iter: 1, curr loss: 0.4465521574020386, avg loss: 0.4465521574020386
trial: 1, epoch, 4, iter: 1, curr loss: 0.43235599994659424, avg loss: 0.43235599994659424
trial: 1, epoch, 5, iter: 1, curr loss: 0.42525187134742737, avg loss: 0.42525187134742737
trial: 1, epoch, 6, iter: 1, curr loss: 0.43201443552970886, avg loss: 0.43201443552970886
trial: 1, epoch, 7, iter: 1, curr loss: 0.4363255500793457, avg loss: 0.4363255500793457
trial: 1, epoch, 8, iter: 1, curr loss: 0.43732601404190063, avg loss: 0.43732601404190063
trial: 1, epoch, 9, iter: 1, curr loss: 0.43077680468559265, avg loss: 0.43077680468559265
trial: 1, epoch, 10, iter: 1, curr loss: 0.4282991290092468, avg loss: 0.4282991290092468
trial: 1, epoch, 11, iter: 1, curr

In [9]:
# create the joint and marginal datasets
x_idx, y_idx, z_idx = [0, 1], [3, 4, 5], [2]
yz_idx = [2, 3, 4, 5]
dataset = dataset[:, :3]
joint_data, joint_label, marginal_data, marginal_label = create_joint_marginal_dataset(dataset, x_idx, z_idx)
data, label = np.concatenate([joint_data, marginal_data]), np.concatenate([joint_label, marginal_label])
randomize_idx = np.random.permutation(np.arange(2 * num_of_samples))
data, label = data[randomize_idx], label[randomize_idx]

In [10]:
# define model parameters
num_input_features = len(x_idx) + len(z_idx)
hidden_size_arr = [256, 256, 256]
lr = 0.001

In [11]:
num_of_outer_iteration = 20
num_of_inner_iteration = 50
batch_size = 4096

# iterate over many times
outer_running_loss = []
outer_running_loss_avg = []
ldr_estimations = []
dv_estimations = []
nwj_estimations = []

for outer_iter in range(num_of_outer_iteration):
    print('################################################################')
    model, inner_running_loss, inner_running_loss_avg, num_of_joint, num_of_marginal = train_binary_classifier_v2(data, label, num_input_features, hidden_size_arr, lr, num_of_inner_iteration, batch_size, outer_iter, save_avg=200, print_progress=True)
    outer_running_loss.append(inner_running_loss)
    outer_running_loss_avg.append(inner_running_loss_avg)
    
    ## estimate cmi
    curr_ldr, curr_dv, curr_nwj = estimate_mi_for_binary_classification(model, joint_data, num_of_joint, marginal_data, num_of_marginal)
    print('trial: {}, ldr: {}, dv: {}, nwj: {}'.format(outer_iter + 1, curr_ldr.item(), curr_dv.item(), curr_nwj.item()))
    print('################################################################\n')
    ldr_estimations.append(curr_ldr.item())
    dv_estimations.append(curr_dv.item())
    nwj_estimations.append(curr_nwj.item())
    
print('final estimations:\n\tldr: {}\n\tdv: {}\n\tnwj: {}'.format(np.mean(ldr_estimations), np.mean(dv_estimations), np.mean(nwj_estimations)))

################################################################
trial: 1, epoch, 1, iter: 1, curr loss: 0.693889856338501, avg loss: 0.693889856338501
trial: 1, epoch, 2, iter: 1, curr loss: 0.5279072523117065, avg loss: 0.5279072523117065
trial: 1, epoch, 3, iter: 1, curr loss: 0.5308347940444946, avg loss: 0.5308347940444946
trial: 1, epoch, 4, iter: 1, curr loss: 0.5243828296661377, avg loss: 0.5243828296661377
trial: 1, epoch, 5, iter: 1, curr loss: 0.530073881149292, avg loss: 0.530073881149292
trial: 1, epoch, 6, iter: 1, curr loss: 0.519391655921936, avg loss: 0.519391655921936
trial: 1, epoch, 7, iter: 1, curr loss: 0.5324214696884155, avg loss: 0.5324214696884155
trial: 1, epoch, 8, iter: 1, curr loss: 0.5272299647331238, avg loss: 0.5272299647331238
trial: 1, epoch, 9, iter: 1, curr loss: 0.5247963666915894, avg loss: 0.5247963666915894
trial: 1, epoch, 10, iter: 1, curr loss: 0.5254008173942566, avg loss: 0.5254008173942566
trial: 1, epoch, 11, iter: 1, curr loss: 0.5268428