# TransportEquation and TransportSolver

Writing new ```TransportSolver``` and ```LaxWendroff``` classes which generalize icepack's ```prognostic_solve``` to solve transport problems. For now, other notebooks will import this as a module. Eventually this will just be a part of icepack. 

## Imports

In [1]:
from abc import ABC
from firedrake import *
from icepack.calculus import FacetNormal
from icepack.optimization import MinimizationProblem, NewtonSolver
from icepack.utilities import default_solver_parameters
from operator import itemgetter

## Write the model and solver

In [2]:
#####################################
#####################################
### transport equation and solver ###
#####################################
#####################################

# Copyright (C) 2023 by Daniel Shapero <shapero@uw.edu>
#
# This file is part of icepack.
#
# icepack is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The full text of the license can be found in the file LICENSE in the
# icepack source directory or at <http://www.gnu.org/licenses/>.

###############################################
###############################################
### copy/paste icepack's transport equation ###
###############################################
###############################################

class TransportEquation(ABC):
    def __init__(self, field_name, conservative):
        self._field_name = field_name
        self._conservative = conservative

    def flux(self, **kwargs):
        keys = (self._field_name, "velocity")
        q, u = itemgetter(*keys)(kwargs)
        q_inflow = kwargs.get(f"{self._field_name}_inflow", Constant(0.0))

        if q.ufl_shape != ():
            raise NotImplementedError(
                "Transport equation only implemented for scalar problems!"
            )

        Q = q.function_space()
        φ = TestFunction(Q)

        mesh = Q.mesh()
        n = FacetNormal(mesh)
        #ds = ds if mesh.layers is None else ds_v

        if self._conservative:
            flux_cells = -inner(q * u, grad(φ)) * dx
        else:
            flux_cells = -q * div(u * φ) * dx

        flux_out = q * max_value(0, inner(u, n)) * φ * ds
        flux_in = q_inflow * min_value(0, inner(u, n)) * φ * ds

        if q.ufl_element().family() == "Discontinuous Lagrange":
            f = q * max_value(0, inner(u, n))
            flux_faces = (f("+") - f("-")) * (φ("+") - φ("-")) * dS
            return flux_cells + flux_faces + flux_out + flux_in

        return flux_cells + flux_out + flux_in

    def sources(self, **kwargs): #tweaking this to only permit (thickness, thickness_accumulation), (damage, damage_accumulation), etc.
        keys = (self._field_name, f"{self._field_name}_accumulation")
        ψ, a = itemgetter(*keys)(kwargs)
        φ = TestFunction(ψ.function_space())
        return a * φ * dx

#################################################################
#################################################################
### essentially copy/paste the relevant bits from flow_solver ###
#################################################################
#################################################################

class TransportSolver:
    def __init__(self, model, **kwargs):
        
        self._model = model
        self._fields = {}

        # Prepare the prognostic solver
        prognostic_parameters = kwargs.get(
            "prognostic_solver_parameters", default_solver_parameters
        )

        self._prognostic_solver = LaxWendroff(
            model, self._fields, prognostic_parameters
        )

    @property
    def model(self):
        r"""The physics model that this object solves"""
        return self._model

    @property
    def fields(self):
        r"""Dictionary of all fields that are part of the simulation"""
        return self._fields

    def prognostic_solve(self, dt, **kwargs):
        r"""Solve the prognostic model physics for the new value of the chosen scaler field"""
        return self._prognostic_solver.solve(dt, **kwargs)

class LaxWendroff:
    def __init__(self, model, fields, solver_parameters):
        self._model = model
        self._fields = fields
        self._solver_parameters = solver_parameters

    def setup(self, **kwargs):
        r"""Create the internal data structures that help reuse information
        from past prognostic solves"""
        for name, field in kwargs.items():
            if name in self._fields.keys():
                self._fields[name].assign(field)
            else:
                if isinstance(field, Constant):
                    self._fields[name] = Constant(field)
                elif isinstance(field, Function):
                    self._fields[name] = field.copy(deepcopy=True)
                else:
                    raise TypeError(
                        "Input %s field has type %s, must be Constant or Function!"
                        % (name, type(field))
                    )

        dt = Constant(1.0)
        ψ = self._fields[self._model._field_name] #generalized from thickness h to scalar field ψ
        u = self._fields["velocity"]
        ψ_0 = ψ.copy(deepcopy=True)

        Q = ψ.function_space()
        mesh = Q.mesh()
        n = FacetNormal(mesh)
        outflow = max_value(0, inner(u, n))
        inflow = min_value(0, inner(u, n))

        # Additional streamlining terms that give 2nd-order accuracy
        q = TestFunction(Q)
        #ds = ds if mesh.layers is None else ds_v
        flux_cells = -div(ψ * u) * inner(u, grad(q)) * dx
        flux_out = div(ψ * u) * q * outflow * ds
        flux_in = div(ψ_0 * u) * q * inflow * ds
        d2ψ_dt2 = flux_cells + flux_out + flux_in

        sources = self._model.sources(**self._fields)
        flux = self._model.flux(**self._fields)
        dψ_dt = sources - flux
        F = (ψ - ψ_0) * q * dx - dt * (dψ_dt + 0.5 * dt * d2ψ_dt2)

        problem = NonlinearVariationalProblem(F, ψ)
        self._solver = NonlinearVariationalSolver(
            problem, solver_parameters=self._solver_parameters
        )

        self._ψ_old = ψ_0
        self._timestep = dt

    def solve(self, dt, **kwargs):
        r"""Compute the evolution of the chosen scalar field after time `dt`"""
        if not hasattr(self, "_solver"):
            self.setup(**kwargs)
        else:
            for name, field in kwargs.items():
                self._fields[name].assign(field)

        ψ = self._fields[self._model._field_name]
        self._ψ_old.assign(ψ)
        self._timestep.assign(dt)
        self._solver.solve()
        return ψ.copy(deepcopy=True)