In [None]:
%matplotlib widget
import numpy as np
import cupy as cp
from functools import partial
from time import time as gct

# to plot (not needed)
import matplotlib.pyplot as plt
import ipywidgets as iwi

# Toggle CPU/GPU
mp = np

In [None]:
## Run an Ensemble Lorenz System 
### https://en.wikipedia.org/wiki/Lorenz_system

# u is [x, y, z]
def RHS(u, t, σ, ρ, β):
    v = mp.empty_like(u)
    v[0] = σ*(u[1] - u[0])
    v[1] = u[0]*(ρ - u[2]) - u[1]
    v[2] = u[0]*u[1] - β*u[2]
    return v

## One step of Runge-Kutta of 4th order
### https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods
def RK4(rhs, u, t, dt):
    k1 = rhs(u, t)
    k2 = rhs(u + 0.5*dt*k1, t + 0.5*dt)
    k3 = rhs(u + 0.5*dt*k2, t + 0.5*dt)
    k4 = rhs(u +     dt*k3, t +     dt)
    return u + dt*(k1 + 2*k2 + 2*k3 + k4)/6

In [None]:
### MAIN LOOP
N_ensemble = 16
u = mp.zeros((3, N_ensemble))

# Initial condition(s)
u[:] = mp.array([1, 1, 1])[:, None] + mp.random.randn(*u.shape)*1e-4

# Simulation parameters
dt, t_f = 0.01, 30
time = mp.r_[0:(t_f + dt/2):dt]
sol = mp.empty((time.size, 3, N_ensemble)) # To plot (not needed to compute)
d = mp.empty(time.size)

rhs = partial(RHS, σ=10.0, ρ=28.0, β=8/3) 

__t0 = gct()
for it, t in enumerate(time):
    sol[it] = u
    d[it] = ((u[:, None, :] - u[:, :, None])**2).sum()/(2*N_ensemble*(N_ensemble - 1))
    u = RK4(rhs, u, t, dt)

__t1 = gct()
print(f"Elapsed time: {__t1 - __t0} seconds.")

In [None]:
# VISUALISATION
plt.close("all")
ax = plt.figure().add_subplot(projection='3d')

lines = []
for s in sol.T:
    lines.append(ax.plot(*s[:, 0:10])[0])

    
ax.axis((-19.418412333150403, 21.41106172373093,
         -25.887737159291817, 29.710790591988996,
         0.961516335668642, 47.83407360152122))

slider = iwi.IntSlider(min=0, max=(tm:=time.size-11), value=0, step=1,
                         description='Iteration', continuous_update=True, readout=False)
play   = iwi.Play(min=0, max=tm, step=1, interval=5, continuous_update=False)
iwi.jslink((slider, 'value'), (play, ('value')))
out = iwi.Output()

def time_changed(change):
    global lines, balls
    it0 = change.new
    it1 = change.new + 10
    for line, s in zip(lines, sol.T):
        line.set_data_3d(*s[:, it0:it1])
        
    with out:
        print(f"{it0:05d}", end='\r')
    
slider.observe(time_changed, 'value')

iwi.HBox([play, slider, out])

In [None]:
plt.close("all")
plt.plot(time, d)
plt.xlabel("Time")
plt.ylabel("$L_2$-distance")
plt.gca().set_yscale("log")
plt.tight_layout()

In [None]:
import matplotlib
matplotlib.__version__