# Emulators

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

### load / simulate data

In [None]:
from L96sim.L96_base import f1, f2, J1, J1_init, f1_juliadef, f2_juliadef
from L96_emulator.util import predictor_corrector
from L96_emulator.run import sel_dataset_class

F, h, b, c = 10, 1, 10, 10
K, J, T, dt = 36, 10, 65, 0.001
spin_up_time, train_frac = 5., 0.8

fn_data = f'out_K{K}_J{J}_T{T}_dt0_{str(dt)[2:]}'
if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)

resimulate, save_sim = True, False
if resimulate:
    print('simulating data')
    X_init = F * (0.5 + np.random.randn(K*(J+1)) * 1.0).astype(dtype=dtype_np) / np.maximum(J,10)
    dX_dt = np.empty(X_init.size, dtype=X_init.dtype)
    times = np.linspace(0, T, int(np.floor(T/dt)+1))

    out = predictor_corrector(fun=fun, y0=X_init.copy(), times=times, alpha=0.5)

    # filename for data storage
    if save_sim: 
        np.save(data_dir + fn_data, out.astype(dtype=dtype_np))
else:
    print('loading data')
    out = np.load(data_dir + fn_data + '.npy')

plt.figure(figsize=(8,4))
plt.imshow(out.T, aspect='auto')
plt.xlabel('time')
plt.ylabel('location')
plt.show()


prediction_task = 'state'
lead_time = 1
DatasetClass = sel_dataset_class(prediction_task=prediction_task)
dg_train = DatasetClass(data=out, J=J, offset=lead_time, normalize=False, 
                   start=int(spin_up_time/dt), 
                   end=int(np.floor(out.shape[0]*train_frac)))

# symmetric solver (one-level Lorenz-96 for now)

## leapfrog 
- method for 2nd-order differential equation, so let's try getting the second derivative... 

In [None]:
from L96sim.L96_base import f1, f2, J1, J1_init, f1_juliadef, f2_juliadef
from L96_emulator.util import predictor_corrector
from L96_emulator.run import sel_dataset_class

F, h, b, c = 10, 1, 10, 10
K, J, T, dt = 36, 0, 65, 0.001
spin_up_time, train_frac = 5., 0.8

fn_data = f'out_K{K}_J{J}_T{T}_dt0_{str(dt)[2:]}'
if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)

resimulate, save_sim = True, False
if resimulate:
    print('simulating data')
    X_init = F * (0.5 + np.random.randn(K*(J+1)) * 1.0).astype(dtype=dtype_np) / np.maximum(J,10)
    dX_dt = np.empty(X_init.size, dtype=X_init.dtype)
    times = np.linspace(0, T, int(np.floor(T/dt)+1))

    out = predictor_corrector(fun=fun, y0=X_init.copy(), times=times, alpha=0.5)

    # filename for data storage
    if save_sim: 
        np.save(data_dir + fn_data, out.astype(dtype=dtype_np))
else:
    print('loading data')
    out = np.load(data_dir + fn_data + '.npy')

plt.figure(figsize=(8,4))
plt.imshow(out.T, aspect='auto')
plt.xlabel('time')
plt.ylabel('location')
plt.show()


In [None]:
from L96_emulator.networks import AnalyticModel_oneLevel, AnalyticModel_twoLevel

model = AnalyticModel_oneLevel(K=K)#, J=J)
X_init = F * (0.5 + np.random.randn(K) * 1.0).astype(dtype=dtype_np) / np.maximum(J,10)
dX_dt = np.empty(K, dtype=X_init.dtype)
    
kplus1, kminus1, kminus2 = model.td_mat(K,1), model.td_mat(K,-1), model.td_mat(K,-2)
d2fdt2 = lambda x: - fun(0,x) - kminus1.dot(fun(0,x).copy())*(kplus1-kminus2).dot(x) - (kplus1-kminus2).dot(fun(0,x).copy())*kminus1.dot(x)
# check second derivative numerically
T = 5000
plt.plot((fun(0.,out[T+1]).copy()-fun(0.,out[T]))/dt)
plt.plot(-d2fdt2(out[T]))
plt.show()

