In [1]:
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = (20,10)
plt.rcParams.update({'font.size': 22})
plt.rc('xtick', labelsize=10) 
plt.rc('ytick', labelsize=10) 

In [72]:
beta = 1
lam = 10
r = 0.99
alpha = 0.8
delta = .7
gamma = 1
Sbar = 1
S1_0 = 0
S2_0 = Sbar
p1_tf = 0
p2_tf = 0
tf = 1
n = 365
iters = 11

In [73]:
#Code from RK4 Lab with minor edits
def initialize_all(y0, t0, tf, n):
    """ An initialization routine for the different ODE solving
    methods in the lab. This initializes Y, T, and h."""
    if isinstance(y0, np.ndarray):
        Y = np.empty((n, y0.size)).squeeze()
    else:
        Y = np.empty(n)
    Y[0] = y0
    T = np.linspace(t0, tf, n)
    h = float(tf - t0) / (n - 1)
    return Y, T, h

def RK4(f, y0, t0, tf, n):
    """ Use the RK4 method to compute an approximate solution
    to the ODE y' = f(t, y) at n equispaced parameter values from t0 to t
    with initial conditions y(t0) = y0.
    y0 is assumed to be either a constant or a one-dimensional numpy array.
    t and t0 are assumed to be constants.
    f is assumed to accept three arguments.
    The first is a constant giving the value of t.
    The second is a one-dimensional numpy array of the same size as y.
    The third is an index to the other arrays.
    This function returns an array Y of shape (n,) if
    y is a constant or an array of size 1.
    It returns an array of shape (n, y.size) otherwise.
    In either case, Y[i] is the approximate value of y at
    the i'th value of np.linspace(t0, t, n).
    """
    Y,T,h = initialize_all(y0,t0,tf,n)
    for i in range(n-1):
        K1 = f(T[i],Y[i],i)
        K2 = f(T[i]+h/2.,Y[i]+h/2.*K1,i)
        K3 = f(T[i]+h/2.,Y[i]+h/2.*K2,i)
        K4 = f(T[i+1],Y[i]+h*K3,i)
        Y[i+1] = Y[i] + h/6.*(K1+2*K2 +2*K3+K4)
    return Y

In [74]:
# initialize global variables, state, costate, and u.
state = np.zeros((n,2))
state0 = np.array([S1_0, S2_0])

costate = np.zeros((n,2))
costate0 = np.zeros(2)

u=np.zeros(n)
u[:] += .1

# define state equations
def state_equations(t,y,i):
    '''
    Parameters
    ---------------
    t : float
    the time
    y : ndarray (2,)
    the T cell concentration and the Virus concentration at time t
    i : int
    index for the global variable u.
    Returns
    --------------
    y_dot : ndarray (2,)
    the derivative of the T cell concentration and the virus ←-
    concentration at time t
    '''
    y_dot = np.zeros_like(y)
    y_dot[0] = u[i] - delta*y[0]
    y_dot[1] = Sbar * y[1] - y[1]**2 - u[i]
    return y_dot

In [75]:
def lambda_hat(t,y,i):
    '''
    Parameters
    ---------------
    t : float
    the time
    y : ndarray (2,)
    the lambda_hat values at time t
    i : int
    index for global variables, u and state.
    Returns
    --------------
    y_dot : ndarray (2,)
    the derivative of the lambda_hats at time t.
    '''
    y_dot = np.zeros_like(y)
    y_dot[0] = delta*y[0] + beta*np.exp(state[-i-1,0]-r*t)
    #y_dot[0] = delta*y[0] + 2.5*beta*np.exp(-r*t)*state[-i-1,0]**1.5
    y_dot[1] = y[1]*alpha*(2*state[-i-1,1] - Sbar) - lam*u[-i-1]*np.exp(-r*tf+r*t) / (state[-i-1,1]**2)
    return y_dot


In [76]:
epsilon = 0.001
test = epsilon + 1
z = np.zeros(n)

while(test > epsilon):
    oldu = u.copy();
    
    #solve the state equations with forward iteration
    state = RK4(state_equations, state0, 0, tf, n)
    
    #solve the costate equations with backwards iteration
    costate = RK4(lambda_hat, costate0, 0, tf, n)[::-1]
    
    #solve for u1 and u2
    u1 = gamma * np.maximum(z, costate[:,1] / (lam + state[:,1]*np.exp(r*np.linspace(0,tf,n))*(state[:,1] - costate[:,0])))
    u1 = np.minimum(u1, state[:,1]-epsilon)
    
    #u1 = costate[:,1] / (lam + state[:,1]*np.exp(r*np.linspace(0,tf,n))*(state[:,1] - costate[:,0]))
    
    #update control
    u = 0.5*(u1 + oldu)
    
    #test for convergence
    test = abs(oldu - u).sum()

In [77]:
domain = np.linspace(0, tf, n)

plt.subplot(131)
plt.plot(domain, u)
plt.axis("tight")
plt.title("Fishing")

plt.subplot(132)
plt.plot(domain, state[:,0])
plt.axis("tight")
plt.title("Pollutant")

plt.subplot(133)
plt.plot(domain, state[:,1])
plt.axis("tight")
plt.title("Fish Stock")

plt.suptitle(r"$\beta=${}, $\lambda$={}, $\alpha=${}, $\delta=${}, $\gamma=${}, $r=${}".format(beta,lam,alpha,delta,gamma,r))
plt.savefig("parameters{}.pdf".format(iters))
plt.clf()