In [None]:
import numpy as np
import pyemma
import matplotlib.pyplot as plt
import copy
%matplotlib inline

In [None]:
# Define class to store trajectories and calculate discrete state trajectories for milestoning
class allTrajs(object):
    
    def __init__(self, Trajs=None):
        if Trajs == None:
            Trajs = []
        # Main variables
        self.Trajs = Trajs
        self.dTrajs = []
        self.milestones = {}
        self.dTrajsclean = []
        # Milestone choice variables
        self.entry_div = 6
        self.exit_div = 6
        self.rentry1 = 2.0
        self.rentry2 = 2.2
        self.rexit1 = 2.2
        self.rexit2 = 3.0
    
    # Get discretized trajectories (dTrajs) in chosen milestones 
    # from continue trajectories (Trajs)
    def getdTrajs(self):
        # Resize dTrajs array 
        self.dTrajs = [None] * len(self.Trajs)
        # Loop over each trajectory
        for i in range(len(self.Trajs)):
            # Create empty list of consistent size for ith discrete trajectory
            trajlen = len(self.Trajs[i])
            self.dTrajs[i] = [None] * trajlen
            # Loop over each time iteration to set corresponding discrete state
            for j in range(trajlen):
                if j > 0:
                    prevstate = self.dTrajs[i][j-1]
                else:
                    prevstate = None
                self.dTrajs[i][j] = self.getState(self.Trajs[i][j],prevstate)
        return self.dTrajs
    
    # Same a getdTrajs but ensuring there are no "None" states,
    # since they can appear if the initial condition is in a "None" state region
    def getdTrajsclean(self):
        # If dTrajs haven't been yet calculated, do so
        if self.dTrajs == []:
            self.getdTrajs()
        self.dTrajsclean = copy.deepcopy(self.dTrajs)
        # Eliminate "None" entries in reverse order to avoid misindexing
        for i in reversed(range(len(self.dTrajs))):
            if self.dTrajs[i] == None:
                self.dTrajsclean.pop(i)
            else:
                for j in reversed(range(len(self.dTrajs[i]))):
                    if self.dTrajs[i][j] == None:
                        self.dTrajsclean[i].pop(j)
            if self.dTrajsclean[i] == []:
                self.dTrajsclean.pop(i)
        return self.dTrajsclean
                        
    
    # Given coordinates, assigns a state which corresponds to an area
    # in space. The state is assigned with an integer value. The center of the
    # state region is given by getMilestones() function
    def getState(self, coord, prevst):
        x = coord[0]
        y = coord[1]
        r = np.sqrt(x*x + y*y)
        th = np.arctan2(y, x)
        angint_entry = 2*np.pi/self.entry_div
        angint_exit = 2*np.pi/self.exit_div
        # Bound state
        if r <= 1:
            state = 0
            return state
        # Entry states
        elif (r >= self.rentry1 and r < self.rentry2):
            for k in range(self.entry_div):
                llim = -np.pi + k*angint_entry
                rlim = -np.pi + (k+1)*angint_entry
                if (th >= llim and th < rlim):
                    state = k + 1
                    return state
        # Exit states
        elif (r >= self.rexit1 and r <= self.rexit2):
            for k in range(self.exit_div):
                llim = -np.pi + k*angint_exit
                rlim = -np.pi + (k+1)*angint_exit
                if (th >= llim and th < rlim):
                    state = self.entry_div + k + 1
                    return state
        # Bath state
        elif (r > self.rexit2):
            state = self.entry_div + self.exit_div + 1
            return state
        # Didn't change state
        else:
            state = prevst
            return state
    
    # Get x,y centers of milestones in a dictionary: milestones[state] = [x,y]    
    def getMilestones(self):
        # Bound state is 0 and assigned origin as center
        self.milestones[0] = ['Bound']
        angint_entry = 2*np.pi/self.entry_div
        angint_exit = 2*np.pi/self.exit_div
        rentry = (self.rentry1 + self.rentry2)/2.0
        rexit = (self.rexit1 + self.rexit2)/2.0
        # Loop over entry states
        for k in range(self.entry_div):
            llim = -np.pi + k*angint_entry
            rlim = -np.pi + (k+1)*angint_entry
            th = (rlim + llim)/2.0
            x = rentry*np.cos(th)
            y = rentry*np.sin(th)
            self.milestones[k+1] = [x,y]
        # Loope over exit states
        for k in range(self.exit_div):
            llim = -np.pi + k*angint_exit
            rlim = -np.pi + (k+1)*angint_exit
            th = (rlim + llim)/2.0
            x = rexit*np.cos(th)
            y = rexit*np.sin(th)
            self.milestones[k + 1 + self.entry_div] = [x,y]
        self.milestones[self.entry_div + self.exit_div + 1] = ['Bath']
        return self.milestones
                
            

In [None]:
# Define filter for trajectory extraction from file
def filter(f, stride):
    for i, line in enumerate(f):
        if (not i%stride):
            yield line
# Extract trajectories from file using allTrajs class
fname = '../data/2DmodifiedLJmultipleTrajsLongR4.txt'
alltrajs = allTrajs([]) # define allTrajs object
dimension = 2
with open(fname) as f:
    data = np.genfromtxt(filter(f, 1))
    ntrajs = data.shape[1]/dimension
    for i in range(0, ntrajs):
        traj = data[:,2*i:2*i+dimension]
        alltrajs.Trajs.append(traj)

In [None]:
# Calculate discrete trajectories and make sure it is clean (no "None" elements)
dtrajs = alltrajs.getdTrajsclean()
# Obtain centers of each state and save
centers = alltrajs.getMilestones()