Skip to content

Commit

Permalink
Remove the global toStringFns map.
Browse files Browse the repository at this point in the history
  • Loading branch information
corywalker committed Oct 10, 2018
1 parent fd5082e commit 731ea52
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 53 deletions.
2 changes: 1 addition & 1 deletion expreduce/builtin_pattern.go
Expand Up @@ -11,7 +11,7 @@ func ToStringBlankType(repr string, parts []Ex, params ToStringParams) (bool, st
} else if len(parts) == 2 {
var buffer bytes.Buffer
buffer.WriteString(repr)
buffer.WriteString(parts[1].String())
buffer.WriteString(parts[1].String(params.es))
return true, buffer.String()
}
return false, ""
Expand Down
5 changes: 1 addition & 4 deletions expreduce/cas.go
Expand Up @@ -4,15 +4,12 @@ package expreduce

type ToStringFnType (func(*Expression, ToStringParams) (bool, string))

// A nasty global to keep track of ToString functions. TODO: Fix this.
var toStringFns = make(map[string]ToStringFnType)

// The interface that fundamental types must implement.
type Ex interface {
Eval(es *EvalState) Ex
// TODO(corywalker): Deprecate this function. All stringification should go
// through StringForm.
String() string
String(es *EvalState) string
StringForm(params ToStringParams) string
IsEqual(b Ex, cl *CASLogger) string
DeepCopy() Ex
Expand Down
49 changes: 38 additions & 11 deletions expreduce/cas_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"math/big"
"regexp"
"testing"
"sync"
)

