In [None]:
from SIR_control import *
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn-poster')
import hjb
from pathlib import Path
Path("../figures").mkdir(parents=True, exist_ok=True)

# Figure 1

In [None]:
fig=plot_phaseplane(beta=0.3,gamma=0.1)
plt.savefig('../figures/sigma3.pdf')

In [None]:
fig=plot_phaseplane(beta=0.15,gamma=0.1)
plt.savefig('../figures/sigma15.pdf')

# Figure 2

In [None]:
x0=0.99
y0 = 0.01

def xinf_sigma(sigma):
    return x_inf(x0, y0, sigma)

from scipy.optimize import fsolve

sigma0 = 3.
myfun = lambda sigma: xinf_sigma(sigma) - 1/sigma0
sigmastar = fsolve(myfun,sigma0)[0]
qfun = lambda t, u: 1-sigmastar/sigma0
1-sigmastar/sigma0

In [None]:
x1, y1, t1 = SIR_forward(qfun=qfun,T=1000)

In [None]:
def qfun(t,u):
    if u[0]>1./sigma0: return 0
    else: return 1
    
x2, y2, t2 = SIR_forward(qfun=qfun,T=1000)

In [None]:
fig=plot_phaseplane([x1,x2],[y1,y2])
plt.savefig('../figures/twocontrols.pdf')

# Figure 3

In [None]:
x0 = 0.7; y0 = 0.2
switch_times=np.array([0,25])

def qfun(t,u):
    if np.argmax(t<switch_times) % 2 == 0: qval = 0
    else: qval = 1
    return qval
        
x1, y1, t1 = SIR_forward(qfun=qfun, x0=x0,y0=y0,T=54)

In [None]:
switch_times=np.array([8,16.5])

def qfun(t,u):
    if np.argmax(t<switch_times) % 2 == 0: qval = 0
    else: qval = 1
    return qval

x2, y2, t2 = SIR_forward(qfun=qfun,x0=x0,y0=y0,T=16.5)

In [None]:
switch_times=np.array([2,4,6,8,10,12,14,16,18,20,21.6])

def qfun(t,u):
    if np.argmax(t<switch_times) % 2 == 0: qval = 0
    else: qval = 1
    return qval

x3, y3, t3 = SIR_forward(qfun=qfun,x0=x0,y0=y0,T=21.6)

In [None]:
fig=plot_phaseplane([x1,x2,x3],[y1,y2,y3])
plt.savefig('../figures/threepaths.pdf')

# Figure 4

In [None]:
T = 100
beta = 0.3
gamma = 0.1
sigma0 = beta/gamma

def qfun(t,u):
    x = u[0]
    if t == T: return 0
    elif x<1./(sigma0*(1-np.exp(-gamma*(T-t)))): return 1
    else: return 0
    
x, y, t = SIR_forward(qfun=qfun,x0=0.99,y0=0.01,beta=beta,gamma=gamma)
fig=plot_phaseplane([x],[y],color='k')
plt.savefig('../figures/example1_xy.pdf')

In [None]:
q = np.array([qfun(tt,[xx,0]) for tt, xx in zip(t,x)])
control = 1-q
fig = plot_timeline(x,y,control,t)
plt.savefig('../figures/example1_time.pdf')

# Figure 5

In [None]:
T = 70
beta = 0.3
gamma = 0.1
sigma0 = beta/gamma

def qfun(t,u):
    x = u[0]
    if t == T: return 0
    elif x<1./(sigma0*(1-np.exp(-gamma*(T-t)))): return 1
    else: return 0
    
x1, y1, t1 = SIR_forward(qfun=qfun,x0=0.99,y0=0.01,beta=beta,gamma=gamma,T=T)

In [None]:
T = 40

def qfun(t,u):
    x = u[0]
    if t == T: return 0
    elif x<1./(sigma0*(1-np.exp(-gamma*(T-t)))): return 1
    else: return 0
    
x2, y2, t2 = SIR_forward(qfun=qfun,x0=0.99,y0=0.01,beta=beta,gamma=gamma,T=T)

In [None]:
T = 30

def qfun(t,u):
    x = u[0]
    if t == T: return 0
    elif x<1./(sigma0*(1-np.exp(-gamma*(T-t)))): return 1
    else: return 0
    
x3, y3, t3 = SIR_forward(qfun=qfun,x0=0.99,y0=0.01,beta=beta,gamma=gamma,T=T)

