In [15]:
#from rich.progress import Progress

def brute_force_mfpt(trajectory, lags, statesA, statesB, dt=1, stride=1):
    """Compute a brute-force MFPT.

    Call signature::

        brute_force_mfpt(trajectory, lags, statesA, statesB, stride=1)


    Parameters
    ----------

    trajectory: A list of lists or 2D array of integer bins

    lags: A list of lag times to compute MFPT at

    statesA: List of bins in the origin state

    statesB: List of bins in the target state

    dt: Timestep of the trajectory

    stride: Stride trajectories
    """

    first = statesA
    last = statesB

    source_sink_states = list(first) + list(last)

    trajectory_to_analyze = trajectory

    all_AB, all_BA = [[] for lag in lags], [[] for lag in lags]

#     with Progress(refresh_per_second=1, transient=True) as progress:

#         lag_task = progress.add_task("[green]Brute force MFPT: Lag...", total=len(lags))

    for lag_idx, lag in enumerate(lags):

        # print(f"Working on lag {lag}")
        AB, BA = [], []

        # print(len(trajectory_to_analyze) * len(trajectory_to_analyze[0]))

#             chunk_task = progress.add_task(
#                 "[cyan]Chunk...", total=len(trajectory_to_analyze)
#             )
#             progress.update(chunk_task, visible=True)

        for i, trajectory in enumerate(trajectory_to_analyze):

            trajectory = trajectory[::stride]

#                 cur_traj_task = progress.add_task(
#                     "[red]Trajectory...", total=len(trajectory)
#                 )
#                 display_chunking = max(1, len(trajectory) // 1000)

            # This does a "sliding scale" so all points with the interval 'lag' are used
            for start in range(lag):

                # This is the state you were last in
                last_in = -1
                last_in_idx = 0

                for idx, point in enumerate(trajectory[start::lag]):

                    if point in source_sink_states:

                        # This should never happen
                        if last_in_idx > idx:
                            raise Exception

                        # If you just finished an A to B transit
                        if last_in in first and point in last:
                            AB.append(idx - last_in_idx)
                            last_in_idx = idx
                            last_in = point

                        # If this is the first time you're reaching a state
                        elif last_in == -1:
                            last_in = point
                            last_in_idx = idx

                        # If you just finished a B to A transit
                        elif last_in in last and point in first:
                            BA.append(idx - last_in_idx)
                            last_in_idx = idx
                            last_in = point

#                         if idx % display_chunking == 0:
#                             progress.update(cur_traj_task, advance=display_chunking)
#                             progress.refresh()

#                 progress.remove_task(cur_traj_task)

#                 progress.update(chunk_task, advance=1)

        # The list comprehension deep-copies the list
        all_AB[lag_idx] = [x * lag * stride * dt for x in AB]
        all_BA[lag_idx] = [x * lag * stride * dt for x in BA]

#             progress.remove_task(chunk_task)
#             progress.update(lag_task, advance=1)

    return all_AB, all_BA, lags


In [55]:
import numpy as np
random_traj = [np.random.choice(range(0,10), size=1000000)]

In [59]:
%%time
AB, BA, lags = brute_force_mfpt(random_traj, [1], [0], [3])

CPU times: user 963 ms, sys: 0 ns, total: 963 ms
Wall time: 963 ms
