In [12]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn.functional import one_hot

from probabilistic_classifier.dataset import create_knn_sampling_joint_cond_marginal_dataset
from probabilistic_classifier.probabilistic_classifier import ProbabilisticClassifier

In [13]:
# create the basis dataset
mean = [0, 0, 0]
cov = [[1, 0.8, 0.5],
       [0.8, 1, 0],
       [0.5, 0, 1]]
num_of_samples = 10000
dataset = np.random.multivariate_normal(mean=mean, cov=cov, size=num_of_samples)

In [14]:
# create the joint and marginal datasets
num_of_neighbors = 4
x_idx, y_idx, z_idx = [0], [1], [2]
joint_data, joint_label, marginal_data, marginal_label = create_knn_sampling_joint_cond_marginal_dataset(dataset, 2,
                                                                                                         [0], [1], [2])
data, label = np.concatenate([joint_data, marginal_data]), np.concatenate([joint_label, marginal_label])
randomize_idx = np.random.permutation(np.arange(2 * num_of_samples))
data, label = data[randomize_idx], label[randomize_idx]

In [15]:
# define model parameters
num_input_features = len(x_idx) + len(y_idx) + len(z_idx)
hidden_size_arr = [64]
num_output_features = 2
lr = 0.001
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
num_of_outer_iteration = 1
num_of_inner_iteration = 80000
batch_size = 4096

# iterate over many times
outer_running_loss = []
outer_running_loss_avg = []
ldr_estimations = []
dv_estimations = []
nwj_estimations = []

for outer_iter in range(num_of_outer_iteration):
    model = ProbabilisticClassifier(num_input_features, hidden_size_arr, num_output_features).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = Adam(model.parameters(), lr=lr)
    
    ## train classifier
    model.train()
    inner_running_loss = []
    inner_running_loss_avg = []
    curr_inner_running_loss_avg = 0
    
    num_of_joint = 0
    num_of_marginal = 0
    for inner_iter in range(num_of_inner_iteration):
        selected_samples = np.random.choice(num_of_samples * 2, batch_size, replace=False)
        batch_data, batch_label = data[selected_samples], label[selected_samples]
        batch_data, batch_label = torch.from_numpy(batch_data).to(torch.float32), torch.from_numpy(batch_label).to(torch.long)
        num_of_marginal += torch.count_nonzero(batch_label)
        num_of_joint += batch_size - torch.count_nonzero(batch_label)
        batch_label = one_hot(batch_label, num_classes=num_output_features).to(torch.float32)
        batch_data, batch_label = batch_data.to(device), batch_label.to(device)
        
        optimizer.zero_grad()
        logits = model(batch_data)
        loss = criterion(logits, batch_label)
        loss.backward()
        optimizer.step()
        
        curr_inner_running_loss_avg += loss.item()
        
        if inner_iter == 0 or ((inner_iter + 1) % 100) == 0:
            print('trial: {}, iter: {}, curr loss: {}, avg loss: {}'.format(outer_iter + 1, inner_iter + 1, loss.item(), curr_inner_running_loss_avg / (inner_iter + 1)))
            inner_running_loss_avg.append(curr_inner_running_loss_avg / (inner_iter + 1))
            
        inner_running_loss.append(loss.item())
    outer_running_loss.append(inner_running_loss)
    outer_running_loss_avg.append(inner_running_loss_avg)
    
    ## estimate cmi
    model.eval()
    with torch.no_grad():
        model = model.to('cpu')
        estimated_logits_for_joint = model(torch.from_numpy(joint_data).to(torch.float32))
        estimated_logits_for_marginal = model(torch.from_numpy(marginal_data).to(torch.float32))
        joint_cond_prob = torch.sigmoid(estimated_logits_for_joint)
        marginal_cond_prob = torch.sigmoid(estimated_logits_for_marginal)
        class_distribution = num_of_marginal / num_of_joint
        pointwise_dependency_joint = torch.log(torch.div(joint_cond_prob[:, 0], joint_cond_prob[:, 1]) * class_distribution)
        pointwise_dependency_marginal = torch.div(marginal_cond_prob[:, 0], marginal_cond_prob[:, 1]) * class_distribution
        curr_ldr = torch.sum(pointwise_dependency_joint) / pointwise_dependency_joint.size(0)
        curr_dv = (torch.sum(pointwise_dependency_joint) / pointwise_dependency_joint.size(0)) - torch.log(torch.sum(pointwise_dependency_marginal) / pointwise_dependency_marginal.size(0))
        curr_nwj = (torch.sum(pointwise_dependency_joint) / pointwise_dependency_joint.size(0)) - (torch.sum(pointwise_dependency_marginal) / pointwise_dependency_marginal.size(0)) + 1
        print('trial: {}, ldr: {}, dv: {}, nwj: {}'.format(outer_iter + 1, curr_ldr.item(), curr_dv.item(), curr_nwj.item()))
        ldr_estimations.append(curr_ldr.item())
        dv_estimations.append(curr_dv.item())
        nwj_estimations.append(curr_nwj.item())

trial: 1, iter: 1, curr loss: 0.692747950553894, avg loss: 0.692747950553894
trial: 1, iter: 100, curr loss: 0.5982009172439575, avg loss: 0.6452596783638
trial: 1, iter: 200, curr loss: 0.49811217188835144, avg loss: 0.5952605031430721
trial: 1, iter: 300, curr loss: 0.4440138339996338, avg loss: 0.5530812853574752
trial: 1, iter: 400, curr loss: 0.4260384440422058, avg loss: 0.523169476389885
trial: 1, iter: 500, curr loss: 0.42652058601379395, avg loss: 0.5028366670608521
trial: 1, iter: 600, curr loss: 0.4219977557659149, avg loss: 0.4887246939043204
trial: 1, iter: 700, curr loss: 0.41412353515625, avg loss: 0.4782590565936906
trial: 1, iter: 800, curr loss: 0.4160095453262329, avg loss: 0.4702806865051389
trial: 1, iter: 900, curr loss: 0.41736942529678345, avg loss: 0.46402409040265613
trial: 1, iter: 1000, curr loss: 0.4124114215373993, avg loss: 0.4590521534979343
trial: 1, iter: 1100, curr loss: 0.404771625995636, avg loss: 0.4549628913944418
trial: 1, iter: 1200, curr loss: 