var testmodules = flag.String("testmodules", "",
Expand Down Expand Up @@ -137,9 +138,9 @@ func TestLowLevel(t *testing.T) {

// Test basic float functionality
var f *Flt = NewReal(big.NewFloat(5.5))
assert.Equal(t, "5.5", f.String())
assert.Equal(t, "5.5", f.String(es))
f.Eval(es)
assert.Equal(t, "5.5", f.String())
assert.Equal(t, "5.5", f.String(es))

// Test nested addition functionality
var a = NewExpression([]Ex{
Expand All @@ -163,13 +164,13 @@ func TestLowLevel(t *testing.T) {

// Test evaluation
newa := a.Eval(es)
assert.Equal(t, "87.5", newa.String())
assert.Equal(t, "87.5", newa.String(es))

// Test basic Symbol functionality
var v *Symbol = NewSymbol("System`x")
assert.Equal(t, "x", v.String())
assert.Equal(t, "x", v.String(es))
v.Eval(es)
assert.Equal(t, "x", v.String())
assert.Equal(t, "x", v.String(es))

assert.Equal(t, "(a + b + c + d + e + f)", EasyRun("a + b + c +d +e +f", es))
assert.Equal(t, "(a*b*c*d*e*f)", EasyRun("a * b * c *d *e *f", es))
Expand All @@ -189,6 +190,9 @@ func TestLowLevel(t *testing.T) {
func TestDeepCopy(t *testing.T) {
fmt.Println("Testing deepcopy")

// So that we can print the values. Not very necessary.
es := NewEvalState()

// Test deepcopy
var t1 = NewSymbol("System`x")
t2 := *t1
Expand All @@ -205,12 +209,35 @@ func TestDeepCopy(t *testing.T) {
var t4 = NewReal(big.NewFloat(2))
t5 := *t4
t6 := t4.DeepCopy().(*Flt)
assert.Equal(t, "2.", t4.String())
assert.Equal(t, "2.", t5.String())
assert.Equal(t, "2.", t6.String())
assert.Equal(t, "2.", t4.String(es))
assert.Equal(t, "2.", t5.String(es))
assert.Equal(t, "2.", t6.String(es))
t5.Val.Add(t5.Val, big.NewFloat(2))
t6.Val.Add(t6.Val, big.NewFloat(3))
assert.Equal(t, "4.", t4.String()) // Because we used the wrong copy method
assert.Equal(t, "4.", t5.String())
assert.Equal(t, "5.", t6.String())
assert.Equal(t, "4.", t4.String(es)) // Because we used the wrong copy method
assert.Equal(t, "4.", t5.String(es))
assert.Equal(t, "5.", t6.String(es))
}

func TestConcurrency(t *testing.T) {

fmt.Println("Testing concurrency")

es1 := NewEvalState()
es2 := NewEvalState()

CasAssertSame(t, es1, "3", "1+2")
CasAssertSame(t, es2, "3", "1+2")

var wg sync.WaitGroup

for i := 0; i < 10; i++ {
wg.Add(1)
go func (t *testing.T) {
defer wg.Done()
esTest := NewEvalState()
CasAssertSame(t, esTest, "3", "1+2")
}(t)
}
wg.Wait()
}
16 changes: 9 additions & 7 deletions expreduce/evalstate.go
Expand Up @@ -23,6 +23,7 @@ type EvalState struct {
thrown *Expression
reapSown *Expression
interrupted bool
toStringFns map[string]ToStringFnType
}

func (this *EvalState) Load(def Definition) {
Expand Down Expand Up @@ -58,15 +59,15 @@ func (this *EvalState) Load(def Definition) {
newDef.defaultExpr = Interp(def.Default, this)
}
if def.toString != nil {
// Global so that standard String() interface can access these
toStringFns[def.Name] = def.toString
this.toStringFns[def.Name] = def.toString
}
this.defined[def.Name] = newDef
EvalInterp("$Context = \"System`\"", this)
}

func (es *EvalState) Init(loadAllDefs bool) {
es.defined = make(map[string]Def)
es.toStringFns = make(map[string]ToStringFnType)
// These are fundamental symbols that affect even the parsing of
// expressions. We must define them before even the bootstrap definitions.
es.Define(NewSymbol("System`$Context"), NewString("System`"))
Expand Down Expand Up @@ -292,8 +293,8 @@ func (this *EvalState) GetDef(name string, lhs Ex) (Ex, bool, *Expression) {
defStr, lhsDefStr := "", ""
started := int64(0)
if this.isProfiling {
defStr = def.String()
lhsDefStr = lhs.String() + defStr
defStr = def.String(this)
lhsDefStr = lhs.String(this) + defStr
started = time.Now().UnixNano()
}

Expand Down Expand Up @@ -381,7 +382,7 @@ func (this *EvalState) MarkSeen(name string) {

// Attempts to compute a specificity metric for a rule. Higher specificity rules
// should be tried first.
func ruleSpecificity(lhs Ex, rhs Ex, name string) int {
func ruleSpecificity(lhs Ex, rhs Ex, name string, es *EvalState) int {
if name == "Rubi`Int" {
return 100
}
Expand All @@ -404,7 +405,7 @@ func ruleSpecificity(lhs Ex, rhs Ex, name string) int {
contextPath: contextPath,
// No need for the EvalState reference. Used for string expansion for
// Definition[], which should not be in an actual definition.
es: nil,
es: es,
}
specificity := len(lhs.StringForm(stringParams))
if _, rhsIsCond := HeadAssertion(rhs, "System`Condition"); rhsIsCond {
Expand Down Expand Up @@ -499,13 +500,14 @@ func (this *EvalState) Define(lhs Ex, rhs Ex) {
// Insert into definitions for name. Maintain order of decreasing
// complexity.
var tmp = this.defined[name]
newSpecificity := ruleSpecificity(heldLhs, rhs, name)
newSpecificity := ruleSpecificity(heldLhs, rhs, name, this)
for i, dv := range this.defined[name].downvalues {
if dv.specificity == 0 {
dv.specificity = ruleSpecificity(
dv.rule.Parts[1],
dv.rule.Parts[2],
name,
this,
)
}
if dv.specificity < newSpecificity {
Expand Down
4 changes: 2 additions & 2 deletions expreduce/ex_complex.go
Expand Up @@ -28,10 +28,10 @@ func (this *Complex) StringForm(p ToStringParams) string {
return fmt.Sprintf("(%v + %v*I)", this.Re.StringForm(p), this.Im.StringForm(p))
}

func (this *Complex) String() string {
func (this *Complex) String(es *EvalState) string {
context, contextPath := DefaultStringFormArgs()
return this.StringForm(ToStringParams{
form: "InputForm", context: context, contextPath: contextPath})
form: "InputForm", context: context, contextPath: contextPath, es: es})
}

func (this *Complex) IsEqual(other Ex, cl *CASLogger) string {
Expand Down
10 changes: 5 additions & 5 deletions expreduce/ex_expression.go
Expand Up @@ -200,8 +200,8 @@ func (this *Expression) Eval(es *EvalState) Ex {
currHeadStr := ""
started := int64(0)
if es.isProfiling {
currStr = curr.String()
currHeadStr = curr.Parts[0].String()
currStr = curr.String(es)
currHeadStr = curr.Parts[0].String(es)
started = time.Now().UnixNano()
}

Expand Down Expand Up @@ -426,7 +426,7 @@ func (this *Expression) StringForm(params ToStringParams) string {
if isHeadSym && !fullForm {
res, ok := "", false
headStr := headAsSym.Name
toStringFn, hasToStringFn := toStringFns[headStr]
toStringFn, hasToStringFn := params.es.toStringFns[headStr]
if hasToStringFn {
ok, res = toStringFn(this, params)
}
Expand Down Expand Up @@ -464,10 +464,10 @@ func (this *Expression) StringForm(params ToStringParams) string {
return buffer.String()
}

func (this *Expression) String() string {
func (this *Expression) String(es *EvalState) string {
context, contextPath := DefaultStringFormArgs()
return this.StringForm(ToStringParams{
form: "InputForm", context: context, contextPath: contextPath})
form: "InputForm", context: context, contextPath: contextPath, es: es})
}

func (this *Expression) IsEqual(otherEx Ex, cl *CASLogger) string {
Expand Down
4 changes: 2 additions & 2 deletions expreduce/ex_integer.go
Expand Up @@ -27,9 +27,9 @@ func (i *Integer) StringForm(params ToStringParams) string {
return fmt.Sprintf("%d", i.Val)
}

func (this *Integer) String() string {
func (this *Integer) String(es *EvalState) string {
context, contextPath := DefaultStringFormArgs()
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath})
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath, es: es})
}

func (this *Integer) IsEqual(other Ex, cl *CASLogger) string {
Expand Down
4 changes: 2 additions & 2 deletions expreduce/ex_rational.go
Expand Up @@ -64,9 +64,9 @@ func (this *Rational) StringForm(params ToStringParams) string {
return fmt.Sprintf("%d/%d", this.Num, this.Den)
}

func (this *Rational) String() string {
func (this *Rational) String(es *EvalState) string {
context, contextPath := DefaultStringFormArgs()
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath})
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath, es: es})
}

func (this *Rational) IsEqual(other Ex, cl *CASLogger) string {
Expand Down
4 changes: 2 additions & 2 deletions expreduce/ex_real.go
Expand Up @@ -34,9 +34,9 @@ func (f *Flt) StringForm(params ToStringParams) string {
return buffer.String()
}

func (this *Flt) String() string {
func (this *Flt) String(es *EvalState) string {
context, contextPath := DefaultStringFormArgs()
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath})
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath, es: es})
}

func (this *Flt) IsEqual(other Ex, cl *CASLogger) string {
Expand Down
4 changes: 2 additions & 2 deletions expreduce/ex_string.go
Expand Up @@ -20,9 +20,9 @@ func (this *String) StringForm(params ToStringParams) string {
return fmt.Sprintf("\"%v\"", this.Val)
}

func (this *String) String() string {
func (this *String) String(es *EvalState) string {
context, contextPath := DefaultStringFormArgs()
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath})
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath, es: es})
}

func (this *String) IsEqual(other Ex, cl *CASLogger) string {
Expand Down
4 changes: 2 additions & 2 deletions expreduce/ex_symbol.go
Expand Up @@ -50,9 +50,9 @@ func (this *Symbol) StringForm(params ToStringParams) string {
return fmt.Sprintf("%v", this.Name)
}

func (this *Symbol) String() string {
func (this *Symbol) String(es *EvalState) string {
context, contextPath := DefaultStringFormArgs()
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath})
return this.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath, es: es})
}

func (this *Symbol) IsEqual(other Ex, cl *CASLogger) string {
Expand Down
8 changes: 4 additions & 4 deletions expreduce/interp_test.go
Expand Up @@ -65,8 +65,8 @@ func TestInterp(t *testing.T) {
//CasAssertSame(t, es, "Plus[Times[2,a],Optional[Pattern[a,Blank[]],5]]", "a + a_ : 5 + a")

// Test newline handling
assert.Equal(t, "CompoundExpression[a, b]", Interp("a;b\n", es).String())
//assert.Equal(t, "Sequence[a, b]", Interp("a\nb\n", es).String())
assert.Equal(t, "(c = a*b)", Interp("c = (a\nb)\n", es).String())
assert.Equal(t, "(c = a*b)", Interp("c = (a\n\nb)\n", es).String())
assert.Equal(t, "CompoundExpression[a, b]", Interp("a;b\n", es).String(es))
//assert.Equal(t, "Sequence[a, b]", Interp("a\nb\n", es).String(es))
assert.Equal(t, "(c = a*b)", Interp("c = (a\nb)\n", es).String(es))
assert.Equal(t, "(c = a*b)", Interp("c = (a\n\nb)\n", es).String(es))
}
4 changes: 2 additions & 2 deletions expreduce/pdmanager.go
Expand Up @@ -49,7 +49,7 @@ func (this *PDManager) Len() int {
return len(this.patternDefined)
}

func (this *PDManager) String() string {
func (this *PDManager) String(es *EvalState) string {
var buffer bytes.Buffer
buffer.WriteString("{")
// We sort the keys here such that converting identical PDManagers always
Expand All @@ -63,7 +63,7 @@ func (this *PDManager) String() string {
v := this.patternDefined[k]
buffer.WriteString(k)
buffer.WriteString("_: ")
buffer.WriteString(v.String())
buffer.WriteString(v.String(es))
buffer.WriteString(", ")
}
if strings.HasSuffix(buffer.String(), ", ") {
Expand Down
6 changes: 3 additions & 3 deletions expreduce/testing.go
Expand Up @@ -64,7 +64,7 @@ type SameTestEx struct {
}

func (this *SameTestEx) Run(t *testing.T, es *EvalState, td TestDesc) bool {
succ, s := CasTestInner(es, this.In.Eval(es), this.Out.Eval(es), this.In.String(), true, td.desc)
succ, s := CasTestInner(es, this.In.Eval(es), this.Out.Eval(es), this.In.String(es), true, td.desc)
assert.True(t, succ, s)
return succ
}
Expand All @@ -81,13 +81,13 @@ func CasTestInner(es *EvalState, inTree Ex, outTree Ex, inStr string, test bool,
context, contextPath := DefinitionComplexityStringFormArgs()
var buffer bytes.Buffer
buffer.WriteString("(")
buffer.WriteString(inTree.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath}))
buffer.WriteString(inTree.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath, es: es}))
if test {
buffer.WriteString(") != (")
} else {
buffer.WriteString(") == (")
}
buffer.WriteString(outTree.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath}))
buffer.WriteString(outTree.StringForm(ToStringParams{form: "InputForm", context: context, contextPath: contextPath, es: es}))
buffer.WriteString(")")
buffer.WriteString("\n\tInput was: ")
buffer.WriteString(inStr)
Expand Down
8 changes: 4 additions & 4 deletions expreduce/utils.go
Expand Up @@ -4,11 +4,11 @@ import (
"bytes"
)

func ExArrayToString(exArray []Ex) string {
func ExArrayToString(exArray []Ex, es *EvalState) string {
var buffer bytes.Buffer
buffer.WriteString("{")
for i, e := range exArray {
buffer.WriteString(e.String())
buffer.WriteString(e.String(es))
if i != len(exArray)-1 {
buffer.WriteString(", ")
}
Expand All @@ -17,11 +17,11 @@ func ExArrayToString(exArray []Ex) string {
return buffer.String()
}

func PFArrayToString(pfArray []parsedForm) string {
func PFArrayToString(pfArray []parsedForm, es *EvalState) string {
var buffer bytes.Buffer
buffer.WriteString("{")
for i, e := range pfArray {
buffer.WriteString(e.origForm.String())
buffer.WriteString(e.origForm.String(es))
if i != len(pfArray)-1 {
buffer.WriteString(", ")
}
Expand Down

0 comments on commit 731ea52

Please sign in to comment.