# Generate diverse data examples for a visual abstract

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [104]:
import matplotlib.pyplot as plt
from causaldynamics.creator import create_scm, simulate_system
from causaldynamics.plot import plot_scm, plot_trajectories, plot_3d_trajectories
from causaldynamics.scm import create_scm_graph 
from causaldynamics.utils import get_timestamp

from pathlib import Path
Path("visual_abstract").mkdir(parents=True, exist_ok=True)

def run(num_nodes=3, confounders=True, time_lag=0, time_lag_edge_probability=0., system_name='random', output_dir="visual_abstract"):
    node_dim = 3
    num_timesteps = 1500

    init_ratios = [2, 1]   # set ratios of dynamical systems and periodic drivers at root nodes. Here: equal ratio.
    noise = 0.0            # set noise for the dynamical systems


    A, W, b, root_nodes, _ = create_scm(num_nodes, 
                                        node_dim,
                                        confounders=confounders,
                                        time_lag=time_lag,
                                        time_lag_edge_probability=time_lag_edge_probability)

    data = simulate_system(A, W, b, 
                        num_timesteps=num_timesteps, 
                        num_nodes=num_nodes,
                        system_name="random",
                        init_ratios=init_ratios,
                        time_lag=time_lag,
                        standardize=False,
                        make_trajectory_kwargs={'noise': noise}) 

    Path(output_dir).mkdir(parents=True, exist_ok=True)
    plot_3d_trajectories(data, root_nodes, line_alpha=1., show_background=False, show_grid=False)
    plt.tight_layout()
    plt.savefig(f"{output_dir}/{i:04d}_N{num_nodes}_C{confounders}_tl{time_lag}_tlp{time_lag_edge_probability}_3d_trajectories.png")
    # plt.show()
    plt.close()

    plot_scm(G=create_scm_graph(A), root_nodes=root_nodes)
    plt.savefig(f"{output_dir}/{i:04d}_N{num_nodes}_C{confounders}_tl{time_lag}_tlp{time_lag_edge_probability}_graph.png")
    # plt.show()
    plt.close()


In [None]:
# Note: The output is stored in the visual_abstract/<timestamp> folder
# Set the number of examples per configuration to sample
num_examples = 10


num_nodes = 3
ts = get_timestamp()
for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=True, time_lag=0, time_lag_edge_probability=0.0, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=False, time_lag=0, time_lag_edge_probability=0.0, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=False, time_lag=10, time_lag_edge_probability=0.1, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=True, time_lag=10, time_lag_edge_probability=0.1, system_name='random', output_dir=f"visual_abstract/{ts}")


num_nodes = 5
for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=True, time_lag=0, time_lag_edge_probability=0.0, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=False, time_lag=0, time_lag_edge_probability=0.0, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=False, time_lag=10, time_lag_edge_probability=0.1, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=True, time_lag=10, time_lag_edge_probability=0.1, system_name='random', output_dir=f"visual_abstract/{ts}")
    
num_examples = 5
num_nodes = 10
for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=True, time_lag=0, time_lag_edge_probability=0.0, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=False, time_lag=0, time_lag_edge_probability=0.0, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=False, time_lag=10, time_lag_edge_probability=0.1, system_name='random', output_dir=f"visual_abstract/{ts}")

for i in range(num_examples):
    run(num_nodes=num_nodes, confounders=True, time_lag=10, time_lag_edge_probability=0.1, system_name='random', output_dir=f"visual_abstract/{ts}")
    