In [None]:
#
# Code by Eric Cyr adapted from Ravi Patel's code
#
# Based on the MOR-Physics paper: https://doi.org/10.1016/j.cma.2020.113500
# also based on : https://proceedings.mlr.press/v190/patel22a/patel22a.pdf
#
# pip install numpy scipy matplotlib torch jupyter 

import pickle
import math
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.func as func
import matplotlib.pyplot as plt

def generate_soln_pairs(steps,trajectory,interval=1):
  """
  Extract solution pairs (in->out) from a single trajectory.

    steps : number of steps between input/output
    trajectory : Input trajectory simulation
    interval : Interval between sampled pairs, controls the size of the data set
  """
  inputs = trajectory[:,:-steps:interval]
  targets = trajectory[:,steps::interval]

  collapsed_shape = (inputs.shape[0]*inputs.shape[1],) + tuple(inputs.shape[2:])
  inputs = inputs.reshape(collapsed_shape)
  targets = targets.reshape(collapsed_shape)

  permute = np.random.permutation(inputs.shape[0])
  inputs = inputs[permute]
  targets = targets[permute]

  return inputs,targets

# load up sample data
filename = 'ks_samples.pkl'
if not os.path.exists('./'+filename):
  import ks_solver
  ks_solver.main()

with open(filename, 'rb') as file:
  x,ks_sims = pickle.load(file)
X = torch.tensor(np.expand_dims(x,-1))

# plot an example
initial_step = 0
num_steps = 10000
stride = 10

spec_soln = ks_sims[0,initial_step:(initial_step+num_steps):stride].T.squeeze()
spec_soln = torch.tensor(spec_soln).to(torch.float32)

print(spec_soln.shape,x.shape)
plt.imshow(spec_soln.detach().numpy(),aspect=4./4.)
plt.xlabel('Time')
plt.ylabel('Space')
plt.title('Kuramoto–Sivashinsky Solution')
plt.xticks([])
plt.yticks([])
plt.savefig('example-ks.png',dpi=200,bbox_inches='tight')

In [None]:
class MLP(nn.Module):
  """
  A multi-layer perceptron (MLP) network with default initializtion
  of dense layers, and usin ELU activations.
  """
  def __init__(self,features):
    super().__init__()

    layers = []
    assert len(features)>1
    for i in range(len(features)-2):
      L = nn.Linear(features[i],features[i+1])
      torch.nn.init.constant_(L.bias,0.0)
      layers.append(L)
      layers.append(nn.ELU())
    layers.append(nn.Linear(features[-2],features[-1]))
    self.mlp = nn.Sequential(*layers)

  def forward(self,x):
    return self.mlp(x)

class MORPhysics_1d(nn.Module):
  """
  This is a 1D MOR Physics network as developed in our paper. The network
  can be trained to learn an operator mapping a function evaluated at
  points 'x' in the physical domain to a new function at points 'x'. This looks
  like where N is the learned operator, u is the input function and f is
  the output function.

    N(u(x)) -> f(x)

  The architecture uses two MLPs sandwiched between inverse/forward
  fourier transforms. So if 'F' represents the fourier transform

    N(u) = inv(F) * g(K) * F * h(u)

  The network h acts on scalars in physical space and 
  can inject a nonlinearity. The network 'g' acts in frequency space
  on the values 'K' which are modes in the frequency domain corrsponding
  to the input 'x'.
  """
  
  def __init__(self,dim_in,dim_out,x,features):
    super().__init__()

    self.features = features
    self.dim_in = dim_in
    self.dim_out = dim_out

    # setup the model evaluation points
    self.X = np.expand_dims(x,-1)
    self.K = self.compute_fourier_frequencies(self.X)
    self.K.requires_grad_(False)

    space_dim = len(x.shape)
    assert space_dim==1
    
    self.h_model = MLP([dim_in]+features)
    self.g_model = MLP([space_dim]+features[0:-1]+[2*features[-1]*dim_out])

  def to_complex(self,x):
    y = torch.complex(*torch.split(x,int(x.shape[-1]/2),dim=-1))
    return torch.reshape(y,list(y.shape[0:-1])+[self.features[-1],self.dim_out])

  def compute_fourier_frequencies(self,X):
    space_dim = len(X.shape)-1
    seg1 = tuple(space_dim*[1])
    seg0 = tuple(space_dim*[0])
    dx = X[seg1]-X[seg0] # compute dimension-wise space difference
    n = np.array(X.shape[0:-1])
    L = dx*n   # domain size

    ks =  [2*np.pi*np.fft.fftfreq(ni)/Li for ni,Li in zip(n[0:-1],L[0:-1])] \
        + [2*np.pi*np.fft.rfftfreq(n[-1])/L[-1]]
    ks = torch.tensor(ks)
        
    K = torch.stack(torch.meshgrid(*ks,indexing='ij'),-1).to(torch.float32)
    return K

  def __call__(self,u):
    h = self.h_model(u)
    gK = self.g_model(self.K)
    g = self.to_complex(gK).squeeze(-1)
    h_freq = torch.fft.rfftn(h,dim=-2)
    Nu_freq = torch.einsum('ij,bij->bi',g,h_freq).unsqueeze(-1)
    Nu = torch.fft.irfftn(Nu_freq,dim=-2)
    return Nu

