In [1]:
import numpy as np
from probabilistic_classifier.knn_sampling.cmie_utils import batch_construction, train_classifier, estimate_CMI

In [2]:
# dataset variables
idx_x, idx_y, idx_z = [0], [1], [2]
dim_x, dim_y, dim_z = len(idx_x), len(idx_y), len(idx_z)
seed = 123
num_of_samples = 80000


# model variables
k_neighbor = 2
lr = 0.002
tau_clipping = 0.0001
batch_size = num_of_samples // 2 
trial = 20
repetition = 10
epoch = 300
input_size = dim_x + dim_y + dim_z
hidden_size = 64
num_classes = 2

In [3]:
last_ldr_arr, last_dv_arr, last_nwj_arr = [], [], []
for r in range(repetition):
    ldr_arr, dv_arr, nwj_arr = [], [], []
    dataset = np.random.multivariate_normal([0, 0, 0], [[1, 0.8, 0.5],[0.8, 1, 0], [0.5, 0, 1]], num_of_samples)
    for t in range(trial):
        train_data, train_label, joint_test, prod_test = batch_construction(np.split(dataset, [1, 2], axis=1), [idx_x, idx_y, idx_z], set_size=batch_size, K_neighbor=k_neighbor)
        model, loss_e = train_classifier(BatchTrain=train_data, TargetTrain=train_label, Params=(input_size, hidden_size, num_classes, tau_clipping), Epoch=epoch, Lr=lr, Seed=seed)
        curr_cmi = estimate_CMI(model, joint_test, prod_test)
        curr_ldr, curr_dv, curr_nwj = curr_cmi[0], curr_cmi[1], curr_cmi[2]
        ldr_arr.append(curr_ldr)
        dv_arr.append(curr_dv)
        nwj_arr.append(curr_nwj)
        print('repetition: {}, trial: {}, ldr: {}, dv: {}, nwj: {}'.format(r, t, curr_ldr, curr_dv, curr_nwj))
    last_ldr_arr.append(np.mean(ldr_arr))
    last_dv_arr.append(np.mean(dv_arr))
    last_nwj_arr.append(np.mean(nwj_arr))

repetition: 0, trial: 0, ldr: 0.9802377833714149, dv: 0.9502484711011732, nwj: 0.9497942625786322
repetition: 0, trial: 1, ldr: 0.909350935771365, dv: 0.9485046495389715, nwj: 0.9477480495884961
repetition: 0, trial: 2, ldr: 0.9558629481955752, dv: 0.961390795469287, nwj: 0.9613755450351763
repetition: 0, trial: 3, ldr: 0.9358579732121154, dv: 0.96934856641674, nwj: 0.9687939650502593
repetition: 0, trial: 4, ldr: 0.9652745062834746, dv: 0.9511342544581987, nwj: 0.9510338082106908
repetition: 0, trial: 5, ldr: 0.9745873562886084, dv: 0.9446866398603218, nwj: 0.9442351244631664
repetition: 0, trial: 6, ldr: 0.9133785034175037, dv: 0.9634575230632662, nwj: 0.9622242417592428
repetition: 0, trial: 7, ldr: 0.9704764450832558, dv: 0.9517378287832511, nwj: 0.9515611591231856
repetition: 0, trial: 8, ldr: 0.9267344915579119, dv: 0.9422410536055352, nwj: 0.9421214459050741
repetition: 0, trial: 9, ldr: 0.9565741492497443, dv: 0.9387514455279571, nwj: 0.9385916733647719
repetition: 0, trial: 10

In [5]:
np.mean(last_dv_arr)

0.9499470167851358