In [None]:
dt = 0.001
T = 1.0
times = np.linspace(0, T, int(np.floor(T/dt)+1))

def fun2(t,x):
    return - d2fdt2(x)

def leapfrog(fun, y0, z0, times):
    
    y = np.zeros((len(times), *y0.shape), dtype=y0.dtype)
    z = np.zeros_like(y)
    y[0] = y0
    z[0] = z0
    for i in range(1,len(times)):        
        dt = times[i] - times[i-1]

        a = fun(times[i-1], y[i-1]).copy()
        y[i] = y[i-1] + dt * z[i-1] + dt**2/2. * a
        z[i] = z[i-1] + dt * (a + fun(times[i], y[i]))/2.

    return y, z

out2, _ = leapfrog(fun=fun2, y0=X_init.copy(), z0=fun(0., X_init).copy(), times=times)

plt.figure(figsize=(8,4))
plt.imshow(out2.T, aspect='auto')
plt.colorbar()
plt.xlabel('time')
plt.ylabel('location')
plt.show()

### load model

In [None]:
from L96_emulator.networks import MinimalNetL96
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch 

model = MinimalNetL96(K,J,F,b,c,h,skip_conn=True,loc=1e3)
std_out = torch.as_tensor(dg_train.std, device=device, dtype=dtype)
mean_out = torch.as_tensor(dg_train.mean, device=device, dtype=dtype)

def model_forward(x):
    alpha = 0.5
    ndim = x.ndim

    x = sortL96fromChannels(x * std_out + mean_out) if ndim == 3 else x

    f0 = model.forward(x)
    f1 = model.forward(x + dt*f0)

    x = x + dt * (alpha*f0 + (1-alpha)*f1)
    x = (sortL96intoChannels(x, J=J) - mean_out) / std_out

    return  sortL96fromChannels(x) if ndim == 2 else x

### example rollout

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout
from L96_emulator.eval import solve_from_init

model_simulate = get_rollout_fun(dg_train, model_forward, prediction_task)

T_start, T_dur = 100*int(spin_up_time/dt), 10000
out_model = model_simulate(y0=dg_train[T_start].copy(), 
                           dy0=dg_train[T_start]-dg_train[T_start-dg_train.offset],
                           T=T_dur)
out_model = sortL96fromChannels(out_model * dg_train.std + dg_train.mean)

solver_comparison = True 
if solver_comparison:
    try: 
        print(F, h, b, c)
    except: 
        F, h, b, c = 10, 1, 10, 10

    out2 = solve_from_init(K, J, 
                           T_burnin=T_start, T_=T_dur, dt=dt, 
                           F=F, h=h, b=b, c=c, 
                           data=out, dilation=2, norm_mean=0., norm_std=1.)

fig = plot_rollout(out, out_model, out_comparison=out2, T_start=T_start, T_dur=T_dur, K=None)


# Solving a fully-observed inverse problem

In [None]:
"""
from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T, N = 10, len(T_start)

roller_outer = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N)
x_init = roller_outer.X.detach().numpy().copy()

target = torch.as_tensor(out[T_start+T], dtype=dtype, device=device)

n_steps, lr, weight_decay = 1000, 5e-2, 0.
roller_outer.train()
optimizer = torch.optim.Adam(roller_outer.parameters(), lr=lr, weight_decay=weight_decay)
loss_vals = np.zeros(n_steps)
for i in range(n_steps):
        optimizer.zero_grad()
        loss = ((roller_outer.forward(T=T) - target)**2).mean()
        loss.backward()
        optimizer.step()
        loss_vals[i] = loss.detach().numpy()

plt.figure(figsize=(8,2))
plt.semilogy(loss_vals, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()
"""

