Skip to content

Commit

Permalink
Calculation.m: Add function to compute derivatives in separate loops …
Browse files Browse the repository at this point in the history
…and store them in grid functions
  • Loading branch information
ianhinder committed Apr 3, 2012
1 parent d73aed9 commit 8498976
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions Tools/CodeGen/Calculation.m
Expand Up @@ -31,6 +31,7 @@
CalculationOnDevice;
GetCalculationWhere;
SplitCalculations;
SeparateDerivatives;

Begin["`Private`"];

Expand Down Expand Up @@ -71,6 +72,11 @@
GetPartialDerivatives[calc_List] :=
lookup[calc,PartialDerivatives]];

DefFn[
GetDerivatives[calc_] :=
GridFunctionDerivativesInExpression[
GetPartialDerivatives[calc], GetEquations[calc],{(* TODO: implement ZeroDimensions here *)}]];

DefFn[
GetCalculationParameters[calc_List] :=
Module[
Expand Down Expand Up @@ -152,6 +158,53 @@
partialCalculation[calc, nameSuffix, {}, splitVars]]],
splitBy]]]];

DefFn[
SeparateDerivatives[calcs_List] :=
Flatten[separateDerivativesInCalculation/@calcs,1]];

(* If the calculation contains a SeparatedDerivatives key, split the
calculation into two. The first one will compute all the
derivatives matching the SeparatedDerivatives pattern and store the
results in grid functions. The second will then use these grid
functions instead of computing the derivatives. *)

separateDerivativesInCalculation[calc_] :=
Module[
{sepPat = lookup[calc,SeparatedDerivatives, None]},
If[sepPat === None, {calc},
If[lookupDefault[calc, Schedule, Automatic] === Automatic,
ThrowError["Separating derivatives in an automatically scheduled function is not supported"]];

Module[
{derivGFName, derivs, sepDerivs, calc2, replaceSymmetric},
derivGFName[pd_[var_,inds___]] :=
Symbol["Global`D"<>ToString[pd]<>ToString[var]<>Apply[StringJoin,Map[ToString,{inds}]]];

replaceSymmetric = pd_[var_,i_,j_] /; i > j :> pd[var,j,i];
derivs = DeleteDuplicates[GetDerivatives[calc] /. replaceSymmetric];
sepDerivs = Flatten[Map[Cases[derivs, #] &, sepPat],1];

derivCalc[sepDeriv_] :=
Module[
{calc1, currentGroups, localGroups, derivName = derivGFName[sepDeriv]},
calc1 = mapReplace[calc,
Equations,
{derivName -> sepDeriv}];
calc1 = mapReplace[calc1, Schedule, Map[#<>" before "<>lookup[calc,Name] &, lookup[calc,Schedule]]];
calc1 = mapReplace[calc1, Name, lookup[calc,Name]<>"_"<>ToString@derivName];
currentGroups = lookup[calc, LocalGroups, {}];
localGroups = Append[currentGroups, {ToString@derivName<>"_group", {derivName}}];
calc1 = mapReplaceAdd[calc1, LocalGroups, localGroups];
calc1];

derivCalcs = Map[derivCalc, sepDerivs];

calc2 = mapReplace[calc,
Equations,
(GetEquations[calc]/.replaceSymmetric) /. Map[# -> derivGFName[#] &, sepDerivs]];

Append[derivCalcs, calc2]]]];

End[];

EndPackage[];

0 comments on commit 8498976

Please sign in to comment.