In [None]:
import numpy as np
import matplotlib.pyplot as plt
import dedalus.public as de
import dedalus.extras.plot_tools as plot_tools
import atmospheres as atmos
import tides
import modes
import parameters as param
import mpi4py.MPI as MPI
import uuid
import logging
logger = logging.getLogger(__name__)
%matplotlib notebook

In [None]:
import importlib
importlib.reload(param)

## Solve 1D linear tide

In [None]:
# Solve tide
domain, problem = tides.linear_tide_1d(param)
solver = problem.build_solver()
solver.solve()

In [None]:
# Plot coefficients to check convergence
fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.semilogy(np.abs(solver.state['p1']['c']))
ax2.semilogy(np.abs(solver.state['a1']['c']))
ax1.set_ylabel("|p1['c']|")
ax2.set_ylabel("|a1['c']|")

In [None]:
# Plot tide
scales = 1
kx = param.k_tide
x = np.linspace(0, param.Lx, param.Nx, endpoint=False)[:,None]
z = domain.grid(0, scales=scales)
Z, X = plot_tools.quad_mesh(z.flatten(), x.flatten())

field = 'p1'
f = solver.state[field]
f.set_scales(scales)
fig, axes = plt.subplots(1, 1)
im = axes.pcolormesh(X, Z, 2*np.real(np.exp(1j*kx*x)*f['g']))
plt.colorbar(im)
axes.set_title(field)
axes.set_xlabel('x')
axes.set_ylabel('z')

## Mode amplitude extraction

Eigenmodes:

$$\lambda_i M \cdot X_i + L \cdot X_i = 0$$

Adjoint modes:

$$Y_j \cdot M \cdot X_i = \delta_{ji}$$

Decompose arbitrary state:

$$X = \sum_i a_i X_i$$

$$a_j = Y_j \cdot M \cdot X$$

Decompose linear solution:

$$L \cdot X = F$$

$$\sum_i a_i L \cdot X_i = F$$

$$\sum_i a_i \lambda_i M \cdot X_i = - F$$

$$a_j = - \frac{Y_j \cdot F}{\lambda_j}$$

In [None]:
# Check mode amplitudes from projecting RHS vs extracting solved coefficients
evals, evecs, adj_evals, adj_evecs, proj, sevp, pevp = modes.compute_eigenmodes(param, param.k_tide, sparse=True, N=20, target=20)

In [None]:
# Construct RHS
p = solver.pencils[0]
pFe = solver.Fe.get_pencil(p)
pFb = solver.Fb.get_pencil(p)
b = p.G_eq * pFe + p.G_bc * pFb
X = solver.state.get_pencil(p).copy()
print('Check L.X = b:', np.allclose(p.L@X, b))

# Bias evals by tidal frequency, which is left out of EVP
amps_proj = - adj_evecs.T.conj() @ b / (evals - param.ω_tide)
amps_solve = proj @ X
print('A_proj == A_solve:', np.allclose(amps_proj, amps_solve))

## Mode completeness

In [None]:
# Check completeness of modes for representing linear tide
evals, evecs, adj_evals, adj_evecs, proj, sevp, pevp = modes.compute_eigenmodes(param, param.k_tide, sparse=False, minreal=0, maxabs=np.inf)

In [None]:
# eigenvalues = sevp.full_eigenvalues
# eigenvectors = sevp.full_eigenvectors
# adjoint_eigenvalues = sevp.full_adjoint_eigenvalues
# adjoint_eigenvectors = sevp.full_adjoint_eigenvectors

# maxabs = np.inf
# minreal = 0
# # Filter modes
# keep = np.isfinite(eigenvalues) * (np.abs(eigenvalues) < maxabs) * (np.abs(eigenvalues.real) > minreal)
# eigenvalues = eigenvalues[keep]
# eigenvectors = eigenvectors[:,keep]
# adjoint_eigenvalues = adjoint_eigenvalues[keep]
# adjoint_eigenvectors = adjoint_eigenvectors[:,keep]
# # Sort modes
# sorting = np.argsort(eigenvalues)
# eigenvalues = eigenvalues[sorting]
# eigenvectors = eigenvectors[:,sorting]
# sorting = np.argsort(adjoint_eigenvalues.conj())
# adjoint_eigenvalues = adjoint_eigenvalues[sorting]
# adjoint_eigenvectors = adjoint_eigenvectors[:,sorting]
# # Check mode matching
# logger.info("Max eval mismatch: %e" %np.max(np.abs(eigenvalues - adjoint_eigenvalues.conj())))
# if not np.allclose(eigenvalues, adjoint_eigenvalues.conj()):
#     logger.warn("WARNING: Adjoint modes may not match forward modes.")
# # Normalize modes
# if True:
#     # Normalize by energy
#     sevp.eigenvalues = eigenvalues
#     sevp.eigenvectors = eigenvectors
#     metric_diag = modes.compute_energies(sevp)
#     eigenvectors /= np.sqrt(metric_diag)
# else:
#     # Normalize by Chebyshev inner product
#     metric = eigenvectors.T.conj() @ eigenvectors
#     eigenvectors /= np.sqrt(np.diag(metric))
# # Normalize adjoint modes
# metric = adjoint_eigenvectors.T.conj() @ pevp.M @ eigenvectors
# adjoint_eigenvectors /= np.diag(metric).conj()
# projector = adjoint_eigenvectors.T.conj() @ pevp.M

# metric = adjoint_eigenvectors.T.conj() @ pevp.M @ eigenvectors



