# Setup
## Import libraries

In [432]:
import numpy as np 
import plotly.graph_objects as go
from tqdm.notebook import tqdm
import plotly.express as px
np.random.seed(0)

## Define constants

In [433]:
L = 2*np.pi # periodic domain size 
dt = 0.01   # time step size 

# define boundaries of simulation box
x0 = 0      
x1 = L
z0 = 0
z1 = L 

# define training parameters
Ne = 5000         # number of episodes
gamma = 0.999     # discount - how much we weight future to present rewards. Close to 0 = myopic view. 
eps = 0           # fraction of the time we allow for exploration in selecting the following action.
                  # value of 0 means we always take the greedy choice; no exploration.
    
alpha = 4e-2      # learning rate 
Ns = 4000         # number of steps in an episode 
N_ensemble = 10 # sample size for ensemble average. I.e. how many times we repeat a given episode.

# define reinforcement learning problem 
N_states = 12 # number of states - one for each coarse-grained degree of vorticity 
N_actions = 4 # number of actions - one for each coarse-grained swimming direction

# define dimensionless groups
Φ = 1 # swimming number = v_s/u_0
Ψ = 1 # stability number = B w_0. B is the characteristic time a perturbed cell takes to return 
         # to orientation ka if w = 0. smaller means swimming more aligned with ka. 
    
# plotting parameters
n_updates = 5 # how often to plot the trajectory undertaken by the particle during the learning process. measured
                 # in number of episodes. 


## Utility functions

In [434]:
# documentation: https://github.com/scipy/scipy/blob/v0.14.0/scipy/stats/stats.py#L1864
# license: https://www.scipy.org/scipylib/license.html
def signaltonoise(a, axis=0, ddof=0):
    a = np.asanyarray(a)
    m = a.mean(axis)
    sd = a.std(axis=axis, ddof=ddof)
    return np.where(sd == 0, 0, m/sd)

## Define useful data structures
### Define a dictionary of the possible states and their assigned indices

In [435]:
direction_states = ["right","down","left","up"] # coarse-grained directions
vort_states = ["w+", "w0", "w-"] # coarse-grained levels of vorticity 
product_states = [(x,y) for x in direction_states for y in vort_states]  # all possible states
state_lookup_table = {product_states[i]:i for i in range(len(product_states))} # returns index of given state

### Define an agent class for reinforcement learning

In [436]:
class Agent:
    def __init__(self):
        self.r = np.zeros(Ns) # reward for each stage
        
    # calculate reward given from entering a new state after a selected action is undertaken
    def calc_reward(self):
        # enforce implementation by subclass
        if self.__class__ == AbstractClass:
                raise NotImplementedError
                
    def update_state(self):
        # enforce implementation by subclass
        if self.__class__ == AbstractClass:
                raise NotImplementedError
                
    def take_random_action(self):
        # enforce implementation by subclass
        if self.__class__ == AbstractClass:
                raise NotImplementedError
                
    def take_greedy_action(self):
        # enforce implementation by subclass
        if self.__class__ == AbstractClass:
                raise NotImplementedError

### Define swimmer class derived from agent

