In [None]:
import matplotlib.pyplot as plt
import numpy as np
import types
from typing import Sequence, Union
from tvb.simulator.lab import *

def simulate_time_varying_connectivity(
    connectivities: np.ndarray,
    slice_dur: Union[float, Sequence[float]],
    coupling_gain: float,
    dt: float,
    noise_sigma: float,
    monitor_list: tuple,
    centres: np.ndarray,
    tract_lengths: np.ndarray
):
    """
    Run a TVB simulation under time-varying connectivity, allowing each slice to have its own duration,
    carrying over state between slices, and returning outputs for any number of monitors. This function can take
    connectivity with any number or regions and slices.

    Args:
        connectivities: 3D array (n_regions, n_regions, n_slices).
        slice_dur:      Either a single float (all slices same length) or a sequence of floats of length n_slices.
        coupling_gain:  Scalar gain for Linear coupling.
        dt:             Integrator time-step in ms.
        noise_sigma:    Std. dev. for additive noise per region.
        monitor_list:   Tuple of TVB Monitor instances to record.
        centres:        (n_regions, 3) array of region coordinates.
        tract_lengths:  (n_regions, n_regions) array of tract lengths.

    Returns:
        List of tuples [(times_i, data_i), ...] for each monitor in monitor_list.
    """
    # Validate monitor_list
    if monitor_list is None or not isinstance(monitor_list, (list, tuple)):
        raise ValueError("monitor_list must be provided as a tuple of Monitor instances")
    # Validate connectivity
    if not isinstance(connectivities, np.ndarray) or connectivities.ndim != 3:
        raise ValueError("connectivities must be a 3D NumPy array of shape (n_regions, n_regions, n_slices)")
    if connectivities.shape[0] != connectivities.shape[1]:
        raise ValueError("The first two dimensions of connectivities must be equal (square per slice)")
    n_regions, _, n_slices = connectivities.shape
    # Validate centres and tract_lengths
    if centres is None or centres.shape != (n_regions, 3):
        raise ValueError(f"centres must be provided with shape ({n_regions}, 3)")
    if tract_lengths is None or tract_lengths.shape != (n_regions, n_regions):
        raise ValueError(f"tract_lengths must be provided with shape ({n_regions}, {n_regions})")
    # Handle slice durations
    if isinstance(slice_dur, (int, float)):
        slice_durs = [float(slice_dur)] * n_slices
    elif isinstance(slice_dur, Sequence):
        if len(slice_dur) != n_slices:
            raise ValueError(
                f"When providing a sequence, slice_dur length must match number of slices ({n_slices})"
            )
        slice_durs = [float(d) for d in slice_dur]
    else:
        raise ValueError("slice_dur must be a float or sequence of floats matching number of slices")

    # Precompute labels and noise vector
    region_labels = np.array([f"R{i+1}" for i in range(n_regions)], dtype="<U128")
    sigmas = np.full(n_regions, noise_sigma)

    # Prepare storage
    n_monitors = len(monitor_list)
    times_storage = [[] for _ in range(n_monitors)]
    data_storage  = [[] for _ in range(n_monitors)]

    # Checkpoint state
    history_buf    = None
    current_step   = 0
    current_state  = None
    rng_state      = None
    monitor_stock  = [None] * n_monitors
    monitor_inner  = [None] * n_monitors
    monitor_state  = [None] * n_monitors

    # Loop over slices
    for k in range(n_slices):
        # Build connectivity
        conn = connectivity.Connectivity(
            weights       = connectivities[..., k],
            tract_lengths = tract_lengths,
            region_labels = region_labels,
            centres       = centres
        )
        
        noise_slice = noise.Additive(nsig=np.array([noise_sigma]), 
                               noise_seed=np.random.randint(0, 10000))
        
        # Instantiate simulator
        sim = simulator.Simulator(
            model        = models.Generic2dOscillator(),
            connectivity = conn,
            coupling     = coupling.Linear(a=np.array([coupling_gain])),
            integrator   = integrators.HeunStochastic(
                dt    = dt,
                noise = noise_slice
            ),
            monitors     = monitor_list
        )
        sim.configure()

        print(conn.summary_info)

        # Restore checkpoint state if any
        if history_buf is not None:
            sim.history.buffer = history_buf
            sim.current_step   = current_step
            sim.current_state  = current_state
            sim.integrator.noise.random_stream.set_state(rng_state)
            for i, mon in enumerate(sim.monitors):
                if monitor_stock[i] is not None:
                    setattr(mon, '_stock', monitor_stock[i].copy())
                if monitor_inner[i] is not None:
                    setattr(mon, '_interim_stock', monitor_inner[i].copy())
                if monitor_state[i] is not None:
                    setattr(mon, '_state', monitor_state[i].copy())

        # Run this slice
        sim.simulation_length = slice_durs[k]
        outputs = sim.run()

        # Collect outputs (no global time offset)
        for i, (t_i, d_i) in enumerate(outputs):
            times_storage[i].append(t_i)
            data_storage [i].append(d_i)

        # Update checkpoint
        history_buf    = sim.history.buffer.copy()
        current_step   = sim.current_step
        current_state  = sim.current_state.copy()
        rng_state      = sim.integrator.noise.random_stream.get_state()
        for i, mon in enumerate(sim.monitors):
            monitor_stock[i] = getattr(mon, '_stock', None)
            monitor_inner[i] = getattr(mon, '_interim_stock', None)
            monitor_state[i] = getattr(mon, '_state', None)

    # Concatenate and return results
    full_outputs = []
    for i in range(n_monitors):
        full_t = np.concatenate(times_storage[i])
        full_d = np.concatenate(data_storage[i], axis=0)
        full_outputs.append((full_t, full_d))

    return full_outputs