## ADAM, solve across full rollout time in one go 

In [None]:
import time
from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T_rollout, N = 100, len(T_start)

x_init = out[T_start+T_rollout].copy()
roller_outer_ADAM = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_init)
x_init = roller_outer_ADAM.X.detach().cpu().numpy().copy()

target = torch.as_tensor(out[T_start+T_rollout], dtype=dtype, device=device)

n_steps, lr, weight_decay = 2000, 0.01, 0.0
roller_outer_ADAM.train()

optimizer = torch.optim.Adam(roller_outer_ADAM.parameters(), lr=lr, weight_decay=weight_decay)

print(((roller_outer_ADAM.forward(T=T_rollout).detach().cpu().numpy() - out[T_start+T_rollout])**2).mean())

loss_vals_ADAM = np.zeros(n_steps)
time_vals_ADAM = time.time() * np.ones(n_steps)
for i in range(n_steps):
        optimizer.zero_grad()
        loss = ((roller_outer_ADAM.forward(T=T_rollout) - target)**2).mean()
        loss.backward()
        optimizer.step()
        loss_vals_ADAM[i] = loss.detach().cpu().numpy()
        time_vals_ADAM[i] = time.time() - time_vals_ADAM[i]
        print((time_vals_ADAM[i], loss_vals_ADAM[i]))
        
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals_ADAM, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()

## ADAM, split rollout time into chunks, solve sequentially from end to beginning

In [None]:
import time
from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T_rollout, N = 100, len(T_start)


n_steps, lr, weight_decay = 2000, 0.01, 0.0

loss_vals_test = np.zeros(n_steps)
time_vals_test = time.time() * np.ones(n_steps)

n_chunks = 10
T_rollout_i = (T_rollout//n_chunks) * np.ones(n_chunks, dtype=np.int)

x_inits = np.zeros((n_chunks, N, K*(J+1)))
x_init = out[T_start+T_rollout].copy()

targets = np.zeros((n_chunks, N, K*(J+1)))
targets[0] = out[T_start+T_rollout]


i_ = 0
for j in range(n_chunks):

    roller_outer_test = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_init)
    x_inits[j] = roller_outer_test.X.detach().cpu().numpy().copy()
    optimizer = torch.optim.Adam(roller_outer_test.parameters(), lr=lr, weight_decay=weight_decay)
    target = torch.as_tensor(targets[j], dtype=dtype, device=device)
    roller_outer_test.train()
    for i in range(n_steps//n_chunks):

        optimizer.zero_grad()
        loss = ((roller_outer_test.forward(T=T_rollout_i[j]) - target)**2).mean()
        loss.backward()
        optimizer.step()
        loss_vals_test[i_] = loss.detach().cpu().numpy()
        time_vals_test[i_] = time.time() - time_vals_test[i_]
        print((time_vals_test[i_], loss_vals_test[i_]))
        
        i_ += 1

    x_init = roller_outer_test.X.detach().cpu().numpy().copy()
    if j < n_chunks - 1:
        targets[j+1] = roller_outer_test.X.detach().cpu().numpy().copy()
            
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals_test, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()

In [None]:
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals_ADAM, label='in one go')
plt.semilogy(loss_vals_test, label='in 10 chunks')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE')
plt.xlabel('gradient step')
plt.legend()
plt.show()

In [None]:
target = torch.as_tensor(targets[0], dtype=dtype, device=device)
((roller_outer_test.forward(T=T_rollout) - target)**2).mean()

In [None]:
target = torch.as_tensor(targets[0], dtype=dtype, device=device)
((roller_outer_ADAM.forward(T=T_rollout) - target)**2).mean()

