In [1]:
import numpy as np

from lcc.dataset import create_lcc_dataset_k1_t1_scalar

from probabilistic_classifier.dataset import create_joint_marginal_dataset
from probabilistic_classifier.estimate import estimate_mi_for_binary_classification
from probabilistic_classifier.train import train_binary_classifier_v2



In [2]:
# create the basis dataset
prime = 5
data_range = 2
num_of_samples = 800000
weight = np.asarray([[1]])
dataset = create_lcc_dataset_k1_t1_scalar(prime, data_range, num_of_samples, weight)

pols created
dataset is created


In [3]:
# create the joint and marginal datasets
x_idx, y_idx, z_idx = [0, 1], [3, 4, 5], [2]
yz_idx = [2, 3, 4, 5]
joint_data, joint_label, marginal_data, marginal_label = create_joint_marginal_dataset(dataset, x_idx, yz_idx)
data, label = np.concatenate([joint_data, marginal_data]), np.concatenate([joint_label, marginal_label])
randomize_idx = np.random.permutation(np.arange(2 * num_of_samples))
data, label = data[randomize_idx], label[randomize_idx]

In [4]:
# define model parameters
num_input_features = len(x_idx) + len(yz_idx)
hidden_size_arr = [256, 256, 256]
lr = 0.001

In [5]:
num_of_outer_iteration = 20
num_of_inner_iteration = 50
batch_size = 4096

# iterate over many times
outer_running_loss = []
outer_running_loss_avg = []
ldr_estimations = []
dv_estimations = []
nwj_estimations = []

for outer_iter in range(num_of_outer_iteration):
    print('################################################################')
    model, inner_running_loss, inner_running_loss_avg, num_of_joint, num_of_marginal = train_binary_classifier_v2(data, label, num_input_features, hidden_size_arr, lr, num_of_inner_iteration, batch_size, outer_iter, save_avg=200, print_progress=True)
    outer_running_loss.append(inner_running_loss)
    outer_running_loss_avg.append(inner_running_loss_avg)
    
    ## estimate cmi
    curr_ldr, curr_dv, curr_nwj = estimate_mi_for_binary_classification(model, joint_data, num_of_joint, marginal_data, num_of_marginal)
    print('trial: {}, ldr: {}, dv: {}, nwj: {}'.format(outer_iter + 1, curr_ldr.item(), curr_dv.item(), curr_nwj.item()))
    print('################################################################\n')
    ldr_estimations.append(curr_ldr.item())
    dv_estimations.append(curr_dv.item())
    nwj_estimations.append(curr_nwj.item())
    
print('final estimations:\n\tldr: {}\n\tdv: {}\n\tnwj: {}'.format(np.mean(ldr_estimations), np.mean(dv_estimations), np.mean(nwj_estimations)))

################################################################
trial: 1, epoch, 1, iter: 1, curr loss: 0.6940299868583679, avg loss: 0.6940299868583679
trial: 1, epoch, 1, iter: 200, curr loss: 0.44217681884765625, avg loss: 0.4565826053917408
trial: 1, epoch, 2, iter: 1, curr loss: 0.4305155277252197, avg loss: 0.4305155277252197
trial: 1, epoch, 2, iter: 200, curr loss: 0.43757250905036926, avg loss: 0.4301589366793632
trial: 1, epoch, 3, iter: 1, curr loss: 0.4406145215034485, avg loss: 0.4406145215034485
trial: 1, epoch, 3, iter: 200, curr loss: 0.41929250955581665, avg loss: 0.42972503215074537
trial: 1, epoch, 4, iter: 1, curr loss: 0.42699992656707764, avg loss: 0.42699992656707764
trial: 1, epoch, 4, iter: 200, curr loss: 0.42695939540863037, avg loss: 0.42881056413054464
trial: 1, epoch, 5, iter: 1, curr loss: 0.42707228660583496, avg loss: 0.42707228660583496
trial: 1, epoch, 5, iter: 200, curr loss: 0.42135727405548096, avg loss: 0.4281506833434105
trial: 1, epoch, 6, iter

In [6]:
# create the joint and marginal datasets
x_idx, y_idx, z_idx = [0, 1], [3, 4, 5], [2]
yz_idx = [2, 3, 4, 5]
dataset = dataset[:, :3]
joint_data, joint_label, marginal_data, marginal_label = create_joint_marginal_dataset(dataset, x_idx, z_idx)
data, label = np.concatenate([joint_data, marginal_data]), np.concatenate([joint_label, marginal_label])
randomize_idx = np.random.permutation(np.arange(2 * num_of_samples))
data, label = data[randomize_idx], label[randomize_idx]

In [7]:
# define model parameters
num_input_features = len(x_idx) + len(z_idx)
hidden_size_arr = [256, 256, 256]
lr = 0.001

In [8]:
num_of_outer_iteration = 20
num_of_inner_iteration = 50
batch_size = 4096

# iterate over many times
outer_running_loss = []
outer_running_loss_avg = []
ldr_estimations = []
dv_estimations = []
nwj_estimations = []

for outer_iter in range(num_of_outer_iteration):
    print('################################################################')
    model, inner_running_loss, inner_running_loss_avg, num_of_joint, num_of_marginal = train_binary_classifier_v2(data, label, num_input_features, hidden_size_arr, lr, num_of_inner_iteration, batch_size, outer_iter, save_avg=200, print_progress=True)
    outer_running_loss.append(inner_running_loss)
    outer_running_loss_avg.append(inner_running_loss_avg)
    
    ## estimate cmi
    curr_ldr, curr_dv, curr_nwj = estimate_mi_for_binary_classification(model, joint_data, num_of_joint, marginal_data, num_of_marginal)
    print('trial: {}, ldr: {}, dv: {}, nwj: {}'.format(outer_iter + 1, curr_ldr.item(), curr_dv.item(), curr_nwj.item()))
    print('################################################################\n')
    ldr_estimations.append(curr_ldr.item())
    dv_estimations.append(curr_dv.item())
    nwj_estimations.append(curr_nwj.item())
    
print('final estimations:\n\tldr: {}\n\tdv: {}\n\tnwj: {}'.format(np.mean(ldr_estimations), np.mean(dv_estimations), np.mean(nwj_estimations)))

################################################################
trial: 1, epoch, 1, iter: 1, curr loss: 0.6922325491905212, avg loss: 0.6922325491905212
trial: 1, epoch, 1, iter: 200, curr loss: 0.5293221473693848, avg loss: 0.5348994317650795
trial: 1, epoch, 2, iter: 1, curr loss: 0.5356400012969971, avg loss: 0.5356400012969971
trial: 1, epoch, 2, iter: 200, curr loss: 0.5210230946540833, avg loss: 0.5242860004305839
trial: 1, epoch, 3, iter: 1, curr loss: 0.5239117741584778, avg loss: 0.5239117741584778
trial: 1, epoch, 3, iter: 200, curr loss: 0.5202776193618774, avg loss: 0.5239950707554817
trial: 1, epoch, 4, iter: 1, curr loss: 0.5150713920593262, avg loss: 0.5150713920593262
trial: 1, epoch, 4, iter: 200, curr loss: 0.5196893215179443, avg loss: 0.5237035164237023
trial: 1, epoch, 5, iter: 1, curr loss: 0.5267345905303955, avg loss: 0.5267345905303955
trial: 1, epoch, 5, iter: 200, curr loss: 0.5248969793319702, avg loss: 0.5230937004089355
trial: 1, epoch, 6, iter: 1, curr l