In [None]:
## Example 1

# Set up connectivity
connectivities = np.stack([
        [[ 1,  -2,   3, 6],
         [ 5,  -6,   7, 7],
         [ 9, -10,  11, -12],
         [13, -14,  15, -3]],

        [[-9,   2,  -3, 1],
         [-5,   6,  -7, 1],
         [-9,  10, -11, 1],
         [-13,  14, -15,2]],

        [[ 0,   0,   0, 0],
         [ 0,   0,   0, 0],
         [ 0,   0,   0, 0],
         [ 0,   0,   0, 0]]
    ], axis=2)


# All slices length = 20000 ms
slice_duration = [10000.0, 5000.0, 20000.0]

# Simple centres and tract_lengths
n_regions = connectivities.shape[0]
centres = np.zeros((n_regions, 3))
tract_lengths = np.ones((n_regions, n_regions))

# Monitor list: raw and bold
monitor_list = (
    monitors.Raw(period=1.0),
    monitors.Bold(period=1000.0),
)

# Run simulation
outputs = simulate_time_varying_connectivity(
    connectivities=connectivities,
    slice_dur=slice_duration,
    coupling_gain=0.5,
    dt=1.0,
    noise_sigma=0.001,
    monitor_list=monitor_list,
    centres=centres,
    tract_lengths=tract_lengths
)

t_raw, data_raw = outputs[0]
plt.figure()
plt.plot(t_raw, data_raw[:, 0, :, 0])
plt.title('Raw time-series of region 1, var 0')
plt.xlabel('Time (ms)')
plt.show()

# Unpack and plot second monitor (bold)
t_bold, data_bold = outputs[1]
plt.figure()
plt.plot(t_bold, data_bold[:, 0, :, 0])
plt.title('BOLD signals of all regions')
plt.xlabel('Time (ms)')
plt.show()

In [None]:
## Example 2: the goal of this objective is to compare this function to the normal simulation

connectivities_constant = np.stack([
        [[ 1,  -2,   3, 6],
         [ 5,  -6,   7, 7],
         [ 9, -10,  11, -12],
         [13, -14,  15, -3]],

        [[ 1,  -2,   3, 6],
         [ 5,  -6,   7, 7],
         [ 9, -10,  11, -12],
         [13, -14,  15, -3]],
    ], axis=2)


slice_duration = 5000.0

# Simple centres and tract_lengths
n_regions = connectivities_constant.shape[0]
centres = np.zeros((n_regions, 3))
tract_lengths = np.ones((n_regions, n_regions))

