In [None]:
import numpy as np
import dedalus.public as d3
import matplotlib.pyplot as plt
import logging
logger = logging.getLogger(__name__)
%config InlineBackend.figure_format = 'retina'

In [None]:
a = 1
C = 6**(-1/2)

σ = 1
R = 1e5

# most solutions in Toomre, Gough & Speigel use "rigid boundary conditions"
rigid = True

In [None]:
nz = 256
Lz = 1
dealias = 3/2
dtype = np.float64
coord = d3.Coordinate('z') #d3.CartesianCoordinates(['z'])
dist = d3.Distributor(coord, dtype=dtype)
zbasis = d3.ChebyshevT(coord, size=nz, bounds=(0, Lz), dealias=dealias)

W = dist.Field(name='W', bases=zbasis)
θ = dist.Field(name='θ', bases=zbasis)
T0 = dist.Field(name='T0', bases=zbasis)
τ1 = dist.Field(name='τ1')
τ2 = dist.Field(name='τ2')
τ3 = dist.Field(name='τ3')
τ4 = dist.Field(name='τ4')
τ5 = dist.Field(name='τ5')
τ6 = dist.Field(name='τ6')
τ7 = dist.Field(name='τ7')
τ8 = dist.Field(name='τ8')


# Substitutions
dz = lambda A: d3.Differentiate(A, coord)

lift_basis = zbasis.derivative_basis(1)
lift = lambda A, n: d3.Lift(A, lift_basis, n)

D = lambda A: dz(dz(A))-a**2*A

In [None]:
# Problem
dt = lambda A: 0*A # search for steady solution

problem = d3.NLBVP([W, θ, T0, τ1, τ2, τ3, τ4, τ5, τ6, τ7, τ8], namespace=locals())
problem.add_equation("1/σ*dt(D(W)) - D(D(W)) + R*a**2*θ + lift(τ1, -1) + lift(τ2, -2) + lift(τ3, -3) + lift(τ4, -4) = -(C/σ)*(2*dz(W)*D(W) + W*D(dz(W)))")
problem.add_equation("dt(θ) - D(θ) + lift(τ5, -1) + lift(τ6, -2) = -dz(T0)*W - C*(2*W*dz(θ) + θ*dz(W))")
problem.add_equation("dt(T0) - dz(dz(T0)) + lift(τ7, -1) + lift(τ8, -2) = -dz(W*θ)")

problem.add_equation("T0(z=0) = 1")
problem.add_equation("θ(z=0) = 0")
problem.add_equation("W(z=0) = 0")
problem.add_equation("T0(z=Lz) = 0")
problem.add_equation("θ(z=Lz) = 0")
problem.add_equation("W(z=Lz) = 0")
if rigid:
    problem.add_equation("dz(W)(z=0) = 0")
    problem.add_equation("dz(W)(z=Lz) = 0")
else:
    # stress-free
    problem.add_equation("dz(dz(W))(z=0) = 0")
    problem.add_equation("dz(dz(W))(z=Lz) = 0")

In [None]:
z = dist.local_grid(zbasis)
T0['g'] = Lz - z
W['g'] = 30*(np.sin(z/Lz*np.pi))
θ['g'] = np.sin(z/Lz*np.pi)+0.1*np.sin(z/Lz*3*np.pi)

In [None]:
solver = problem.build_solver()

In [None]:
for system in ['subsystems', 'solvers']:
    logging.getLogger(system).setLevel(logging.WARNING)
pert_norm = np.inf
tolerance = 1e-6
while pert_norm > tolerance:
    solver.newton_iteration()
    pert_norm = sum(pert.allreduce_data_norm('c', 2) for pert in solver.perturbations)
    N = (W*θ - dz(T0)).evaluate()['g'][0]
    logger.info('Perturbation norm: {:.3e}, N = {:g}'.format(pert_norm, N))

In [None]:
for field in [W, θ, T0]:
    field.change_scales(1)

fig, ax = plt.subplots()
ax.plot(z, W['g']/np.max(W['g']), linestyle='dashdot')
ax.plot(z, θ['g']/np.max(θ['g']), linestyle='dashed')
ax.plot(z, T0['g'])


In [None]:
N = (W*θ - dz(T0)).evaluate()['g'][0]
print("θ = {:g}".format(np.max(θ['g'])))
print("W = {:g}".format(np.max(W['g'])))
print("N = {:g}".format(N))