sorting = np.argsort(sevp.full_eigenvalues)
sevp.full_eigenvalues = sevp.full_eigenvalues[sorting]
sevp.full_eigenvectors = sevp.full_eigenvectors[:,sorting]
sevp.full_adjoint_eigenvectors = sevp.full_adjoint_eigenvectors[:,sorting]

sevp.eigenvalues = sevp.full_eigenvalues
sevp.eigenvectors = sevp.full_eigenvectors

FE = sevp.full_eigenvectors
FAE = sevp.full_adjoint_eigenvectors

energy = modes.compute_energies(sevp)
FE = FE / np.sqrt(energy)

FM = FAE.conj().T @ pevp.M @ FE
FAE = FAE / np.diag(FM).conj()
FM = FAE.conj().T @ pevp.M @ FE

metric = FM[:509,:509]

plt.figure()
plt.imshow(np.log10(np.abs(metric)), cmap='viridis', clim=(-20,0))
plt.colorbar()


In [None]:
def update_inv(F, I, M, indeces, loops):
    P = I @ M @ F
    # P[i,j] = I[i,:] @ M @ F[:,j]
    for l in range(loops):
        I_new = I.copy()
        for i in indeces:
            for j in indeces:
                if i == j:
                    I_new[i] /= P[i,i]
                else:
                    I_new[i] -= P[i,j] * I[j] 
        I = I_new
    return I_new

filt = (np.abs(sevp.full_eigenvalues.real) > 1e-5) * (np.abs(sevp.full_eigenvalues.real) < 1e5)
N = np.sum(filt)

F = FE[:,filt]
I = FAE[:,filt].conj().T
M = pevp.M

I = update_inv(F, I, M, np.arange(np.sum(filt)), 3)

In [None]:
metric = (I@M@F)

plt.figure()
plt.imshow(np.log10(np.abs(metric-np.eye(N))), cmap='viridis')
plt.colorbar()

In [None]:
from scipy import linalg
full_evecs = sevp.full_eigenvectors
inv_full_evecs = linalg.inv(full_evecs)

plt.figure()
plt.imshow(np.log10(np.abs(full_evecs)), cmap='viridis')
plt.colorbar()

In [None]:
filt = np.abs(evals.real) > 0

metric_L = adj_evecs.T.conj() @ pevp.L @ evecs
adj_evecs_L = adj_evecs / np.diag(metric_L).conj()
proj_L = adj_evecs_L.T.conj() @ pevp.L
metric_filt = proj[filt,:] @ evecs[:,filt]

gevp_error =  pevp.M @ evecs[:,filt] @ np.diag(evals[filt]) + pevp.L @ evecs[:,filt]
gevp_rel_error = np.abs(gevp_error) / np.abs(evecs[:,filt])

plt.figure()
plt.imshow(np.log10(np.abs(evecs[:,filt].T)), cmap='viridis')
plt.colorbar()

plt.figure()
plt.imshow(np.log10(np.abs(gevp_error.T)), cmap='viridis')
plt.colorbar()

plt.figure()
plt.imshow(np.log10(np.abs(gevp_rel_error.T)), cmap='viridis')
plt.colorbar()

plt.figure()
plt.imshow(np.log10(np.abs(metric_filt)-np.eye(np.sum(filt))), cmap='viridis')
plt.colorbar()

plt.figure()
plt.plot(evals[filt].real, evals[filt].imag, '.k')
plt.plot(evals[~filt].real, evals[~filt].imag, '.r')
plt.xscale('symlog', linthreshx=1e-10)
plt.yscale('symlog', linthreshy=1e-10)
plt.grid()

In [None]:
plt.figure()
for i in range(6):
    plt.semilogy(np.abs(evecs[:,filt][:,122][i::6]), label=str(i))
plt.gca().set_color_cycle(None)
for i in range(6):
    pass#plt.semilogy(np.abs(adj_evecs[:,filt][:,122][i::6]), '--', label=str(i))
plt.legend()

plt.figure()
for i in range(6):
    plt.semilogy(np.abs(evecs[:,filt][:,132][i::6]), label=str(i))
plt.gca().set_color_cycle(None)
for i in range(6):
    pass#plt.semilogy(np.abs(adj_evecs[:,filt][:,122][i::6]), '--', label=str(i))
plt.legend()

In [None]:
amps = - adj_evecs.T.conj() @ b / (evals - param.ω_tide)
X_recon = evecs @ amps
dX = X - X_recon

print('X == X_recon:', np.allclose(X, X_recon))
print('|X - X_recon|/|X|:', (np.dot(dX, dX.conj()) / np.dot(X, X.conj())).real)

plt.figure()
plt.loglog(evals.real, np.abs(amps), 'ob')
plt.loglog(-evals.real, np.abs(amps), '.r')
plt.grid()

In [None]:
z = domain.grid(0)
f = 'p1'
plt.figure()

solver.state.set_pencil(p, X)
solver.state.scatter()
plt.plot(z, solver.state[f]['g'].real, '-k', lw=2)
plt.plot(z, solver.state[f]['g'].imag, '--k', lw=2)

solver.state.set_pencil(p, X_recon)
solver.state.scatter()
plt.plot(z, solver.state[f]['g'].real, '-r', lw=1)
plt.plot(z, solver.state[f]['g'].imag, '--r', lw=1)

## 2D linear tide

In [None]:
domain2, problem2 = tides.linear_tide_2d(param)
solver2 = problem2.build_solver()
solver2.solve()

In [None]:
fig, axes = plt.subplots(1, 1)
field = 'p1'
f = solver2.state[field]
f.require_grid_space()
plot_tools.plot_bot_2d(f, axes=axes)