In [None]:
def loss_fn(model,x,y,dt,steps):
  for _ in range(steps):
    dx = model(x)
    x = x+dt*dx
  y_pred = x
  return ((y_pred - y) ** 2).mean()

def train_step(model, optimizer, x, y,dt,steps):
  optimizer.zero_grad()
  
  loss = loss_fn(model,x,y,dt,steps)
  loss.backward()
  optimizer.step()  # inplace updates

  return loss

def get_batch(u,v,bs):
    for i in range(math.ceil(u.shape[0]/bs)):
        yield (torch.tensor(u[bs*i:bs*(i+1)]).to(torch.float32),
               torch.tensor(v[bs*i:bs*(i+1)]))

def train(model,u_train,v_train,bs,epochs,dt,steps):
  optimizer = optim.Adam(model.parameters(),lr=1e-3)

  for lr in [1e-3]:
    for epochs in range(epochs):
        data = get_batch(u_train,v_train,bs)
        for b in range(len(u_train)//bs): # floor division
            u,v = next(data)
            loss_value = train_step(model,optimizer,u,v,dt,steps)
            if b % 100==0:
              print(f'epoch: {epochs:4d}, learning rate: {lr}, iter: {b}, loss: {loss_value:4e}')

# define the architecture
#########################################
dim_in = 1   # input field dimension
dim_out = 1  # output field dimension
depth = 4
width = 8
features = depth*[width]

morp = MORPhysics_1d(dim_in,dim_out,x,features)

# train the architecture
#########################################
bs = 32
init_skips = 1
init_dt = init_skips*0.01

# we train over multiple ranges of time steps. Mean predict 1, 2, 3 and 4 time
# steps out. This implictly makes the solution more stable so that we can evolve
# the entire domain. Without this iterative training process you would not find a
# stable surrogate.
for steps in [1,2,3,4]:
  print(f'starting {steps}')
  skips = steps*init_skips
  dt = init_dt

  num_train_trajs = 8
  u_train,v_train = generate_soln_pairs(skips,ks_sims[:num_train_trajs],interval=4)
  train(morp,u_train,v_train,bs,10,dt,steps)

In [None]:
def eval_model(u0,steps,model,dt):
  u = u0.unsqueeze(-1).unsqueeze(0)
  results = [u.squeeze()]
  for i in range(steps):
    u = u+dt*model(u)
    results.append(u.squeeze())

  return torch.stack(results).T

dt = init_dt

initial_step = 0
num_steps = 10000
stride = 10

plt.figure()
fig,axs = plt.subplots(4,2,figsize=(10,8))
for i,index in enumerate([3,7,8,9]):
  spec_soln = ks_sims[index,initial_step:(initial_step+num_steps):stride].T.squeeze()
  spec_soln = torch.tensor(spec_soln).to(torch.float32)

  # compute NN prediction
  trajectory = eval_model(spec_soln[:,0],num_steps,morp,dt)
  morp_soln = trajectory[:,::stride]

  axs[i,0].imshow(morp_soln.detach().numpy(),aspect='auto')
  axs[i,0].set_title(f'MOR-Physics {index}')
  axs[i,0].set_xticks([])
  axs[i,0].set_yticks([])

  axs[i,1].imshow(spec_soln.detach().numpy(),aspect='auto')
  axs[i,1].set_title(f'Reference {index}')
  axs[i,1].set_xticks([])
  axs[i,1].set_yticks([])

plt.subplots_adjust(hspace=0.5)
plt.savefig('ks-morp-compare.png',dpi=200,bbox_inches='tight')