In [None]:
fig = plot_phaseplane([x1,x2,x3],[y1,y2,y3],labels=['T=70','T=40','T=30'])
plt.legend()
plt.savefig('../figures/diff-time-opt.pdf')

# Figure 6

## via PMP (necessary conditions)

In [None]:
qmax = 0.6
N = 50000

# Use parameter continuation to get a good initial guess and successively improve it
c2 = 0.8
x, y, sigma, t, newguess, J = solve_pmp(T=100,qmax=qmax,c2=1.25,N=N,guess=None)
reduction_factor = 10.

while c2 > 1.e-7:
    #print(reduction_factor,c2/reduction_factor)
    x, y, sigma, t, newg, J = solve_pmp(T=100,qmax=qmax,c2=c2/reduction_factor,N=N,guess=newguess)
    if isinstance(x,int):
        reduction_factor = np.sqrt(reduction_factor)
    else:
        print('Solver converged for c2 = ', c2)
        c2 = c2/reduction_factor
        newguess = newg
        reduction_factor = reduction_factor**2

x, y, sigma, t, newguess, J = solve_pmp(T=100,qmax=qmax,c2=1e-7,N=N,guess=newguess)

In [None]:
fig=plot_timeline(x,y,sigma/sigma0,t)
plt.savefig('../figures/example2_time.pdf')

In [None]:
fig=plot_phaseplane([x],[y],color='k')
plt.savefig('../figures/example2_xy.pdf')

## via HJB (necessary + sufficient conditions)

In [None]:
beta = 0.3
gamma = 0.1
x0 = 0.99
y0 = 0.01
c2 = 0.
T = 100
qmax=0.6
x, y, sigma, t = hjb.solve_hjb(beta=beta,gamma=gamma,x0=x0,y0=y0,c2=0,T=T,qmax=qmax)

In [None]:
fig=plot_phaseplane([x],[y],color='k')

# Figure 7

## Via PMP

In [None]:
x0 = 0.9
y0 = 0.1
beta = 0.3
gamma = 0.1
sigma0 = beta/gamma
c2s = [2e-2,1e-3,1e-5]
T = 100
x1, y1, sigma1, t1, newguess, J = solve_pmp(c2=1,T=100,guess=None,x0=x0,y0=y0)
x1, y1, sigma1, t1, newguess, J = solve_pmp(c2=0.1,T=T,guess=newguess,x0=x0,y0=y0)
x1, y1, sigma1, t1, newguess, J = solve_pmp(c2=c2s[0],T=T,guess=newguess,x0=x0,y0=y0)
_, _, _, _, newguess, J = solve_pmp(c2=1e-2,T=T,guess=newguess,x0=x0,y0=y0)
_, _, _, _, newguess, J = solve_pmp(c2=5e-3,T=T,guess=newguess,x0=x0,y0=y0)
_, _, _, _, newguess, J = solve_pmp(c2=3e-3,T=T,guess=newguess,x0=x0,y0=y0)
x2, y2, sigma2, t2, newguess, J = solve_pmp(c2=c2s[1],T=T,guess=newguess,x0=x0,y0=y0)
x3, y3, sigma3, t3, newguess, J = solve_pmp(c2=c2s[2],T=T,guess=newguess,x0=x0,y0=y0)

In [None]:
labels = ['$c_2='+str(val)+'$' for val in c2s]
fig=plot_timelines([x1,x2,x3],[y1,y2,y3],[sigma1/sigma0,sigma2/sigma0,sigma3/sigma0],[t1,t2,t3],labels=labels)
plt.savefig('../figures/varying_c2.pdf')

In [None]:
fig=plot_phaseplane([x1,x2,x3],[y1,y2,y3])
plt.savefig('../figures/varying_c2_xy.pdf')

## Via HJB

This code takes several minutes to run, and even so the results are not as sharp as in the figure above.  Sharper results could be obtained with even more computational effort.

In [None]:
mx = 500
my = 500

x1, y1, sigma1, t1 = hjb.solve_hjb(beta=beta,gamma=gamma,x0=x0,y0=y0,c2=c2s[0],T=T,mx=mx,my=my)
x2, y2, sigma2, t2 = hjb.solve_hjb(beta=beta,gamma=gamma,x0=x0,y0=y0,c2=c2s[1],T=T,mx=mx,my=my)
x3, y3, sigma3, t3 = hjb.solve_hjb(beta=beta,gamma=gamma,x0=x0,y0=y0,c2=c2s[2],T=T,mx=mx,my=my)

