In [19]:
import networkx as nx
from castle.datasets import DAG, IIDSimulation

weighted_random_dag = DAG.erdos_renyi(n_nodes=10, n_edges=15, weight_range=(0.5, 2.0), seed=1)
dataset = IIDSimulation(W = weighted_random_dag, n = 2000, method = "linear", sem_type = "gauss")
true_dag, data = dataset.B, dataset.X
true_network = nx.DiGraph(true_dag)
pos = nx.spring_layout(true_network)

2023-12-27 17:04:28,246 - /Users/chunyuko/Documents/anaconda3/envs/py_3_11/lib/python3.11/site-packages/castle/datasets/simulator.py[line:270] - INFO: Finished synthetic dataset


In [3]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

def plotly_network(network, pos, title = None):
    
    nodes = pd.DataFrame([{"x": pos[node][0], "y": pos[node][1], "label": node} for node in network.nodes()])
    edges = [ (pos[i[0]][0], pos[i[0]][1], pos[i[1]][0], pos[i[1]][1]) for i in network.edges]
    arrows = [ go.layout.Annotation(dict(x=x, y=y, ax=ax, ay=ay, xref="x", yref="y", text="", showarrow=True, axref="x", ayref="y", arrowhead=2, arrowwidth=1, arrowcolor="rgba(140,31,40, .8)", standoff = 10, startstandoff = 10)) for x, y, ax, ay in edges]

    figure = px.scatter(nodes, x = "x", y = "y", text = "label", width=600, height = 400).\
        update_layout(
            xaxis=dict(showgrid=False, title="", zeroline = False, showticklabels=False),
            yaxis=dict(showgrid=False, title="", zeroline = False, showticklabels=False),
            title=title,
            template="plotly_white",
            annotations = arrows).\
        update_traces(marker=dict(size=30, color = "rgba(242,242,242,1)"),
                      textfont = dict(color = "rgba(4,64,64,1)"),
                      hovertemplate = "%{text}")
    
    return figure

fig1 = plotly_network(true_network, pos, "True DAG")
fig1

In [4]:
from castle.algorithms import PC

pc = PC()
pc.learn(data)
pc_1_network = nx.DiGraph(np.array(pc.causal_matrix))

In [9]:
plotly_network(pc_1_network, pos, "CPDAG by PC")

In [6]:
from castle.metrics import MetricsDAG

print(MetricsDAG(pc.causal_matrix, true_dag).metrics)

{'fdr': 0.1429, 'tpr': 0.6667, 'fpr': 0.0741, 'shd': 6, 'nnz': 14, 'precision': 0.6667, 'recall': 0.6667, 'F1': 0.6667, 'gscore': 0.3333}



DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.



In [8]:
from random import shuffle

data_test = pd.DataFrame(data)
label_text = list(data_test.columns)
pc = PC()
result_random_orders = []

for i in range(1000):
    shuffle(label_text)
    pc.learn(np.array(data_test[label_text]))
    result = MetricsDAG(pc.causal_matrix, true_dag).metrics
    result["order"] = ", ".join([str(i) for i in label_text])
    result_random_orders.append(result)



DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been deprecated. Use DataFrame.map instead.


DataFrame.applymap has been depre

In [11]:
result_random_orders = pd.DataFrame(result_random_orders)
result_random_orders["id"] = [*result_random_orders.index]

In [12]:
colors = {"fpr": "rgba(242,102,139,.3)", "tpr": "rgba(3,166,136,.3)"}

px.scatter(result_random_orders, x = "id", y = ["tpr", "fpr"], width=800, height = 400, marginal_y = "histogram").\
    update_layout(
        xaxis=dict(title="Random Test Times", zeroline = False, tickformat = ",.0f"),
        yaxis=dict(title="", zeroline = False, tickformat = ".0%"),
        template="plotly_white", legend_title_text = "DAG Metrics").\
    for_each_trace(lambda t: t.update(name=t.name.upper()+"%", marker=dict(color=colors[t.name])))

In [13]:
update_orders = [i.split(", ") for i in result_random_orders.sort_values("tpr", ascending = False).head(10)["order"].tolist()]

In [16]:
update_order = update_orders[0]
title = "Best Order: " + " > ".join(update_order)
update_order = [int(i) for i in update_order]
nodes = pd.DataFrame([{"x": pos[node][0], "y": pos[node][1], "label": node} for node in true_network.nodes()])
edges = [(pos[i[0]][0], pos[i[0]][1], pos[i[1]][0], pos[i[1]][1]) for i in true_network.edges()]

arrows = [go.layout.Annotation(dict(
    x=x, y=y, ax=ax, ay=ay, xref="x", yref="y",
    text="", showarrow=True, axref="x", ayref="y",
    arrowhead=2, arrowwidth=1, arrowcolor="rgba(140,31,40, .8)",
    standoff=10, startstandoff=10)) for x, y, ax, ay in edges]

nodes["color"] = "rgba(242,242,242,1)"

fig = go.Figure(data=[go.Scatter(x=nodes["x"], y=nodes["y"], text=nodes["label"], mode="markers+text", 
                                 marker=dict(color=nodes["color"], size=30), textfont=dict(color="rgba(4,64,64,1)"))],
                layout=go.Layout(
                    xaxis=dict(showgrid=False, title="", zeroline=False, showticklabels=False),
                    yaxis=dict(showgrid=False, title="", zeroline=False, showticklabels=False),
                    template="plotly_white",
                    annotations=arrows))

frames = []
for step, node in enumerate(update_order):
    nodes.loc[nodes["label"] == node, "color"] = "rgba(218,253,186,1)"
    frame = go.Frame(data=[go.Scatter(x=nodes["x"], y=nodes["y"], text=nodes["label"], mode="markers+text", 
                                      marker=dict(color=nodes["color"], size=30), textfont=dict(color="rgba(4,64,64,1)"))])
    frames.append(frame)

fig.frames = frames

fig.update_layout(title = title, width=400, height=400, 
                  updatemenus=[dict(type="buttons", 
                                    buttons=[dict(label="Play", method="animate",
                                                  args=[None, {"frame": {"duration": 500, "redraw": True}, 
                                                               "fromcurrent": True,
                                                               "transition": {"duration": 300, "easing": "quadratic-in-out"}
                                                               }
                                                        ]
                                                  )]
                                    )]
                  )

fig.update_layout()

fig.show()