# Monitor list: raw and bold
monitor_list = (
    monitors.Raw(period=1.0),
    monitors.Bold(period=1000.0),
)

# Run simulation
outputs_2 = simulate_time_varying_connectivity(
    connectivities=connectivities_constant,
    slice_dur=slice_duration,
    coupling_gain=0.5,
    dt=1.0,
    noise_sigma=0.001,
    monitor_list=monitor_list,
    centres=centres,
    tract_lengths=tract_lengths
)

t_raw_2, data_raw_2 = outputs_2[0]
plt.figure()
plt.plot(t_raw_2, data_raw_2[:, 0, :, 0])
plt.title('Raw time-series of all regions')
plt.xlabel('Time (ms)')
plt.show()

# Unpack and plot second monitor (bold)
t_bold_2, data_bold_2 = outputs_2[1]
plt.figure()
plt.plot(t_bold_2, data_bold_2[:, 0, :, 0])
plt.title('BOLD signals of all regions')
plt.xlabel('Time (ms)')
plt.show()

In [None]:
## Example 3: the normal simulation as a point of comparison

sim_nominal = simulator.Simulator(
    model        = models.Generic2dOscillator(),
    connectivity = connectivity.Connectivity(
        weights       = connectivities_constant[..., 0],
        tract_lengths = tract_lengths,
        region_labels = np.array([f"R{i+1}" for i in range(n_regions)], dtype="<U128"),
        centres       = centres
    ),
    coupling     = coupling.Linear(a=np.array([0.5])), 
    integrator   = integrators.HeunStochastic(
        dt    = 1.0,
        noise = noise.Additive(nsig=np.array([0.001]), noise_seed=12)
    ),
    monitors     = monitor_list,
    simulation_length = 10000.0
)
sim_nominal.configure()

outputs_3 = sim_nominal.run()

t_raw_3, data_raw_3 = outputs_3[0]
plt.figure()
plt.plot(t_raw_3, data_raw_3[:, 0, :, 0])
plt.title('Raw time-series of all regions (normal simulation)')
plt.xlabel('Time (ms)')
plt.show()

# Unpack and plot second monitor (bold)
t_bold_3, data_bold_3 = outputs_3[1]
plt.figure()
plt.plot(t_bold_3, data_bold_3[:, 0, :, 0])
plt.title('BOLD signals of all regions (normal simulation)')
plt.xlabel('Time (ms)')
plt.show()

In [None]:
# this is a point of comparison for both constant connectivity simulations above

connectivities_varying = np.stack([
        [[ 1,  -2,   3, 6],
         [ 5,  -6,   7, 7],
         [ 9, -10,  11, -12],
         [13, -14,  15, -3]],

        [[ 10,  8,   4, 2],
         [ -9, -10, -11, -12],
         [ 1,  2, 3, 1],
         [-19, -1,  1, -3]],
    ], axis=2)



# Simple centres and tract_lengths
n_regions = connectivities_varying.shape[0]
centres = np.zeros((n_regions, 3))
tract_lengths = np.ones((n_regions, n_regions))
slice_duration = 5000.0

# Monitor list: raw and bold
monitor_list = (
    monitors.Raw(period=1.0),
    monitors.Bold(period=1000.0),
)

# Run simulation
outputs_4 = simulate_time_varying_connectivity(
    connectivities=connectivities_varying,
    slice_dur=slice_duration,
    coupling_gain=0.5,
    dt=1.0,
    noise_sigma=0.001,
    monitor_list=monitor_list,
    centres=centres,
    tract_lengths=tract_lengths
)

t_raw_4, data_raw_4 = outputs_4[0]
plt.figure()
plt.plot(t_raw_4, data_raw_4[:, 0, :, 0])
plt.title('Raw time-series of all regions')
plt.xlabel('Time (ms)')
plt.show()

# Unpack and plot second monitor (bold)
t_bold_4, data_bold_4 = outputs_4[1]
plt.figure()
plt.plot(t_bold_4, data_bold_4[:, 0, :, 0])
plt.title('BOLD signals of all regions')
plt.xlabel('Time (ms)')
plt.show()