# Setup
## Import libraries

In [495]:
import numpy as np 
import plotly.graph_objects as go
from tqdm import tqdm
import plotly.express as px
np.random.seed(0)

## Define constants

In [496]:
L = 2*np.pi # periodic domain size 
dt = 0.01   # time step size 

# define boundaries of simulation box
x0 = 0      
x1 = L
z0 = 0
z1 = L 

# define training parameters
Ne = 5000     # number of episodes
gamma = 0.999 # discount - how much we weight future to present rewards. Close to 0 = myopic view. 
eps = 0       # fraction of the time we allow for exploration in selecting the following action.
              # value of 0 means we always take the greedy choice; no exploration.
    
alpha = 4e-2  # learning rate 
Ns = 4000     # number of steps in an episode 

# define reinforcement learning problem 
N_states = 12 # number of states - one for each coarse-grained degree of vorticity 
N_actions = 4 # number of actions - one for each coarse-grained swimming direction

# define dimensionless groups
Φ = 1 # swimming number = v_s/u_0
Ψ = 1 # stability number = B w_0. B is the characteristic time a perturbed cell takes to return 
         # to orientation ka if w = 0. smaller means swimming more aligned with ka. 
    
# plotting parameters
n_updates = 1 # how often to plot the trajectory undertaken by the particle during the learning process. measured
                 # in number of episodes. 


## Define swimmer class

In [497]:
class Swimmer:
    def __init__(self):
        # local position within the periodic box. X = [x, z]^T with 0 <= x < 2 pi and 0 <= z < 2 pi
        self.X = np.array([np.random.uniform(0, L), np.random.uniform(0, L)])
        
        # absolute position. -inf. <= x_total < inf. and -inf. <= z_total < inf.
        self.X_total = self.X
        
        # particle orientation 
        self.theta = np.random.uniform(0, 2*np.pi) # polar angle theta in the x-z plane 
        self.p = np.array([np.cos(self.theta), np.sin(self.theta)]) # p = [px, pz]^T
        
        # translational and rotational velocity
            # TODO: check if this is the proper initialization
        self.U = np.zeros(2)
        self.W = np.zeros(2)
        
        # preferred swimming direction (equal to [1,0], [0,1], [-1,0], or [0,-1])
        self.ka = np.array([0,1])
        
        # history of local and global position. Only store information for this episode. 
        self.history_X = [self.X]
        self.history_X_total = [self.X_total]
        
    def update_kinematics(self):
        # calculate new translational and rotational velocity 
        self.calc_velocity()
        
        self.update_position()
        self.update_orientation()
    
    def update_position(self):
        # use explicit euler to update 
        self.X = self.X + dt*self.U
        self.X_total = self.X_total + dt*self.U
        
        # check if still in the periodic box
        self.check_in_box()
        
        # store positions
        self.history_X.append(self.X)
        self.history_X_total.append(self.X_total)
        
        # TODO: explore use of RK4
        
        # TODO: add in noise in update
    
    def update_orientation(self):
        self.p = self.p + dt*self.W
        
        # ensure the vector p has unit length 
        self.p /= np.linalg.norm(self.p)
        
        # TODO: explore use of RK4
        
        # TODO: add in noise in update

    def calc_velocity(self):
        ux, uz, w = tgf(self.X[0], self.X[1])
        
        self.U = np.array(ux, uz) + Φ*self.p
        
        px = self.p[0]
        pz = self.p[1]
        self.W = 1/2/Ψ*(self.ka - np.dot(self.ka, self.p)*self.p) + 1/2*np.array(pz*w, -px*w)
        
        
    def check_in_box(self): 
        if self.X[0] < x0:
            self.X[0] += L 
        elif self.X[0] > x1:
            self.X[0] -= L 
        if self.X[1] < z0:
            self.X[1] += L 
        elif self.X[1] > z1:
            self.X[1] -= L     

## Define Taylor-Green vortex

