In [1]:
import numpy as np
from probabilistic_classifier.knn_sampling.cmie_utils import batch_construction, train_classifier, estimate_CMI
from lcc.dataset import create_lcc_dataset_k1_t1_scalar

In [2]:
# dataset variables
idx_x, idx_y, idx_z = [0], [2], [1]
dim_x, dim_y, dim_z = 2, 3, 1
seed = 123
num_of_samples = 10000


# model variables
k_neighbor = 2
lr = 0.002
tau_clipping = 0.0001
batch_size = num_of_samples // 2 
trial = 20
repetition = 10
epoch = 300
input_size = dim_x + dim_y + dim_z
hidden_size = 64
num_classes = 2

# dresfl variables
prime = 5
data_range = 2
weight = 1

In [3]:
last_ldr_arr, last_dv_arr, last_nwj_arr = [], [], []
for r in range(repetition):
    ldr_arr, dv_arr, nwj_arr = [], [], []
    dataset = create_lcc_dataset_k1_t1_scalar(prime, data_range, num_of_samples, weight)
    for t in range(trial):
        train_data, train_label, joint_test, prod_test = batch_construction([x for i, x in enumerate(np.split(dataset, [2, 3, 4], axis=1)) if i != 2], [idx_x, idx_y, idx_z], set_size=batch_size, K_neighbor=k_neighbor)
        model, loss_e = train_classifier(BatchTrain=train_data, TargetTrain=train_label, Params=(input_size, hidden_size, num_classes, tau_clipping), Epoch=epoch, Lr=lr, Seed=seed)
        curr_cmi = estimate_CMI(model, joint_test, prod_test)
        curr_ldr, curr_dv, curr_nwj = curr_cmi[0], curr_cmi[1], curr_cmi[2]
        ldr_arr.append(curr_ldr)
        dv_arr.append(curr_dv)
        nwj_arr.append(curr_nwj)
        print('repetition: {}, trial: {}, ldr: {}, dv: {}, nwj: {}'.format(r, t, curr_ldr, curr_dv, curr_nwj))
    last_ldr_arr.append(np.mean(ldr_arr))
    last_dv_arr.append(np.mean(dv_arr))
    last_nwj_arr.append(np.mean(nwj_arr))

repetition: 0, trial: 0, ldr: 2.3450300958710937, dv: 2.6460022975763007, nwj: 2.6049317499383413
repetition: 0, trial: 1, ldr: 2.3310493775734678, dv: 2.645934265345092, nwj: 2.6011764908455124
repetition: 0, trial: 2, ldr: 2.3351973608304135, dv: -5.400597291345625, nwj: -2285.491605651496
repetition: 0, trial: 3, ldr: 1.9728411848598615, dv: -4.788370591318611, nwj: -860.7153174762351
repetition: 0, trial: 4, ldr: 3.635979545365451, dv: 4.338631660765911, nwj: 4.140709498228726
repetition: 0, trial: 5, ldr: 2.330182324176143, dv: -5.382983449877607, nwj: -2234.2846578596755
repetition: 0, trial: 6, ldr: 3.6822268218377707, dv: -3.9147440214174716, nwj: -1987.469977973213
repetition: 0, trial: 7, ldr: 2.388358257211677, dv: -5.324525921155103, nwj: -2233.5964679473173
repetition: 0, trial: 8, ldr: 1.7581070239795955, dv: 2.033005167579745, nwj: 1.9984575295324183
repetition: 0, trial: 9, ldr: 1.8722081058109823, dv: 2.2086063931603914, nwj: 2.1578695686637546
repetition: 0, trial: 10

KeyboardInterrupt: 

In [5]:
np.mean(last_dv_arr)

0.9499470167851358