Skip to content

Commit

Permalink
Convert addition to a more formal fold.
Browse files Browse the repository at this point in the history
  • Loading branch information
corywalker committed Jul 16, 2017
1 parent 81aab97 commit fbd9f3a
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 126 deletions.
197 changes: 77 additions & 120 deletions expreduce/builtin_arithmetic.go
Expand Up @@ -31,6 +31,70 @@ func RationalAssertion(num Ex, den Ex) (r *Rational, isR bool) {
return NewRational(numInt.Val, denInt.Val), true
}

func typedRealPart(i *Integer, r *Rational, f *Flt) Ex {
if f != nil {
toReturn := f
if r != nil {
toReturn.AddR(r)
}
if i != nil {
toReturn.AddI(i)
}
return toReturn
}
if r != nil {
toReturn := r
if i != nil {
toReturn.AddI(i)
}
return toReturn
}
if i != nil {
return i
}
return nil
}

func computeRealPart(e *Expression) (Ex, int) {
var foldedInt *Integer
var foldedRat *Rational
var foldedFlt *Flt
for i := 1; i < len(e.Parts); i++ {
asInt, isInt := e.Parts[i].(*Integer)
if isInt {
if foldedInt == nil {
// Try deepcopy if problems. I think this does not cause
// problems now because we will only modify the value if we end
// up creating an entirely new expression.
foldedInt = asInt.DeepCopy().(*Integer)
continue
}
foldedInt.AddI(asInt)
continue
}
asRat, isRat := e.Parts[i].(*Rational)
if isRat {
if foldedRat == nil {
foldedRat = asRat.DeepCopy().(*Rational)
continue
}
foldedRat.AddR(asRat)
continue
}
asFlt, isFlt := e.Parts[i].(*Flt)
if isFlt {
if foldedFlt == nil {
foldedFlt = asFlt.DeepCopy().(*Flt)
continue
}
foldedFlt.AddF(asFlt)
continue
}
return typedRealPart(foldedInt, foldedRat, foldedFlt), i
}
return typedRealPart(foldedInt, foldedRat, foldedFlt), -1
}