In [498]:
# given position, return local velocity and vorticity 
def tgf(x, z):
    ux = -1/2*np.cos(x)*np.sin(z)
    uz = 1/2*np.sin(x)*np.cos(z)
    w = -np.cos(x)*np.cos(z)
    return ux, uz, w

In [499]:
# visualize 
x = np.linspace(0,L,100)
z = np.linspace(0,L,100)
xv, zv = np.meshgrid(x, z)
ux, uz, w = tgf(xv, zv)

fig = go.Figure(data = go.Contour(x = x, y = z, z=w))

fig.update_layout(
    title=r"$\text{Vorticity }(w)$",
    xaxis_title="$x$",
    yaxis_title="$z$"
)

fig.show()

# Training process

In [503]:
Q = np.ones((12, 4))        # 12 states and 4 possible actions
agent = Swimmer()
Ne = 1
for ep in tqdm(range(Ne)):  # for each episode
    for stage in range(Ns): # for each stage
        agent.update_kinematics()
    # plot trajectory every so often 
    if stage%n_updates==0:
        history_X_total = np.array(agent.history_X_total)
        fig = go.Figure(go.Scatter(x=history_X_total[:,0], y=history_X_total[:,1],mode='markers',
            marker=dict(size=4,
                color=np.linspace(0,Ns,Ns+1), #set color equal to a variable
                colorscale='plasma', # one of plotly colorscales
                showscale=True,
                colorbar=dict(title="Step")
            )))
        fig.update_layout(
            title="Trajectory for episode " + str(ep),
            xaxis_title="$x$",
            yaxis_title="$z$",
        )
        fig.show()


  0%|          | 0/1 [00:00<?, ?it/s][A


100%|██████████| 1/1 [00:00<00:00,  2.31it/s][A


In [501]:
fig = go.Figure(data = go.Contour(x = x, y = z, z=w, colorscale='Gray', colorbar=dict(x=1,title="w")))

history_X = np.array(agent.history_X)

fig.add_trace(go.Scatter(x=history_X[:,0], y=history_X[:,1],mode='markers',
    marker=dict(
        size=4,
        color=np.linspace(0,Ns,Ns+1), #set color equal to a variable
        colorscale='plasma', # one of plotly colorscales
        showscale=True,
        colorbar=dict(x=1.1,title="Step")
    )))

fig.update_layout(
    title="Trajectory for the final episode",
    xaxis_title="$x$",
    yaxis_title="$z$",
    xaxis=dict(
        range=[0, L]
    ),
    yaxis=dict(
        range=[0, L]
    )
)

fig.show()

In [502]:
# list of colorscales
# ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance',
#              'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg',
#              'brwnyl', 'bugn', 'bupu', 'burg', 'burgyl', 'cividis', 'curl',
#              'darkmint', 'deep', 'delta', 'dense', 'earth', 'edge', 'electric',
#              'emrld', 'fall', 'geyser', 'gnbu', 'gray', 'greens', 'greys',
#              'haline', 'hot', 'hsv', 'ice', 'icefire', 'inferno', 'jet',
#              'magenta', 'magma', 'matter', 'mint', 'mrybm', 'mygbm', 'oranges',
#              'orrd', 'oryel', 'peach', 'phase', 'picnic', 'pinkyl', 'piyg',
#              'plasma', 'plotly3', 'portland', 'prgn', 'pubu', 'pubugn', 'puor',
#              'purd', 'purp', 'purples', 'purpor', 'rainbow', 'rdbu', 'rdgy',
#              'rdpu', 'rdylbu', 'rdylgn', 'redor', 'reds', 'solar', 'spectral',
#              'speed', 'sunset', 'sunsetdark', 'teal', 'tealgrn', 'tealrose',
#              'tempo', 'temps', 'thermal', 'tropic', 'turbid', 'twilight',
#              'viridis', 'ylgn', 'ylgnbu', 'ylorbr', 'ylorrd']