In [437]:
class Swimmer(Agent):
    def __init__(self):
        # call init for superclass
        super().__init__()
        
        # local position within the periodic box. X = [x, z]^T with 0 <= x < 2 pi and 0 <= z < 2 pi
        self.X = np.array([np.random.uniform(0, L), np.random.uniform(0, L)])
        
        # absolute position. -inf. <= x_total < inf. and -inf. <= z_total < inf.
        self.X_total = self.X
        
        # particle orientation 
        self.theta = np.random.uniform(0, 2*np.pi) # polar angle theta in the x-z plane 
        self.p = np.array([np.cos(self.theta), np.sin(self.theta)]) # p = [px, pz]^T
        
        # translational and rotational velocity
            # TODO: check if this is the proper initialization
        self.U = np.zeros(2)
        self.W = np.zeros(2)
        
        # preferred swimming direction (equal to [1,0], [0,1], [-1,0], or [0,-1])
        self.ka = np.array([0,1])
        
        # history of local and global position. Only store information for this episode. 
        self.history_X = [self.X]
        self.history_X_total = [self.X_total]
        
        # local vorticity at the current location
        _, _, self.w = tgf(self.X[0], self.X[1])
        
        # update coarse-grained state
        self.update_state()
        
    def reinitialize(self):
        self.X = np.array([np.random.uniform(0, L), np.random.uniform(0, L)])
        self.X_total = self.X
        
        self.theta = np.random.uniform(0, 2*np.pi) # polar angle theta in the x-z plane 
        self.p = np.array([np.cos(self.theta), np.sin(self.theta)]) # p = [px, pz]^T
        
        self.U = np.zeros(2)
        self.W = np.zeros(2)

        self.ka = np.array([0,1])

        self.history_X = [self.X]
        self.history_X_total = [self.X_total]
        
        
    def update_kinematics(self):
        # calculate new translational and rotational velocity 
        self.calc_velocity()
        
        self.update_position()
        self.update_orientation()
    
    def update_position(self):
        # use explicit euler to update 
        self.X = self.X + dt*self.U
        self.X_total = self.X_total + dt*self.U
        
        # check if still in the periodic box
        self.check_in_box()
        
        # store positions
        self.history_X.append(self.X)
        self.history_X_total.append(self.X_total)
        
        # TODO: explore use of RK4
        
        # TODO: add in noise in update
    
    def update_orientation(self):
        self.p = self.p + dt*self.W
        
        # ensure the vector p has unit length 
        self.p /= np.linalg.norm(self.p)
        
        # update polar angle
        x = self.X[0]
        y = self.X[1]
        self.theta = np.arctan2(y,x) if y >= 0 else (np.arctan2(y,x) + 2*np.pi)
        
        # TODO: explore use of RK4
        
        # TODO: add in noise in update

    def calc_velocity(self):
        ux, uz, self.w = tgf(self.X[0], self.X[1])
        
        self.U = np.array(ux, uz) + Φ*self.p
        
        px = self.p[0]
        pz = self.p[1]
        self.W = 1/2/Ψ*(self.ka - np.dot(self.ka, self.p)*self.p) + 1/2*np.array(pz*self.w, -px*self.w)
        
        
    def check_in_box(self): 
        if self.X[0] < x0:
            self.X[0] += L 
        elif self.X[0] > x1:
            self.X[0] -= L 
        if self.X[1] < z0:
            self.X[1] += L 
        elif self.X[1] > z1:
            self.X[1] -= L    
            
    def calc_reward(self, n):
        self.r[n] = naive.history_X[-1][1]-naive.history_X[-2][1]
        
    def update_state(self):
        if self.w < -0.33:
            w_state = "w-"
        elif self.w >= -0.33 and self.w <= 0.33:
            w_state = "w0"
        else:
            w_state = "w+"

        if self.theta >= np.pi/4 and self.theta < 3*np.pi/4:
            p_state = "up"
        elif self.theta >= 3*np.pi/4 and self.theta < 5*np.pi/4:
            p_state = "left"
        elif self.theta >= 5*np.pi/4 and self.theta < 7*np.pi/4:
            p_state = "down"
        else:
            p_state = "right"

        self.my_state = (p_state, w_state)
        
    def take_greedy_action(self, Q):
        state_index = lookup_table[self.my_state]
        action_index = np.argmax(Q[state_index])  # find largest entry in this row of Q (i.e. this state)
        if action_index == 0:   # up
            self.ka = [0, 1]
        elif action_index == 1: # down
            self.ka = [0, -1]
        elif action_index == 2: # right
            self.ka = [1, 0]
        else:                   # left
            self.ka = [-1, 0]
        return action_index
            
    def take_random_action(self, Q):
        action_index = np.random.randint(4)
        if action_index == 0:   # up
            self.ka = [0, 1]
        elif action_index == 1: # down
            self.ka = [0, -1]
        elif action_index == 2: # right
            self.ka = [1, 0]
        else:                   # left
            self.ka = [-1, 0]
        return action_index

## Define Taylor-Green vortex

In [438]:
# given position, return local velocity and vorticity 
def tgf(x, z):
    ux = -1/2*np.cos(x)*np.sin(z)
    uz = 1/2*np.sin(x)*np.cos(z)
    w = -np.cos(x)*np.cos(z)
    return ux, uz, w

In [439]:
# visualize 
x = np.linspace(0,L,100)
z = np.linspace(0,L,100)
xv, zv = np.meshgrid(x, z)
ux, uz, w = tgf(xv, zv)

fig = go.Figure(data = go.Contour(x = x, y = z, z=w))

fig.update_layout(
    title=r"$\text{Vorticity }(w)$",
    xaxis_title="$x$",
    yaxis_title="$z$"
)

fig.show()

# Training process

In [440]:
Ne = 10
Q = L*Ns*np.ones((12, 4))   # 12 states and 4 possible actions. Each column is an action, ka: up, down, right, left

stored_histories = []       # story position = f(t) every so often for an episode
Σ = np.zeros(Ne)            # learning gain per episode

naive = Swimmer()
smart = Swimmer()

# total long-term vertical displacement, one for each realization
R_tots_naive = np.zeros(N_ensemble)
R_tots_smart = np.zeros(N_ensemble)