In [None]:
plt.figure(figsize=(16,16))
for i in range(N):
    plt.subplot(2,N,i+1)
    plt.plot(roller_outer_ADAM.X.detach().cpu().numpy().copy()[i], label='one go')
    plt.plot(roller_outer_test.X.detach().cpu().numpy().copy()[i], '--', label='in 10 chunks')
    plt.legend()

    plt.subplot(2,N,N+i+1)
    plt.plot(roller_outer_ADAM.forward(T=T_rollout).detach().cpu().numpy().copy()[i], label='one go')
    plt.plot(roller_outer_test.forward(T=T_rollout).detach().cpu().numpy().copy()[i], '--', label='in 10 chunks')
    plt.legend()
    
plt.show()

### compare with plain gradient descent (SGD with single data point)

In [None]:
x_init = out[T_start+T_rollout].copy()
roller_outer_SGD = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_init)
x_init = roller_outer_SGD.X.detach().cpu().numpy().copy()

n_steps, lr, weight_decay = 2000, 0.01, 0.0
roller_outer_SGD.train()

optimizer = torch.optim.Adam(roller_outer_SGD.parameters(), lr=lr, weight_decay=weight_decay)
"""
optimizer = torch.optim.LBFGS(params=roller_outer.parameters(), 
                              lr=lr, 
                              max_iter=20, 
                              max_eval=None, 
                              tolerance_grad=1e-07, 
                              tolerance_change=1e-09, 
                              history_size=100, 
                              line_search_fn=None)
"""
loss_vals_SGD = np.zeros(n_steps)
time_vals_SGD = time.time() * np.ones(n_steps)
for i in range(n_steps):
        optimizer.zero_grad()
        loss = ((roller_outer_SGD.forward(T=T_rollout) - target)**2).mean()
        loss.backward()
        optimizer.step()
        loss_vals_SGD[i] = loss.detach().cpu().numpy()
        time_vals_SGD[i] = time.time() - time_vals_SGD[i]
        print((time_vals_SGD[i], loss_vals_SGD[i]))
        
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals_SGD, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()

In [None]:
plt.figure(figsize=(16,8))

plt.subplot(1,2,2)
plt.semilogy(loss_vals_ADAM, '--', label=f'ADAM, lr=0.01')
try:
    plt.semilogy(loss_vals_SGD, label=f'SGD, lr=0.01')
except:
    pass
#plt.legend()
#plt.ylabel('MSE')
plt.xlabel('# gradient steps')

plt.subplot(1,2,1)
plt.semilogy(time_vals_ADAM, loss_vals_ADAM, '--', label=f'ADAM, lr=0.01')
try:
    plt.semilogy(time_vals_SGD, loss_vals_SGD, label=f'SGD, lr=0.01')
except:
    pass
plt.legend()
plt.suptitle('rollout final state loss across gradient descent steps')
plt.ylabel('MSE')
plt.xlabel('time [s]')
plt.show()

## L-BFGS, split rollout time into chunks, solve sequentially from end to beginning

In [None]:
import time
from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T_rollout, N = 100, len(T_start)


n_steps, lr, weight_decay = 1000, 1.0, 0.0

loss_vals_LBFGS_chunks = np.zeros(n_steps)
time_vals_LBFGS_chunks = time.time() * np.ones(n_steps)

