In [None]:
import numpy as np
import matplotlib.pyplot as plt
from nodepy import rk
import cvxpy as cp


rk4 = rk.loadRKM('RK44').__num__()
rk4x2 = rk4*rk4
ssp2 = rk.loadRKM('SSP22').__num__()
ssp3 = rk.loadRKM('SSP33').__num__()
ssp104 = rk.loadRKM('SSP104').__num__()
merson4 = rk.loadRKM('Merson43').__num__()
bs5 = rk.loadRKM('BS5').__num__()

trbdf = rk.loadRKM('TR-BDF2').__num__()
be = rk.loadRKM('BE').__num__()
irk2 = rk.loadRKM('LobattoIIIA2').__num__()

In [None]:
def OC_b(A,c,p):
    # Order conditions matrix
    s = len(c)
    rows = []
    rhs = []
    if p >= 1:
        rows.append(np.ones(s)); rhs.append(1.)
    if p >= 2:
        rows.append(c); rhs.append(1/2)
    if p >= 3:
        tau2 = c**2 - 2*np.dot(A,c)
        rows.append(c**2); rhs.append(1/3)
        rows.append(tau2); rhs.append(0.)
    if p >= 4:
        tau3 = c**3 - 3*np.dot(A,c**2)
        rows.append(c**3); rhs.append(1/4)
        rows.append(tau2*c); rhs.append(0)
        rows.append(np.dot(tau2,A.T)); rhs.append(0)
        rows.append(tau3); rhs.append(0)
    if p >= 5:
        tau4 = c**4 - 4*np.dot(A,c**3)
        raise NotImplementedError
    
    V = np.vstack(rows)
    return V, np.array(rhs)

In [None]:
rkm = rk4
V, rhs = OC_b(rkm.A,rkm.c,rkm.p)

In [None]:
np.dot(V,rkm.b)-rhs

In [None]:
def RRK_linprog(rkm, dt, f, w0=[1.,0], t_final=1.):
    """    
    Options:
    
        rkm: Base Runge-Kutta method, in Nodepy format
        dt: time step size
        f: RHS of ODE system
        w0: Initial data
        t_final: final solution time        
    """
    w = np.array(w0)
    t = 0
    ww = np.zeros([len(w0),int((t_final-t)/dt+1)])
    ww[:,0] = w.copy()
    tt = [t]
    ii = 0
    s = len(rkm)
    b = rkm.b
    y = np.zeros((s,len(w0)))
    max_gammam1 = 0.
    gams = []

    e = np.ones(s)
    bb = cp.Variable(s)

    VV = []
    rr = []
    for order in range(1,rkm.p):
        V, rhs = OC_b(rkm.A,rkm.c,p)
        VV.append(V)
        rr.append(rhs)
        
    while t < t_final:
        if t + dt >= t_final:
            dt = t_final - t # Hit final time exactly
        
        for i in range(s):
            y[i,:] = w.copy()
            for j in range(i):
                y[i,:] += rkm.A[i,j]*dt*f(y[j,:])
                
        F = np.array([f(y[i,:]) for i in range(s)])
        dw = [dt*b[j]*F[j] for j in range(s)]

        # Set up and solve the linear program
        p = rkm.p
        V, rhs = OC_b(rkm.A,rkm.c,p)
        prob = cp.Problem(cp.Minimize(e.T@bb),
                 [V@b==rhs,w+dt*bb*F>=0])
        prob.solve()
        b = bb.value()
        
        t += dt
        ii += 1
        tt.append(t)
        ww[:,ii] = w.copy()

    return tt, ww[:,:ii+1]