# Independent Markov decomposition

In this notebook, we will explain how to split a global system into weakly coupled subsystems with independent Markov decomposition (IMD) [<a id="ref-1" href="#cite-imd">1</a>,<a id="ref-2" href="#cite-syt">2</a>]. Using a test system, we will show how to find an optimal partition into Markov-independent subsystems and how to model them independently.



**Remember**:
- to run the currently highlighted cell, hold <kbd>&#x21E7; Shift</kbd> and press <kbd>&#x23ce; Enter</kbd>;
- to get help for a specific function, place the cursor within the function's brackets, hold <kbd>&#x21E7; Shift</kbd>, and press <kbd>&#x21E5; Tab</kbd>;
- you can find the full documentation at [PyEMMA.org](http://www.pyemma.org).

In [None]:
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

import itertools
import networkx as nx

import mdshare
from deeptime.markov.msm import MaximumLikelihoodMSM, MarkovStateModel

## state mapping
We first have to look into different representations of a global system state. Imagine a system that consists of 2 subsystems, that each can exist in 3 states. On the one hand, we can write the system's state as a tuple, e.g. `(0, 2)` for the first sub-system being in state `0` and the second one in state `2`. On the other hand, we can also write the tuple as an integer, much like compressing the information into a single number. For the example system, the table of all possible states would be like this:

|  |  | | | | | | | | | 
| ---- | ---- | ---- | ---- |  ---- | ---- |  ---- | ---- |  ---- | ---- |
| **state integer** | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8  |
| **state tuple** | (0, 0) | (0, 1) | (0, 2) | (1, 0) | (1, 1) | (1, 2) | (2, 0) | (2, 1) | (2, 2)  |

Of course, mapping between these two representations can be generalized to arbitrary numbers of sub-systems with arbitrary states numbers.
The notion here is that the **integer** describes the global system's state, whereas the **tuple** encodes each local system's state individually.

In practise, system states can be converted between the tuple (local states) and integer (global state) using numpy. We only have to provide a corresponding *shape* for the system, `(3, 3)` in our case. Here's our example:

In [None]:
n_systems = 2  # number of local systems (tuple length)
n_states = 3  # number of states per local system
integer_trajectory = np.arange(9)  # global states (cf. first line of above table)
# this could be a time series!

shape = tuple((n_states for _ in range(n_systems)))
print('shape for unravaling: ', shape)

tuple_trajectory = np.vstack(
                     np.unravel_index(integer_trajectory, shape)
)
print('unraveled states:')
print(tuple_trajectory)

We see that numpy has converted our `integer_trajectory` into two separate trajectories, each representing the state of a local agent.

In [None]:
print('int \t tuple')
for t in range(9):
    int_state = integer_trajectory[t]
    subsys0_state = tuple_trajectory[0][t]
    subsys1_state = tuple_trajectory[1][t]
    print(f'{int_state} \t ({subsys0_state}, {subsys1_state})')

Looks familiar?

**Task:**
Please do the inverse operation: Map back from the tuple trajectories into the space of full system integers. There is a numpy function for this task.

In [None]:
remapped_int_traj = #FIXME

np.all(integer_trajectory == remapped_int_traj)

In [None]:
# solution
remapped_int_traj = np.ravel_multi_index(tuple_trajectory, 
                                         tuple((n_states for _ in range(n_systems))))

np.all(integer_trajectory == remapped_int_traj)

## A system of unkown structure

You are now given discrete data for a system of unknown structure. The task is a) to identify weakly coupled sub-systems and b) to approximate such a subsystem using an independent MSM.

**Hint:** The system consists of ten 2-state subsystems, i.e., has a total of $2^{10}=1024$ states. Some of the subsystems are strongly coupled, others have weak couplings only.

**Task:** Please define the number of subsystems and the number of subsystem states

In [None]:
n_systems = #FIXME
n_states = #FIXME

In [None]:
# solution
n_systems = 10  # number of local systems
n_states = 2  # number of states per local system

### Data
First, we load the data. The trajectories were obtained by first defining a *global* transition matrix. Subsequently, a Markov chain sampler was used to create a time series from that matrix (saved every 20 steps). The *global* transition matrix uses (global) state integers to enumerate its states, therefore the trajectory that is loaded uses them as well. 

In [None]:
file = mdshare.fetch('imd_full_system_trajectory.npy', working_directory='data')
full_sys_traj = np.load(file)

In order to check *dependencies* between subsystems, we first need to retrieve the subsystem time series.

**Task:** Compute the individual subsystem state trajectories as done above.

In [None]:
subsys_trajs = # FIXME

In [None]:
# solution
subsys_trajs = np.vstack(
    np.unravel_index(full_sys_traj, tuple((n_states for _ in range(n_systems))))
)

