# Complexity-Regularized Tree-Structured Partition for Mutual Information Estimation
#### Miguel Videla A. <br> Information and Decision Systems Group

Algorithm implementation of the paper [Complexity-Regularized Tree-Structured Partition for Mutual Information Estimation](http://repositorio.uchile.cl/bitstream/handle/2250/125678/Silva_Jorge.pdf?sequence=1&isAllowed=y).

In [0]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@authors: Miguel Videla and Mauricio Gonzalez
@organization: Information and Decision Systems Group
@date: 1/10/2020
@licence: GPL
@version: 1.0.0
@mainteiner: Miguel Videla
@email: miguel.videla@ug.uchile.cl
@status: Development
"""

import time
import numpy as np
import multiprocessing as mp

class TSP():
  def __init__(self, l=0.1666, factor_kn=0.05, alpha=0.00065, num_workers=1):
    if l >= 1/3 or l <= 0:
      raise ValueError("Parameter 'l' must belong to the interval (0, 1/3).")
    self.root = None
    self.data = None
    self.dim_X = None
    self.l = l
    self.factor_kn = factor_kn
    self.alpha = alpha
    self.num_workers = num_workers
  
  def __grow(self):
    stopping_criterion = max(1, self.factor_kn * self.data.shape[0] ** (1 - self.l))
    print(stopping_criterion)
    self.root = TSPNode()
    self.root.grow(self.data, self.dim_X, self.data, self.data.shape[1], stopping_criterion)
    self.root.updateSize()

  def setObservations(self, X, Y):
    self.data = np.concatenate((X, Y), axis=1)
    self.dim_X = X.shape[1]
    self.__grow()

  def emi(self):
    if self.root is None:
      raise TypeError("Observations not provided.")
    else:
      return self.root.getEMI()

  def regularize_old(self, alpha=1):
    if self.root is None:
      raise TypeError("Observations not provided.")
    else:
      n, d = self.data.shape
      bn = self.factor_kn * n ** (-self.l)
      Cn = 12 / bn * np.sqrt(((d + 1) * np.log(2) + d * np.log(n)) * 8 / n)
      regularizer = Cn * np.sqrt(self.size())
      min_cost = np.inf
      optimal_tree = None
      start_global_time = time.time()
      for k in range(1, self.size() + 1):
        start_time = time.time()
        pruned_tree = self.root.prune(k)
        print('k: {} | Elapsed Time: {} [s] | Total Time: {} [s]'.format(k, time.time() - start_time, time.time() - start_global_time))
        cost = -pruned_tree.getEMI() + alpha * regularizer
        if cost < min_cost:
          print('Optimal K: {} | EMI: {} | Regularizer {}'.format(k, pruned_tree.getEMI(), alpha * regularizer))
          optimal_tree = pruned_tree
          min_cost = cost
      self.root = optimal_tree

  def regularize(self):
    if self.root is None:
      raise TypeError("Observations not provided.")
    else:
      n, d = self.data.shape
      bn = self.factor_kn * n ** (-self.l)
      Cn = 12 / bn * np.sqrt(((d + 1) * np.log(2) + d * np.log(n)) * 8 / n)
      regularizer = Cn * np.sqrt(self.size())
      min_cost = np.inf
      optimal_tree = None
      pool = mp.Pool(self.num_workers)
      start_time = time.time()
      results = pool.map_async(self.root.poolPrune, range(1, self.size() + 1))
      pool.close()
      pool.join()
      print('Prunning Time: {} [min]'.format((time.time() - start_time) / 60))
      pruned_trees = results.get()
      for pruned_tree in pruned_trees:
        cost = -pruned_tree.getEMI() + self.alpha * regularizer
        if cost < min_cost:
          #print('Optimal K: {} | EMI: {} | Regularizer {}'.format(k, pruned_tree.getEMI(), alpha * regularizer))
          optimal_tree = pruned_tree
          min_cost = cost
      self.root = optimal_tree

  def size(self):
    return self.root.size


class TSPNode():
  def __init__(self, leftNode=None, rightNode=None, emi=None, emp_joint_dist=None):
    self.leftNode = leftNode
    self.rightNode = rightNode
    self.emi = emi
    self.emp_joint_dist = emp_joint_dist
    self.size = None

  def grow(self, partition_data, dim_X, data, dim, grow_thresh, proj_dim=0):
    self.updateEMI(partition_data, data, dim_X)
    projected_data = partition_data[:, proj_dim % dim]
    sorted_idx = np.argsort(projected_data)
    median_idx = sorted_idx.shape[0] // 2
    left_data = partition_data[sorted_idx[:median_idx]]
    right_data = partition_data[sorted_idx[median_idx:]]
    if left_data.shape[0] > grow_thresh:
      self.leftNode = TSPNode()
      self.leftNode.grow(left_data, dim_X, data, dim, grow_thresh, proj_dim + 1)
    else:
      self.leftNode = TSPTerminalNode(left_data)
      self.leftNode.updateEMI(left_data, data, dim_X)
    if right_data.shape[0] > grow_thresh:
      self.rightNode = TSPNode()
      self.rightNode.grow(right_data, dim_X, data, dim, grow_thresh, proj_dim + 1)
    else:
      self.rightNode = TSPTerminalNode(right_data)
      self.rightNode.updateEMI(right_data, data, dim_X)

  def empirical_marginal_distribution(self, partition_data, data, dim_array):
    mask = np.array([(data[:,i] >= partition_data[:,i].min()) & 
                     (data[:,i] <= partition_data[:,i].max()) for i in dim_array])
    return mask.all(axis=0).sum() / data.shape[0]

  def partitionEMI(self, partition_data, data, dim_X):
    emp_joint_dist = partition_data.shape[0] / data.shape[0]
    marginal_dist_X = self.empirical_marginal_distribution(partition_data, data, range(dim_X))
    marginal_dist_Y = self.empirical_marginal_distribution(partition_data, data, range(dim_X, data.shape[1]))
    return emp_joint_dist * np.log2(emp_joint_dist / (marginal_dist_X * marginal_dist_Y))

  def getEMI(self):
    return self.leftNode.getEMI() + self.rightNode.getEMI()

  def cmi_gain(self):
    emp_joint_dist = self.emp_joint_dist
    emp_joint_dist_left = self.leftNode.emp_joint_dist
    emp_joint_dist_right = self.rightNode.emp_joint_dist
    left_branch_emi = self.leftNode.getEMI()
    right_branch_emi = self.rightNode.getEMI()
    left_cmi = (emp_joint_dist / emp_joint_dist_left) * left_branch_emi
    right_cmi = (emp_joint_dist / emp_joint_dist_right) * right_branch_emi
    return left_cmi + right_cmi

  def poolPrune(self, size):
    start_time = time.time()
    pruned_tree = self.prune(size)
    print('k: {} | Elapsed Time: {} [s]'.format(size, time.time() - start_time))
    return pruned_tree

  def prune(self, size):
    if size == 1:
      return TSPTerminalNode(self.emi, self.emp_joint_dist)
    else:
      max_cmi = np.NINF
      max_cmi_tree = None
      for i in range(1, size):
        j = size - i
        if (i <= self.leftNode.size) and (j <= self.rightNode.size):
          pruned_tree = TSPNode(self.leftNode.prune(i), self.rightNode.prune(j),
                                self.emi, self.emp_joint_dist)
          cmi = pruned_tree.cmi_gain()
          if cmi > max_cmi:
            max_cmi = cmi
            max_cmi_tree = pruned_tree
    return max_cmi_tree

  def updateSize(self):
    size = self.leftNode.updateSize() + self.rightNode.updateSize()
    self.size = size
    return size

  def updateEMI(self, partition_data, data, dim_X):
    self.emi = self.partitionEMI(partition_data, data, dim_X)
    self.emp_joint_dist = partition_data.shape[0] / data.shape[0]


class TSPTerminalNode(TSPNode):
  def __init__(self, emi=None, emp_joint_dist=None):
    super().__init__(emi=emi, emp_joint_dist=emp_joint_dist)

  def grow(self, partition_data, dim_X, data, dim, grow_thresh, proj_dim=0):
    return None

  def prune(self, size):
    return self

  def updateSize(self):
    self.size = 1
    return 1

  def getEMI(self):
    return self.emi

  def cmi_gain(self):
    raise TypeError('TerminalNode can not compute conditional mutual information gain.')

In [0]:
# Data generator
def multivariate_gaussian_sampler(dim, corr_factor, n_samples):
  '''
  Generates samples from two gaussian random variables X_a and X_b of 
  dimension dim with Corr(X_a^i, X_b^j) = delta_ij * corr_factor, returning the 
  theoretical mutual information MI(X_a, X_b).
  '''
  joint_mean = np.zeros(2 * dim)
  identity = np.identity(dim)
  joint_cov = np.concatenate([np.concatenate([identity, corr_factor * identity], axis=1),
                            np.concatenate([corr_factor * identity, identity], axis=1)], 
                            axis=0)
  np.fill_diagonal(joint_cov, 1)
  theoretical_mi = -(1/2) * np.log2(np.linalg.det(joint_cov))
  X_joint = np.random.multivariate_normal(joint_mean, joint_cov, size=n_samples)
  X_a = X_joint[:, :dim]
  X_b = X_joint[:, dim:]
  return X_joint, X_a, X_b, theoretical_mi

In [0]:
# Test
dim = 1
corr = 0.8
n_samples = 9442
X_joint, X_a, X_b, mi = multivariate_gaussian_sampler(dim, corr, n_samples)
tsp = TSP(l=0.2, factor_kn=0.05, alpha=0.00065, num_workers=4)
tsp.setObservations(X_a, X_b)
print('TSP Size: {}'.format(tsp.size()))
start_time = time.time()
emi = tsp.emi()
print('Theoretical MI: {} | Estimated MI: {} | Elapsed Time: {} [sec]'.format(mi, emi, time.time() - start_time))
start_time = time.time()
tsp.regularize()
emi = tsp.emi()
print('Theoretical MI: {} | Regularized MI: {} | Elapsed Time: {} [min]'.format(mi, emi, (time.time() - start_time) / 60))