# Transition path theory

Transition path theory (TPT) is a method to study the ensemble of reactive trajectories, i.e., trajectories which come from a defined set of states $A$ and go next to $B$. It can answer at which rate they occur, as well as depict parallel pathways, traps, sequences of events, etc. Furthermore it introduces the notion of 'committor functions', which deals with probabilities of ending up in set $A$ or $B$ given the trajectory starts at some state potentially outside $A\cup B$.

A mathematical description of TPT can be found in <cite data-cite="nbtpt-weinan2006towards">(Weinan, 2006) and </cite><cite data-cite="nbtpt-metzner2009transition">(Metzner, 2009)</cite>. The implementation is based on <cite data-cite="nbtpt-noe2009constructing">(Noe, 2009)</cite>. Coarse-graining by path decomposition is presented in <cite data-cite="nbtpt-noe2009constructing">(Noe, 2009)</cite> and <cite data-cite="nbtpt-berezhkovskii2009reactive">(Berezhokovskii, 2009)</cite>.
    
To demonstrate the TPT API ([API docs here](../api/generated/sktime.markov.ReactiveFlux.rst#sktime.markov.ReactiveFlux)), in the following the example of a drunkard's walk is presented. The example is motivated by <cite data-cite="nbtpt-doyle1984random">(Doyle, 1984)</cite> and <cite data-cite="nbtpt-valleriani2015circular">(Valleriani, 2015)</cite>, where a drunkard is placed on a network of states and two special states, home and the bar. When the drunkard reaches either of these special states the trajectory stays there with high probability.
One can then ask which paths can be taken and also with which probability the drunkard is going to reach either of the states given a certain current position.

To this end, import sktime and numpy for general numerical operations.

In [None]:
import numpy as np
import sktime

We can create a [DrunkardsWalk](../api/generated/sktime.data.drunkards_walk.rst#sktime.data.drunkards_walk) simulator by specifying bar and home locations. As the drunkard lives on a 2-dimensional surface, the locations are given in terms of integer coordinates. Internally, this is related back to $\mathrm{width}\times\mathrm{height}$ states.

In [None]:
sim = sktime.data.drunkards_walk(grid_size=(10, 10),
                                 bar_location=[(0, 0), (0, 1), (1, 0), (1, 1)], 
                                 home_location=[(8, 8), (8, 9), (9, 8), (9, 9)])

To make the scenario a bit more interesting, we can add hard and soft barriers by specifying start and end points of the barrier. If no weight is given, the barrier is `hard`, i.e., cannot be crossed by a trajectory.

In [None]:
sim.add_barrier((5, 1), (5, 5))
sim.add_barrier((0, 9), (5, 8))
sim.add_barrier((9, 2), (7, 6))
sim.add_barrier((2, 6), (5, 6))

sim.add_barrier((7, 9), (7, 7), weight=5.)
sim.add_barrier((8, 7), (9, 7), weight=5.)

sim.add_barrier((0, 2), (2, 2), weight=5.)
sim.add_barrier((2, 0), (2, 1), weight=5.)

Now we can simulate a trajectory on this grid by specifying a starting point and a number of simulation steps. The effective length of the trajectory might be lower than the number of simulation steps as the simulation stops if the state is `home` or `bar`.

In [None]:
start = (7, 2)
walk = sim.walk(start=start, n_steps=250, seed=40)
print("Number of steps in the walk:", len(walk))

The trajectory can be visualized with a few helper functions attached to the simulator:

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(figsize=(10, 10))

ax.scatter(*start, marker='*', label='Start', c='cyan', s=150, zorder=5)
handles, labels = sim.plot_2d_map(ax)
sim.plot_path(ax, walk)
ax.legend(handles=handles, labels=labels);

In [None]:
flux = sim.msm.reactive_flux([sim.coordinate_to_state(start)], 
                             np.concatenate([sim.home_state, sim.bar_state]))
flux = sim.msm.reactive_flux(sim.home_state, 
                             sim.bar_state)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(18, 10))
dividers = [make_axes_locatable(axes[i]) for i in range(len(axes))]
caxes = [divider.append_axes("right", size="5%", pad=0.05) for divider in dividers]

titles = ["Gross flux", "Net flux"]
fluxes = [flux.gross_flux.m, flux.net_flux.m]

cmap = plt.cm.copper_r
thresh = [0, 1e-12]

for i in range(len(axes)):
    ax = axes[i]
    F = fluxes[i]
    ax.set_title(titles[i])

    vmin = np.min(F[np.nonzero(F)])
    vmax = np.max(F)

    sim.plot_2d_map(ax)
    sim.plot_network(ax, F, cmap=cmap, connection_threshold=thresh[i])
    norm = mpl.colors.LogNorm(vmin=vmin, vmax=vmax)
    fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=caxes[i]);

In [None]:
paths, capacities = flux.pathways()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 10))
dividers = [make_axes_locatable(axes[i]) for i in range(len(axes))]
caxes = [divider.append_axes("right", size="5%", pad=0.05) for divider in dividers]
titles = ["Forward committor", "Backward committor"]