n_chunks = 10
T_rollout_i = (T_rollout//n_chunks) * np.ones(n_chunks, dtype=np.int)

x_inits = np.zeros((n_chunks, N, K*(J+1)))
x_init = out[T_start+T_rollout].copy()

targets = np.zeros((n_chunks, N, K*(J+1)))
targets[0] = out[T_start+T_rollout]

x_sols = np.zeros_like(x_inits)

i_ = 0
for j in range(n_chunks):

    roller_outer_LBFGS_chunks = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_init)
    x_inits[j] = roller_outer_LBFGS_chunks.X.detach().cpu().numpy().copy()
    #optimizer = torch.optim.Adam(roller_outer_LBFGS_chunks.parameters(), lr=lr, weight_decay=weight_decay)
    optimizer = torch.optim.LBFGS(params=roller_outer_LBFGS_chunks.parameters(), 
                                  lr=lr, 
                                  max_iter=20, 
                                  max_eval=None, 
                                  tolerance_grad=1e-07, 
                                  tolerance_change=1e-09, 
                                  history_size=50, 
                                  line_search_fn='strong_wolfe')
    target = torch.as_tensor(targets[j], dtype=dtype, device=device)
    roller_outer_LBFGS_chunks.train()
    for i in range(n_steps//n_chunks):

        loss = ((roller_outer_LBFGS_chunks.forward(T=T_rollout_i[j]) - target)**2).mean()
        #optimizer.zero_grad()
        #loss.backward()
        #optimizer.step()
        def closure():
            loss = ((roller_outer_LBFGS_chunks.forward(T=T_rollout_i[j]) - target)**2).mean()
            optimizer.zero_grad()
            loss.backward()
            return loss            
        optimizer.step(closure)        
        loss_vals_LBFGS_chunks[i_] = loss.detach().cpu().numpy()
        time_vals_LBFGS_chunks[i_] = time.time() - time_vals_LBFGS_chunks[i_]
        print((time_vals_LBFGS_chunks[i_], loss_vals_LBFGS_chunks[i_]))
        i_ += 1

    x_init = roller_outer_LBFGS_chunks.X.detach().cpu().numpy().copy()
    x_sols[j] = roller_outer_LBFGS_chunks.X.detach().cpu().numpy().copy()
    if j < n_chunks - 1:
        targets[j+1] = roller_outer_LBFGS_chunks.X.detach().cpu().numpy().copy()
            
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals_LBFGS_chunks, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()

## L-BFGS, solve across full rollout time in one go 

In [None]:
import time
from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T_rollout, N = 100, len(T_start)


n_steps, lr, weight_decay = 1000, 1.0, 0.0

loss_vals_LBFGS_chunks = np.zeros(n_steps)
time_vals_LBFGS_chunks = time.time() * np.ones(n_steps)
x_init = out[T_start+T_rollout].copy()

n_chunks = 10
i_ = 0
for j in range(n_chunks):

    roller_outer_LBFGS_chunks = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_init)
    optimizer = torch.optim.LBFGS(params=roller_outer_LBFGS_chunks.parameters(), 
                                  lr=lr, 
                                  max_iter=20, 
                                  max_eval=None, 
                                  tolerance_grad=1e-07, 
                                  tolerance_change=1e-09, 
                                  history_size=50, 
                                  line_search_fn='strong_wolfe')
    target = torch.as_tensor(out[T_start+T_rollout], dtype=dtype, device=device)
    roller_outer_LBFGS_chunks.train()
    for i in range(n_steps//n_chunks):

        loss = ((roller_outer_LBFGS_chunks.forward(T=(j+1)*T_rollout//n_chunks) - target)**2).mean()
        #optimizer.zero_grad()
        #loss.backward()
        #optimizer.step()
        def closure():
            loss = ((roller_outer_LBFGS_chunks.forward(T=(j+1)*T_rollout//n_chunks) - target)**2).mean()
            optimizer.zero_grad()
            loss.backward()
            return loss            
        optimizer.step(closure)        
        loss_vals_LBFGS_chunks[i_] = loss.detach().cpu().numpy()
        time_vals_LBFGS_chunks[i_] = time.time() - time_vals_LBFGS_chunks[i_]
        print((time_vals_LBFGS_chunks[i_], loss_vals_LBFGS_chunks[i_]))
        i_ += 1
            
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals_LBFGS_chunks, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()

## L-BFGS, solve across full rollout time in one go, initialize from chunked approach with 10-step delay

In [None]:
import time
from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T_rollout, N = 100, len(T_start)


n_steps, lr, weight_decay = 1000, 1.0, 0.0

loss_vals_LBFGS_chunks = np.zeros(n_steps)
time_vals_LBFGS_chunks = time.time() * np.ones(n_steps)

n_chunks = 10
i_ = 0
for j in range(n_chunks):

    x_init = x_inits[j]
    roller_outer_LBFGS_chunks = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_init)
    optimizer = torch.optim.LBFGS(params=roller_outer_LBFGS_chunks.parameters(), 
                                  lr=lr, 
                                  max_iter=20, 
                                  max_eval=None, 
                                  tolerance_grad=1e-07, 
                                  tolerance_change=1e-09, 
                                  history_size=50, 
                                  line_search_fn='strong_wolfe')
    target = torch.as_tensor(out[T_start+T_rollout], dtype=dtype, device=device)
    roller_outer_LBFGS_chunks.train()
    for i in range(n_steps//n_chunks):

        loss = ((roller_outer_LBFGS_chunks.forward(T=(j+1)*T_rollout//n_chunks) - target)**2).mean()
        #optimizer.zero_grad()
        #loss.backward()
        #optimizer.step()
        def closure():
            loss = ((roller_outer_LBFGS_chunks.forward(T=(j+1)*T_rollout//n_chunks) - target)**2).mean()
            optimizer.zero_grad()
            loss.backward()
            return loss            
        optimizer.step(closure)        
        loss_vals_LBFGS_chunks[i_] = loss.detach().cpu().numpy()
        time_vals_LBFGS_chunks[i_] = time.time() - time_vals_LBFGS_chunks[i_]
        print((time_vals_LBFGS_chunks[i_], loss_vals_LBFGS_chunks[i_]))
        i_ += 1
            
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals_LBFGS_chunks, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()

## L-BFGS, solve across full rollout time in one go, initialize from chunked approach

In [None]:
import time
from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T_rollout, N = 100, len(T_start)


n_steps, lr, weight_decay = 1000, 1.0, 0.0

loss_vals_LBFGS_chunks = np.zeros(n_steps)
time_vals_LBFGS_chunks = time.time() * np.ones(n_steps)

n_chunks = 10
i_ = 0
for j in range(n_chunks):

    x_init = x_sols[j]
    roller_outer_LBFGS_chunks = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_init)
    optimizer = torch.optim.LBFGS(params=roller_outer_LBFGS_chunks.parameters(), 
                                  lr=lr, 
                                  max_iter=20, 
                                  max_eval=None, 
                                  tolerance_grad=1e-07, 
                                  tolerance_change=1e-09, 
                                  history_size=50, 
                                  line_search_fn='strong_wolfe')
    target = torch.as_tensor(out[T_start+T_rollout], dtype=dtype, device=device)
    roller_outer_LBFGS_chunks.train()
    for i in range(n_steps//n_chunks):

        loss = ((roller_outer_LBFGS_chunks.forward(T=(j+1)*T_rollout//n_chunks) - target)**2).mean()
        #optimizer.zero_grad()
        #loss.backward()
        #optimizer.step()
        def closure():
            loss = ((roller_outer_LBFGS_chunks.forward(T=(j+1)*T_rollout//n_chunks) - target)**2).mean()
            optimizer.zero_grad()
            loss.backward()
            return loss            
        optimizer.step(closure)        
        loss_vals_LBFGS_chunks[i_] = loss.detach().cpu().numpy()
        time_vals_LBFGS_chunks[i_] = time.time() - time_vals_LBFGS_chunks[i_]
        print((time_vals_LBFGS_chunks[i_], loss_vals_LBFGS_chunks[i_]))
        i_ += 1
            
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals_LBFGS_chunks, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()

In [None]:
MSEs_chunks = np.zeros(n_chunks)
MSEs_direct__init_chunks = np.zeros(n_chunks)
MSEs_direct__init_prev = np.zeros(n_chunks)

target = torch.as_tensor(out[T_start+T_rollout], dtype=dtype, device=device)
for j in range(n_chunks):

    roller_outer_LBFGS_chunks = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_sols[j])
    MSEs_chunks[j] = ((roller_outer_LBFGS_chunks.forward(T=(j+1)*T_rollout//n_chunks) - target)**2).mean().detach().cpu().numpy()

    #roller_outer_LBFGS_chunks = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=x_sols[j])
    #MSEs_chunks[j] = ((roller_outer_LBFGS_chunks.forward(T=(j+1)*T_rollout//n_chunks) - target)**2).mean().detach().cpu().numpy()
    

In [None]:
plt.figure(figsize=(16,16))
for i in range(N):
    plt.subplot(2,N,i+1)
    plt.plot(roller_outer_ADAM.X.detach().cpu().numpy().copy()[i], label='one go')
    plt.plot(roller_outer_test.X.detach().cpu().numpy().copy()[i], '--', label='in 10 chunks')
    plt.plot(roller_outer_LBFGS_chunks.X.detach().cpu().numpy().copy()[i], label='in 10 chunks, L-BFGS')
    plt.legend()

    plt.subplot(2,N,N+i+1)
    plt.plot(roller_outer_ADAM.forward(T=T_rollout).detach().cpu().numpy().copy()[i], label='one go')
    plt.plot(roller_outer_test.forward(T=T_rollout).detach().cpu().numpy().copy()[i], '--', label='in 10 chunks')
    plt.plot(roller_outer_LBFGS_chunks.forward(T=T_rollout).detach().cpu().numpy().copy()[i], '--', label='in 10 chunks, L-BFGS')
    plt.legend()
    
plt.show()

## more plotting

In [None]:
N_max = 2 # chose N_max << N if N is very large and you don't want hundreds of subplots

plt.figure(figsize=(16,2*N_max))
for n in range(N_max):
    plt.subplot(np.ceil(N_max/2),2,n+1)
    plt.plot(x_init[n], 'k', label='init', alpha=0.2)
    plt.plot(roller_outer.X.detach().cpu().numpy()[n,:], color='orange', linewidth=1.5, label='target')
    plt.plot(out[T_start[n]].T, 'b--', linewidth=0.5, label='est.')
    plt.xlabel('state dimension')
    plt.ylabel(f'iniital state at T = {T_start[n]}')
    plt.legend()
plt.suptitle('estimated initial state')
plt.show()

plt.figure(figsize=(16,2*N_max))
for n in range(N_max):
    plt.subplot(np.ceil(N_max/2),2,n+1)
    plt.plot(out[T_start[n]].flatten() - roller_outer.X.detach().cpu().numpy()[n,:], 'orange', 
             label='est. - true initial state')
    plt.plot(out[T_start[n]].flatten() - out[T_start[n]+T_rollout].flatten(), 'k', alpha=0.3, 
             label='future - initial state')
    plt.legend()
    plt.xlabel('state dimension')
    plt.ylabel(f'initial state error at T = {T_start[n]}')
plt.suptitle('error of estimated initial state')
plt.show()

plt.figure(figsize=(16,2*N_max))
roller_outer2 = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=out[T_start])
for n in range(N_max):
    plt.subplot(np.ceil(N_max/2),2,n+1)
    plt.plot(out[T_start[n]+T_rollout].flatten() - roller_outer.forward(T=T_rollout).detach().cpu().numpy()[n], 
             'orange', label='true final state - rollout from est. init. state')
    plt.plot(out[T_start[n]+T_rollout].flatten() - roller_outer2.forward(T=T_rollout).detach().cpu().numpy()[n], 
             'k', alpha=0.3, label='true final state - rollout from true init. state')
    plt.legend()
    plt.xlabel('state dimension')
    plt.ylabel(f'final state error at T = {T_start[n]}+{T_rollout}')
plt.suptitle('error of estimated final state (under the learned model)')
plt.show()