for ep in tqdm(range(Ne)):  # for each episode
    
    # initialize Q with its final value from the last episode
    
    # iterate over different realizations of noise and initial conditions
    for realization in tqdm(range(N_ensemble)):
        
        # assign random orientation and position 
        naive.reinitialize()
        smart.reinitialize()
    
        # iterate over stages within an episode
        for stage in range(Ns): 
            # select an action eps-greedily. Note naive never changes its action/strategy (i.e. trying to swim upward)
            if np.random.uniform(0, 1) < eps:
                action_index = smart.take_random_action()
            else:
                action_index = smart.take_greedy_action(Q)
                
            # record prior state
            old_state_index = lookup_table[smart.my_state]
            
            # given selected action, update the state
            naive.update_kinematics()
            smart.update_kinematics()
            smart.update_state()      # only need to update smart particle since naive has ka = [0, 1]

            # calculate reward based on new state
            naive.calc_reward(stage)
            smart.calc_reward(stage)

            # update Q matrix 
            state_index = lookup_table[smart.my_state]
            Q[old_state_index, action_index] += alpha*(smart.r[-1] + \
                    gamma*(np.max(Q[state_index,:])-np.max(Q[old_state_index,:])))
            
        # calculate Rtot for this realization
        R_tot_naive = np.mean(naive.r)
        R_tot_smart = np.mean(smart.r)
        
        # collect Rtot across realizations
        R_tots_naive[realization] = R_tot_naive
        R_tots_smart[realization] = R_tot_smart
        
    # TODO: add warning based on average initial position and orientation vanishing 
#     if abs(signaltonoise(R_tots_naive))<1 or abs(signaltonoise(R_tots_smart))<1:
#       raise Exception(("Signal to nosie ratios % 5.2f and % 5.2f are too small. " + \
#             " Consider increasing ensemble size.") %(signaltonoise(R_tots_naive), signaltonoise(R_tots_smart)))
    
    # calculate ensemble average of total gain for this episode
    avg_R_tot_naive = np.mean(R_tots_naive)
    avg_R_tot_smart = np.mean(R_tots_smart)
        
    # calculate learning gain for this episode
    Σ[ep] = avg_R_tot_smart/avg_R_tot_naive - 1
        
    # plot trajectory every so often 
    if stage%n_updates==0:
        history_X_total = np.array(smart.history_X_total)
        stored_histories.append((ep,history_X_total))
        


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))





## Plot learning gain over time

In [441]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=np.arange(Ne), y=Σ, mode='lines'))
fig.update_layout(
    title="Learning gain over time",
    xaxis_title="Episode, E",
    yaxis_title=r"$\text{Learning gain, }\Sigma$",
)
fig.show()

## Plot trajectories for for selected episodes

In [442]:
for episode in stored_histories:
        (ep, history_X_total) = episode
        fig = go.Figure(go.Scatter(x=history_X_total[:,0], y=history_X_total[:,1],mode='markers',
            marker=dict(size=4,
                color=np.linspace(0,Ns,Ns+1), #set color equal to a variable
                colorscale='plasma', # one of plotly colorscales
                showscale=True,
                colorbar=dict(title="Step")
            )))
        fig.update_layout(
            title="Trajectory for episode " + str(ep),
            xaxis_title="$x$",
            yaxis_title="$z$",
        )
        fig.show()
    

## Visualize strategy learned by smart gyrotactic particle

In [443]:
fig = go.Figure(data = go.Contour(x = x, y = z, z=w, colorscale='Gray', colorbar=dict(x=1,title="w")))

history_X = np.array(smart.history_X)

fig.add_trace(go.Scatter(x=history_X[:,0], y=history_X[:,1],mode='markers',
    marker=dict(
        size=4,
        color=np.linspace(0,Ns,Ns+1), #set color equal to a variable
        colorscale='plasma', # one of plotly colorscales
        showscale=True,
        colorbar=dict(x=1.1,title="Step")
    )))

fig.update_layout(
    title="Trajectory for the final episode",
    xaxis_title="$x$",
    yaxis_title="$z$",
    xaxis=dict(
        range=[0, L]
    ),
    yaxis=dict(
        range=[0, L]
    )
)

fig.show()

In [444]:
# list of colorscales
# ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance',
#              'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg',
#              'brwnyl', 'bugn', 'bupu', 'burg', 'burgyl', 'cividis', 'curl',
#              'darkmint', 'deep', 'delta', 'dense', 'earth', 'edge', 'electric',
#              'emrld', 'fall', 'geyser', 'gnbu', 'gray', 'greens', 'greys',
#              'haline', 'hot', 'hsv', 'ice', 'icefire', 'inferno', 'jet',
#              'magenta', 'magma', 'matter', 'mint', 'mrybm', 'mygbm', 'oranges',
#              'orrd', 'oryel', 'peach', 'phase', 'picnic', 'pinkyl', 'piyg',
#              'plasma', 'plotly3', 'portland', 'prgn', 'pubu', 'pubugn', 'puor',
#              'purd', 'purp', 'purples', 'purpor', 'rainbow', 'rdbu', 'rdgy',
#              'rdpu', 'rdylbu', 'rdylgn', 'redor', 'reds', 'solar', 'spectral',
#              'speed', 'sunset', 'sunsetdark', 'teal', 'tealgrn', 'tealrose',
#              'tempo', 'temps', 'thermal', 'tropic', 'turbid', 'twilight',
#              'viridis', 'ylgn', 'ylgnbu', 'ylorbr', 'ylorrd']

In [450]:
# R tots naive and smart are identical