#### Generate random DAGs of increasing size, run the identification finder and time how long it takes.

In [None]:
import sys
import random
import time

sys.path.append("../src")

from identification_strategy_finder import AdjustmentSetFinder

#### Helpers

In [None]:
def generate_random_dag(num_nodes):
    """ 
    Generates a random DAG of a given size
    """
    G = nx.DiGraph()
    nodes = list(range(num_nodes))
    random.shuffle(nodes)
    
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if random.choice([True, False]):
                G.add_edge(nodes[i], nodes[j])
    return G

def pick_random_nodes(G, num_hidden=1):
    """
    Picks a random X node, a random Y node and a random hidden node from a DAG
    """
    nodes = list(G.nodes)
    X = {random.choice(nodes)}
    Y = random.choice(nodes)
    while Y in X:
        Y = random.choice(nodes)
    
    hidden_nodes = set()
    for _ in range(num_hidden):
        hidden_node = random.choice(nodes)
        while hidden_node in X or hidden_node == Y or hidden_node in hidden_nodes:
            hidden_node = random.choice(nodes)
        hidden_nodes.add(hidden_node)
    
    return X, Y, hidden_nodes

def analyze_graph(G, X, Y, hidden_nodes):
    """
    Runs the enumeration algorithm on a given DAG for an X,Y and hidden node
    """
    finder = AdjustmentSetFinder(G, X, Y, hidden_nodes, "total", False)
    id_strats = finder.find_adjustment_sets_nuisance()
    return id_strats

#### Run the simulation and save to pickle file
Can easily parallelize if desired.

In [None]:
min_nodes = 5
max_nodes = 15
number_dags = 100

results = []
for num_nodes in range(min_nodes, max_nodes+1):
    times = []
    for _ in range(number_dags):
        G = generate_random_dag(num_nodes)
        X, Y, hidden_nodes = pick_random_nodes(G, num_hidden=1 if num_nodes > 2 else 0)
        
        start_time = time.time()
        try:
            analyze_graph(G, X, Y, hidden_nodes)
        except AssertionError as e:
            print(e)
            continue
        end_time = time.time()
        times.append(end_time - start_time)
    
    for t in times:
        results.append((num_nodes, t))

df_results = pd.DataFrame(results, columns=["Number of Nodes", "Time (s)"])
df_results.to_pickle("./results/dag_analysis_times.pkl")

#### Plot

In [None]:
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "text.latex.preamble": r"\usepackage{amsmath}\usepackage{amssymb}",
    "axes.labelsize": 22,     
    "xtick.labelsize": 16,    
    "ytick.labelsize": 18,    
    "legend.fontsize": 14,    
    "figure.titlesize": 22,   
    "axes.titlesize": 22,
    "legend.title_fontsize": 18
})

palette = sns.color_palette("Paired", 12)
mean_color = palette[1]  
error_bar_color = palette[3]

grouped = df_results.groupby("Number of Nodes")["Time (s)"]
means = grouped.mean()
stds = grouped.std()
n = grouped.count()
stderr = stds / np.sqrt(n)

confidence_interval = 1.96 * stderr
lower_bounds = means - confidence_interval
upper_bounds = means + confidence_interval

plt.figure(figsize=(12, 8))
plt.errorbar(means.index, means, 
             yerr=[means - lower_bounds, upper_bounds - means], 
             fmt="o", capsize=5, ecolor=error_bar_color, color=mean_color)
plt.scatter(means.index, means, color=mean_color, zorder=5)

plt.xlabel("Number of Nodes")
plt.ylabel("Seconds")
plt.grid(True)
plt.savefig("./plots/execution_time_plot.svg", format="svg")
plt.show()