In [None]:
fig=plot_phaseplane([x1,x2,x3],[y1,y2,y3])

# Figure 8

## PMP

In [None]:
beta = 0.3
gamma = 0.1
sigma0 = beta/gamma
x0 = 0.9
y0 = 0.01
c2 = 1.e-2
c3 = 100.
ymax = 0.1
T = 100
x, y, sigma, t, newguess, J = solve_pmp(c2=c2,c3=c3,ymax=ymax,T=T,guess=None,x0=x0,y0=y0)
x2, y2, t2 = SIR_forward(beta=beta,gamma=gamma,x0=x0,y0=y0,T=T)
fig=plot_timeline(x,y,sigma/sigma0,t,y2=y2,t2=t2)
plt.savefig('../figures/min_hosp_1_t.pdf')
print(J)

In [None]:
fig = plot_phaseplane([x],[y],color='k',x2=x2,y2=y2)
plt.savefig('../figures/min_hosp_1_xy.pdf')

## HJB

In [None]:
x, y, sigma, t = hjb.solve_hjb(beta=beta,gamma=gamma,x0=x0,y0=y0,c2=c2,c3=c3,ymax=ymax,T=T,mx=500,my=500)

In [None]:
fig=plot_timeline(x,y,sigma/sigma0,t)

In [None]:
fig=plot_phaseplane([x],[y],color='k')

# Figure 9

## PMP

In [None]:
beta = 0.3
gamma = 0.1
sigma0 = beta/gamma
x0 = 0.9
y0 = 0.01
c2 = 1e-2
c3 = 1.
ymax = 0.1
T = 100
x, y, sigma, t, newguess, J = solve_pmp(c2=5e-2,c3=c3,ymax=ymax,T=T,guess=None,x0=x0,y0=y0)
x, y, sigma, t, newguess, J = solve_pmp(c2=c2,c3=c3,ymax=ymax,T=T,guess=newguess,x0=x0,y0=y0)
x2, y2, t2 = SIR_forward(beta=beta,gamma=gamma,x0=x0,y0=y0,T=T)
fig=plot_timeline(x,y,sigma/sigma0,t,y2=y2,t2=t2)
plt.savefig('../figures/min_hosp_2_t.pdf')
print(J)

In [None]:
fig = plot_phaseplane([x],[y],color='k',x2=x2,y2=y2)
plt.savefig('../figures/min_hosp_2_xy.pdf')

# Figure 10

In [None]:
N = 1
alpha = 0.006  # IFR
eta = alpha # Increase in IFR when no medical care is given
d = 1e4 # Days left of life for average victim
eps = 0.2  # Fraction of value of a day of life that is lost due to intervention
c1 = N*alpha
c2 = N*eps/d
c3 = eta*N
gamma = 1./10
sigma0 = 3.2
beta = sigma0*gamma
ymax=0.02
y0 = 1e-3
x0 = 0.999
T = 200
npts = 10000

In [None]:
# Use parameter continuation to get a good initial guess and successively improve it
c2temp = 900*c2
print(c2,c2temp)
x, y, sigma, t, newguess, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,T=T,ymax=ymax,c2=c2temp,c3=c3,N=npts,
                                        guess=None,x0=x0,y0=y0)
reduction_factor = 10.

while c2temp > c2:
    print(c2,c2temp/reduction_factor)

    x, y, sigma, t, newg, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,T=T,ymax=ymax,c2=c2temp/reduction_factor,c3=c3,N=npts,
                                        guess=newguess,x0=x0,y0=y0)
    if isinstance(x,int):
        reduction_factor = np.sqrt(reduction_factor)
    else:
        c2temp = c2temp/reduction_factor
        newguess = newg
        reduction_factor = reduction_factor**2

x, y, sigma, t, newguess, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,T=T,ymax=ymax,c2=c2,c3=c3,N=npts,
                                        guess=newguess,x0=x0,y0=y0)

In [None]:
x2, y2, t2 = SIR_forward(beta=beta,gamma=gamma,x0=x0,y0=y0,T=T)
fig=plot_timeline(x,y,sigma/sigma0,t,y2=y2,t2=t2)
plt.savefig('../figures/real_world_1_t.pdf')

