diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 3c9a2689a4..609d599800 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -300,7 +300,7 @@ def solve(self, bounds=None): coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None)) # Make sure the solution dm is visited last solution_dm = self.snes.getDM() - problem_dms = [V.dm for V in utils.unique(c.function_space() for c in coefficients) if V.dm != solution_dm] + problem_dms = [V.dm for V in utils.unique(chain.from_iterable(c.function_space() for c in coefficients)) if V.dm != solution_dm] problem_dms.append(solution_dm) for dbc in problem.dirichlet_bcs(): diff --git a/tests/regression/test_star_pc.py b/tests/regression/test_star_pc.py index b636dc9710..38ad450514 100644 --- a/tests/regression/test_star_pc.py +++ b/tests/regression/test_star_pc.py @@ -1,4 +1,5 @@ import pytest +import warnings from firedrake import * from firedrake.petsc import DEFAULT_DIRECT_SOLVER try: @@ -20,6 +21,12 @@ def backend(request): return request.param +def filter_warnings(caller): + with warnings.catch_warnings(): + warnings.filterwarnings("error", "Creating new TransferManager", RuntimeWarning) + caller() + + def test_star_equivalence(problem_type, backend): distribution_parameters = {"partition": True, "overlap_type": (DistributedMeshOverlapType.VERTEX, 1)} @@ -171,12 +178,12 @@ def test_star_equivalence(problem_type, backend): star_params["mg_levels_pc_star_mat_ordering_type"] = "rcm" nvproblem = NonlinearVariationalProblem(a, u, bcs=bcs) star_solver = NonlinearVariationalSolver(nvproblem, solver_parameters=star_params, nullspace=nsp) - star_solver.solve() + filter_warnings(star_solver.solve) star_its = star_solver.snes.getLinearSolveIterations() u.assign(0) comp_solver = NonlinearVariationalSolver(nvproblem, solver_parameters=comp_params, nullspace=nsp) - comp_solver.solve() + filter_warnings(comp_solver.solve) comp_its = comp_solver.snes.getLinearSolveIterations() assert star_its == comp_its @@ -348,12 +355,12 @@ def test_vanka_equivalence(problem_type): vanka_params["mg_levels_pc_vanka_mat_ordering_type"] = "rcm" nvproblem = NonlinearVariationalProblem(a, u, bcs=bcs) star_solver = NonlinearVariationalSolver(nvproblem, solver_parameters=vanka_params, nullspace=nsp) - star_solver.solve() + filter_warnings(star_solver.solve) star_its = star_solver.snes.getLinearSolveIterations() u.assign(0) comp_solver = NonlinearVariationalSolver(nvproblem, solver_parameters=comp_params, nullspace=nsp) - comp_solver.solve() + filter_warnings(comp_solver.solve) comp_its = comp_solver.snes.getLinearSolveIterations() assert star_its == comp_its