Skip to content

Commit

Permalink
nrpy/infrastructures/BHaH/general_relativity/BSSN_C_codegen_library.p…
Browse files Browse the repository at this point in the history
…y: improvements
  • Loading branch information
zachetienne committed Jun 20, 2024
1 parent 0ae6781 commit 6e3dc2e
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 45 deletions.
146 changes: 107 additions & 39 deletions nrpy/infrastructures/BHaH/general_relativity/BSSN_C_codegen_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,11 @@ def register_CFunction_rhs_eval(
enable_KreissOliger_dissipation: bool,
LapseEvolutionOption: str,
ShiftEvolutionOption: str,
KreissOliger_strength_mult_by_W: bool = False,
# when mult by W, strength_gauge=0.99 & strength_nongauge=0.3 is best.
KreissOliger_strength_gauge: float = 0.3,
KreissOliger_strength_nongauge: float = 0.3,
enable_CAKO: bool = False,
enable_CAHD: bool = False,
enable_SSL: bool = False,
OMP_collapse: int = 1,
fp_type: str = "double",
validate_expressions: bool = False,
Expand All @@ -348,9 +349,11 @@ def register_CFunction_rhs_eval(
:param enable_KreissOliger_dissipation: Whether to enable Kreiss-Oliger dissipation.
:param LapseEvolutionOption: Lapse evolution equation choice.
:param ShiftEvolutionOption: Lapse evolution equation choice.
:param KreissOliger_strength_mult_by_W: Whether to multiply Kreiss-Oliger strength by W.
:param KreissOliger_strength_gauge: Gauge strength for Kreiss-Oliger dissipation.
:param KreissOliger_strength_nongauge: Non-gauge strength for Kreiss-Oliger dissipation.
:param enable_CAKO: Whether to enable curvature-aware Kreiss-Oliger dissipation (multiply strength by W).
:param enable_CAHD: Whether to enable curvature-aware Hamiltonian-constraint damping.
:param enable_SSL: Whether to enable slow-start lapse.
:param OMP_collapse: Degree of OpenMP loop collapsing.
:param fp_type: Floating point type, e.g., "double".
:param validate_expressions: Whether to validate generated sympy expressions against trusted values.
Expand Down Expand Up @@ -398,6 +401,24 @@ def register_CFunction_rhs_eval(
sorted(local_BSSN_RHSs_varname_to_expr_dict.items())
)

# Define conformal factor W.
Bq = BSSN_quantities[
CoordSystem
+ ("_rfm_precompute" if enable_rfm_precompute else "")
+ ("_RbarDD_gridfunctions" if enable_RbarDD_gridfunctions else "")
]
EvolvedConformalFactor_cf = par.parval_from_str("EvolvedConformalFactor_cf")
if EvolvedConformalFactor_cf == "W":
W = Bq.cf
elif EvolvedConformalFactor_cf == "chi":
W = sp.sqrt(Bq.cf)
elif EvolvedConformalFactor_cf == "phi":
W = sp.exp(-2 * Bq.cf)
else:
raise ValueError(
"Error: only EvolvedConformalFactor_cf = (W or chi or phi) supported."
)

# Add Kreiss-Oliger dissipation to the BSSN RHSs:
if enable_KreissOliger_dissipation:
diss_strength_gauge, diss_strength_nongauge = par.register_CodeParameters(
Expand All @@ -408,22 +429,11 @@ def register_CFunction_rhs_eval(
commondata=True,
)

if KreissOliger_strength_mult_by_W:
Bq = BSSN_quantities[
CoordSystem
+ ("_rfm_precompute" if enable_rfm_precompute else "")
+ ("_RbarDD_gridfunctions" if enable_RbarDD_gridfunctions else "")
]
EvolvedConformalFactor_cf = par.parval_from_str("EvolvedConformalFactor_cf")
if EvolvedConformalFactor_cf == "W":
diss_strength_gauge *= Bq.cf
diss_strength_nongauge *= Bq.cf
elif EvolvedConformalFactor_cf == "chi":
diss_strength_gauge *= sp.sqrt(Bq.cf)
diss_strength_nongauge *= sp.sqrt(Bq.cf)
elif EvolvedConformalFactor_cf == "phi":
diss_strength_gauge *= sp.exp(-2 * Bq.cf)
diss_strength_nongauge *= sp.exp(-2 * Bq.cf)
# vvv BEGIN CAKO vvv
if enable_CAKO:
diss_strength_gauge *= W
diss_strength_nongauge *= W
# ^^^ END CAKO ^^^

rfm = refmetric.reference_metric[
CoordSystem + "_rfm_precompute" if enable_rfm_precompute else CoordSystem
Expand Down Expand Up @@ -465,6 +475,60 @@ def register_CFunction_rhs_eval(
diss_strength_nongauge * hDD_dKOD[i][j][k] * rfm.ReU[k]
) # ReU[k] = 1/scalefactor_orthog_funcform[k]

# vvv BEGIN CAHD vvv
if enable_CAHD:
Bcon = BSSN_constraints[
CoordSystem
+ ("_rfm_precompute" if enable_rfm_precompute else "")
+ ("_RbarDD_gridfunctions" if enable_RbarDD_gridfunctions else "")
+ ("_T4munu" if enable_T4munu else "")
]
if "cahdprefactor" not in gri.glb_gridfcs_dict:
_ = gri.register_gridfunctions("cahdprefactor")
C_CAHD = par.register_CodeParameter(
"REAL", __name__, "C_CAHD", 0.15, commondata=True, add_to_parfile=True
)
# Initialize CAHD_term assuming phi is the evolved conformal factor. CFL_FACTOR is defined in MoL.
# CAHD_term = -C_CAHD * (sp.symbols("CFL_FACTOR") * sp.symbols("dsmin")) * Bcon.H
CAHD_term = -sp.symbols("cahdprefactor") * Bcon.H
if EvolvedConformalFactor_cf == "phi":
pass # CAHD_term already assumes phi is the evolved conformal factor.
elif EvolvedConformalFactor_cf == "W":
# \partial_t W = \partial_t e^{-2 phi} = -2 W \partial_t phi
CAHD_term *= -2 * Bq.cf
elif EvolvedConformalFactor_cf == "chi":
# \partial_t chi = \partial_t e^{-4 phi} = -4 chi \partial_t phi
CAHD_term *= -4 * Bq.cf
else:
raise ValueError(
"Error: only EvolvedConformalFactor_cf = (W or chi or phi) supported."
)
local_BSSN_RHSs_varname_to_expr_dict["cf_rhs"] += CAHD_term
# ^^^ END CAHD ^^^

# vvv BEGIN SSL vvv
if enable_SSL:
SSL_gaussian_prefactor = par.register_CodeParameter(
"REAL",
__name__,
"SSL_gaussian_prefactor",
1.0,
commondata=True,
add_to_parfile=False,
)
_SSL_h, _SSL_sigma = par.register_CodeParameters(
"REAL",
__name__,
["SSL_h", "SSL_sigma"],
[0.6, 20.0],
commondata=True,
add_to_parfile=True,
)
local_BSSN_RHSs_varname_to_expr_dict["alpha_rhs"] -= (
W * SSL_gaussian_prefactor * (Bq.alpha - W)
)
# ^^^ END SSL ^^^

BSSN_RHSs_access_gf: List[str] = []
for var in local_BSSN_RHSs_varname_to_expr_dict.keys():
BSSN_RHSs_access_gf += [
Expand Down Expand Up @@ -1196,23 +1260,27 @@ def register_CFunction_psi4_spinweightm2_decomposition_on_sphlike_grids() -> Non
ShiftEvolOption = "GammaDriving2ndOrder_Covariant"
for Rbar_gfs in [True, False]:
for T4munu_enable in [True, False]:
results_dict = register_CFunction_rhs_eval(
CoordSystem=Coord,
enable_rfm_precompute=True,
enable_RbarDD_gridfunctions=Rbar_gfs,
enable_T4munu=T4munu_enable,
enable_simd=False,
enable_fd_functions=False,
enable_KreissOliger_dissipation=True,
LapseEvolutionOption=LapseEvolOption,
ShiftEvolutionOption=ShiftEvolOption,
validate_expressions=True,
)
ve.compare_or_generate_trusted_results(
os.path.abspath(__file__),
os.getcwd(),
# File basename. If this is set to "trusted_module_test1", then
# trusted results_dict will be stored in tests/trusted_module_test1.py
f"{os.path.splitext(os.path.basename(__file__))[0]}_{LapseEvolOption}_{ShiftEvolOption}_{Coord}_Rbargfs{Rbar_gfs}_T4munu{T4munu_enable}",
cast(Dict[str, Union[mpf, mpc]], results_dict),
)
for enable_Improvements in [True, False]:
results_dict = register_CFunction_rhs_eval(
CoordSystem=Coord,
enable_rfm_precompute=True,
enable_RbarDD_gridfunctions=Rbar_gfs,
enable_T4munu=T4munu_enable,
enable_simd=False,
enable_fd_functions=False,
enable_KreissOliger_dissipation=True,
LapseEvolutionOption=LapseEvolOption,
ShiftEvolutionOption=ShiftEvolOption,
enable_CAKO=enable_Improvements,
enable_CAHD=enable_Improvements,
enable_SSL=enable_Improvements,
validate_expressions=True,
)
ve.compare_or_generate_trusted_results(
os.path.abspath(__file__),
os.getcwd(),
# File basename. If this is set to "trusted_module_test1", then
# trusted results_dict will be stored in tests/trusted_module_test1.py
f"{os.path.splitext(os.path.basename(__file__))[0]}_{LapseEvolOption}_{ShiftEvolOption}_{Coord}_Rbargfs{Rbar_gfs}_T4munu{T4munu_enable}_Improvements{enable_Improvements}",
cast(Dict[str, Union[mpf, mpc]], results_dict),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from mpmath import mpf # type: ignore

trusted_dict = {
"a_rhsDD00": mpf("69421.8569726864077633773520086218"),
"a_rhsDD01": mpf("63338.392455907554107729958303091"),
"a_rhsDD02": mpf("7542.40362913282752430183606927844"),
"a_rhsDD11": mpf("92397.9594314949128611676531792922"),
"a_rhsDD12": mpf("64832.8427837073297770326604194367"),
"a_rhsDD22": mpf("103443.389039802068547607289496256"),
"alpha_rhs": mpf("1.43755962968502667895353294595469"),
"bet_rhsU0": mpf("-459357.481307345836572572277583582"),
"bet_rhsU1": mpf("527920.270319507679149920541081937"),
"bet_rhsU2": mpf("-299450.652180767992214068137391737"),
"cf_rhs": mpf("-150754.451928636598842856910853664"),
"h_rhsDD00": mpf("30.2061942218421715191473115028784"),
"h_rhsDD01": mpf("28.9606608467494555508828366456674"),
"h_rhsDD02": mpf("6.39276587975033212268750672045724"),
"h_rhsDD11": mpf("42.3283467195269134291581082722674"),
"h_rhsDD12": mpf("32.2991327983378680834578313124515"),
"h_rhsDD22": mpf("51.8948766361664809982989623435546"),
"lambda_rhsU0": mpf("-612397.437920142318886113953300607"),
"lambda_rhsU1": mpf("703797.513292818430552961290146218"),
"lambda_rhsU2": mpf("-399214.390904380225420814759105428"),
"trK_rhs": mpf("3608.14456884928328004617440611385"),
"vet_rhsU0": mpf("56.3772664742160443704069163257009"),
"vet_rhsU1": mpf("-63.287039038369237559871735826842"),
"vet_rhsU2": mpf("41.7728207000753948065078730776413"),
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from mpmath import mpf # type: ignore

trusted_dict = {
"a_rhsDD00": mpf("69194.2157349814477743801679736856"),
"a_rhsDD01": mpf("62722.8762224442147470616033680122"),
"a_rhsDD02": mpf("6776.99911156377271544768660397319"),
"a_rhsDD11": mpf("91367.8546422242458261243976095765"),
"a_rhsDD12": mpf("63921.9664939839708927198386704627"),
"a_rhsDD22": mpf("102968.772400124015386760602592297"),
"alpha_rhs": mpf("1.43755962968502667895353294595469"),
"bet_rhsU0": mpf("-459434.270609846610793207217165059"),
"bet_rhsU1": mpf("527794.286024702434715574819766772"),
"bet_rhsU2": mpf("-299533.911756559408276733172263271"),
"cf_rhs": mpf("-150757.09148416778076748410091957"),
"h_rhsDD00": mpf("30.2061942218421715191473115028784"),
"h_rhsDD01": mpf("28.9606608467494555508828366456674"),
"h_rhsDD02": mpf("6.39276587975033212268750672045724"),
"h_rhsDD11": mpf("42.3283467195269134291581082722674"),
"h_rhsDD12": mpf("32.2991327983378680834578313124515"),
"h_rhsDD22": mpf("51.8948766361664809982989623435546"),
"lambda_rhsU0": mpf("-612499.823656810017846960539408933"),
"lambda_rhsU1": mpf("703629.534233078104640500328392974"),
"lambda_rhsU2": mpf("-399325.403672102113504368138934053"),
"trK_rhs": mpf("4246.50022341571559667830337355975"),
"vet_rhsU0": mpf("56.3772664742160443704069163257009"),
"vet_rhsU1": mpf("-63.287039038369237559871735826842"),
"vet_rhsU2": mpf("41.7728207000753948065078730776413"),
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"a_rhsDD02": mpf("-30.9240355525784715027439938760933"),
"a_rhsDD11": mpf("-359.185832966049654823394884978503"),
"a_rhsDD12": mpf("-279.06616827406854451123184758721"),
"a_rhsDD22": mpf("-430.500831984357897224020758115333"),
"a_rhsDD22": mpf("-430.500831984357897224020758115232"),
"alpha_rhs": mpf("1.86343241276651738837621566532146"),
"bet_rhsU0": mpf("-459356.057986185400806741677016022"),
"bet_rhsU1": mpf("527921.799575815979614714973298668"),
Expand All @@ -21,7 +21,7 @@
"lambda_rhsU0": mpf("-612397.344985818760575739355092038"),
"lambda_rhsU1": mpf("703797.685870452238730555226669028"),
"lambda_rhsU2": mpf("-399214.257047651378801301938993041"),
"trK_rhs": mpf("3608.25992596186612275396039375975"),
"trK_rhs": mpf("3608.25992596186612275396039375934"),
"vet_rhsU0": mpf("57.2743391537225002976754927475324"),
"vet_rhsU1": mpf("-62.2809474810246133206795120322693"),
"vet_rhsU2": mpf("42.4311464203862956100197642341794"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from mpmath import mpf # type: ignore

trusted_dict = {
"a_rhsDD00": mpf("-305.590211004687372621858242982517"),
"a_rhsDD01": mpf("-306.072433041488754781876273519897"),
"a_rhsDD02": mpf("-31.0189720222153563713973145801687"),
"a_rhsDD11": mpf("-359.333521492285340915757894748913"),
"a_rhsDD12": mpf("-279.191368758453845395590996384951"),
"a_rhsDD22": mpf("-430.613858263561015993770724784613"),
"alpha_rhs": mpf("1.43755962968502667895353294595469"),
"bet_rhsU0": mpf("-459357.481307345836572572277583582"),
"bet_rhsU1": mpf("527920.270319507679149920541081937"),
"bet_rhsU2": mpf("-299450.652180767992214068137391737"),
"cf_rhs": mpf("325.609470965348508787744352325275"),
"h_rhsDD00": mpf("30.2061942218421715191473115028784"),
"h_rhsDD01": mpf("28.9606608467494555508828366456674"),
"h_rhsDD02": mpf("6.39276587975033212268750672045724"),
"h_rhsDD11": mpf("42.3283467195269134291581082722674"),
"h_rhsDD12": mpf("32.2991327983378680834578313124515"),
"h_rhsDD22": mpf("51.8948766361664809982989623435546"),
"lambda_rhsU0": mpf("-612397.437920142318886113953300607"),
"lambda_rhsU1": mpf("703797.513292818430552961290146218"),
"lambda_rhsU2": mpf("-399214.390904380225420814759105428"),
"trK_rhs": mpf("3608.14456884928328004617440611385"),
"vet_rhsU0": mpf("56.3772664742160443704069163257009"),
"vet_rhsU1": mpf("-63.287039038369237559871735826842"),
"vet_rhsU2": mpf("41.7728207000753948065078730776413"),
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"a_rhsDD00": mpf("-533.108740980496100937097099267652"),
"a_rhsDD01": mpf("-921.51434608052129573246854342304"),
"a_rhsDD02": mpf("-796.328553121633280356893459193684"),
"a_rhsDD11": mpf("-1389.29062223671668986665045397035"),
"a_rhsDD11": mpf("-1389.29062223671668986665045397337"),
"a_rhsDD12": mpf("-1189.9424579974274288240535959141"),
"a_rhsDD22": mpf("-905.117471662411058070707661947768"),
"alpha_rhs": mpf("1.86343241276651738837621566532146"),
Expand All @@ -18,10 +18,10 @@
"h_rhsDD11": mpf("42.4917015225065157490920912601867"),
"h_rhsDD12": mpf("32.3644647640678045705913110855292"),
"h_rhsDD22": mpf("51.9783657551188411079835306608628"),
"lambda_rhsU0": mpf("-612499.730722486459536585941200364"),
"lambda_rhsU0": mpf("-612499.730722486459536585941200571"),
"lambda_rhsU1": mpf("703629.706810711912818094264915785"),
"lambda_rhsU2": mpf("-399325.269815373266884855318821666"),
"trK_rhs": mpf("4246.6155805282984393860893612"),
"lambda_rhsU2": mpf("-399325.269815373266884855318821718"),
"trK_rhs": mpf("4246.61558052829843938608936121131"),
"vet_rhsU0": mpf("57.2743391537225002976754927475324"),
"vet_rhsU1": mpf("-62.2809474810246133206795120322693"),
"vet_rhsU2": mpf("42.4311464203862956100197642341794"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from mpmath import mpf # type: ignore

trusted_dict = {
"a_rhsDD00": mpf("-533.231448709647361619042277356778"),
"a_rhsDD01": mpf("-921.588666504828115450231208125833"),
"a_rhsDD02": mpf("-796.423489591270165225546779897823"),
"a_rhsDD11": mpf("-1389.43831076295237595901346374378"),
"a_rhsDD12": mpf("-1190.06765848181272970841274470497"),
"a_rhsDD22": mpf("-905.230497941614176840457628605587"),
"alpha_rhs": mpf("1.43755962968502667895353294595469"),
"bet_rhsU0": mpf("-459434.270609846610793207217165059"),
"bet_rhsU1": mpf("527794.286024702434715574819766772"),
"bet_rhsU2": mpf("-299533.911756559408276733172263271"),
"cf_rhs": mpf("322.969915434166584160554286439647"),
"h_rhsDD00": mpf("30.2061942218421715191473115028784"),
"h_rhsDD01": mpf("28.9606608467494555508828366456674"),
"h_rhsDD02": mpf("6.39276587975033212268750672045724"),
"h_rhsDD11": mpf("42.3283467195269134291581082722674"),
"h_rhsDD12": mpf("32.2991327983378680834578313124515"),
"h_rhsDD22": mpf("51.8948766361664809982989623435546"),
"lambda_rhsU0": mpf("-612499.82365681001784696053940914"),
"lambda_rhsU1": mpf("703629.534233078104640500328392974"),
"lambda_rhsU2": mpf("-399325.403672102113504368138934105"),
"trK_rhs": mpf("4246.50022341571559667830337355975"),
"vet_rhsU0": mpf("56.3772664742160443704069163257009"),
"vet_rhsU1": mpf("-63.287039038369237559871735826842"),
"vet_rhsU2": mpf("41.7728207000753948065078730776413"),
}

0 comments on commit 6e3dc2e

Please sign in to comment.