In [None]:
fig = plot_phaseplane([x],[y],beta=beta,gamma=gamma,color='k',x2=x2,y2=y2)
plt.savefig('../figures/real_world_1_xy.pdf')

In [None]:
print(c1,c2,c3)

# Figure 11

In [None]:
N = 1
alpha = 0.012  # IFR
eta = alpha # Increase in IFR when no medical care is given
d = 1e4 # Days left of life for average victim
eps = 0.05  # Fraction of value of a day of life that is lost due to intervention
c1 = N*alpha
c2 = N*eps/d
c3 = eta*N
gamma = 1./10
sigma0 = 3.2
beta = sigma0*gamma
ymax=0.02
y0 = 1e-3
x0 = 1-y0
T = 200


# Use parameter continuation to get a good initial guess and successively improve it
c2temp = 900*c2
print(c2,c2temp)
x, y, sigma, t, newguess, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,c2=c2temp,c3=c3,ymax=ymax,T=T,N=npts,
                                        guess=None,x0=x0,y0=y0)
reduction_factor = 10.

while c2temp > c2:
    if c2temp/reduction_factor<c2:
        reduction_factor = c2temp/c2 * 1.001
    print(c2,c2temp/reduction_factor)

    x, y, sigma, t, newg, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,c2=c2temp/reduction_factor,c3=c3,T=T,
                                        ymax=ymax,N=npts,guess=newguess,x0=x0,y0=y0)
    if isinstance(x,int):
        reduction_factor = np.sqrt(reduction_factor)
    else:
        c2temp = c2temp/reduction_factor
        newguess = newg
        reduction_factor = reduction_factor**2

x, y, sigma, t, newguess, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,c2=c2,c3=c3,ymax=ymax,T=T,N=npts,
                                        guess=newguess,x0=x0,y0=y0)

In [None]:
x2, y2, t2 = SIR_forward(beta=beta,gamma=gamma,x0=x0,y0=y0,T=T)
fig=plot_timeline(x,y,sigma/sigma0,t,y2=y2,t2=t2)
plt.savefig('../figures/real_world_2_t.pdf')

In [None]:
fig = plot_phaseplane([x],[y],beta=beta,gamma=gamma,color='k',x2=x2,y2=y2)
plt.savefig('../figures/real_world_2_xy.pdf')

In [None]:
print(c1,c2,c3)

# Figure 12

In [None]:
N = 1
alpha = 0.006  # IFR
eta = alpha # Increase in IFR when no medical care is given
d = 1e4 # Days left of life for average victim
eps = 1.0  # Fraction of value of a day of life that is lost due to intervention
c1 = N*alpha
c2 = N*eps/d
c3 = eta*N
gamma = 1./10
sigma0 = 3.2
beta = sigma0*gamma
ymax=0.02
y0 = 1e-3
x0 = 1-y0
T = 200

# Use parameter continuation to get a good initial guess and successively improve it
c2temp = 900*c2
print(c2,c2temp)
x, y, sigma, t, newguess, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,c2=c2temp,c3=c3,ymax=ymax,T=T,N=npts,
                                        guess=None,x0=x0,y0=y0)
reduction_factor = 10.

while c2temp > c2:
    if c2temp/reduction_factor<c2:
        reduction_factor = c2temp/c2 * 1.001
    print(c2,c2temp/reduction_factor)

    x, y, sigma, t, newg, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,c2=c2temp/reduction_factor,c3=c3,T=T,
                                        ymax=ymax,N=npts,guess=newguess,x0=x0,y0=y0)
    if isinstance(x,int):
        reduction_factor = np.sqrt(reduction_factor)
    else:
        c2temp = c2temp/reduction_factor
        newguess = newg
        reduction_factor = reduction_factor**2

x, y, sigma, t, newguess, J = solve_pmp(beta=beta,gamma=gamma,c1=c1,c2=c2,c3=c3,ymax=ymax,T=T,N=npts,
                                        guess=newguess,x0=x0,y0=y0)

In [None]:
x2, y2, t2 = SIR_forward(beta=beta,gamma=gamma,x0=x0,y0=y0,T=T)
fig=plot_timeline(x,y,sigma/sigma0,t,y2=y2,t2=t2)
plt.savefig('../figures/real_world_3_t.pdf')

In [None]:
fig = plot_phaseplane([x],[y],beta=beta,gamma=gamma,color='k',x2=x2,y2=y2)
plt.savefig('../figures/real_world_3_xy.pdf')