In [None]:
from dedalus.tools import post
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as ani
import time
import h5py
import pathlib
import shutil
import random
import logging
from dedalus.extras import flow_tools
import dedalus.public as de
logger = logging.getLogger('2DRB')

In [None]:
#non-dimensionless numbers
Lx, d = (5. ,1. )
xres, zres = (128, 64)
Ra = 100000
Pr = 5

xbasis = de.Fourier('x',xres,interval=(0,Lx),dealias=3/2)
zbasis = de.Chebyshev('z',zres,interval=(0,d),dealias=3/2)
domain = de.Domain([xbasis,zbasis],grid_dtype=np.float64)
problem = de.IVP(domain,variables=['T','p','u','w','Tz','uz','wz'])

In [None]:
#input equations
problem.parameters['Pr'] = Pr
problem.parameters['Ra'] = Ra
problem.parameters['xres'] = xres
problem.add_equation("dt(u) + dx(p) - (dx(dx(u)) + dz(uz)) = - (u * dx(u) + w * uz)")
problem.add_equation("dt(w) + dz(p) - (dx(dx(w)) + dz(wz)) - (Ra / Pr) * T = - (u * dx(w) + w * wz)")
problem.add_equation("dt(T) - (1 / Pr) * (dx(dx(T)) + dz(Tz)) = - (u * dx(T) + w * Tz)")
problem.add_equation("dz(u) - uz = 0")
problem.add_equation("dz(w) - wz = 0")
problem.add_equation("dz(T) - Tz = 0")
problem.add_equation("dx(u) + wz = 0")
problem.add_bc("right(T) = 0.2")
problem.add_bc("left(Tz) = 1")
problem.add_bc("left(u) = 0")
problem.add_bc("right(u) = 0")
problem.add_bc("left(w) = 0")
problem.add_bc("right(w) = 0",condition="(nx != 0)")
problem.add_bc("right(p) = 0",condition="(nx == 0)")

In [None]:
#solver
solver = problem.build_solver(de.timesteppers.RK111)
logger.info('Solver built')

In [None]:
if not pathlib.Path('restart.h5').exists():
    x, z = domain.all_grids()
    T = solver.state['T']
    Tz = solver.state['Tz']
    
    #perturbations
    gshape = domain.dist.grid_layout.global_shape(scales=1)
    slices = domain.dist.grid_layout.slices(scales=1)
    rand = np.random.RandomState(seed=42)
    noise = rand.standard_normal(gshape)[slices]
    zb,zt = zbasis.interval
    pert = 1e-1 * noise * (zt - z) * (z - zb)
    T['g'] = pert
    T.differentiate('z',out=Tz)
    dt = 1e-5
    stop_sim_time = 5
    fh_mode = 'overwrite'

else:
    write,last_dt = solver.load_state('restart.h5', -1
    dt = last_dt
    stop_sim_time = 5e-3
    fh_mode = 'append'

In [None]:
solver.stop_sim_time = stop_sim_time
solver.stop_wall_time = np.inf
solver.stop_iteration = np.inf

shutil.rmtree('snapshots', ignore_errors=True)
snapshots = solver.evaluator.add_file_handler('snapshots',sim_dt=1e-3,max_writes=200,mode=fh_mode)
snapshots.add_task("integ(T,'x')/xres", layout='g', name='<Tx>')
snapshots.add_task("0.5 * (u ** 2 + w ** 2)", layout='g', name='KE')
snapshots.add_task("sqrt(u ** 2 + w ** 2)", layout='g', name='|uvec|')
snapshots.add_system(solver.state)

#CFL (don't touch)
CFL = flow_tools.CFL(solver, initial_dt=dt, cadence=10, safety=0.5, max_change=1.5, min_change=1, max_dt=1e-3, threshold=0.05)
CFL.add_velocities(('u', 'w'))

flow = flow_tools.GlobalFlowProperty(solver, cadence=10)
flow.add_property("sqrt(u ** 2 + w ** 2)/Ra", name='Re')

In [None]:
try:
    logger.info('Starting loop')
    start_time = time.time()
    while solver.proceed:
        dt = CFL.compute_dt()
        dt = solver.step(dt)
        if (solver.iteration-1) % 10 == 0:
            logger.info('Iteration: %i, Time: %e, dt: %e' %(solver.iteration, solver.sim_time, dt))
except:
    logger.error('Exception raised, triggering end of main loop.')
    raise
finally:
    end_time = time.time()
    logger.info('Iterations: %i' %solver.iteration)
    logger.info('Sim end time: %f' %solver.sim_time)
    logger.info('Run time: %.2f minutes' %((end_time-start_time)/60))

In [None]:
post.merge_process_files("snapshots", cleanup=True)
set_paths = list(pathlib.Path("snapshots").glob("snapshots_s*.h5"))
post.merge_sets("snapshots/snapshots.h5", set_paths, cleanup=True)

In [None]:
#gif
with h5py.File("./snapshots/snapshots.h5", mode='r') as file:

    # Load datasets
    T = file['tasks']['T']
    t = T.dims[0]['sim_time']
    x = T.dims[1][0]
    z = T.dims[2][0]

    def animate(frame):
        quad.set_array(frame.T)

    #Plot data
    fig = plt.figure(figsize=(7, 6), dpi=100)
    quad = plt.pcolormesh(x, z, T[0].T, shading='nearest', cmap='coolwarm',vmin=0,vmax=1e-1)
    plt.colorbar()
    plt.xlabel('x')
    plt.ylabel('z')
    plt.tight_layout()
    
    #Animation
    animation = ani.FuncAnimation(fig, animate, frames=T[:])
    animation.save('convection.gif',fps=200)