In [1]:
from dyn_net.dynamical_systems import get_drift
from dyn_net.noise import get_noise
from dyn_net.integrator.params import EulerMaruyamaParams
from dyn_net.dynamical_systems.stats_registry import get_stats
from dyn_net.utils.stats import open_stats_writer, close_stats_writer
from dyn_net.utils.state import open_state_writer, close_state_writer
from dyn_net.networks import get_network
import numpy as np
from dyn_net.dynamical_systems.jit_kuramoto import kuramoto_chunk, build_kuramoto_kernel_params
from dyn_net.integrator.jit import integrate_chunked_jit_timed


In [2]:
# Number of agents
n = 1000
# Dynamical System definition
system = "kuramoto"
params_system = {
    "theta": 1.0,
    "A": None,
}
# Network definition
network_name = "erdos_renyi"
network_params = {
    "n": n,
    "p": 0.2,
    "seed": 1,
    "directed": False,
}
# Noise definition
params_noise = {
    "sigma": 0.3,
}

# Parameters integrator
p_int = EulerMaruyamaParams(
    tmin=0.0,
    tmax=1000.0,
    dt=0.01,
    stats_every=10,
    state_every=200,
    write_stats_at_start=True,
    write_state_at_start=True,
)

# Initial condition
x0 = np.random.uniform(0.0, 2 * np.pi, size=n)


In [3]:
build_net, p_net = get_network(network_name, network_params)
A = build_net(p_net)
params_system["A"] = A

F, pF = get_drift(system, params_system)
G, pG = get_noise("additive_gaussian", params_noise)
stats_fn, stats_fields = get_stats(system)
kernel_params = build_kuramoto_kernel_params(pF, pG)

stats_writer = open_stats_writer(
    "stats.h5",
    fieldnames=stats_fields,
)

state_writer = open_state_writer(
    "state.h5",
    dim=len(x0),
)


In [4]:
try:
    x_final, timings = integrate_chunked_jit_timed(
        kuramoto_chunk,
        x0,
        params_int=p_int,
        kernel_params=kernel_params,
        stats_fn=stats_fn,
        stats_writer=stats_writer,
        stats_params=pF,
        state_writer=state_writer,
        state_transform=lambda x: np.mod(x, 2 * np.pi),
    )
    print(timings)
finally:
    close_stats_writer(stats_writer)
    close_state_writer(state_writer)


{'loop_s': 24.18048054096289, 'write_stats_s': 4.005512463743798, 'write_state_s': 0.13420967245474458, 'compute_s': 20.040758404764347, 'steps': 100000}
