In [3]:
%matplotlib ipympl

import joblib
import matplotlib.pyplot as plt
import matplotlib as mpl

In [20]:
import numpy as np
import scipy.spatial
import itertools as it
import networkx as nx

class Schedule:
    
    def __init__(self, labels=None, lengths=None):
        if labels:
            assert len(labels) == len(lengths)
            self.labels = labels
            self.lengths = np.array(lengths)
        else:
            self.labels = []
            self.lengths = np.array([])
    
    def append(self, a, t):
        self.labels.append(a)
        self.lengths = np.append(self.lengths, [t])

    def length(self):
        return self.lengths.sum()
    
    def scale_lengths(t):
        self.lengths =  self.lengths * t
        
    def reduce(self):
        selectors = self.lengths != 0
        self.labels = list(it.compress(self.labels, selectors))
        self.lengths = self.lengths[selectors]
    
    def is_reduced(self):
        return not np.any(self.lengths == 0)
    
    def __str__(self):
        return ''.join(['({}, {})'.format(a, t) for a, t in zip(self.labels, self.lengths)])

    def __getitem__(self, key):
        if isinstance(key, int):
            return (self.labels[key], self.lengths[key])
        else:
            return Schedule(self.labels[key], self.lengths[key])
    
    def __delitem__(self, key):
        del self.labels[key]
        self.lengths = np.delete(self.lengths, key)
    
    def __iter__(self):
        return zip(self.labels, self.lengths)
        
    def __len__(self):
        return len(self.labels)

    
class MilestoningModel:

    def __init__(self, anchors):
        nanchors, ndim = anchors.shape
        
        self._graph = nx.Graph()
        if ndim > 1:
            tri = scipy.spatial.Delaunay(points)
            indptr, indices = tri.vertex_neighbor_vertices
            for i in range(nanchors-1):
                self._graph.add_edges_from([(i, j) for j in indices[indptr[i]:indptr[i+1]]])
        else:
            self._graph.add_edges_from([(i, i+1) for i in range(nanchors-1)])
        
        self._anchor_kdtree = scipy.spatial.cKDTree(anchors)
        self._parent_node = list(range(nanchors))
        self._node_anchors = {i: set([tuple(anchors[i])]) for i in range(nanchors)}
        self._sampled_edges = set()
  
        self._schedules = []
    
    @property
    def milestones(self):
        return [set(e) for e in self._graph.edges]
    
    @property
    def cell_anchors(self):
        return self._node_anchors
    
    @property
    def is_resolved(self):
        return all([self._graph.has_edge(i, j) for i, j in self._sampled_edges])
    
    @property
    def unresolved_jumps(self):
        return sorted([(i, j) for i, j in self._sampled_edges 
                       if not self._graph.has_edge(i, j)])
    
    def load_trajectory_data(self, trajs, dt=1):
        
        for traj in trajs:
            node_path = [self._parent_node[k] for k in self._anchor_kdtree.query(traj)[1]]
            edge_path = list(zip(node_path[:-1], node_path[1:]))
            
            schedule = Schedule()
            
            start = 1
            for source, target in edge_path:
                if target != source:
                    schedule.append((source, target), dt)
                    break
                start += 1
                
            for source, target in edge_path[start:]:
                if target in schedule.labels[-1]:
                    schedule.lengths[-1] += dt
                    continue
                schedule.append((source, target), dt)
            
            if schedule:
                self._schedules.append(schedule)
            self._sampled_edges |= set(schedule.labels)
            
    def remove_milestone(self, i, j):
        
        self._graph = nx.contracted_nodes(self._graph, i, j, self_loops=False)
        for a in self._node_anchors[j]:
            self._parent_node = i
        self._node_anchors[i] |= self._node_anchors.pop(j)
        
        self._sampled_edges = set()
        
        schedules = []
        for schedule_old in self._schedules:

            edge = schedule_old.labels[0]
            if edge == (i, j) or edge == (j, i):
                if not schedule_old[1:]:
                    continue
                schedule_old = schedule_old[1:]
            
            schedule = Schedule()
            
            (source, target), lifetime = schedule_old[0]
            if source == j:
                schedule.append((i, target), lifetime)
            elif target == j:
                schedule.append((source, i), lifetime)
            else:
                schedule.append((source, target), lifetime)

            for edge, lifetime in schedule_old[1:]:
                
                if edge == (i, j) or edge == (j, i):
                    schedule.lengths[-1] += lifetime
                    continue

                source, target = edge
                if target in schedule.labels[-1]:
                    schedule.lengths[-1] += lifetime
                    continue
                
                if source == j:
                    schedule.append((i, target), lifetime)
                elif target == j:
                    schedule.append((source, i), lifetime)
                else:
                    schedule.append((source, target), lifetime)
            
            if schedule:
                schedules.append(schedule)
            self._sampled_edges |= set(schedule.labels)
        
        self._schedules = schedules
    
    def estimate(self):
        
        milestone_index = {}
        for m, (i, j) in enumerate(self._graph.edges):
            milestone_index[(i, j)] = m
            milestone_index[(j, i)] = m
        
        M = self._graph.number_of_edges()
        self._N = np.zeros((M, M), dtype=np.int32)
        self._T = np.zeros(M)
        
        for schedule in self._schedules:
            for (a, t), (b, _) in zip(schedule[:-1], schedule[1:]):
                if a in milestone_index and b in milestone_index:
                    self._T[milestone_index[a]] += t
                    self._N[milestone_index[a], milestone_index[b]] += 1
        
        self.K = self._N / np.sum(self._N, axis=1)[:, np.newaxis]
        self.t = self._T / np.sum(self._N, axis=1)


