In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
from colabtools import adhoc_import
import xarray
import functools
import jax
#jax.config.update('jax_platform_name', 'cpu')

import matplotlib.pyplot as plt
import jax_cfd.base as cfd
import jax_cfd.spectral.utils as spectral_utils
import jax_cfd.spectral.equations as spectral_equations
import jax_cfd.spectral.time_stepping as spectral_stepping
from jax_cfd.base import grids

#   import jax_cfd.ml as ml

#   equations = adhoc_import.Reload(equations, reset_flags=True)
#   utils = adhoc_import.Reload(utils, reset_flags=True)

import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
from jax import grad, jit, vmap, jacfwd, jvp, vjp
from jax import random
from tqdm.auto import tqdm
from scipy.ndimage import correlate1d
from jax.scipy.signal import correlate
import numpy as np

In [None]:
 def truncated_rfft(u):
  """Applies the 2/3 rule by truncating higher Fourier modes.

  Args:
    u: the real-space representation of the input signal

  Returns:
    Downsampled version of `u` in rfft-space.
  """
  uhat = jnp.fft.rfft(u)
  k, = uhat.shape
  final_size = int(np.ceil(2 / 3 * k))# + 1
  return 2 / 3 * uhat[:final_size]


def padded_irfft(uhat):
  n, = uhat.shape
  final_shape = int(np.floor(3 / 2 * n))
  smoothed = jnp.pad(uhat, (0, final_shape - n))
  assert smoothed.shape == (final_shape,), "incorrect padded shape"
  return (3/2) * jnp.fft.irfft(smoothed)

In [None]:
import dataclasses
#from jax_cfd.base import boundaries
@dataclasses.dataclass
class NonlinearSchrodinger(spectral_stepping.ImplicitExplicitODE):
  """Nonlinear schrodinger equation split in implicit and explicit parts.

  The NLS equation is
    psi_t = -i psi_xx/8 - i|psi|^2 psi/2 - psi_x/2 

  Attributes:
    grid: underlying grid of the process
    smooth: smooth the non-linear term using the 3/2-rule
  """
  grid: grids.Grid
  smooth: bool = True

  def __post_init__(self):
    self.kx, = self.grid.rfft_axes()
    self.two_pi_i_k = 2j * jnp.pi * self.kx
    diffusive_term = -self.two_pi_i_k**2/8# +self.two_pi_i_k**2/2#-self.two_pi_i_k**2/8 
    self.diffusive_term = jnp.concatenate([diffusive_term,diffusive_term])
    self.advection_term = -jnp.concatenate([self.two_pi_i_k,self.two_pi_i_k])/2
    self.rfft = truncated_rfft if self.smooth else jnp.fft.rfft
    self.irfft = padded_irfft if self.smooth else jnp.fft.irfft

  def mul_i(self, psi):
    """ multiply the state by i"""
    N = len(psi)//2
    real, imag = psi[:N], psi[N:]
    return jnp.concatenate([-imag, real])

  def explicit_terms(self, psihat):
    """Non-linear parts of the equation."""
    N = len(psihat)//2
    uhat,vhat = psihat[:N],psihat[N:]
    u = self.irfft(uhat)
    v = self.irfft(vhat)
    psi_squared = (u**2+v**2)
    cubic_real = self.rfft(psi_squared*u)
    cubic_imag = self.rfft(psi_squared*v)
    ipsi_cubed_hat = self.mul_i(jnp.concatenate([cubic_real,cubic_imag]))
    return -ipsi_cubed_hat/2#ipsi_cubed_hat#-ipsi_cubed_hat/2#+self.advection_term*psihat

  def implicit_terms(self, psihat):
    """Linear parts of the equation, namely `i psi_xx/2`."""
    return self.diffusive_term*self.mul_i(psihat)

  def implicit_solve(self, psihat, time_step):
    """Solves for `implicit_terms`, implicitly. 
        Implements (1-idtA)^-1 = (1+idtA)/(1+dt^2A^2) where A is the invertible
        implicit terms. Must be done via conjugate because i is a matrix."""
    ipsihat = self.mul_i(psihat)
    numerator = psihat + time_step*self.diffusive_term*self.mul_i(psihat)
    denominator = 1+(time_step*self.diffusive_term)**2
    return numerator/denominator

