Skip to content

Commit

Permalink
[NB] update multi strong components to have slices (OpenModelica#12586)
Browse files Browse the repository at this point in the history
* [NB] update multi strong components to have slices

 - support slices in multi strong components

* [testsuite] update for new debug output
  • Loading branch information
kabdelhak committed Jun 15, 2024
1 parent 025d351 commit ac439f6
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 75 deletions.
36 changes: 17 additions & 19 deletions OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ public
record MULTI_COMPONENT
"component for all equations that can solve for more than one variable instance
ALGORITHM, WHEN_EQUATION, IF_EQUATION"
list<Pointer<Variable>> vars;
Pointer<Equation> eqn;
list<Slice<VariablePointer>> vars;
Slice<EquationPointer> eqn;
Solve.Status status;
end MULTI_COMPONENT;

Expand Down Expand Up @@ -171,10 +171,8 @@ public
case MULTI_COMPONENT() algorithm
str := StringUtil.headline_3("BLOCK" + indexStr + ": Multi Strong Component (status = " + Solve.statusString(comp.status) + ")");
str := str + "### Variables:\n";
for var in comp.vars loop
str := str + Variable.toString(Pointer.access(var), "\t") + "\n";
end for;
str := str + "\n### Equation:\n" + Equation.toString(Pointer.access(comp.eqn), "\t") + "\n";
str := str + List.toString(comp.vars, function Slice.toString(func = BVariable.pointerToString, maxLength = 10), "", "\t", "\n\t", "");
str := str + "\n### Equation:\n" + Slice.toString(comp.eqn, function Equation.pointerToString(str = "\t")) + "\n";
then str;

case SLICED_COMPONENT() algorithm
Expand Down Expand Up @@ -248,7 +246,7 @@ public
then ();

case MULTI_COMPONENT() algorithm
_ := match Pointer.access(comp.eqn)
_ := match Pointer.access(Slice.getT(comp.eqn))
case Equation.ALGORITHM() algorithm collector.multi_algorithm := collector.multi_algorithm + 1; Pointer.update(collector_ptr, collector); then ();
case Equation.WHEN_EQUATION() algorithm collector.multi_when := collector.multi_when + 1; Pointer.update(collector_ptr, collector); then ();
case Equation.IF_EQUATION() algorithm collector.multi_if := collector.multi_if + 1; Pointer.update(collector_ptr, collector); then ();
Expand Down Expand Up @@ -282,7 +280,7 @@ public
algorithm
i := match comp
case SINGLE_COMPONENT() then BVariable.hash(comp.var) + Equation.hash(comp.eqn);
case MULTI_COMPONENT() then Equation.hash(comp.eqn);
case MULTI_COMPONENT() then Equation.hash(Slice.getT(comp.eqn));
case SLICED_COMPONENT() then ComponentRef.hash(comp.var_cref) + Equation.hash(Slice.getT(comp.eqn));
case GENERIC_COMPONENT() then Equation.hash(Slice.getT(comp.eqn));
case ENTWINED_COMPONENT() then sum(hash(sub_comp) for sub_comp in comp.entwined_slices);
Expand All @@ -298,7 +296,7 @@ public
algorithm
b := match(comp1, comp2)
case (SINGLE_COMPONENT(), SINGLE_COMPONENT()) then BVariable.equalName(comp1.var, comp2.var) and Equation.isEqualPtr(comp1.eqn, comp2.eqn);
case (MULTI_COMPONENT(), MULTI_COMPONENT()) then Equation.isEqualPtr(comp1.eqn, comp2.eqn);
case (MULTI_COMPONENT(), MULTI_COMPONENT()) then Equation.isEqualPtr(Slice.getT(comp1.eqn), Slice.getT(comp2.eqn));
case (SLICED_COMPONENT(), SLICED_COMPONENT()) then ComponentRef.isEqual(comp1.var_cref, comp2.var_cref) and Slice.isEqual(comp1.eqn, comp2.eqn, Equation.isEqualPtr);
case (GENERIC_COMPONENT(), GENERIC_COMPONENT()) then Slice.isEqual(comp1.eqn, comp2.eqn, Equation.isEqualPtr);
case (ENTWINED_COMPONENT(), ENTWINED_COMPONENT()) then List.isEqualOnTrue(comp1.entwined_slices, comp2.entwined_slices, isEqual);
Expand Down Expand Up @@ -477,7 +475,7 @@ public
algorithm
eqn := match comp
case SINGLE_COMPONENT(status = NBSolve.Status.EXPLICIT) then comp.eqn;
case MULTI_COMPONENT(status = NBSolve.Status.EXPLICIT) then comp.eqn;
case MULTI_COMPONENT(status = NBSolve.Status.EXPLICIT) then Slice.getT(comp.eqn);
case SLICED_COMPONENT(status = NBSolve.Status.EXPLICIT) then Slice.getT(comp.eqn);
case GENERIC_COMPONENT() then Slice.getT(comp.eqn);
else algorithm
Expand Down Expand Up @@ -529,11 +527,11 @@ public
then ();

case MULTI_COMPONENT() algorithm
dependencies := Equation.collectCrefs(Pointer.access(comp.eqn), function Slice.getDependentCrefCausalized(set = set));
dependencies := Equation.collectCrefs(Pointer.access(Slice.getT(comp.eqn)), function Slice.getDependentCrefCausalized(set = set));
dependencies := list(ComponentRef.stripIteratorSubscripts(dep) for dep in dependencies);
dependencies := List.flatten(list(ComponentRef.scalarizeAll(dep) for dep in dependencies));
for var in comp.vars loop
for cref in ComponentRef.scalarizeAll(BVariable.getVarName(var)) loop
for cref in ComponentRef.scalarizeAll(BVariable.getVarName(Slice.getT(var))) loop
updateDependencyMap(cref, dependencies, map, jacType);
end for;
end for;
Expand Down Expand Up @@ -683,7 +681,7 @@ public
algorithm
vars := match comp
case SINGLE_COMPONENT() then {comp.var};
case MULTI_COMPONENT() then comp.vars;
case MULTI_COMPONENT() then list(Slice.getT(v) for v in comp.vars);
case SLICED_COMPONENT() then {Slice.getT(comp.var)};
case ENTWINED_COMPONENT() then List.flatten(list(getVariables(slice) for slice in comp.entwined_slices));
case ALGEBRAIC_LOOP() then Tearing.getResidualVars(comp.strict); // + inner?
Expand All @@ -705,7 +703,7 @@ public
algorithm
b := match comp
case SINGLE_COMPONENT() then Equation.isDiscrete(comp.eqn);
case MULTI_COMPONENT() then Equation.isDiscrete(comp.eqn);
case MULTI_COMPONENT() then Equation.isDiscrete(Slice.getT(comp.eqn));
case SLICED_COMPONENT() then Equation.isDiscrete(Slice.getT(comp.eqn));
case ENTWINED_COMPONENT() then List.all(list(isDiscrete(c) for c in comp.entwined_slices), bool_ident);
case GENERIC_COMPONENT() then Equation.isDiscrete(Slice.getT(comp.eqn));
Expand Down Expand Up @@ -769,9 +767,9 @@ public
else
// case 2: just create a single strong component
comp := match Pointer.access(eqn)
case Equation.WHEN_EQUATION() then MULTI_COMPONENT({var}, eqn, NBSolve.Status.UNPROCESSED);
case Equation.IF_EQUATION() then MULTI_COMPONENT({var}, eqn, NBSolve.Status.UNPROCESSED);
case Equation.ALGORITHM() then MULTI_COMPONENT({var}, eqn, NBSolve.Status.UNPROCESSED);
case Equation.WHEN_EQUATION() then MULTI_COMPONENT({Slice.SLICE(var, {})}, Slice.SLICE(eqn, {}), NBSolve.Status.UNPROCESSED);
case Equation.IF_EQUATION() then MULTI_COMPONENT({Slice.SLICE(var, {})}, Slice.SLICE(eqn, {}), NBSolve.Status.UNPROCESSED);
case Equation.ALGORITHM() then MULTI_COMPONENT({Slice.SLICE(var, {})}, Slice.SLICE(eqn, {}), NBSolve.Status.UNPROCESSED);
else SINGLE_COMPONENT(var, eqn, NBSolve.Status.UNPROCESSED);
end match;
end if;
Expand Down Expand Up @@ -803,8 +801,8 @@ public
// getting to this point is an actual algebraic loop
case (_, {eqn_slice}) guard(not Equation.isForEquation(Slice.getT(eqn_slice)))
then MULTI_COMPONENT(
vars = list(Slice.getT(comp_var) for comp_var in comp_vars),
eqn = Slice.getT(eqn_slice),
vars = comp_vars,
eqn = eqn_slice,
status = NBSolve.Status.UNPROCESSED
);

Expand Down
6 changes: 0 additions & 6 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBSorting.mo
Original file line number Diff line number Diff line change
Expand Up @@ -487,12 +487,6 @@ public
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because of unknown adjacency matrix type."});
then fail();
end match;
/*
print(Adjacency.Matrix.toString(adj, "before"));
print(Matching.toString(matching, "before"));
print(Adjacency.Matrix.toString(phase2_adj, "after"));
print(Matching.toString(phase2_matching, "after"));
*/
end create;

function collapse
Expand Down
68 changes: 43 additions & 25 deletions OMCompiler/Compiler/NBackEnd/Modules/3_Post/NBSolve.mo
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ public
then ({StrongComponent.SINGLE_COMPONENT(comp.var, Pointer.create(eqn), solve_status)}, solve_status);

case StrongComponent.MULTI_COMPONENT() algorithm
(eqn, funcTree, solve_status, implicit_index) := solveMultiStrongComponent(Pointer.access(comp.eqn), comp.vars, funcTree, systemType, implicit_index, slicing_map);
then ({StrongComponent.MULTI_COMPONENT(comp.vars, Pointer.create(eqn), solve_status)}, solve_status);
(eqn_slice, funcTree, solve_status, implicit_index) := solveMultiStrongComponent(comp.eqn, comp.vars, funcTree, systemType, implicit_index, slicing_map);
then ({StrongComponent.MULTI_COMPONENT(comp.vars, eqn_slice, solve_status)}, solve_status);

case StrongComponent.ALGEBRAIC_LOOP(strict = strict) algorithm
for inner_comp in listReverse(arrayList(strict.innerEquations)) loop
Expand Down Expand Up @@ -403,32 +403,37 @@ public
end solveSingleStrongComponent;

function solveMultiStrongComponent
input output Equation eqn;
input list<Pointer<Variable>> vars;
input output Slice<EquationPointer> eqn_slice;
input list<Slice<VariablePointer>> var_slices;
input output FunctionTree funcTree;
input SystemType systemType;
output Status status;
input output Integer implicit_index;
input UnorderedMap<ComponentRef, list<Pointer<Equation>>> slicing_map;
algorithm
(eqn, funcTree, status) := match eqn
protected
Equation eqn = Pointer.access(Slice.getT(eqn_slice));
algorithm
(eqn_slice, funcTree, status) := match eqn
local
list<Pointer<Variable>> vars = list(Slice.getT(v) for v in var_slices);
Equation solved_eqn;
IfEquationBody if_body;
Expression lhs, rhs;
list<Option<Pointer<Variable>>> record_parents;
Pointer<Variable> parent;
ComponentRef var_cref;
UnorderedSet<ComponentRef> record_crefs;

case Equation.IF_EQUATION() algorithm
(if_body, funcTree, status, implicit_index) := solveIfBody(eqn.body, VariablePointers.fromList(vars), funcTree, systemType, implicit_index, slicing_map);
eqn.body := if_body;
then (eqn, funcTree, status);
then (Slice.SLICE(Pointer.create(eqn), eqn_slice.indices), funcTree, status);

// ToDo: inverse algorithms
case Equation.ALGORITHM() then (eqn, funcTree, Status.EXPLICIT);
case Equation.ALGORITHM()
then (Slice.SLICE(Pointer.clone(Slice.getT(eqn_slice)), eqn_slice.indices), funcTree, Status.EXPLICIT);

// for now assume they are solved
case Equation.WHEN_EQUATION() then (eqn, funcTree, Status.EXPLICIT);
case Equation.WHEN_EQUATION()
then (Slice.SLICE(Pointer.clone(Slice.getT(eqn_slice)), eqn_slice.indices), funcTree, Status.EXPLICIT);

// solve tuple equations
case Equation.RECORD_EQUATION() algorithm
Expand All @@ -442,21 +447,26 @@ public
then (eqn, Status.EXPLICIT);
else algorithm
// check if all belong to the same record
record_parents := list(BVariable.getParent(var) for var in vars);
solved_eqn := match UnorderedSet.unique_list(record_parents, function Util.optionHash(inFunc = BVariable.hash), function Util.optionEqual(inFunc = BVariable.equalName))
case {SOME(parent)} algorithm
(solved_eqn, funcTree, status, _) := solveBody(eqn, BVariable.getVarName(parent), funcTree);
record_crefs := UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
for var_slice in var_slices loop
(var_cref, status) := getVarSlice(BVariable.getVarName(Slice.getT(var_slice)), eqn);
UnorderedSet.add(var_cref, record_crefs);
if status == Status.UNSOLVABLE then break; end if;
end for;

solved_eqn := match (UnorderedSet.toList(record_crefs), status)
case ({var_cref}, Status.UNPROCESSED) algorithm
(solved_eqn, funcTree, status, _) := solveBody(eqn, var_cref, funcTree);
then solved_eqn;
else algorithm
status := Status.IMPLICIT;
then eqn;
else eqn;
end match;

then (solved_eqn, status);
end match;
then (solved_eqn, funcTree, status);
then (Slice.SLICE(Pointer.create(solved_eqn), eqn_slice.indices), funcTree, status);

else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for equation:\n" + Equation.toString(eqn)});
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for equation:\n" + Slice.toString(eqn_slice, function Equation.pointerToString(str = ""))});
then fail();
end match;
end solveMultiStrongComponent;
Expand All @@ -474,14 +484,15 @@ public
(eqn, funcTree, status, invertRelation) := match eqn
local
Equation body;
Slice<EquationPointer> body_slice;
Pointer<Variable> indexed_var;

// For equations are expected to only have one body equation at this point
case Equation.FOR_EQUATION(body = {body as Equation.IF_EQUATION()}) algorithm
// create indexed variable to trick matching algorithm to solve for it
indexed_var := BVariable.makeVarPtrCyclic(BVariable.getVar(cref), cref);
(body, funcTree, status, implicit_index) := solveMultiStrongComponent(body, {indexed_var}, funcTree, systemType, implicit_index, slicing_map);
eqn.body := {body};
(body_slice, funcTree, status, implicit_index) := solveMultiStrongComponent(Slice.SLICE(Pointer.create(body), {}), {Slice.SLICE(indexed_var, {})}, funcTree, systemType, implicit_index, slicing_map);
eqn.body := {Pointer.access(Slice.getT(body_slice))};
then (eqn, funcTree, status, false);

case Equation.FOR_EQUATION(body = {body}) algorithm
Expand Down Expand Up @@ -759,16 +770,23 @@ protected
output Status solve_status;
protected
list<ComponentRef> slices_lst;
Option<Pointer<Variable>> record_parent;
algorithm
slices_lst := Equation.collectCrefs(eqn, function Slice.getSliceCandidates(name = var_cref));

if listLength(slices_lst) == 1 then
var_cref := List.first(slices_lst);
solve_status := Status.UNPROCESSED;
else
// todo: choose best slice of list if more than one.
// only fail for listLength == 0
solve_status := Status.UNSOLVABLE;
// check if the record parents occur (todo: vice versa?)
record_parent := BVariable.getParent(BVariable.getVarPointer(var_cref));
if Util.isSome(record_parent) then
(var_cref, solve_status) := getVarSlice(BVariable.getVarName(Util.getOption(record_parent)), eqn);
else
// todo: choose best slice of list if more than one.
// only fail for listLength == 0
solve_status := Status.UNSOLVABLE;
end if;
end if;
end getVarSlice;

Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NBackEnd/Modules/3_Post/NBTearing.mo
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ public
case StrongComponent.MULTI_COMPONENT() algorithm
new_comp := StrongComponent.ALGEBRAIC_LOOP(
idx = index,
strict = singleImplicit(List.first(comp.vars), comp.eqn), // this is wrong! need to take all vars
strict = singleImplicit(Slice.getT(List.first(comp.vars)), Slice.getT(comp.eqn)), // this is wrong! need to take all vars
casual = NONE(),
linear = false,
mixed = false,
Expand Down
10 changes: 5 additions & 5 deletions OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public
local
Pointer<Variable> new_var;
Pointer<Equation> new_eqn;
list<Pointer<Variable>> new_vars;
list<Slice<VariablePointer>> new_var_slices;
list<Pointer<Equation>> new_eqns;
ComponentRef new_cref;
Slice<VariablePointer> new_var_slice;
Expand All @@ -151,10 +151,10 @@ public
then StrongComponent.SINGLE_COMPONENT(new_var, new_eqn, comp.status);

case StrongComponent.MULTI_COMPONENT() algorithm
new_vars := list(differentiateVariablePointer(var, diffArguments_ptr) for var in comp.vars);
new_eqn := differentiateEquationPointer(comp.eqn, diffArguments_ptr, name);
Equation.createName(new_eqn, idx, context);
then StrongComponent.MULTI_COMPONENT(new_vars, new_eqn, comp.status);
new_var_slices := list(Slice.apply(var, function differentiateVariablePointer(diffArguments_ptr = diffArguments_ptr)) for var in comp.vars);
new_eqn_slice := Slice.apply(comp.eqn, function differentiateEquationPointer(diffArguments_ptr = diffArguments_ptr, name = name));
Equation.createName(Slice.getT(new_eqn_slice), idx = idx, context = context);
then StrongComponent.MULTI_COMPONENT(new_var_slices, new_eqn_slice, comp.status);

case StrongComponent.SLICED_COMPONENT() algorithm
(Expression.CREF(cref = new_cref), diffArguments) := differentiateComponentRef(Expression.fromCref(comp.var_cref), Pointer.access(diffArguments_ptr));
Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NBackEnd/Util/NBSlice.mo
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public
String sliceStr;
algorithm
str := func(slice.t);
if maxLength > 0 then
if maxLength > 0 and not listEmpty(slice.indices) then
str := str + "\n\t slice: " + List.toString(inList = slice.indices, inPrintFunc = intString, maxLength = 10);
end if;
end toString;
Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NSimCode/NSimStrongComponent.mo
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ public
then (tmp, getIndex(tmp));

case StrongComponent.MULTI_COMPONENT() algorithm
(tmp, simCodeIndices) := createEquation(NBVariable.DUMMY_VARIABLE, Pointer.access(comp.eqn), comp.status, simCodeIndices, systemType, simcode_map, equation_map);
(tmp, simCodeIndices) := createEquation(NBVariable.DUMMY_VARIABLE, Pointer.access(Slice.getT(comp.eqn)), comp.status, simCodeIndices, systemType, simcode_map, equation_map);
then (tmp, getIndex(tmp));

case StrongComponent.SLICED_COMPONENT() guard(Equation.isForEquation(Slice.getT(comp.eqn))) algorithm
Expand Down
4 changes: 4 additions & 0 deletions OMCompiler/Compiler/Util/Pointer.mo
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ static inline void* pointerAccess(void *ptr)
");
end access;

function clone
input output Pointer<T> mutable = create(access(mutable));
end clone;

function apply
input output Pointer<T> mutable;
input Func func;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ val(x[3],1);
// x[4]
// ### Equation:
// [SCAL] (1) x[4] = $FUN_2 ($RES_SIM_0)
// slice: {}
//
// BLOCK 3: Sliced Component (status = Solve.UNPROCESSED)
// --------------------------------------------------------
Expand Down Expand Up @@ -103,7 +102,6 @@ val(x[3],1);
// x[4]
// ### Equation:
// [SCAL] (1) x[4] = $FUN_2 ($RES_SIM_0)
// slice: {}
//
// BLOCK 3: Sliced Component (status = Solve.UNPROCESSED)
// --------------------------------------------------------
Expand Down Expand Up @@ -161,7 +159,6 @@ val(x[3],1);
// x[4]
// ### Equation:
// [SCAL] (1) x[4] = $FUN_2 ($RES_SIM_0)
// slice: {}
//
// BLOCK 3: Sliced Component (status = Solve.UNPROCESSED)
// --------------------------------------------------------
Expand Down Expand Up @@ -221,7 +218,6 @@ val(x[3],1);
// x[4]
// ### Equation:
// [SCAL] (1) x[4] = $FUN_2 ($RES_SIM_0)
// slice: {}
//
// --- Alias of INI[1 | 3] ---
// BLOCK 3: Generic Component (status = Solve.EXPLICIT)
Expand Down Expand Up @@ -282,7 +278,6 @@ val(x[3],1);
// x[4]
// ### Equation:
// [SCAL] (1) x[4] = $FUN_2 ($RES_SIM_0)
// slice: {}
//
// BLOCK 3: Generic Component (status = Solve.EXPLICIT)
// ------------------------------------------------------
Expand Down
Loading

0 comments on commit ac439f6

Please sign in to comment.