In [38]:
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.sparse.csgraph import breadth_first_order
from scipy.special import logsumexp
import numpy as np
import csv
import pandas as pd
from itertools import combinations
import math

In [39]:
def joint_prob(df,alpha,rv1,rv2,val1,val2):

    joint_p = (alpha + len(df[(df[rv1] == val1) & (df[rv2] == val2)])) / (4 * alpha + len(df))

    return joint_p


In [40]:
def single_prob(df, alpha, rv, val):

    single_p = (2* alpha + len(df[(df[rv] == val)])) / (4 * alpha + len(df))

    return single_p

In [41]:
def mutual_info(df,rv1,rv2,alpha):

    summ = 0
    for i in [0, 1]:
        for j in [0, 1]:

            summ += joint_prob(df,alpha,rv1,rv2,i,j) * \
                    math.log(joint_prob(df,alpha,rv1,rv2,i,j) / single_prob(df,alpha,rv1,i) * single_prob(df,alpha,rv2,j))

    return summ

In [110]:
#read dataset
with open("binary_datasets/nltcs/nltcs.train.data", "r" ) as file:

    reader = csv.reader(file, delimiter=',')
    dataset = np.array(list(reader)).astype(np.float)

df = pd.DataFrame(dataset, columns = list(range(0, dataset.shape[1])))

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dataset = np.array(list(reader)).astype(np.float)


In [165]:
from scipy.sparse.csgraph import minimum_spanning_tree

class BinaryCLT:
    def __init__(self, data, root: int = None, alpha: float = 0.01):

        if root == None:
            root = np.random.choice(data.shape[1])

        #calculate mutual infromation for all RV combinations and fill the table
        mi_table = np.zeros((data.shape[1],data.shape[1]))

        for xi in range(data.shape[1]):
            for xj in range(xi + 1, data.shape[1]):

                mi_val = mutual_info(data,xi,xj,alpha)
                mi_table[xi,xj] = mi_val
#                 mi_table[xj,xi] = mi_val
        
        #calculate the minimum spanning tree
#         print(mi_table)
        min_span_tree = minimum_spanning_tree(-mi_table)

        #get directed tree
        directed_tree = breadth_first_order(min_span_tree,0,directed = True, return_predecessors=True)
        self.tree_nodes = directed_tree[0]
        self.ancestors = directed_tree[1]
                
    def get_tree(self):

        tree_nodes = self.tree_nodes
        ancestors = self.ancestors

        return tree_nodes,ancestors

    def get_log_params(self):

        pass

    def log_prob(self, x, exhaustive: bool = False):

        pass

    def sample(self, n_samples: int):

        pass


In [166]:
binaryclt_obj = BinaryCLT(data = df)
tree_nodes, ancestors = binaryclt_obj.get_tree()
print(tree_nodes,ancestors)

[ 0 15] [-9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999
 -9999 -9999 -9999     0]


In [167]:
#log params[i,j,k] Dx2x2
log_params = np.zeros((16,2,2))
alpha = 0.01
print(tree_nodes,ancestors)
root = 0
# tree_nodes = [0,1,2]
# ancestors = [-9999,0,1,-9999,-9999,-9999,-9999,-9999,
#              -9999,-9999,-9999,-9999,-9999,-9999,-9999,-9999]

for i in range(df.shape[1]):
    
    if (ancestors[i]==-9999) and (i != root): continue
    else:
        if i == root:
            log_params[i,0,0] = math.log(single_prob(df, alpha, i, 0))
            log_params[i,0,1] = math.log(single_prob(df, alpha, i, 1))
            log_params[i,1,0] = log_params[i,0,0]
            log_params[i,1,1] = log_params[i,0,1]
            
        else:
            for j in [0,1]:
                for k in [0,1]:

                    cond_p = joint_prob(df,alpha,i,ancestors[i],j,k) / single_prob(df, alpha, ancestors[i], k)
                    log_params[i,j,k] = math.log(cond_p)
            
print(log_params)


[ 0 15] [-9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999 -9999
 -9999 -9999 -9999     0]
[[[-0.1580114  -1.92305371]
  [-0.1580114  -1.92305371]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [ 0.          0.        ]]

 [[-0.05382667 -0.52788

In [164]:
np.round(log_params,3)

array([[[-0.158, -1.923],
        [-0.158, -1.923]],

       [[-0.143, -1.116],
        [-2.019, -0.397]],

       [[-0.131, -1.022],
        [-2.099, -0.446]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]],

       [[ 0.   ,  0.   ],
        [ 0.   ,  0.   ]]])