diff --git a/Tools/CodeGen/KrancScript.m b/Tools/CodeGen/KrancScript.m index 2ddfd3f7..4177c95f 100644 --- a/Tools/CodeGen/KrancScript.m +++ b/Tools/CodeGen/KrancScript.m @@ -50,6 +50,11 @@ (* Print["Full expression is: ", HoldForm[h[args]]]; *) ThrowError["Failed to parse script"]]; +AddTmpCalcs[eqns_] := Module[{tmpeqns={}}, + eqns /. {"TmpVar"[var_],"TmpCalc"[calc_],"Result"[res_]} :> AppendTo[tmpeqns,calc]; + Join[eqns,tmpeqns] +]; + process[thorn:"thorn"[content___]] := Module[ {calcs = {}, name, parameters = {}, variables = {}, temporaries = {}, tensors, kernels, @@ -72,8 +77,14 @@ _, ThrowError["Unrecognised element '"<>Head[el]<>"' in thorn"]], {el, {content}}]; + (* Pull out implicitly defined temporary equations *) + calcs = calcs /. (Equations->eqn_) :> (Equations->AddTmpCalcs[eqn]); + (* Pull out implicitly defined temporary variables *) + calcs = calcs /. {"TmpVar"[var_],"TmpCalc"[calc_],"Result"[res_]} :> ( + AppendTo[temporaries,var]; + res); + tensors = Join[variables,temporaries]; - Print["tensors ==> ",tensors]; kernels = Map[If[AtomQ[#],#,First[#]] &, tensors]; Scan[DefineTensor, kernels]; nonScalars = Cases[tensors, _tensor]; @@ -134,46 +145,39 @@ {b,builtIns}]; process["tensor"["name"[k_],inds_]] := - (Print["Define:","tensor"["name"[k],inds]]; - tensor[ToExpression[If[Names[k] === {}, "Global`"<>k, k]],Sequence@@process[inds]]); + tensor[ToExpression[If[Names[k] === {}, "Global`"<>k, k]],Sequence@@process[inds]]; fprint[x_] := (Print[x//InputForm]; x); (* Simple derivative rules *) TempNum = 1 Clear[PDiv,MakeTemp,dtens]; -(* -PDiv[n1_?NumberQ,inds_] := 0; -PDiv[n1_?NumberQ x2,inds_] := n1 PDiv[x2,inds]; -PDiv[x1_+x2_,inds_] := PDiv[x1,inds]+PDiv[x2,inds]; -PDiv[x1_ x2_,inds_] := PDiv[x1,inds] x2 + PDiv[x2,inds] x1; -PDiv[x1_^n2_?NumberQ,inds_] := n2 x1^(n2-1) PDiv[x1,inds]; -PDiv[x1_,inds_] := PD[x1,inds] -*) -MakeTemp[x1_,inds_] := Module[{tmpnm,inds2}, - tmpnm="Global`tmp"<>ToString[TempNum++]; - Print["indexes$$==",InputForm[inds]];tmpnm[Sequence@@inds]; +MakeTemp[x1_,inds_] := Module[{tmpnm,inds2,tensorAST}, + tmpnm="tmp"<>ToString[TempNum++]; inds2 = inds /. { TensorIndex[n1_,"l"] :> "lower_index"["index_symbol"[n1]], TensorIndex[n1_,"u"] :> "upper_index"["index_symbol"[n1]] } /. List->"indices"; - process["tensor"["name"[tmpnm],inds2]]]; + tensorAST="tensor"["name"[tmpnm],inds2]; + {tmpnm,tensorAST}]; MakeTemp[x1__] := MakeTemp[x1,freesIn[x1]]; -TensorStrQ["tensor"[__]] := True; -TensorStrQ[__] := False; - process["dtensor"["dname"[dname_],inds_,"tensor"[tensor__]]] := ToExpression[dname][process["tensor"[tensor]],Sequence@@process[inds]]; process["dtensor"["dname"["D"],inds_,"tensor"[tensor__]]] := PD[process["tensor"[tensor]],Sequence@@process[inds]]; -process["dtensor"["dname"["D"],inds_,"expr"[tensor__]]] := Module[{xxx,tmp}, +process["dtensor"["dname"["D"],inds_,"expr"[tensor__]]] := Module[{summed,tmp,tmpAST,procexpr}, procexpr =process["expr"[tensor]]; - tmp=MakeTemp[procexpr]; - Print["tmp calc=",tmp,"=",procexpr]; - xxx=makeSumOverDummies[PD[tmp,Sequence@@process[inds]]]; - xxx] + tmparray=MakeTemp[procexpr]; + tmpAST = tmparray[[2]]; + tmp = process[tmpAST]; + tmpnm = tmparray[[1]]; + summed=makeSumOverDummies[PD[tmp,Sequence@@process[inds]]]; + (* Construct the AST for the equation, re-use Kranc machinery *) + precalc="eqn"[tmpAST,"expr"[tensor]]; + calc=process[precalc]; + {"TmpVar"[tmp],"TmpCalc"[calc],"Result"[summed]}] process["dtensor"["dname"["D"], "indices"["lower_index"["index_symbol"["t"]]],"tensor"[tensor__]]] := dot[process["tensor"[tensor]]]; diff --git a/Tools/CodeGen/TensorTools.m b/Tools/CodeGen/TensorTools.m index 20a52a7c..dbedd8bb 100644 --- a/Tools/CodeGen/TensorTools.m +++ b/Tools/CodeGen/TensorTools.m @@ -269,14 +269,12 @@ DefineTensor[T_[inds__], Symmetric[{symInds__}]] := Module[{}, - Print["Defining ",T,"[",inds,"]"]; DefineTensor[T]; AssertSymetricIncreasing[T[inds], symInds]; ] DefineTensor[T_[_, _, _, _], RiemannSymmetric[{_, _, _, _}]] := Module[{}, - Print["Defining ",T,"[...]"]; DefineTensor[T]; Tensor[T, i_, j_, k_, l_] /; i > j := -T[j, i, k, l]; Tensor[T, i_, j_, k_, l_] /; i == j := 0; @@ -292,7 +290,6 @@ DefineTensor[T_] := Module[{}, - Print["Defining ",T,"[?]"]; Format[Tensor[T, is:((TensorIndex[_,_] | _Integer) ..) ], StandardForm] := Row[{T,is}]/.x_Integer->Subscript[null,x];