Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def solve(self, bounds=None):
coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None))
# Make sure the solution dm is visited last
solution_dm = self.snes.getDM()
problem_dms = [V.dm for V in utils.unique(c.function_space() for c in coefficients) if V.dm != solution_dm]
problem_dms = [V.dm for V in utils.unique(chain.from_iterable(c.function_space() for c in coefficients)) if V.dm != solution_dm]
problem_dms.append(solution_dm)

for dbc in problem.dirichlet_bcs():
Expand Down
15 changes: 11 additions & 4 deletions tests/regression/test_star_pc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import warnings
from firedrake import *
from firedrake.petsc import DEFAULT_DIRECT_SOLVER
try:
Expand All @@ -20,6 +21,12 @@ def backend(request):
return request.param


def filter_warnings(caller):
with warnings.catch_warnings():
warnings.filterwarnings("error", "Creating new TransferManager", RuntimeWarning)
caller()


def test_star_equivalence(problem_type, backend):
distribution_parameters = {"partition": True,
"overlap_type": (DistributedMeshOverlapType.VERTEX, 1)}
Expand Down Expand Up @@ -171,12 +178,12 @@ def test_star_equivalence(problem_type, backend):
star_params["mg_levels_pc_star_mat_ordering_type"] = "rcm"
nvproblem = NonlinearVariationalProblem(a, u, bcs=bcs)
star_solver = NonlinearVariationalSolver(nvproblem, solver_parameters=star_params, nullspace=nsp)
star_solver.solve()
filter_warnings(star_solver.solve)
star_its = star_solver.snes.getLinearSolveIterations()

u.assign(0)
comp_solver = NonlinearVariationalSolver(nvproblem, solver_parameters=comp_params, nullspace=nsp)
comp_solver.solve()
filter_warnings(comp_solver.solve)
comp_its = comp_solver.snes.getLinearSolveIterations()

assert star_its == comp_its
Expand Down Expand Up @@ -348,12 +355,12 @@ def test_vanka_equivalence(problem_type):
vanka_params["mg_levels_pc_vanka_mat_ordering_type"] = "rcm"
nvproblem = NonlinearVariationalProblem(a, u, bcs=bcs)
star_solver = NonlinearVariationalSolver(nvproblem, solver_parameters=vanka_params, nullspace=nsp)
star_solver.solve()
filter_warnings(star_solver.solve)
star_its = star_solver.snes.getLinearSolveIterations()

u.assign(0)
comp_solver = NonlinearVariationalSolver(nvproblem, solver_parameters=comp_params, nullspace=nsp)
comp_solver.solve()
filter_warnings(comp_solver.solve)
comp_its = comp_solver.snes.getLinearSolveIterations()

assert star_its == comp_its