In [None]:
import javabridge
import os
import glob
import pandas as pd
import pydot
from IPython.display import SVG

In [None]:
tetrad_libdir = os.path.join(os.getcwd(), '../../','src', 'pycausal', 'lib')
for l in glob.glob(tetrad_libdir + os.sep + "*.jar"):
    print l
    javabridge.JARS.append(str(l))

In [None]:
javabridge.start_vm(run_headless=True, max_heap_size = '100M')
javabridge.attach()

In [None]:
data_dir = os.path.join(os.getcwd(), '../../','data', 'audiology.txt')
data_dir

In [None]:
dframe = pd.read_table(data_dir, sep="\t")
dframe.head()

In [None]:
dataBox = javabridge.JClassWrapper("edu.cmu.tetrad.data.VerticalIntDataBox")(len(dframe.index),dframe.columns.size)

In [None]:
node_list = javabridge.JClassWrapper("java.util.ArrayList")()
# load dataset
col_no = 0
for col in dframe.columns:
    
    cat_array = sorted(set(dframe[col]))
    cat_list = javabridge.JClassWrapper("java.util.ArrayList")()

    for cat in cat_array:
        cat = str(cat)
        cat_list.add(cat)
        
    nodi = javabridge.JClassWrapper("edu.cmu.tetrad.data.DiscreteVariable")(col, cat_list)
    node_list.add(nodi)
    
    for row in dframe.index:
        value = javabridge.JClassWrapper("java.lang.Integer") (cat_array.index(dframe.ix[row][col_no]))
        dataBox.set(row,col_no,value)
    
    col_no = col_no + 1

In [None]:
boxData = javabridge.JClassWrapper("edu.cmu.tetrad.data.BoxDataSet")(dataBox, node_list)

In [None]:
score = javabridge.JClassWrapper("edu.cmu.tetrad.search.BDeuScore")(boxData)
score.setStructurePrior(1.0)
score.setSamplePrior(1.0)

In [None]:
fges = javabridge.JClassWrapper("edu.cmu.tetrad.search.Fges")(score)

In [None]:
fges.setMaxDegree(-1)
fges.setNumPatternsToStore(0)
fges.setFaithfulnessAssumed(True)
fges.setParallelism(2)
fges.setVerbose(True)

In [None]:
prior = javabridge.JClassWrapper('edu.cmu.tetrad.data.Knowledge2')()
prior.setForbidden('history_noise','class') # forbidden directed edges
prior.setForbidden('history_fluctuating','class') # forbidden directed edges
prior.setTierForbiddenWithin(0, True)
prior.addToTier(0, 'class')
prior.addToTier(0, 'history_fluctuating')
prior.addToTier(0, 'history_noise')
fges.setKnowledge(prior)
prior

In [None]:
tetradGraph = fges.search()
tetradGraph

In [None]:
tetradGraph.toString()

In [None]:
tetradGraph.getNodeNames()

In [None]:
tetradGraph.getEdges()

In [None]:
graph = pydot.Dot(graph_type='digraph')

In [None]:
n = tetradGraph.getNodeNames().toString()
n = n[1:len(n)-1]
n = n.split(",")
nodes = []
for i in range(0,len(n)):
    node = n[i]
    n[i] = node.strip()
    nodes.append(pydot.Node(n[i]))
    graph.add_node(nodes[i])

In [None]:
def isNodeExisting(nodes,node):
    try:
        nodes.index(node)
        return True
    except IndexError:
        print "Node %s does not exist!", node
        return False

e = tetradGraph.getEdges().toString()
e = e[1:len(e)-1]
e = e.split(",")
for i in range(0,len(e)):
    e[i] = e[i].strip()
    token = e[i].split(" ")
    if(len(token) >= 3):
        src = token[0]
        arc = token[1]
        dst = token[2]
        if(isNodeExisting(n,src) and isNodeExisting(n,dst)):
            edge = pydot.Edge(nodes[n.index(src)],nodes[n.index(dst)])
            if(arc == "---"):
                edge.set_arrowhead("none")
            graph.add_edge(edge)

In [None]:
svg_str = graph.create_svg(prog='dot')
SVG(svg_str)

In [None]:
javabridge.detach()
javabridge.kill_vm()