# Axisymmetric Magneto-Stokes Flow with a Rigid Free-slip Lid

This notebook performs the simulations described in Section 3.2 of David et al. (2024) (https://doi.org/10.1017/jfm.2024.674)

**NOTE:** *This notebook uses Dedalus v2 (https://github.com/DedalusProject/dedalus/releases/tag/v2.2207)*

In [None]:
import dedalus.public as de
from dedalus.tools import post
import magstokes_soln_minimal as magstokes
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors
from scipy.special import erf
import os
import pandas as pd
import glob
import file_tools as flt
import re

# Implement solver for magneto-Couette flow

In [None]:
def MagnetoCouetteFlow(Re,H,R,simtime,save_freq,dt,nz,nr,timestepper,save_dir,sim_name,Print=True):
    split = 1

    max_sim_time = simtime
    save_freq = int(round(save_freq))
    
    zbasis = de.SinCos('z',nz,interval=(-.02,1))
    rbasis = de.Chebyshev('r',nr,interval=(R,1))
    domain = de.Domain([zbasis,rbasis],grid_dtype=np.float64)
    zg, rg = domain.all_grids() # domain.grids()
    r,z = [de.Field(domain, name='r'),de.Field(domain, name='z')]
    z.meta['z']['parity'] = -1
    r.meta['z']['parity'] = 1
    zz, rr = zg+0*rg, 0*zg+rg
    r['g'] = rr
    z['g'] = zz

    mask = lambda x: (1/2)*(1 - erf(np.pi**(1/2)*x))
    delta = 0.01
    A = 3.11346786
    wall = domain.new_field()
    wall.meta['z']['parity'] = 1
    wall['g'] = mask(zg/delta)
    eps = delta/A
    Gamma = 1/Re*eps**(-2)

    switch = domain.new_field()
    switch.meta['z']['parity'] = -1
    switch['g'] = np.tanh(-(zg-1)/delta) - wall['g']


    def heaviside(x): 
        if x < 0: return 0
        elif x >=0: return 1

    from dedalus.core.operators import GeneralFunction
    # Define GeneralFunction subclass for time dependent boundary conditions
    class ConstantFunction(GeneralFunction):
        def __init__(self, domain, layout, func, args=[], kw={}, out=None,):
            super().__init__(domain, layout, func, args=[], kw={}, out=None,)

        def meta_constant(self, axis):
            return True

        def meta_parity(self, axis):
            return 1

    def impulse(solver): return heaviside(solver.sim_time)

    impulse_func = ConstantFunction(domain, layout='g', func=impulse)

    problem = de.IVP(domain,variables=['ur','urr','ut','utr','uz','uzr','p','Theta','Radius'])
    problem.parameters['Re'] = Re
    problem.parameters['wall'] = wall
    problem.parameters['H'] = H
    problem.parameters['R'] = R
    problem.parameters['Gamma'] = Gamma
    problem.parameters['switch'] = switch
    problem.parameters['i'] = impulse_func
    problem.parameters['A'] = split
    
    problem.meta['ur','ut','urr','utr','p','Theta','Radius']['z']['parity'] = 1 # cosines
    problem.meta['uz','uzr']['z']['parity'] = -1 # sines
    
    problem.add_equation('dr(ur) - urr = 0')
    problem.add_equation('dr(ut) - utr = 0')
    problem.add_equation('dr(uz) - uzr = 0')
    problem.add_equation('ur + r*urr + r*dz(uz) = 0') 
    problem.add_equation('dt(ur) + (1+R)*dr(p)      - 1/Re*(H**2*dr(urr) + dz(dz(ur))) + A*Gamma*ur = -(1+R)*(ur*urr + uz*dz(ur) -  ut**2/r ) + H**2/Re*(urr/r - ur/r**2)                 + Gamma*(A-wall)*ur')
    problem.add_equation('dt(ut)                    - 1/Re*(H**2*dr(utr) + dz(dz(ut))) + A*Gamma*ut = -(1+R)*(ur*utr + uz*dz(ut) + (ut*ur)/r) + H**2/Re*(utr/r - ut/r**2) + (1+R)/Re*i/r + Gamma*(A-wall)*ut')
    problem.add_equation('dt(uz) + (1+R)/H**2*dz(p) - 1/Re*(H**2*dr(uzr) + dz(dz(uz))) + A*Gamma*uz = -(1+R)*(ur*uzr + uz*dz(uz)            ) + H**2/Re*uzr/r                             + Gamma*(A-wall)*uz')
    problem.add_equation('dt(Theta) - (1+R)*ut/r = 0')
    problem.add_equation('dt(Radius) - (1+R)*ur = 0')
    
    problem.add_bc('left(ur) = 0',condition='nz != 0')
    problem.add_bc('left(p) = 0', condition='nz == 0')
    problem.add_bc('left(ut) = 0')
    problem.add_bc('left(uz) = 0')
    problem.add_bc('right(ur) = 0')
    problem.add_bc('right(ut) = 0')
    problem.add_bc('right(uz) = 0')

    solver = problem.build_solver(getattr(de.timesteppers,timestepper))

    impulse_func.original_args = impulse_func.args = [solver]

    ur, urr, ut, utr, uz, uzr, p, Theta, Radius = [solver.state[_] for _ in problem.variables]
    Radius['g'] = rg[0,:]
    
    utTheory = de.Field(domain, name='utTheory')
    utTheory.meta['z']['parity'] = 1
    utTheory['g'] = magstokes.uND(rr,zz,solver.sim_time,H,R,Re,5)

    import file_tools as flt # another github repo of Eric Hester's on github
    flt.makedir(save_dir)
    analysis = solver.evaluator.add_file_handler(f'{save_dir}/analysis-{sim_name}',iter=save_freq)
    for f in problem.variables: analysis.add_task(f)

    norm_op = de.operators.integrate(r*1,'z','r')
    rad_norm_op = de.operators.integrate(r*1,'r')

    merid_mag2 = (ur*ur + H*uz*H*uz) 
    rms_u_merid_op =  (1/norm_op * de.operators.integrate((1-wall)*r*merid_mag2,'z','r'))**(1/2)

    avg_ut_op =  (1/norm_op * de.operators.integrate((1-wall)*r*ut,'z','r'))**(1/2)

    ut_mag2 = ut*ut
    rms_ut_op =  (1/norm_op * de.operators.integrate((1-wall)*r*ut_mag2,'z','r'))**(1/2)

    error_mag2 = (ut - utTheory)**2
    rms_ut_error_op =  (1/norm_op * de.operators.integrate((1-wall)*r*error_mag2,'z','r'))**(1/2)
    rad_rms_ut_error2_op =  (1/rad_norm_op * de.operators.integrate((1-wall)*r*error_mag2,'r'))

    rel_error_mag2 = ((ut - utTheory)/utTheory)**2
    rms_rel_ut_error_op =  (1/norm_op * de.operators.integrate((1-wall)*r*rel_error_mag2,'z','r'))**(1/2)
    rad_rms_rel_ut_error2_op =  (1/rad_norm_op * de.operators.integrate((1-wall)*r*rel_error_mag2,'r'))
    
    analysis.add_task(1/H*zbasis.Differentiate(ur) - H*rbasis.Differentiate(uz), layout='g', name='vorttheta')
    analysis.add_task(rms_u_merid_op, layout='g', name='rms_u_merid')
    analysis.add_task(avg_ut_op, layout='g', name='avg_ut')
    analysis.add_task(rms_ut_op, layout='g', name='rms_ut')
    analysis.add_task(rms_ut_error_op, layout='g', name='rms_ut_error')
    analysis.add_task(rad_rms_ut_error2_op, layout='g', name='rad_rms_ut_error2')
    analysis.add_task(rms_rel_ut_error_op, layout='g', name='rms_rel_ut_error')
    analysis.add_task(rad_rms_rel_ut_error2_op, layout='g', name='rad_rms_rel_ut_error2')
    analysis.add_task(wall, layout='g', name='wall')

    while solver.sim_time <= simtime:
        utTheory['g'] = magstokes.uND(rr,zz,solver.sim_time,H,R,Re,5)
        solver.step(dt)
        if solver.iteration % 100 == 0:
            if Print==True:
                print(f'it {solver.iteration}',
                      f't {solver.sim_time:.2f}',
                      f'|ut| max {np.abs(ut["g"]).max():.3f}')
            if np.any(np.isnan(ut['g'])): 
                print('Broken')
                break

In [None]:
control_exp_param = pd.read_csv(os.path.join('..','laboratory','laboratory-data','nondimensional-parameters.csv'))
control_exp_param.head(4)

In [None]:
runs = np.array(['RUN A3','RUN C2','RUN D1','RUN E1'])

In [None]:
save_dir = os.path.join('DNS-data','RigidLid','exp-matching')
flt.makedir(save_dir)

## Simulate

In [None]:
%%time
# Should take around 6 hours total

for run in runs:
    Re = control_exp_param[control_exp_param['RUN']==run]['Re'].item()
    H = control_exp_param[control_exp_param['RUN']==run]['H'].item()
    R = control_exp_param[control_exp_param['RUN']==run]['R'].item()
    
    tau = (2/np.pi)**2*Re
    simtime = 1.1*np.pi #5*tau #2*5*tau
    save_freq = 10
    timestepper = 'SBDF2'
    
    dt = simtime/1e4 #2.5e-4*simtime
    nz = 512 #20
    nr = 256 #500 #50
    
    run_name = run.replace(' ','-')
    sim_name = f'mhd-2D-axisymmetric-VP-implicit-{timestepper}-{run_name}-nz={nz}-nr={nr}'
    print(sim_name)
   
    ###############

    MagnetoCouetteFlow(Re,H,R,simtime,save_freq,dt,nz,nr,timestepper,save_dir,sim_name,Print=True)
    

In [None]:
# Get sorted list of directories containing data files
analysis_dirs = sorted(glob.glob(os.path.join(f'{save_dir}','analysis*VP*')))
# Merge output files
for analysis in analysis_dirs:
    post.merge_analysis(analysis,cleanup=True)
    
    # Get list of HDF5 set paths
    setpaths = glob.glob(os.path.join(f'{analysis}','*s?.h5'))
    if len(setpaths)!=0:
        jointpath = re.sub('_s*.', '', setpaths[0])
        post.merge_sets(jointpath, setpaths, cleanup=True)

## Plot

In [None]:
# Import data
run='RUN A3'

nz = 512
nr = 256
run_name = run.replace(' ','-')
file = glob.glob(f'{save_dir}/*/*{run_name}-nz={nz}-nr={nr}.h5')[0]
r, z, t = flt.load_data(file,'r/1.0','z/1.0','sim_time',group='scales')
zz, rr = np.meshgrid(z, r, indexing='ij')
uts, = flt.load_data(file, 'ut',group='tasks')
urs, = flt.load_data(file, 'ur',group='tasks')
uzs, = flt.load_data(file, 'uz',group='tasks')
vortthetas, = flt.load_data(file, 'vorttheta',group='tasks')
Radii, = flt.load_data(file, 'Radius',group='tasks')
Thetas, = flt.load_data(file, 'Theta',group='tasks')
rms_u_merid, = flt.load_data(file, 'rms_u_merid',group='tasks')
rms_ut,= flt.load_data(file, 'rms_ut',group='tasks')
avg_ut,= flt.load_data(file, 'avg_ut',group='tasks')
rms_ut_error, = flt.load_data(file, 'rms_ut_error',group='tasks')
rms_rel_ut_error, = flt.load_data(file, 'rms_rel_ut_error',group='tasks')
wall, = flt.load_data(file, 'wall',group='tasks')

In [None]:
# Plot azimuthal velocity

z_levels = 5
z_stride = zz.shape[0]//z_levels
cmap = matplotlib.cm.Blues
for i in range(z_levels,0,-1):
    z_idx = i*z_stride - 1 
    plt.plot(rr[z_idx,:],uts[-1,z_idx,:],color=cmap(i/z_levels),label=f'$\\zeta = {zz[z_idx,0]:.1f}$')
plt.xlabel('$\\rho = r/r_o$')
plt.ylabel('$\\upsilon_\\phi$')
plt.legend()
plt.show()
