In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import xarray as xr
import matplotlib as mpl

from causaldynamics.scm import create_scm_graph
from causaldynamics.plot import animate_3d_trajectories, plot_trajectories, plot_scm


All relevant data is stored in a single dataset for easy access. Let's load it and look at the content:

In [None]:
# TODO: Update the path to the data
ds = xr.load_dataset("../output/20250505_161143/data/Srandom_N5_T100_seed0.nc")
ds

The data is stored in the `Data variables`. To get the data, e.g. the time_series, you can index it like a dictionary. 

In [None]:
time_series = ds["time_series"]
time_series 

Let's plot the time series.

In [None]:
root_nodes = ds["root_nodes"]
plot_trajectories(time_series, root_nodes=root_nodes, sharey=False, node_color="orange", root_node_color="dimgrey")

The adjacency matrix encodes the SCM graph structure. The first dimension are the incoming nodes and the second dimension are the outgoing nodes.

In [None]:
A = ds["adjacency_matrix"]
A

In [None]:
# You can access the underlying numpy array directly by calling .data
print(A.data)
# You can also convert it to a pandas DataFrame
print(A.to_dataframe())

Let's visualize the graph

In [None]:
G = create_scm_graph(A.data)
plot_scm(G, root_nodes=root_nodes)


In [None]:
# Creating an animation of the trajectories. 
# This may take a while to run...
mpl.rcParams["animation.embed_limit"] = 50 * 1024**2  # Increase the limit to 50MB

animate_3d_trajectories(time_series, 
                        root_nodes=root_nodes, 
                        plot_type="subplots",
                        frame_skip=5, 
                        rotation_speed=0.2, 
                        rotate=True, 
                        show_history=True, 
                        save_path=None, 
                        return_html_anim=True, # perfect for use in notebooks :)
                        show_plot=False,
                        root_node_alpha=0.5,
                        node_alpha=0.5,
                        linewidth=1.5)