Skip to content

Commit

Permalink
Optimize.m: Fix up CSE for vectorised code
Browse files Browse the repository at this point in the history
  • Loading branch information
eschnett authored and ianhinder committed May 27, 2011
1 parent 9c11d83 commit b4ac7ab
Showing 1 changed file with 40 additions and 7 deletions.
47 changes: 40 additions & 7 deletions Tools/CodeGen/Optimize.m
Expand Up @@ -26,6 +26,9 @@

Begin["`Private`"];

CSEPrint[___] = null;
(* CSEPrint = Print; *)

Options[EliminateCommonSubexpressions] = ThornOptions;
EliminateCommonSubexpressions[calc_List, OptionsPattern[]] :=
Module[{eqs, shorts, name, pdDefs, derivs, newShorts, newEqs, allShorts, newCalc},
Expand Down Expand Up @@ -55,9 +58,14 @@
];

cse[eqs_, v_, exceptions_, minSaving_:0] :=
Module[{subexprs, replacements, newEqs, defs, newDefs, i, relabelVars, allEqs, sortedEqs, newVars},
Module[{subexprs, replacements, replace, newEqs, defs, newDefs, i, relabelVars, allEqs, sortedEqs, newVars},
(* Find all possible subexpressions and how many times they occur *)
subexprs = Tally[Reap[Scan[If[! AtomQ[#], Sow[#]] &, eqs[[All,2]], Infinity]][[2, 1]]];
CSEPrint["CSE"];
CSEPrint["CSE: eqs=", eqs];
subexprs = Reap[Scan[If[! AtomQ[#], Sow[#]] &, eqs[[All,2]], Infinity]];
CSEPrint["CSE: subexprs=", subexprs];
If[subexprs[[2]]=={}, Return[{{}, eqs}]];
subexprs = Tally[subexprs[[2, 1]]];

(* Discard subexpressions which only appear once *)
subexprs = Select[subexprs, #[[2]] >= 2 &];
Expand All @@ -78,42 +86,60 @@
replacements = Thread[subexprs -> Table[v[i], {i, Length[subexprs]}]];

(* Replace common subexpressions with new variables *)
newEqs = eqs //. replacements;
(* Do not replace certain terms, e.g. the first argument of IfThen. *)
(* newEqs = eqs //. replacements; *)
CSEPrint["CSE: eqs=", eqs];
CSEPrint["CSE: replacements=", replacements];
replace[expr_] := Replace[Switch[expr,
IfThen[_,_,_], IfThen[expr[[1]], replace[expr[[2]]], replace[expr[[3]]]],
(* ToReal[_], ToReal[expr[[1]]], *)
_?AtomQ, expr,
_, Map[replace, expr]],
replacements];
newEqs = FixedPoint[replace, eqs];
CSEPrint["CSE: newEqs=", newEqs];

(* Build up definitions for the new variables *)
defs = Reverse/@replacements;
CSEPrint["CSE: defs=", defs];
For[i = 2, i <= Length[subexprs], i++,
defs[[i,2]] = defs[[i,2]] /. replacements[[1;;i-1]];
];
CSEPrint["CSE: defs=", defs];

(* Select only the definitions which are needed for the new expressions.
This accounts for cases where a subespression appears multiple times,
This accounts for cases where a subexpression appears multiple times,
but always as part of the same larger subexpression. For example, in
expr = Sqrt[(a+b)(a-b)c]+(a+b)(a-b)c+(a+b)d+Sqrt[(a+b)d+(a+b)c];
we would identify the subexpressions
{v[1]->a+b,v[2]->d v[1],v[3]->a-b,v[4]->c v[1] v[3]};
whereas all we really want it to identify is
{v[1]->a+b,v[2]->d v[1],v[4]->(a-b) c v[1]};
and the introduction of v[3] is unnecessary. To achieve this, we only
keep temporary variables which appear in the expression after substition
keep temporary variables which appear in the expression after substitution
or which appear more than once in the definition of the temporary variables.
*)
newDefs = Select[defs, (Count[newEqs, #[[1]], Infinity] > 0) ||
(Count[defs[[All,2]], #[[1]], Infinity] > 1) &];
CSEPrint["CSE: newDefs=", newDefs];

(* Replace any temporaries eliminated by the previous procedure with their definition *)
newDefs = newDefs //. Complement[defs, newDefs];
CSEPrint["CSE: newDefs2=", newDefs];

(* Check we actually have subexpressions to eliminate. Otherwise just return the original expression *)
If[Length[newDefs]==0, Return[{{}, eqs}]];

(* This is our new system of equations *)
allEqs = Join[newDefs, newEqs];
CSEPrint["CSE: allEqs=", allEqs];

sortedEqs = Fold[InsertNewEquation, newEqs, Reverse[newDefs]];
CSEPrint["CSE: sortedEqs=", sortedEqs];

(* Relabel new temporary variables so that they are sequential and C friendly *)
newVars = Select[sortedEqs[[All,1]], MemberQ[newDefs[[All, 1]],#]&];
CSEPrint["CSE: newVars=", newVars];
i = 0;
relabelVars = (# -> Symbol[ToString[v] <> ToString[i++]]) & /@ newVars;

Expand Down Expand Up @@ -148,8 +174,15 @@
];

InsertNewEquation[oldEqs_, newEq_] := Module[{before},
before = Position[oldEqs[[All,2]], newEq[[1]]][[1,1]];
Insert[oldEqs, newEq, before]
CSEPrint["InsertNewEquation oldEqs=", oldEqs, " newEq=", newEq];
(* For some reason, we can be asked to insert an equation that is
not actually needed. This should not be the case. However, handle
it gracefully for now. *)
(* before = Position[oldEqs[[All,2]], newEq[[1]]][[1,1]]; *)
before = Position[oldEqs[[All,2]], newEq[[1]]];
If[before=={},
oldEqs,
Insert[oldEqs, newEq, before[[1,1]]]]
];

End[];
Expand Down

0 comments on commit b4ac7ab

Please sign in to comment.