func getArithmeticDefinitions() (defs []Definition) {
defs = append(defs, Definition{
Name: "Plus",
Expand All @@ -51,133 +115,26 @@ func getArithmeticDefinitions() (defs []Definition) {
return &Integer{big.NewInt(0)}
}

addends := this.Parts[1:len(this.Parts)]
// If this expression contains any floats, convert everything possible to
// a float
if ExArrayContainsFloat(addends) {
for i, e := range addends {
subint, isint := e.(*Integer)
subrat, israt := e.(*Rational)
if isint {
newfloat := big.NewFloat(0)
newfloat.SetInt(subint.Val)
addends[i] = &Flt{newfloat}
} else if israt {
num := big.NewFloat(0)
den := big.NewFloat(0)
newquo := big.NewFloat(0)
num.SetInt(subrat.Num)
den.SetInt(subrat.Den)
newquo.Quo(num, den)
addends[i] = &Flt{newquo}
}
}
}

// Accumulate floating point values towards the end of the expression
var lastf *Flt = nil
for _, e := range addends {
f, ok := e.(*Flt)
if ok {
if lastf != nil {
f.Val.Add(f.Val, lastf.Val)
lastf.Val = big.NewFloat(0)
}
lastf = f
}
}

if len(addends) == 1 {
f, fOk := addends[0].(*Flt)
if fOk {
if f.Val.Cmp(big.NewFloat(0)) == 0 {
return f
}
}
i, iOk := addends[0].(*Integer)
if iOk {
if i.Val.Cmp(big.NewInt(0)) == 0 {
return i
}
}
}

// Remove zero Floats
for i := len(addends) - 1; i >= 0; i-- {
f, ok := addends[i].(*Flt)
if ok && f.Val.Cmp(big.NewFloat(0)) == 0 && len(addends) > 1 {
addends[i] = addends[len(addends)-1]
addends[len(addends)-1] = nil
addends = addends[:len(addends)-1]
}
}

// Accumulate integer values towards the end of the expression
var lasti *Integer = nil
for _, e := range addends {
i, ok := e.(*Integer)
if ok {
if lasti != nil {
i.Val.Add(i.Val, lasti.Val)
lasti.Val = big.NewInt(0)
}
lasti = i
}
}

// Accumulate rational values towards the end of the expression
var lastr *Rational = nil
for _, e := range addends {
therat, ok := e.(*Rational)
if ok {
if lastr != nil {
tmp := big.NewInt(0)
// lastrNum/lastrDen + theratNum/theratDen // Together
tmp.Mul(therat.Den, lastr.Num)
therat.Den.Mul(therat.Den, lastr.Den)
therat.Num.Mul(therat.Num, lastr.Den)
therat.Num.Add(therat.Num, tmp)
lastr.Num = big.NewInt(0)
lastr.Den = big.NewInt(1)
}
lastr = therat
}
}

// If there is one Integer and one Rational left, merge the Integer into
// the Rational
if lasti != nil && lastr != nil {
lasti.Val.Mul(lasti.Val, lastr.Den)
lastr.Num.Add(lastr.Num, lasti.Val)
lasti.Val = big.NewInt(0)
}

// Remove zero Integers and Rationals
for i := len(addends) - 1; i >= 0; i-- {
toRemove := false
theint, isInt := addends[i].(*Integer)
if isInt {
toRemove = theint.Val.Cmp(big.NewInt(0)) == 0
}
therat, isRat := addends[i].(*Rational)
if isRat {
toRemove = therat.Num.Cmp(big.NewInt(0)) == 0 && therat.Den.Cmp(big.NewInt(1)) == 0
res := this
realPart, symStart := computeRealPart(this)
if realPart != nil {
if symStart == -1 {
return realPart
}
if toRemove && len(addends) > 1 {
addends[i] = addends[len(addends)-1]
addends[len(addends)-1] = nil
addends = addends[:len(addends)-1]
res = NewExpression([]Ex{&Symbol{"Plus"}})
rAsInt, rIsInt := realPart.(*Integer)
if !(rIsInt && rAsInt.Val.Cmp(big.NewInt(0)) == 0) {
res.Parts = append(res.Parts, realPart)
}
res.Parts = append(res.Parts, this.Parts[symStart:]...)
}

// If one expression remains, replace this Plus with the expression
if len(addends) == 1 {
return addends[0]
if len(res.Parts) == 2 {
return res.Parts[1]
}

this.Parts = this.Parts[0:1]
this.Parts = append(this.Parts, addends...)
return this
return res
},
SimpleExamples: []TestInstruction{
&SameTest{"2", "1 + 1"},
Expand Down
16 changes: 10 additions & 6 deletions expreduce/ex_expression.go
Expand Up @@ -124,12 +124,6 @@ func (this *Expression) Eval(es *EvalState) Ex {
return curr
}

currStr := ""
started := int64(0)
if es.isProfiling {
currStr = currEx.String()
started = time.Now().UnixNano()
}
if *printevals {
fmt.Printf("Evaluating %v.\n", currEx)
}
Expand All @@ -152,6 +146,15 @@ func (this *Expression) Eval(es *EvalState) Ex {
return toReturn
}

currStr := ""
currHeadStr := ""
started := int64(0)
if es.isProfiling {
currStr = curr.String()
currHeadStr = curr.Parts[0].String()
started = time.Now().UnixNano()
}

// Start by evaluating each argument
headSym, headIsSym := &Symbol{}, false
attrs := Attributes{}
Expand Down Expand Up @@ -264,6 +267,7 @@ func (this *Expression) Eval(es *EvalState) Ex {
if es.isProfiling {
elapsed := float64(time.Now().UnixNano() - started) / 1000000000
es.timeCounter.AddTime(CounterGroupEvalTime, currStr, elapsed)
es.timeCounter.AddTime(CounterGroupHeadEvalTime, currHeadStr, elapsed)
}
}
curr, isExpr := currEx.(*Expression)
Expand Down
10 changes: 10 additions & 0 deletions expreduce/ex_integer.go
Expand Up @@ -59,3 +59,13 @@ func (this *Integer) Hash(h *hash.Hash64) {
bytes, _ := this.Val.MarshalText()
(*h).Write(bytes)
}

func (this *Integer) AsBigFloat() *big.Float {
newfloat := big.NewFloat(0)
newfloat.SetInt(this.Val)
return newfloat
}

func (this *Integer) AddI(i *Integer) {
this.Val.Add(this.Val, i.Val)
}
25 changes: 25 additions & 0 deletions expreduce/ex_rational.go
Expand Up @@ -103,3 +103,28 @@ func (this *Rational) Hash(h *hash.Hash64) {
dBytes, _ := this.Den.MarshalText()
(*h).Write(dBytes)
}

func (this *Rational) AsBigFloat() *big.Float {
num := big.NewFloat(0)
den := big.NewFloat(0)
newquo := big.NewFloat(0)
num.SetInt(this.Num)
den.SetInt(this.Den)
newquo.Quo(num, den)
return newquo
}

func (this *Rational) AddI(i *Integer) {
tmp := big.NewInt(0)
tmp.Mul(i.Val, this.Den)
this.Num.Add(this.Num, tmp)
}

func (this *Rational) AddR(r *Rational) {
tmp := big.NewInt(0)
// lastrNum/lastrDen + theratNum/theratDen // Together
tmp.Mul(this.Den, r.Num)
this.Den.Mul(this.Den, r.Den)
this.Num.Mul(this.Num, r.Den)
this.Num.Add(this.Num, tmp)
}
12 changes: 12 additions & 0 deletions expreduce/ex_real.go
Expand Up @@ -73,3 +73,15 @@ func (this *Flt) Hash(h *hash.Hash64) {
bytes, _ := this.Val.MarshalText()
(*h).Write(bytes)
}

func (this *Flt) AddI(i *Integer) {
this.Val.Add(this.Val, i.AsBigFloat())
}

func (this *Flt) AddR(r *Rational) {
this.Val.Add(this.Val, r.AsBigFloat())
}

func (this *Flt) AddF(f *Flt) {
this.Val.Add(this.Val, f.Val)
}
7 changes: 7 additions & 0 deletions expreduce/time_counter.go
Expand Up @@ -62,18 +62,21 @@ const (
CounterGroupDefTime CounterGroupType = iota + 1
CounterGroupLhsDefTime
CounterGroupEvalTime
CounterGroupHeadEvalTime
)

type TimeCounterGroup struct {
defTimeCounter TimeCounter
lhsDefTimeCounter TimeCounter
evalTimeCounter TimeCounter
headEvalTimeCounter TimeCounter
}

func (tcg *TimeCounterGroup) Init() {
tcg.defTimeCounter.Init()
tcg.lhsDefTimeCounter.Init()
tcg.evalTimeCounter.Init()
tcg.headEvalTimeCounter.Init()
}

func (tcg *TimeCounterGroup) AddTime(counter CounterGroupType, key string, elapsed float64) {
Expand All @@ -83,19 +86,23 @@ func (tcg *TimeCounterGroup) AddTime(counter CounterGroupType, key string, elaps
tcg.lhsDefTimeCounter.AddTime(key, elapsed)
} else if counter == CounterGroupEvalTime {
tcg.evalTimeCounter.AddTime(key, elapsed)
} else if counter == CounterGroupHeadEvalTime {
tcg.headEvalTimeCounter.AddTime(key, elapsed)
}
}

func (tcg *TimeCounterGroup) Update(other *TimeCounterGroup) {
tcg.defTimeCounter.Update(&other.defTimeCounter)
tcg.lhsDefTimeCounter.Update(&other.lhsDefTimeCounter)
tcg.evalTimeCounter.Update(&other.evalTimeCounter)
tcg.headEvalTimeCounter.Update(&other.headEvalTimeCounter)
}

func (tcg *TimeCounterGroup) String() string {
var buffer bytes.Buffer
buffer.WriteString(tcg.defTimeCounter.String() + "\n")
buffer.WriteString(tcg.lhsDefTimeCounter.String() + "\n")
buffer.WriteString(tcg.evalTimeCounter.String() + "\n")
buffer.WriteString(tcg.headEvalTimeCounter.String() + "\n")
return buffer.String()
}

0 comments on commit fbd9f3a

Please sign in to comment.