Skip to content

Commit

Permalink
[NB] correctly differentiate min and max functions (OpenModelica#12601)
Browse files Browse the repository at this point in the history
- safety update: match list outputs instead of just hoping they have the correct length
  • Loading branch information
kabdelhak committed Jun 18, 2024
1 parent 680083d commit 73adf06
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,12 @@ public
// d/dz delay(x, delta) = (dt/dz - d delta/dz) * delay(der(x), delta)
case (Expression.CALL()) guard(name == "delay")
algorithm
{arg1, arg2, arg3} := Call.arguments(exp.call);
(arg1, arg2, arg3) := match Call.arguments(exp.call)
case {arg1, arg2, arg3} then (arg1, arg2, arg3);
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
then fail();
end match;
// if z = t then dt/dz = 1 else dt/dz = 0
ret1 := Expression.REAL(if diffArguments.diffType == DifferentiationType.TIME then 1.0 else 0.0);
// d delta/dz
Expand All @@ -876,13 +881,12 @@ public
// SMOOTH
case (Expression.CALL()) guard(name == "smooth")
algorithm
{arg1, arg2} := Call.arguments(exp.call);
ret := match arg1
case Expression.INTEGER(i) guard(i > 0) algorithm
ret := match Call.arguments(exp.call)
case {arg1 as Expression.INTEGER(i), arg2} guard(i > 0) algorithm
(ret2, diffArguments) := differentiateExpression(arg2, diffArguments);
exp.call := Call.setArguments(exp.call, {Expression.INTEGER(i-1), ret2});
then exp;
case Expression.INTEGER(i) algorithm
case {arg1 as Expression.INTEGER(i), arg2} algorithm
(ret2, diffArguments) := differentiateExpression(arg2, diffArguments);
exp := Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.NO_EVENT,
Expand All @@ -900,15 +904,25 @@ public
// NO_EVENT
case (Expression.CALL()) guard(name == "noEvent")
algorithm
{arg1} := Call.arguments(exp.call);
arg1 := match Call.arguments(exp.call)
case {arg1} then arg1;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
then fail();
end match;
(ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
exp.call := Call.setArguments(exp.call, {ret1});
then exp;

// HOMOTOPY
case (Expression.CALL()) guard(name == "homotopy")
// MIN, MAX, HOMOTOPY
case (Expression.CALL()) guard(List.contains({"min", "max", "homotopy"}, name, stringEqual))
algorithm
{arg1, arg2} := Call.arguments(exp.call);
(arg1, arg2) := match Call.arguments(exp.call)
case {arg1, arg2} then (arg1, arg2);
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
then fail();
end match;
(ret1, diffArguments) := differentiateExpression(arg1, diffArguments);
(ret2, diffArguments) := differentiateExpression(arg2, diffArguments);
exp.call := Call.setArguments(exp.call, {ret1, ret2});
Expand All @@ -927,7 +941,12 @@ public
// d sL(x, m1, m2)/dt = sL(x, dm1/dt, dm2/dt) + dx/dt * if (x>=0) then m1 else m2
case (Expression.CALL()) guard(name == "semiLinear")
algorithm
{arg1, arg2, arg3} := Call.arguments(exp.call);
(arg1, arg2, arg3) := match Call.arguments(exp.call)
case {arg1, arg2, arg3} then (arg1, arg2, arg3);
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
then fail();
end match;

// dx/dt, dm1/dt, dm2/dt
(diffArg1, diffArguments) := differentiateExpression(arg1, diffArguments);
Expand Down Expand Up @@ -956,8 +975,13 @@ public
// df(y)/dx = df/dy * dy/dx
case (Expression.CALL()) guard(listLength(Call.arguments(exp.call)) == 1)
algorithm
arg1 := match Call.arguments(exp.call)
case {arg1} then arg1;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
then fail();
end match;
// differentiate the call
{arg1} := Call.arguments(exp.call);
(ret, diffArguments) := differentiateBuiltinCall1Arg(name, arg1, diffArguments);
if not Expression.isZero(ret) then
// differentiate the argument (inner derivative)
Expand All @@ -970,8 +994,13 @@ public
// df(y,z)/dx = df/dy * dy/dx + df/dz * dz/dx
case (Expression.CALL()) guard(listLength(Call.arguments(exp.call)) == 2)
algorithm
(arg1, arg2) := match Call.arguments(exp.call)
case {arg1, arg2} then (arg1, arg2);
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."});
then fail();
end match;
// differentiate the call
{arg1, arg2} := Call.arguments(exp.call);
(ret1, ret2) := differentiateBuiltinCall2Arg(name, arg1, arg2); // df/dy and df/dz
diffArg1 := differentiateExpression(arg1, diffArguments); // dy/dx
diffArg2 := differentiateExpression(arg2, diffArguments); // dz/dx
Expand Down

0 comments on commit 73adf06

Please sign in to comment.