@dataclasses.dataclass
class ModifiedNonlinearSchrodinger(spectral_stepping.ImplicitExplicitODE):
  """Nonlinear schrodinger equation split in implicit and explicit parts.

  The MNLS equation is
    psi_t = -i psi_xx/8 - i|psi|^2 psi/2 - psi_x/2 + HOT

  Attributes:
    grid: underlying grid of the process
    smooth: smooth the non-linear term using the 3/2-rule
  """
  grid: grids.Grid
  smooth: bool = True

  def __post_init__(self):
    self.kx, = self.grid.rfft_axes()
    self.two_pi_i_k = 2j * jnp.pi * self.kx
    self.doubled = jnp.concatenate([self.two_pi_i_k,self.two_pi_i_k])
    implicit_term = -self.two_pi_i_k**2/8
    self.implicit_term = jnp.concatenate([implicit_term,implicit_term])
    self.rfft = truncated_rfft if self.smooth else jnp.fft.rfft
    self.irfft = padded_irfft if self.smooth else jnp.fft.irfft

  def mul_i(self, psi):
    """ multiply the state by i"""
    N = len(psi)//2
    real, imag = psi[:N], psi[N:]
    return jnp.concatenate([-imag, real])

  def explicit_terms(self, psihat):
    """Non-linear parts of the equation,."""
    N = len(psihat)//2
    uhat,vhat = psihat[:N],psihat[N:]
    u = self.irfft(uhat)
    v = self.irfft(vhat)
    psi_squared = (u**2+v**2)
    uterm = self.rfft(psi_squared*v)
    vterm = self.rfft(-psi_squared*u)
    cubic = jnp.concatenate([uterm,vterm])/2
    dispersion = (psihat*self.doubled**3)/16
    
    dx_u = self.irfft(uhat*self.two_pi_i_k)
    dx_v = self.irfft(vhat*self.two_pi_i_k)
    transport_a_real = self.rfft(-(3/2)*psi_squared*dx_u)
    transport_a_imag = self.rfft(-(3/2)*psi_squared*dx_v)
    transport_a = jnp.concatenate([transport_a_real,transport_a_imag])
    transport_b_real = self.rfft(-psi_squared*dx_u/4)
    transport_b_imag = self.rfft(psi_squared*dx_v/4)
    transport_b = jnp.concatenate([transport_b_real,transport_b_imag])
    dx_potential = -self.rfft(psi_squared)*jnp.abs(self.kx)/2
    potential_term_real = self.rfft(self.irfft(dx_potential)*v)
    potential_term_imag = self.rfft(-self.irfft(dx_potential)*u)
    potential_term = jnp.concatenate([potential_term_real,potential_term_imag])
    return cubic+dispersion+transport_a+transport_b+potential_term

  def implicit_terms(self, psihat):
    """Linear parts of the equation, namely `i psi_xx/2`."""
    return self.implicit_term*self.mul_i(psihat)

  def implicit_solve(self, psihat, time_step):
    """Solves for `implicit_terms`, implicitly. 
        Implements (1-idtA)^-1 = (1+idtA)/(1+dt^2A^2) where A is the invertible
        implicit terms. Must be done via conjugate because i is a matrix."""
    ipsihat = self.mul_i(psihat)
    numerator = psihat + time_step*self.implicit_term*self.mul_i(psihat)
    denominator = 1+(time_step*self.implicit_term)**2
    return numerator/denominator

In [None]:
jnp.fft.fftfreq(99)[int(np.ceil(99/2)):]

In [None]:
jnp.fft.fftfreq(99)#.sum()

In [None]:
int(np.ceil(99/2))

In [None]:
fft_truncated_2x()

In [None]:
z = jnp.fft.fft(u0)
z1 = ifft_padded_2x(z)
z2 = fft_truncated_2x(z1)
jnp.abs(z2-z).mean()

