From 73adf0650416ceb8b9d147a9d1353c22daa64b1e Mon Sep 17 00:00:00 2001 From: kabdelhak <38032125+kabdelhak@users.noreply.github.com> Date: Tue, 18 Jun 2024 11:59:28 +0200 Subject: [PATCH] [NB] correctly differentiate min and max functions (#12601) - safety update: match list outputs instead of just hoping they have the correct length --- .../Compiler/NBackEnd/Util/NBDifferentiate.mo | 53 ++++++++++++++----- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo b/OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo index a2f9447b78..27e583bf72 100644 --- a/OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo +++ b/OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo @@ -854,7 +854,12 @@ public // d/dz delay(x, delta) = (dt/dz - d delta/dz) * delay(der(x), delta) case (Expression.CALL()) guard(name == "delay") algorithm - {arg1, arg2, arg3} := Call.arguments(exp.call); + (arg1, arg2, arg3) := match Call.arguments(exp.call) + case {arg1, arg2, arg3} then (arg1, arg2, arg3); + else algorithm + Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."}); + then fail(); + end match; // if z = t then dt/dz = 1 else dt/dz = 0 ret1 := Expression.REAL(if diffArguments.diffType == DifferentiationType.TIME then 1.0 else 0.0); // d delta/dz @@ -876,13 +881,12 @@ public // SMOOTH case (Expression.CALL()) guard(name == "smooth") algorithm - {arg1, arg2} := Call.arguments(exp.call); - ret := match arg1 - case Expression.INTEGER(i) guard(i > 0) algorithm + ret := match Call.arguments(exp.call) + case {arg1 as Expression.INTEGER(i), arg2} guard(i > 0) algorithm (ret2, diffArguments) := differentiateExpression(arg2, diffArguments); exp.call := Call.setArguments(exp.call, {Expression.INTEGER(i-1), ret2}); then exp; - case Expression.INTEGER(i) algorithm + case {arg1 as Expression.INTEGER(i), arg2} algorithm (ret2, diffArguments) := differentiateExpression(arg2, diffArguments); exp := Expression.CALL(Call.makeTypedCall( fn = NFBuiltinFuncs.NO_EVENT, @@ -900,15 +904,25 @@ public // NO_EVENT case (Expression.CALL()) guard(name == "noEvent") algorithm - {arg1} := Call.arguments(exp.call); + arg1 := match Call.arguments(exp.call) + case {arg1} then arg1; + else algorithm + Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."}); + then fail(); + end match; (ret1, diffArguments) := differentiateExpression(arg1, diffArguments); exp.call := Call.setArguments(exp.call, {ret1}); then exp; - // HOMOTOPY - case (Expression.CALL()) guard(name == "homotopy") + // MIN, MAX, HOMOTOPY + case (Expression.CALL()) guard(List.contains({"min", "max", "homotopy"}, name, stringEqual)) algorithm - {arg1, arg2} := Call.arguments(exp.call); + (arg1, arg2) := match Call.arguments(exp.call) + case {arg1, arg2} then (arg1, arg2); + else algorithm + Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."}); + then fail(); + end match; (ret1, diffArguments) := differentiateExpression(arg1, diffArguments); (ret2, diffArguments) := differentiateExpression(arg2, diffArguments); exp.call := Call.setArguments(exp.call, {ret1, ret2}); @@ -927,7 +941,12 @@ public // d sL(x, m1, m2)/dt = sL(x, dm1/dt, dm2/dt) + dx/dt * if (x>=0) then m1 else m2 case (Expression.CALL()) guard(name == "semiLinear") algorithm - {arg1, arg2, arg3} := Call.arguments(exp.call); + (arg1, arg2, arg3) := match Call.arguments(exp.call) + case {arg1, arg2, arg3} then (arg1, arg2, arg3); + else algorithm + Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."}); + then fail(); + end match; // dx/dt, dm1/dt, dm2/dt (diffArg1, diffArguments) := differentiateExpression(arg1, diffArguments); @@ -956,8 +975,13 @@ public // df(y)/dx = df/dy * dy/dx case (Expression.CALL()) guard(listLength(Call.arguments(exp.call)) == 1) algorithm + arg1 := match Call.arguments(exp.call) + case {arg1} then arg1; + else algorithm + Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."}); + then fail(); + end match; // differentiate the call - {arg1} := Call.arguments(exp.call); (ret, diffArguments) := differentiateBuiltinCall1Arg(name, arg1, diffArguments); if not Expression.isZero(ret) then // differentiate the argument (inner derivative) @@ -970,8 +994,13 @@ public // df(y,z)/dx = df/dy * dy/dx + df/dz * dz/dx case (Expression.CALL()) guard(listLength(Call.arguments(exp.call)) == 2) algorithm + (arg1, arg2) := match Call.arguments(exp.call) + case {arg1, arg2} then (arg1, arg2); + else algorithm + Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for: " + Expression.toString(exp) + "."}); + then fail(); + end match; // differentiate the call - {arg1, arg2} := Call.arguments(exp.call); (ret1, ret2) := differentiateBuiltinCall2Arg(name, arg1, arg2); // df/dy and df/dz diffArg1 := differentiateExpression(arg1, diffArguments); // dy/dx diffArg2 := differentiateExpression(arg2, diffArguments); // dz/dx