/
test_mrt_simplifications.py
38 lines (28 loc) · 1.31 KB
/
test_mrt_simplifications.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pytest
import sympy as sp
from pystencils.sympyextensions import is_constant
from lbmpy import Stencil, LBStencil, Method, create_lb_collision_rule, LBMConfig, LBMOptimisation
@pytest.mark.parametrize('method', [Method.MRT, Method.CENTRAL_MOMENT, Method.CUMULANT])
def test_mrt_simplifications(method: Method):
stencil = Stencil.D3Q19
lbm_config = LBMConfig(stencil=stencil, method=method, compressible=True)
lbm_opt = LBMOptimisation(simplification='auto')
cr = create_lb_collision_rule(lbm_config=lbm_config, lbm_optimisation=lbm_opt)
for subexp in cr.subexpressions:
rhs = subexp.rhs
# Check for aliases
assert not isinstance(rhs, sp.Symbol)
# Check for logarithms
assert not rhs.atoms(sp.log)
# Check for nonextracted constant summands or factors
exprs = rhs.atoms(sp.Add, sp.Mul)
for expr in exprs:
for arg in expr.args:
if isinstance(arg, sp.Number):
if arg not in {sp.Number(1), sp.Number(-1), sp.Float(1), sp.Float(-1)}:
breakpoint()
# Check for divisions
if not (isinstance(rhs, sp.Pow) and rhs.args[1] < 0):
powers = rhs.atoms(sp.Pow)
for p in powers:
assert p.args[1] > 0