In [None]:
 def fft_truncated_2x(u):
  """Applies the 1/2 rule by truncating higher Fourier modes.

  Args:
    u: the (complex) input signal

  Returns:
    Downsampled version of `u` in fft-space.
  """
  uhat = jnp.fft.fftshift(jnp.fft.fft(u))
  k, = uhat.shape
  final_size = (k+1)//2#int(np.ceil(k/2))# + 1
  # shifted_freq = jnp.fft.fftshift(jnp.fft.fftfreq(k))
  # #print('sfreq',shifted_freq)
  # clipped_sfreq = shifted_freq[final_size//2:(-final_size+1)//2]
  # print('clipped_sfreq',clipped_sfreq)
  # #print('out_sfreq',jnp.fft.ifftshift(clipped_sfreq))
  return jnp.fft.ifftshift(uhat[final_size//2:(-final_size+1)//2])/2


def ifft_padded_2x(uhat):
  n, = uhat.shape
  final_size = n+2*(n//2)
  added = n//2
  smoothed = jnp.pad(jnp.fft.fftshift(uhat), (added, added))
  assert smoothed.shape == (final_size,), "incorrect padded shape"
  return 2 * jnp.fft.ifft(jnp.fft.ifftshift(smoothed))

@dataclasses.dataclass
class NLS(spectral_stepping.ImplicitExplicitODE):
  """Nonlinear schrodinger equation split in implicit and explicit parts.

  The NLS equation is
    psi_t = -i psi_xx/8 - i|psi|^2 psi/2 - psi_x/2 

  Attributes:
    grid: underlying grid of the process
    smooth: smooth the non-linear term using the 3/2-rule
  """
  grid: grids.Grid
  smooth: bool = True

  def __post_init__(self):
    self.kx, = self.grid.fft_axes()
    self.two_pi_i_k = 2j * jnp.pi * self.kx
    self.fft = fft_truncated_2x if self.smooth else jnp.fft.fft
    self.ifft = ifft_padded_2x if self.smooth else jnp.fft.ifft

  def explicit_terms(self, psihat):
    """Non-linear parts of the equation."""
    psi = self.ifft(psihat)
    ipsi_cubed = 1j*psi*jnp.abs(psi)**2
    ipsi_cubed_hat = self.fft(ipsi_cubed)
    return -ipsi_cubed_hat/2

  def implicit_terms(self, psihat):
    """Linear parts of the equation, namely `-i psi_xx/2`."""
    return -1j*psihat*self.two_pi_i_k**2/8

  def implicit_solve(self, psihat, time_step):
    """Solves for `implicit_terms`, implicitly. 
        Implements (1-idtA)^-1 = (1+idtA)/(1+dt^2A^2) where A is the invertible
        implicit terms. Must be done via conjugate because i is a matrix."""
    return psihat/(1-time_step*(-1j*self.two_pi_i_k**2/8))

In [None]:
def rollout(stepfn,steps,u0,max_samples=1024):
   multistepfn = jit(cfd.funcutils.repeated(stepfn,max(steps//max_samples,1)))
   return cfd.funcutils.trajectory(multistepfn,max_samples,start_with_input=True)(u0)


def solve(u0, t_final=1., max_samples=1024,dt=1e-2,L=500):
  N = len(u0)
  grid = grids.Grid((N,),domain=((-L/2,L/2),))
  dx, = grid.step
  xs, = grid.axes(offset=(0,))
  eq = NLS(grid=grid)
  stepfn = spectral_stepping.crank_nicolson_rk4(eq,dt)
  #stepfn = spectral_stepping.imex_runge_kutta(eq,dt)
  uhat0 = jnp.fft.fft(u0)
  #print(stepfn(uhat0))
  numsteps = int(t_final/dt)
  steps,uhat_traj = rollout(stepfn,numsteps,uhat0,max_samples)
  #print(uhat_traj.shape,steps[0].shape,steps[1].shape)
  u_traj = jax.vmap(jnp.fft.ifft)(uhat_traj)
  #timesteps = steps*dt#
  timesteps = (jnp.arange(min(max_samples,numsteps)))*dt*max(numsteps//max_samples,1)
  return u_traj,xs,timesteps

L=500#256*np.pi
N=2**4+1
#x = (np.arange(N)/N)*L
eps=.05
sig=.1

dx=L/N
k = np.fft.fftfreq(N, d=dx)
eps=.05
sigma=.01 # original .01
dk = k[1]-k[0]
u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)
u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)
u0 = np.fft.ifft(u0_phase*np.sqrt(2*dk*u0_spectrum),norm='forward')
dt=1e-2
N = len(u0)
grid = grids.Grid((N,),domain=((-L/2,L/2),))
dx, = grid.step
xs, = grid.axes(offset=(0,))
eq = NLS(grid=grid)
stepfn = spectral_stepping.crank_nicolson_rk2(eq,dt)
#stepfn = spectral_stepping.imex_runge_kutta(eq,dt)
uhat0 = jnp.fft.fft(u0)
stepfn(uhat0)

In [None]:


# def rollout(stepfn,steps,u0,max_samples=1024):
#   stepfn = jit(stepfn)
#   u=u0
#   out = []
#   step_out = []
#   for step in range(steps):
#     u = stepfn(u)
#     if steps<=max_samples or not (step%(steps//max_samples)):
#       out.append(u)
#       step_out.append(step)
#   return jnp.array(step_out),jnp.stack(out,axis=0)

def rollout(stepfn,steps,u0,max_samples=1024):
   multistepfn = jit(cfd.funcutils.repeated(stepfn,max(steps//max_samples,1)))
   return cfd.funcutils.trajectory(multistepfn,max_samples)(u0)


def solve(u0, t_final=1., max_samples=1024,dt=1e-2,L=500):
  N = len(u0)
  grid = grids.Grid((N,),domain=((-L/2,L/2),))
  dx, = grid.step
  xs, = grid.axes(offset=(0,))
  eq = NonlinearSchrodinger(grid=grid)
  stepfn = spectral_stepping.crank_nicolson_rk4(eq,dt)
  #stepfn = spectral_stepping.imex_runge_kutta(eq,dt)
  uhat0_real = jnp.fft.rfft(jnp.real(u0))
  n = len(uhat0_real)
  uhat0 = jnp.concatenate([uhat0_real,jnp.fft.rfft(jnp.imag(u0))])
  numsteps = int(t_final/dt)
  steps,uhat_traj = rollout(stepfn,numsteps,uhat0,max_samples)
  #print(uhat_traj.shape,steps[0].shape,steps[1].shape)
  u_traj_real = jax.vmap(jnp.fft.irfft)(uhat_traj[:,:n])
  u_traj_imag = jax.vmap(jnp.fft.irfft)(uhat_traj[:,n:])
  #timesteps = steps*dt#
  timesteps = (1+jnp.arange(numsteps//max(numsteps//max_samples,1)))*dt*max(numsteps//max_samples,1)
  return u_traj_real+1j*u_traj_imag,xs,timesteps

## Random phase initial condition distribution

In [None]:
L=500#256*np.pi
N=2**11
#x = (np.arange(N)/N)*L
eps=.05
sig=.1

dx=L/N
k = np.fft.fftfreq(N, d=dx)
eps=.05
sigma=.01 # original .01
dk = k[1]-k[0]
u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)
u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)
u0 = np.fft.ifft(u0_phase*np.sqrt(2*dk*u0_spectrum),norm='forward')
#u0 = .5*(np.mean(np.abs(u0))+u0)
plt.plot(jnp.abs(u0))

In [None]:
T = 1024#4096
dt=1e-2
soln,x_ds,t_ds = solve(u0,T,dt=dt,max_samples=T)

In [None]:
soln.shape

In [None]:
t_ds.shape

In [None]:
u_ds = jnp.abs(soln)
import xarray
plt.figure(figsize=(9, 6))
xarray.DataArray(
    u_ds, dims=["time", "space"], coords={"time": t_ds[:u_ds.shape[0]], "space": x_ds,
}).plot.imshow(
    cmap="RdBu", robust=True)
plt.grid(False)
plt.show()

In [None]:
import scipy

In [None]:
a = scipy.ndimage.zoom(soln,(1/32,1/16))
donwsampled_u = a

In [None]:
a.shape

In [None]:
t_ds.shape

In [None]:

import xarray
plt.figure(figsize=(9, 6))
xarray.DataArray(
    jnp.imag(donwsampled_u), dims=["time", "space"], coords={"time": t_ds[::32], "space": x_ds[::16],
}).plot.imshow(
    cmap="RdBu", robust=True)
plt.grid(False)
plt.show()

In [None]:
u_ds[::8,::16].shape

In [None]:
plt.plot(t_ds,jnp.max(jnp.abs(u_ds),axis=-1))
plt.xlabel('Time t')
plt.ylabel('Max wave height')

In [None]:
plt.plot(t_ds[::32],jnp.max(jnp.abs(a),axis=-1))
plt.xlabel('Time t')
plt.ylabel('Max wave height')

In [None]:
import tensorflow as tf
#tf.config.experimental.set_visible_devices([], "GPU")

from colabtools import adhoc_import
import importlib
from simulation_research.diffusion.config import get_config
from simulation_research.diffusion.train import train_and_evaluate
from simulation_research.diffusion import ode_datasets
from simulation_research.diffusion import diffusion_unet
from simulation_research.diffusion import samplers
from simulation_research.diffusion import diffusion as train
importlib.reload(ode_datasets)
importlib.reload(diffusion_unet)
importlib.reload(samplers)
importlib.reload(train)

In [None]:
jnp.abs(soln).shape

In [None]:
test_x = train_x =  jnp.abs(u_ds[::8,::16])#.shape#jnp.abs(u_ds[::64,::32])

In [None]:
test_x[:64,].shape

In [None]:
x = test_x[None,:,:,None]#next(dataiter())
t = np.random.rand(x.shape[0])
model = diffusion_unet.UNet(diffusion_unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))

In [None]:
x.shape

In [None]:
from absl import logging
#logging.getLogger().setLevel(logging.INFO)
logging.get_absl_handler().python_handler.stream = sys.stdout

In [None]:
t = np.random.rand(x.shape[0])
params = model.init(random.PRNGKey(42), x=x,t=t,train=False)

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator

In [None]:
import scipy
downsample = lambda u: u[...,::2]#scipy.signal.decimate(u,2)
soln2,x_ds2,_ = solve(jnp.array(downsample(u0)),T,dt=dt)

In [None]:
errs = jnp.sqrt(jnp.square(downsample(soln)-soln2).mean(-1)/jnp.square(soln).mean(-1))
plt.plot(t_ds,errs)
plt.xlabel('Time t')
plt.ylabel('Relative error vs doubled resolution')
plt.yscale('log')

In [None]:
# RW solution
N=2**11
L = 40*jnp.pi#*np.sqrt(2)
grid = grids.Grid((N,),domain=((-L/2,L/2),))
dx, = grid.step
xs, = grid.axes(offset=(0,))
zs = xs*np.sqrt(2)
u0 = (4*zs**2-3)/(1+4*zs**2)

tau = 8
T = tau*2
dt=3e-4
soln,x_ds,t_ds = solve(u0,T,dt=dt,L=L)
z_ds = x_ds*np.sqrt(2)
tau_ds = t_ds/2
u_ds = jnp.abs(soln)
import xarray
plt.figure(figsize=(9, 6))
xarray.DataArray(
    u_ds, dims=["time", "space"], coords={"time": t_ds[:u_ds.shape[0]], "space": x_ds,
}).plot.imshow(
    cmap="RdBu", robust=False)
plt.grid(False)
plt.show()

In [None]:
print(u_ds.shape,t_ds.shape,x_ds.shape)

In [None]:

gt_soln = jnp.conj((1-4*(1+2j*tau_ds[:,None])/(1+4*(z_ds**2+tau_ds[:,None]**2)))*jnp.exp(1j*tau_ds[:,None]))

In [None]:
plt.figure(figsize=(9, 6))
xarray.DataArray(
    jnp.abs(gt_soln), dims=["time", "space"], coords={"time": t_ds[:u_ds.shape[0]], "space": x_ds,
}).plot.imshow(
    cmap="RdBu", robust=False)
plt.grid(False)
plt.show()

In [None]:
gt_soln.shape

In [None]:
plt.plot(t_ds,jnp.abs(soln-gt_soln).mean(-1))
plt.yscale('log')
plt.xlabel('Time t')
plt.ylabel('Psi error')

In [None]:
jnp.abs(soln-gt_soln).mean()<1e-3

In [None]:
plt.plot(t_ds,jnp.abs(jnp.abs(u_ds)-jnp.abs(gt_soln)).mean(-1))
plt.yscale('log')
plt.xlabel('Time t')
plt.ylabel('|Psi| error')

In [None]:
plt.plot(x_ds,jnp.abs(u0))
plt.plot(x_ds,jnp.abs(gt_soln[0]))
plt.plot(x_ds,jnp.abs(soln[0]))

In [None]:
import xarray
plt.figure(figsize=(9, 6))
xarray.DataArray(
    jnp.abs(gt_soln), dims=["time", "space"], coords={"time": t_ds[:u_ds.shape[0]], "space": x_ds,
}).plot.imshow(
    cmap="RdBu", robust=False)
plt.grid(False)
plt.show()

In [None]:
wave = jnp.real(jnp.exp(1j*(x_ds-t_ds[:,None])/30))
plt.figure(figsize=(18, 12))
xarray.DataArray(
    wave, dims=["time", "space"], coords={"time": t_ds[:u_ds.shape[0]], "space": x_ds,
}).plot.imshow(
    cmap="RdBu", robust=True)
plt.grid(False)
plt.show()

In [None]:
u_ds = jnp.abs(soln)
import xarray
plt.figure(figsize=(9, 6))
xarray.DataArray(
    u_ds, dims=["time", "space"], coords={"time": t_ds[:u_ds.shape[0]], "space": x_ds,
}).plot.imshow(
    cmap="RdBu", robust=True)
plt.grid(False)
plt.show()

In [None]:
plt.imshow(jnp.abs(soln))

In [None]:
plt.plot(jnp.absolute(u0))
plt.plot(jnp.real(u0))

In [None]:

import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
N = 1024
dx = 500/N
x = jnp.arange(N)*dx

eps =.5
a=.001


  #du3dx = correlate1d(u3,ddx,mode='wrap')
  #rhs = d2udx2*eps**2/2+u3+1j*a*eps*du3dx
  #return rhs/(-1j*eps) 

# k = np.fft.fftfreq(N, d=dx)
# eps=2e1
# sigma=.05
# u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)
# u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)
# u0 = np.fft.ifft(u0_phase*np.sqrt(2*u0_spectrum))

k = np.fft.fftfreq(N, d=dx)
eps=2.5
sigma=.01
u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)
u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)
u0 = np.fft.ifft(u0_phase*np.sqrt(2*u0_spectrum))

ddx = np.array([-1,0,1])/(2*dx)
ddx2 = np.array([1,-2,1])/(dx**2)
def nls(t,u):
  dudx = correlate1d(u,ddx,mode='wrap')
  d2udx2 = correlate1d(u,ddx2,mode='wrap')
  u3 = (jnp.abs(u)**2)*u
  return dudx/2+1j*d2udx2/8+1j*u3/2
sol = solve_ivp(nls,(0,200),u0,rtol=1e-6,method='DOP853')

In [None]:
t = sol.t
y = sol.y[:,::(len(t)//100)+1]
t= t[::len(t)//100+1]

In [None]:
k=15
cs = np.random.rand(k)*300+80
ws = np.random.rand(k)*15+10
vs = np.random.randn(k)*10
rs = np.random.randn(k)*3
u0 = .1*sum(jnp.exp(-((x-c)/w)**2/2-v*1j*x/w)*r for c,w,v,r in zip(cs,ws,vs,rs))
#u0= jnp.exp(-((x-300)/30)**2/2 - 20*1j*x)*.2
plt.plot(np.real(u0))

In [None]:

sol = solve_ivp(mnls,(0,200),u0,rtol=1e-6,method='BDF')#,method='RK23',rtol=1e-3)#,t_eval=jnp.linspace(0,20,10))

In [None]:
len(t)

In [None]:
plt.plot(jnp.abs(sol.y[:,76]))

In [None]:
from matplotlib import rc
rc('animation', html='jshtml')

fig = plt.figure()
ax1 = fig.add_subplot(111)
line, = ax1.plot(x, u0, c='r', label=r'$|\psi(x)|$')
plt.ylim(-.3,.3)
def init():
    line.set_data(x, u0)
    return [line]

def animate(i):
    line.set_data(x,jnp.real(y[:,i]))
    return [line]

from matplotlib import animation
anim = animation.FuncAnimation(
        fig,
        animate,
        frames=len(t),
        interval=33,
        init_func=init,
        blit=False)

In [None]:
anim

In [None]:
plt.plot(k/(k[1]-k[0]))

In [None]:
L=256*np.pi
eps=.05
sig=.1
N=2**12
dx=L/N
k = np.fft.fftfreq(N, d=dx)

In [None]:
# u0RandPhase from https://www.dropbox.com/sh/ov0952luetkpgwr/AAAJyAo94peygtxQyb9_FmAua/%2BIC?dl=0&subfolder_nav_tracking=1
eps=.05
sigma=.1
u0 = 1*np.ones(N)+0j
x = (np.arange(N)/N)*L
ii = np.arange(1000)+1
u0 += np.exp(-(2*np.pi*ii/(L*sigma))**2+1j*(2*np.pi/L)*ii*x[:,None]+2*np.pi*np.random.rand(len(ii))*1j).sum(1)
u0 += np.exp(-(2*np.pi*ii/(L*sigma))**2-1j*(2*np.pi/L)*ii*x[:,None]+2*np.pi*np.random.rand(len(ii))*1j).sum(1)
u0 = u0*eps/np.sqrt(2*np.pi*sigma**2)
#k = np.fft.fftfreq(N, d=dx)

In [None]:
# u0GaussSpec from https://www.dropbox.com/sh/ov0952luetkpgwr/AAAJyAo94peygtxQyb9_FmAua/%2BIC?dl=0&preview=u0GaussSpec.m&subfolder_nav_tracking=1
eps=.05
sigma=.1
x = (np.arange(N)/N)*L
dkx = 2*np.pi/L
k = 6*np.fft.fftfreq(N, d=dx)
S = (1+eps*np.random.randn(N))*(eps**2*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2))
uhat = np.sqrt(2*dkx*S)*np.exp(2j*np.pi*np.random.rand(N))
u0 = np.fft.ifft(uhat)

In [None]:
# y0JONSWAP_Hs from https://www.dropbox.com/sh/ov0952luetkpgwr/AAAJyAo94peygtxQyb9_FmAua/%2BIC?dl=0&preview=u0JONSWAP_Hs.m&subfolder_nav_tracking=1

k0 = 2*np.pi/200
L = 256*np.pi/k0
N = 2**int(np.ceil(np.log2(8*k0*L/np.pi)))
x = (np.arange(N)/N)*L

dx = L/N
ksl = k0
gamma=5
Hs=9.5
k = np.fft.fftfreq(N, d=dx)
sig0 = .07*(k<=0)+.09*(k>0)
S = (k+k0)**(-3) * np.exp(-1.5*(k0/(k+0j))**2)*gamma**np.exp(-((k+0j)-k0)**2/2/(sig0*k0)**2)
#print(S.shape)
S[k+k0<=0]=0
S[np.abs(k)>ksl]=0
uhat  = np.sqrt(S)*np.exp(2j*np.pi*np.random.rand(len(S)))
u = np.fft.ifft(uhat)
Hs_0 = 4*np.std(np.real(u))
u0 = (Hs/Hs_0)*u

In [None]:
L=500#256*np.pi
N=2**12
x = (np.arange(N)/N)*L
eps=.05
sig=.1

dx=L/N
k = np.fft.fftfreq(N, d=dx)
eps=.05
sigma=.01
dk = k[1]-k[0]
u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)
u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)
u0 = np.fft.ifft(u0_phase*np.sqrt(2*dk*u0_spectrum),norm='forward')

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10,3))
plt.plot(x,np.real(u0*np.exp(1j*x)))
plt.xlabel('Space')
plt.ylim(-.3,.3)
plt.xlim(0,L)

In [None]:
plt.plot(x,np.abs(u0))

In [None]:
np.real(u0)[:20]

In [None]:

plt.plot(k,np.abs(uhat))
plt.xlabel('angular frequency k')
plt.yscale('log')

In [None]:

plt.plot(k,np.abs(np.fft.fft(u0)))
plt.xlabel('angular frequency k')
plt.yscale('log')
#plt.xlim(-2,2)

In [None]:
np.fft.fft(u0)

In [None]:
plt.plot(np.real(out))
plt.plot(np.imag(out))

In [None]:
from colabtools import adhoc_import
import importlib
import jax_cfd.base as cfd
import jax_cfd.spectral.utils as spectral_utils
import jax_cfd.spectral.equations as spectral_equations
import jax_cfd.spectral.time_stepping as spectral_stepping
from jax_cfd.base import grids
from jax_cfd.spectral.equations_test import EquationsTest1D

In [None]:
t = EquationsTest1D()

In [None]:
t.test_nls_equation()