In [21]:
trajs = [x[:, np.newaxis] for x in joblib.load('/data/p38a-SB2/short_md/path1/x.joblib').values()]

In [22]:
plt.figure()

plt.hist(np.concatenate(trajs), 1000, density=True, histtype='step')
plt.ylabel('Empirical Density')
_ = plt.xlabel('Reaction Coordinate ($\mathrm{\AA}$)')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous â€¦

In [26]:
anchors = np.arange(13, 30, 1)[:, np.newaxis]

In [27]:
model = MilestoningModel(anchors)
model.load_trajectory_data(trajs, dt=0.1)

In [28]:
model.unresolved_jumps

[]

In [29]:
model.estimate()

In [33]:
model._T[0]

41.89999999999993

In [31]:
model._N

array([[   0,    9,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   9,    0, 1290,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0, 1301,    0, 9357,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0, 9345,    0, 6218,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0,    0, 6247,    0,  791,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0,    0,    0,  774,    0,  558,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,  560,    0, 1467,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0, 1482,    0, 1863,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0, 1923,    0, 1298,    0,
           0,    0,   

In [32]:
model.K, model.t

(array([[0.        , 1.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.00692841, 0.        , 0.99307159, 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.        , 0.12206793, 0.        , 0.87793207, 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.        , 0.        , 0.60046264, 0.        , 0.39953736,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.        , 0.        , 0.        , 0.88761012, 0.        ,
         0.1123

In [None]:
model._T, counts, t

In [None]:
46911.2 / 32527, 2559.6 / 1118

In [None]:
figure()

cmap = mpl.cm.get_cmap('jet', len(milestones))

t = 0
for milestone, lifetime in milestone_schedule:
    T = dt * np.arange(t, t + lifetime + 1)
    plt.plot(T, traj[t:t+lifetime+1], c=cmap(milestone_idx[milestone]))
    t += lifetime
    
plt.hlines(faces, xmin=0, xmax=t*dt, colors='gray', linestyles='dashed', linewidth=1, zorder=10)
plt.xlim([0, dt*(t+1)])
plt.xlabel('Time (ps)')
_ = plt.ylabel('Reaction Coordinate ($\mathrm{\AA}$)')

In [None]:
milestones = set()
for schedule in milestone_schedules:
    milestones |= set([pair[0] for pair in schedule[1:]])
milestones = list(milestones)

In [None]:
milestone_index = {}
for i, milestone in enumerate(milestones):
    milestone_index[milestone] = i

mean_lifetimes = np.zeros(len(milestones))
counts = np.zeros(len(milestones), dtype=int)

for schedule in milestone_schedules:
    for milestone, lifetime in schedule[1:-1]:
        mean_lifetimes[milestone_index[milestone]] += lifetime
        counts[milestone_index[milestone]] += 1

mean_lifetimes /= counts
for milestone in milestones:
    print(milestone, mean_lifetimes[milestone_index[milestone]])