In [None]:
# Change the directory to the Tables folder
TABLE_DIR = '../Tables/'
FIG_DIR = '../Figures/'

Plot the exact and inexact (correlated noise) versions of a directed acyclic graph (DAG) with three nodes.

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

def plot_causal_graph():
  plt.figure(figsize=(5, 5))
  G = nx.DiGraph()

  # Add nodes with LaTeX labels
  G.add_nodes_from([
      (1, {'label': r'$z_1$'}),  # Node 1 with label z_1
      (2, {'label': r'$z_2$'}),  # Node 2 with label z_2
      (3, {'label': r'$z_3$'}),  # Node 3 with label z_3
      (4, {'label': r'$\epsilon_1$'}),  # Node 4 with label epsilon_1
      (5, {'label': r'$\epsilon_2$'}),  # Node 5 with label epsilon_2
      (6, {'label': r'$\epsilon_3$'})   # Node 6 with label epsilon_3
  ])

  G.add_edges_from([(1, 2), (2, 3), (1, 3), (4, 1), (5, 2), (6, 3)])

  pos = {
      1: (0, 0),
      2: (1, 0.5),
      3: (2, 0),
      4: (0, 1),
      5: (1, 1.5),
      6: (2, 1)
  }

  # Get labels from node attributes
  labels = nx.get_node_attributes(G, 'label')

  nx.draw(G, pos, with_labels=True, labels=labels, node_size=1000, node_color="white", font_size=30, font_weight="bold", arrowsize=30, width = 5)
  # plt.title("Structural Causal Graph")

  plt.savefig(FIG_DIR + 'causal_graph.png', dpi=300)

  dashed_edges = [(4, 5), (5, 6), (4, 6)]
  nx.draw_networkx_edges(G, pos, edgelist=dashed_edges, style='dashed',
                           arrows=False, width=2)
  plt.savefig(FIG_DIR + 'confounded_causal_graph.png', dpi=300)

plot_causal_graph()