### Define dependency score
We now define the *dependency* score:

In [None]:
def compute_dependency(tmat12, tmat1, tmat2, score='frobenius'):
    """
    compute dependency score between two systems
    :param tmat12: np.ndarray, transition matrix in joint space
    :param tmat1: np.ndarray, transition matrix in subsystem 1
    :param tmat2: np.ndarray, transition matrix in subsystem 2
    :param score: str, optional, matrix norm. one of frobenius, tracenorm.
    :return: float, dependency score
    """
    if score == 'frobenius':
        d = np.linalg.norm(tmat12, ord='fro')**2 - \
                    np.linalg.norm(tmat1, ord='fro')**2 * np.linalg.norm(tmat2, ord='fro')**2
    elif score == 'tracenorm':
        d = np.linalg.norm(tmat12, ord='nuc') - \
                    np.linalg.norm(tmat1, ord='nuc') * np.linalg.norm(tmat2, ord='nuc')
    else:
        raise NotImplementedError('score must be one of frobenius, tracenorm.')
        
    return abs(d)


To compute the score on a pair of subsystems, we need to evaluate 
- the transition matrix of subsystem 1

- the transition matrix of subsystem 2

- the transition matrix in the joint space

Let's start with the single sub-systems.

In [None]:
# we will store the results in numpy arrays.
single_tmats = np.empty((n_systems, n_states, n_states))

**Task:** Compute each system's transition matrix and store it in the above array

In [None]:
for n in range(n_systems):
    single_tmats[n] = #FIXME

In [None]:
# solution:
for n in range(n_systems):
    msm = MaximumLikelihoodMSM(lagtime=1).fit_fetch(subsys_trajs[n])
    single_tmats[n] = msm.transition_matrix

We will now compute all pairs of systems or joint transition matrices. 

In [None]:
joint_tmats = np.empty((n_systems, n_systems, 2**n_states, 2**n_states))

In [None]:
# compute pairwise transition matrices
for n1, n2 in itertools.combinations(range(n_systems), 2):
    dtraj_system1 = subsys_trajs[n1]
    dtraj_system2 = subsys_trajs[n2]
    
    # combine both system states into a global number
    # note that the number of systems in the *pair* is 2.
    combined_dtraj = np.ravel_multi_index((dtraj_system1, dtraj_system2), 
                                         tuple((n_states for _ in range(2))))
    
    msm = MaximumLikelihoodMSM(lagtime=1).fit_fetch(combined_dtraj)
    joint_tmats[n1, n2] = msm.transition_matrix

### graph analysis
We now compute dependencies for all pairs of systems and store them in a `networkx` graph. 

**Task**: Compute the dependency for all edges using the above defined function.

In [None]:
# compute different scores and store in a networkx graph object
graph_fronorm = nx.Graph()
graph_trace = nx.Graph()

# for all pairs of subsystems, compute dependency scores with Frobenius and trace norm
for n1, n2 in itertools.combinations(range(n_systems), 2):
    # compute with trace norm
    d = # FIXME
    graph_trace.add_edge(n1, n2, weight=d)
    
    # compute with frobenius norm
    d = #FIXME
    graph_fronorm.add_edge(n1, n2, weight=d)

In [None]:
# solution

# compute different scores and store in a networkx graph object
graph_fronorm = nx.Graph()
graph_trace = nx.Graph()

# for all pairs of subsystems, compute dependency scores with Frobenius and trace norm
for n1, n2 in itertools.combinations(range(n_systems), 2):
    d = compute_dependency(joint_tmats[n1, n2], 
                           single_tmats[n1], 
                           single_tmats[n2], 
                           score='tracenorm')
    graph_trace.add_edge(n1, n2, weight=d)
    

    d = compute_dependency(joint_tmats[n1, n2], 
                           single_tmats[n1], 
                           single_tmats[n2], 
                           score='frobenius')
    graph_fronorm.add_edge(n1, n2, weight=d)

### Draw the graph
We now have an edge-weight graph, i.e., a network of subsystems (nodes) that are connected by their *dependency* (edges). We can use that graph to identify clusters of strongly coupled subsystems.

