Skip to content

Commit

Permalink
Merge pull request #183 from firedrakeproject/swe_checkpoints
Browse files Browse the repository at this point in the history
Shallow water checkpoints
  • Loading branch information
JHopeCollins committed Apr 23, 2024
2 parents 894b19a + e78df9b commit cf6ffba
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 2 deletions.
53 changes: 53 additions & 0 deletions case_studies/checkpoint_to_pvd.py
@@ -0,0 +1,53 @@
import firedrake as fd
from firedrake.output import VTKFile
from firedrake.petsc import PETSc

import argparse
parser = argparse.ArgumentParser(
description='Read a timeseries from a checkpoint file and write to a pvd file',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument('--ifilename', type=str, default='time_series', help='Name of checkpoint file.')
parser.add_argument('--ofilename', type=str, default='time_series', help='Name of vtk file.')
parser.add_argument('--ifuncname', type=str, default='func', help='Name of the Function in the input checkpoint file.')
parser.add_argument('--ofuncnames', type=str, nargs='+', default='func', help='Names of the (sub)Function(s) to write to the output pvd file.')
parser.add_argument('--nsteps', type=int, default=0, help='How many timesteps in the checkpoint file. If nsteps is 0 then only one Function is written and the idx argument to CheckpointFile.load_function is not used.')
parser.add_argument('--dt', type=float, default=1, help='Timestep size between different checkpoints.')
parser.add_argument('--show_args', action='store_true', help='Output all the arguments.')

args = parser.parse_known_args()
args = args[0]

if args.show_args:
PETSc.Sys.Print(args)

is_series = args.nsteps > 0

with fd.CheckpointFile(f"{args.ifilename}.h5", "r") as checkpoint:
pfile = VTKFile(f"{args.ofilename}.pvd")
mesh = checkpoint.load_mesh()

if is_series:
idx = 0
func = checkpoint.load_function(mesh, args.ifuncname, idx=idx)
else:
func = checkpoint.load_function(mesh, args.ifuncname)

if len(args.ofuncnames) != len(func.subfunctions):
msg = "--ofuncnames should contain one name for every component of the Function in the CheckpointFile. " \
+ f"{len(args.ofuncnames)} names given for {len(func.subfunctions)} subfunctions."
raise ValueError(msg)

outputfuncs = tuple(fd.Function(fsub.function_space(), name=fname).assign(fsub)
for fsub, fname in zip(func.subfunctions, args.ofuncnames))
if is_series:
pfile.write(*outputfuncs, t=idx*args.dt)
else:
pfile.write(*outputfuncs)

for idx in range(1, args.nsteps):
func = checkpoint.load_function(mesh, args.ifuncname, idx=idx)
for g, f in zip(outputfuncs, func.subfunctions):
g.assign(f)
pfile.write(*outputfuncs, t=idx*args.dt)
213 changes: 213 additions & 0 deletions case_studies/shallow_water/checkpoints/galewsky_checkpoints.py
@@ -0,0 +1,213 @@

import firedrake as fd
from firedrake.petsc import PETSc

from utils import units
from utils import mg
from utils.planets import earth
import utils.shallow_water as swe
from utils.shallow_water import galewsky

from utils.serial import SerialMiniApp

PETSc.Sys.popErrorHandler()

# Run the Galewsky test case and checkpoint solution to disk at specified intervals.
#
# These checkpoints can be used for:
# 1. converting to vtu to check solution veracity using `checkpoint_to_pvd.py`
# 2. starting a test from partway through the Galewsky case - e.g. starting
# a Paradiag run from several days in once nonlinearity has developed.
# 3. Testing solvers for the nonlinear blocks using standalone complex-proxy
# scripts. The checkpoints are used to linearise the nonlinear operator around
# to construct the complex-proxy blocks.

# get command arguments
import argparse
parser = argparse.ArgumentParser(
description='Galewsky testcase using fully implicit SWE solver.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument('--ref_level', type=int, default=2, help='Refinement level of icosahedral grid.')
parser.add_argument('--nt', type=int, default=48, help='Number of time steps.')
parser.add_argument('--dt', type=float, default=0.5, help='Timestep in hours.')
parser.add_argument('--degree', type=float, default=swe.default_degree(), help='Degree of the depth function space.')
parser.add_argument('--theta', type=float, default=0.5, help='Parameter for implicit theta method. 0.5 for trapezium rule, 1 for backwards Euler.')
parser.add_argument('--atol', type=float, default=1e0, help='Absolute tolerance for solution of each timestep.')
parser.add_argument('--filename', type=str, default='hdf5/galewsky_series', help='Name of checkpoint file.')
parser.add_argument('--save_freq', type=int, default=12, help='How many timesteps between each checkpoint.')
parser.add_argument('--show_args', action='store_true', help='Output all the arguments.')
parser.add_argument('--verbose', '-v', action='store_true', help='Print SNES and KSP outputs.')

args = parser.parse_known_args()
args = args[0]

nt = args.nt
degree = args.degree

if args.show_args:
PETSc.Sys.Print(args)

PETSc.Sys.Print('')
PETSc.Sys.Print('### === --- Setting up --- === ###')
PETSc.Sys.Print('')

# icosahedral mg mesh
mesh = swe.create_mg_globe_mesh(ref_level=args.ref_level, coords_degree=1)
x = fd.SpatialCoordinate(mesh)

# time step
dt = args.dt*units.hour

# shallow water equation function spaces (velocity and depth)
W = swe.default_function_space(mesh, degree=args.degree)
Vu, Vh = W.subfunctions

# parameters
g = earth.Gravity

topography_expr = galewsky.topography_expression(*x)
coriolis_expr = swe.earth_coriolis_expression(*x)
b = fd.Function(Vh, name="topography").project(topography_expr)
f = fd.Function(Vh, name="coriolis").project(coriolis_expr)

# initial conditions
w_initial = fd.Function(W)
u_initial, h_initial = w_initial.subfunctions

u_initial.project(galewsky.velocity_expression(*x))
h_initial.project(galewsky.depth_expression(*x))

# current and next timestep
w0 = fd.Function(W).assign(w_initial)
w1 = fd.Function(W).assign(w_initial)

# mean height
H = galewsky.H0


# shallow water equation forms
def form_mass(u, h, v, q):
return swe.nonlinear.form_mass(mesh, u, h, v, q)


def form_function(u, h, v, q, t):
return swe.nonlinear.form_function(mesh, g, b, f,
u, h, v, q, t)


def aux_form_function(u, h, v, q, t):
return swe.linear.form_function(mesh, g, H, f,
u, h, v, q, t)


appctx = {'aux_form_function': aux_form_function}


# solver parameters for the implicit solve

linear_snes_params = {
'lag_jacobian': -2,
'lag_jacobian_persists': None,
'lag_preconditioner': -2,
'lag_preconditioner_persists': None,
}

lu_params = {
'ksp_type': 'preonly',
'pc_type': 'lu',
'pc_factor': {
'mat_solver_type': 'mumps',
'reuse_ordering': None,
'reuse_fill': None,
}
}

hybridization_sparams = {
"mat_type": "matfree",
"pc_type": "python",
"pc_python_type": "firedrake.HybridizationPC",
"hybridization": lu_params,
"hybridization_snes": linear_snes_params
}

aux_sparams = {
"mat_type": "matfree",
"pc_type": "python",
"pc_python_type": "asQ.AuxiliaryRealBlockPC",
"aux": hybridization_sparams,
"aux_snes": linear_snes_params
}


sparameters = {
'snes': {
'rtol': 1e-12,
'atol': args.atol,
'ksp_ew': None,
'ksp_ew_version': 1,
'ksp_ew_threshold': 1e-5,
'ksp_ew_rtol0': 1e-2,
'lag_preconditioner': -2,
'lag_preconditioner_persists': None,
},
'ksp_type': 'fgmres',
'ksp': {
'atol': args.atol,
'rtol': 1e-5,
},
}
sparameters.update(aux_sparams)

if args.verbose:
sparameters['snes_monitor'] = None
sparameters['snes_converged_reason'] = None
sparameters['ksp_monitor'] = None
sparameters['ksp_converged_rate'] = None

# set up nonlinear solver
miniapp = SerialMiniApp(dt, args.theta,
w_initial,
form_mass,
form_function,
sparameters,
appctx=appctx)

miniapp.nlsolver.set_transfer_manager(
mg.ManifoldTransferManager())

# save initial conditions

PETSc.Sys.Print('### === --- Timestepping loop --- === ###')


def preproc(app, step, t):
if args.verbose:
PETSc.Sys.Print('')
PETSc.Sys.Print(f'=== --- Timestep {step} --- ===')


wout = fd.Function(W, name="swe").assign(w_initial)
checkpoint = fd.CheckpointFile(f"{args.filename}.h5", 'w')
checkpoint.save_mesh(mesh)
checkpoint.save_function(b)
checkpoint.save_function(f)
idx = 0
checkpoint.save_function(wout, idx=idx)
idx += 1


def postproc(app, step, t):
global idx
if ((step+1) % args.save_freq) == 0:
wout.assign(app.w1)
checkpoint.save_function(wout, idx=idx)
idx += 1


miniapp.solve(args.nt,
preproc=preproc,
postproc=postproc)

checkpoint.close()
12 changes: 10 additions & 2 deletions utils/serial/miniapp.py
Expand Up @@ -12,7 +12,7 @@ def __init__(self,
form_mass,
form_function,
solver_parameters,
bcs=[]):
bcs=[], appctx={}):
'''
A miniapp to integrate a finite element form forward in time using the implicit theta method
Expand Down Expand Up @@ -46,9 +46,17 @@ def __init__(self,
self.dt, self.theta,
self.w0, self.w1)

appctx['uref'] = self.w1
appctx['bcs'] = bcs
appctx['tref'] = self.time
appctx['theta'] = theta
appctx['dt'] = dt
appctx['form_mass'] = form_mass
appctx['form_function'] = form_function

self.nlproblem = fd.NonlinearVariationalProblem(self.form_full, self.w1, bcs=bcs)

self.nlsolver = fd.NonlinearVariationalSolver(self.nlproblem,
self.nlsolver = fd.NonlinearVariationalSolver(self.nlproblem, appctx=appctx,
solver_parameters=self.solver_parameters)

def set_theta_form(self, form_mass, form_function, dt, theta, w0, w1):
Expand Down

0 comments on commit cf6ffba

Please sign in to comment.