for i, ax in enumerate(axes):
    ax.set_title(titles[i])
    ax.scatter(*start, marker='*', label='Start', c='cyan', s=150, zorder=5)
    handles, labels = sim.plot_2d_map(ax, barrier_mode='hollow')

    for capacity, path in zip((capacities / np.array(capacities).sum())[:10], paths[:10]):
        path = np.array([sim.state_to_coordinate(state) for state in path])
        sim.plot_path(ax, path, lw=1., intermediates=False, color_lerp=False)
        ax.scatter(*path.T, marker='x')
    
    if i == 0:
        Q = flux.forward_committor.reshape(sim.grid_size)
    else:
        Q = flux.backward_committor.reshape(sim.grid_size)
    cb = ax.imshow(Q, interpolation='nearest', origin='lower', cmap='coolwarm')
    fig.colorbar(cb, cax=caxes[i])
    if i == 0:
        fig.delaxes(fig.axes[2])

    ax.legend(handles=handles, labels=labels)


plt.tight_layout()

In [None]:
gross_flux = flux.gross_flux.m

In [None]:
gross_flux.shape

In [None]:
x = np.arange(-1/2, sim.grid_size[0]-1 + 1/2, 1/2.)
y = np.arange(-1/2, sim.grid_size[1]-1 + 1/2, 1/2.)
X, Y = np.meshgrid(x, y, indexing='ij')
U = np.zeros_like(X)
V = np.zeros_like(X)

In [None]:
F = flux.net_flux.m

In [None]:
for i in range(sim.grid_size[0]):
    for j in range(sim.grid_size[1]):
        i_grid = 2*i + 1
        j_grid = 2*j + 1
        coord = (i, j)
        state = sim.coordinate_to_state(coord)
        
        for offset_i in [-1, 0, 1]:
            for offset_j in [-1, 0, 1]:
                neighbor_coord = (i + offset_i, j + offset_j)
                if sim.is_valid_coordinate(neighbor_coord):
                    neighbor_state = sim.coordinate_to_state(neighbor_coord)
                    neighbor_F = F[state, neighbor_state]
                    U[i_grid + offset_i, j_grid + offset_j] += offset_i * neighbor_F
                    V[i_grid + offset_i, j_grid + offset_j] += offset_j * neighbor_F

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))

ax.scatter(*start, marker='*', label='Start', c='cyan', s=150, zorder=5)

sim.plot_2d_map(ax)

for capacity, path in zip(capacities[:15], paths[:15]):
    path = np.array([sim.state_to_coordinate(state) for state in path])
    sim.plot_path(ax, path, intermediates=False, color_lerp=False, alpha=.3)

C = np.linalg.norm(np.vstack((U.flatten(),V.flatten())),axis=0)
ax.quiver(X, Y, U, V, C, scale=8*flux.total_flux.m)
ax.set_xlim([-.5, sim.grid_size[0]-.5])
ax.set_ylim([-.5, sim.grid_size[1]-.5])

In [None]:
pcca = sim.msm.pcca(6)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)

for i, ax in enumerate(axes.flatten()):

    handles, labels = sim.plot_2d_map(ax, barrier_mode='hollow')

    Q = pcca.memberships[:, i].reshape(sim.grid_size)
    cb = ax.imshow(Q, interpolation='nearest', origin='lower');

In [None]:
from tqdm.notebook import tqdm
trajs = []
for _ in tqdm(range(1000)):
    trajs.append(sim.walk(start=start, n_steps=2000, return_states=True, stop=False))

In [None]:
count_model = sktime.markov.TransitionCountEstimator(1, 'sliding', n_states=sim.n_states) \
    .fit(trajs).fetch_model()

In [None]:
count_model = count_model.submodel_largest()

In [None]:
mlmsm = sktime.markov.msm.MaximumLikelihoodMSM().fit(count_model).fetch_model()

In [None]:
flux = mlmsm.reactive_flux(
     mlmsm.count_model.symbols_to_states(sim.coordinate_to_state(start)), 
     mlmsm.count_model.symbols_to_states(np.concatenate([sim.home_state, sim.bar_state])))
flux = mlmsm.reactive_flux(
     mlmsm.count_model.symbols_to_states(sim.home_state), 
     mlmsm.count_model.symbols_to_states(sim.bar_state))

In [None]:
print("Count fraction:", mlmsm.count_fraction)
print("State fraction:", mlmsm.state_fraction)

In [None]:
paths, capacities = flux.pathways()

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

ax.scatter(*start, marker='*', label='Start', c='cyan', s=150, zorder=5)
handles, labels = sim.plot_2d_map(ax, barriers=False)

for capacity, path in zip((capacities / np.array(capacities).sum())[:10], paths[:10]):
    path = mlmsm.count_model.states_to_symbols(path)
    path = np.array([sim.state_to_coordinate(state) for state in path])
    sim.plot_path(ax, path, lw=1., intermediates=False, color_lerp=False)
    ax.scatter(*path.T, marker='x')

Q = np.ones((sim.n_states))*np.nan
Q[mlmsm.state_symbols()] = flux.forward_committor
Q = Q.reshape(sim.grid_size)
    
cb = ax.imshow(Q, interpolation='nearest', origin='lower')
fig.colorbar(cb, ax=ax)

ax.legend(handles=handles, labels=labels);