In [None]:
# some plot properties
_c = (0., 0., 0.)
nodesize = 35
edge_cmap = plt.matplotlib.colors.LinearSegmentedColormap.from_list("uwe", [(*_c, 0.025), (*_c, 1)])
font = plt.matplotlib.font_manager.FontProperties(size=12)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(7.5, 3), gridspec_kw={'hspace':.25})
names = ['trace norm', 'frobenius norm']
for n_graph, graph in enumerate([graph_trace, graph_fronorm]):
    
    ax = axes[n_graph]
    # positions by Fruchterman-Reingold
    pos_dict = nx.spring_layout(graph, k=0.75 if n_graph == 0 else 0.4)
    ax.set_title(names[n_graph])

    weights = np.array(list(nx.get_edge_attributes(graph, 'weight').values()))
    
    # draw nodes
    nx.draw_networkx_nodes(graph, node_shape='s',
                           node_size=nodesize, 
                           pos=pos_dict,
                          ax=ax)
    nx.draw_networkx_labels(graph, pos=pos_dict, ax=ax, 
                            font_color='red', font_weight='bold', font_size=15)
    # draw all edges
    pc = nx.draw_networkx_edges(graph, edge_cmap=edge_cmap,
                     edge_color=weights, width=2.4,
                     pos=pos_dict, node_size=nodesize,
                     ax=ax,
                    )
    
    # define colormap
    pc.set_array(weights)
    pc.set_cmap(edge_cmap)

    cb = fig.colorbar(pc, ax=ax,
                      aspect=25, pad=.15)
    cb.set_label(r'$d$')
    cbarticks = cb.ax.yaxis.get_ticklabels()
    
    # set font properties
    for _t in list(cbarticks):
        _t.set_font_properties(font)
    ax.axis('off');

### Interpretation:
You should see a grouping of your nodes into 2 strongly coupled clusters. Within these clusters, the *dependency* is large - it is low between different clusters. The node node numbers tell you which of them belong to a certain cluster. They can be used to extract the given systems for individual modeling. 

## Modeling a single cluster independently

Now that we have found an optimal partition, we can retrieve the model of one of the clusters, ignoring weak coupling between them. (Note that one would probably like to model both parts independently, for the sake of time we only look at one here - in this particular example, they are the same anyways.)

**Task:** Please choose a set of subsystems to be modeled independently of the rest.

In [None]:
system_nodes = #FIXME

In [None]:
# solution
system_nodes = [0, 2, 3, 4, 5] # or [1, 6, 7, 8, 9]

Now, the trajectories of these subsystems are extracted from the data. It will is re-written to an integer that describes the full state of that set of subsystems.

In [None]:
# subsystem indexing ordered to match resulting matrices
subsystem_trajectory = np.ravel_multi_index(
        np.array(subsys_trajs)[system_nodes], 
        tuple((n_states for _ in range(len(system_nodes))))
)

**Task:** Fit a maximum likelihood MSM to the subsystem-cluster trajectory; use a lag time of 1 steps.

In [None]:
msm = # FIXME

In [None]:
msm = MaximumLikelihoodMSM(lagtime=1).fit_fetch(subsystem_trajectory)

Note that this transition matrix effectively models a lagtime of 20 because the trajectory was generated with that lag time.

### compare transition matrices & implied timescales

In [None]:
# reference transition matrix (does not include weak couplings between the two clusters!)
channel_tmat = np.load(mdshare.fetch('imd_channel_transitionmatrix.npy', working_directory='data'))
dt = 20  # time step used for generating the data

# adjust lag time of generating matrix
ref_msm = MarkovStateModel(np.linalg.matrix_power(channel_tmat, dt))

In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(msm.transition_matrix, norm=plt.matplotlib.colors.LogNorm())
ax[0].set_title('estimated')
ax[1].imshow(ref_msm.transition_matrix, norm=plt.matplotlib.colors.LogNorm())
ax[1].set_title('reference');

The transition matrices look very similar, however few pixels are empty (white) due to the fact that even with 1,000,000 steps, not all states of the chosen set of subsystems were sampled.

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(8, 4))

ax[0].plot(msm.transition_matrix.flat, ref_msm.transition_matrix.flat, '.')
ax[0].loglog()

its_ref = ref_msm.timescales()
its_est = msm.timescales()

ax[1].plot(its_ref, 'r.', label='reference')
ax[1].plot(its_est, 'b.', label='estimate')
ax[1].semilogy()
fig.legend()

The spectrum is well-approximated. As this model does not incoporate the weak coupling between the two large node clusters, it is only an approximation.

## References:

<a id="cite-imd"/><sup><a href=#ref-1>[1]</a></sup> Hempel, T.; del Razo, M. J.; Lee, C. T.; Taylor, B. C.; Amaro, R. E.; Noé, F. _Independent Markov Decomposition: Toward Modeling Kinetics of Biomolecular Complexes._ Proc Natl Acad Sci USA 2021, 118 (31), e2105230118. https://doi.org/10.1073/pnas.2105230118.
.

<a id="cite-syt"/><sup><a href=#ref-2>[2]</a></sup> Hempel, T.; Plattner, N.; Noé, F. _Coupling of Conformational Switches in Calcium Sensor Unraveled with Local Markov Models and Transfer Entropy._ J. Chem. Theory Comput. 2020, 16 (4), 2584–2593. https://doi.org/10.1021/acs.jctc.0c00043.
