In [1]:
import numpy as np
import pyemma
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def filter(f, stride):
    for i, line in enumerate(f):
        if (not i%stride):
            yield line

In [139]:
# Define class to store trajectories and their state for milestoning
class allTrajs(object):
    th_divisions = 4
    
    def __init__(self, truncTrajs, dTrajs=None, milestones=None):
        if dTrajs == None:
            dTrajs = []
        if milestones == None:
            milestones = {}
        self.truncTrajs = truncTrajs
        self.dTrajs = dTrajs
        self.milestones = milestones
    
    # Get discretized trajectories (dTrajs) in chosen milestones 
    # from continue truncated trajectories (truncTrajs)
    def getdTrajs(self):
        # Resize dTrajs array 
        self.dTrajs = [None] * len(self.truncTrajs)
        # Loop over each trajectory
        for i in range(len(self.truncTrajs)):
            # Create empty list of consistent size for ith discrete trajectory
            trajlen = len(self.truncTrajs[i])
            self.dTrajs[i] = [None] * trajlen
            # Loop over each time iteration to set corresponding discrete state
            for j in range(trajlen):
                if j > 0:
                    prevstate = self.dTrajs[i][j-1]
                else:
                    prevstate = None
                self.dTrajs[i][j] = self.getState(self.truncTrajs[i][j],prevstate)
        return self.dTrajs
    
    # Given coordinates, assigns a state which corresponds to an area
    # in space. The state is assigned with an integer value. The center of the
    # state region is given by getMilestones() function
    def getState(self, coord, prevst):
        x = coord[0]
        y = coord[1]
        r = np.sqrt(x*x + y*y)
        th = np.arctan2(y, x)
        ndiv = allTrajs.th_divisions
        angint = 2*np.pi/ndiv 
        if r <= 1:
            state = 0
            return state
        elif (r >= 2 and r <= 3):
            for k in range(ndiv):
                llim = -np.pi + k*angint
                rlim = -np.pi + (k+1)*angint
                if (th >= llim and th < rlim):
                    state = k + 1
                    return state
        else:
            state = prevst
            return state
    
    # Get x,y centers of milestones in a dictionary: milestones[state] = [x,y]    
    def getMilestones(self):
        # Bound state is 0 and assigned origin as center
        self.milestones[0] = [0,0]
        ndiv = allTrajs.th_divisions
        angint = 2*np.pi/ndiv
        r = 2.5
        for k in range(ndiv):
            llim = -np.pi + k*angint
            rlim = -np.pi + (k+1)*angint
            th = (rlim + llim)/2.0
            x = r*np.cos(th)
            y = r*np.sin(th)
            self.milestones[k+1] = [x,y]
        return self.milestones
                
            

In [140]:
# Extract truncated trajectries as before, but using now allTrajs class
fname = '../data/2DmodifiedLJmultipleTrajsLongR4.txt'
alltrajs = allTrajs([]) # define allTrajs object
trajs = []
bathtoMSMs = []
MSMtobaths = []
trajAssignment = []
fileIndex = []
dimension = 2
with open(fname) as f:
    data = np.genfromtxt(filter(f, 1))
    ntrajs = data.shape[1]/dimension
    for i in range(0, ntrajs):
        traj = data[:,2*i:2*i+dimension]
        trajs.append(traj)
        abs = np.linalg.norm(traj, axis = 1)
        MSMdomain = (abs < 3.)
        bathtoMSM = np.where(np.logical_and(~MSMdomain[:-1], MSMdomain[1:]) )[0]+1
        MSMtobath = np.where(np.logical_and(MSMdomain[:-1], ~MSMdomain[1:]) )[0]+1
        #make sure both arrays have the same lenght
        if MSMdomain[0]:
            bathtoMSM = np.insert(bathtoMSM, 0, 0)
        if MSMdomain[-1]:
            MSMtobath = np.append(MSMtobath, len(MSMdomain))
        bathtoMSMs.append(bathtoMSM)
        MSMtobaths.append(MSMtobath)
        for i in range(0, len(bathtoMSM)):
            trajAssignment.append(fileIndex)
            alltrajs.truncTrajs.append(traj[bathtoMSM[i]:MSMtobath[i],:])

In [80]:
# Calculate discrete trajectories and save them into alltrajs.dtrajs
alltrajs.getdTrajs()

In [81]:
len(alltrajs.dTrajs)

223080

In [82]:
len(alltrajs.truncTrajs)

223080

In [103]:
alltrajs.dTrajs[0]

[3, 3, 3, 3, 3, 3, 3]

In [104]:
alltrajs.truncTrajs[0]

array([[-0.11318649,  2.460619  ],
       [-0.228444  ,  2.48857247],
       [-0.20297694,  2.37455592],
       [-0.23480105,  2.60129679],
       [-0.12868021,  2.69948628],
       [-0.12764051,  2.78087705],
       [-0.36861899,  2.94587384]])

In [142]:
dd = alltrajs.getMilestones()

In [152]:
dd

{0: [0, 0],
 1: [-1.7677669529663687, -1.7677669529663689],
 2: [1.7677669529663689, -1.7677669529663687],
 3: [1.7677669529663689, 1.7677669529663687],
 4: [-1.7677669529663687, 1.7677669529663689]}