Skip to content

Commit

Permalink
Changed the definition of TypeOp for a more generic replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
chewxy committed Nov 15, 2016
1 parent 2687bad commit f6b4dcc
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 34 deletions.
8 changes: 4 additions & 4 deletions example_greenspun_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func (t prim) String() string {
}

// implement TypeOp
func (t prim) Types() Types { return nil }
func (t prim) Replace(TypeVariable, Type) TypeOp { return t }
func (t prim) Clone() TypeOp { return t }
func (t prim) Types() Types { return nil }
func (t prim) Replace(Type, Type) TypeOp { return t }
func (t prim) Clone() TypeOp { return t }

func (t prim) IsConst() bool { return true }

Expand Down Expand Up @@ -159,7 +159,7 @@ func Example_greenspun() {

fmt.Printf("Type: %v | err: %v", t, err)

// Outputs:
// Outputs
// Type: Float | err: <nil>

}
25 changes: 17 additions & 8 deletions functionType.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,42 @@ func (t *FunctionType) String() string { return fmt.Sprintf("%v", t) }

func (t *FunctionType) Types() Types { return Types(t.ts[:]) }

func (t *FunctionType) Replace(tv TypeVariable, with Type) TypeOp {
func (t *FunctionType) Replace(what, with Type) TypeOp {
switch tt := t.ts[0].(type) {
case TypeVariable:
if tt.Eq(tv) {
if tt.Eq(what) {
t.ts[0] = with
}
case TypeConst:
// do nothing
case TypeOp:
t.ts[0] = tt.Replace(tv, with)
if t.ts[0].Eq(what) {
t.ts[0] = with
} else {
t.ts[0] = tt.Replace(what, with)
}
default:
panic("WTF?")
panic("Unreachable")

}

switch tt := t.ts[1].(type) {
case TypeVariable:
if tt.Eq(tv) {
if tt.Eq(what) {
t.ts[1] = with
}
case TypeConst:
// do nothing
case TypeOp:
t.ts[1] = tt.Replace(tv, with)
if tt.Eq(what) {
t.ts[1] = with
} else {
tt = tt.Replace(what, with)
}
default:
panic("WTF?")
}
panic("Unreachable")

}
return t
}

Expand Down
18 changes: 9 additions & 9 deletions test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ func (t particle) Eq(other Type) bool {
return false
}

func (t particle) Name() string { return t.String() }
func (t particle) Format(state fmt.State, c rune) { fmt.Fprintf(state, t.String()) }
func (t particle) Types() Types { return nil }
func (t particle) Clone() TypeOp { return t }
func (t particle) Replace(TypeVariable, Type) TypeOp { return t }
func (t particle) IsConstant() bool { return true }
func (t particle) Name() string { return t.String() }
func (t particle) Format(state fmt.State, c rune) { fmt.Fprintf(state, t.String()) }
func (t particle) Types() Types { return nil }
func (t particle) Clone() TypeOp { return t }
func (t particle) Replace(Type, Type) TypeOp { return t }
func (t particle) IsConstant() bool { return true }
func (t particle) String() string {
switch t {
case proton:
Expand Down Expand Up @@ -90,16 +90,16 @@ func (t list) Clone() TypeOp {
return retVal
}

func (t list) Replace(tv TypeVariable, with Type) TypeOp {
func (t list) Replace(what, with Type) TypeOp {
switch tt := t.t.(type) {
case TypeVariable:
if tt.Eq(tv) {
if tt.Eq(what) {
t.t = with
}
case TypeConst:
// do nothing
case TypeOp:
t.t = tt.Replace(tv, t)
t.t = tt.Replace(what, with)
default:
panic("WTF")
}
Expand Down
2 changes: 1 addition & 1 deletion type.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type TypeOp interface {
Clone() TypeOp

// Replaces all the instances of tv with t. If your data structure is recursive, it needs to be replaced for the entire data structure
Replace(tv TypeVariable, t Type) TypeOp
Replace(a, b Type) TypeOp
}

// TypeConst is a constant type. Replace() will not change the TypeOp. It's useful for implementing atomic types. Formerly called Atomic
Expand Down
49 changes: 37 additions & 12 deletions typeSystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,35 @@ import "github.com/pkg/errors"
//
// The more complicated constructor unification and arrow unification isn't quite covered yet.
func Unify(t1, t2 Type) (retVal1, retVal2 Type, replacements map[TypeVariable]Type, err error) {
logf("Unifying %v and %v", t1, t2)
enterLoggingContext()
defer leaveLoggingContext()
a := Prune(t1)
b := Prune(t2)

switch at := a.(type) {
case TypeVariable:
retVal1, retVal2, err = UnifyVar(at, b)
if retVal1, retVal2, err = UnifyVar(at, b); err != nil {
return
}
if replacements == nil {
replacements = make(map[TypeVariable]Type)
}

replacements[at] = retVal1
case TypeOp:
switch bt := b.(type) {
case TypeVariable:
retVal2, retVal1, err = UnifyVar(bt, at) // note the order change
// note the order change
if retVal2, retVal1, err = UnifyVar(bt, at); err != nil {
return
}

if replacements == nil {
replacements = make(map[TypeVariable]Type)
}

replacements[bt] = retVal2
case TypeOp:
atypes := at.Types()
btypes := bt.Types()
Expand All @@ -44,11 +61,13 @@ func Unify(t1, t2 Type) (retVal1, retVal2 Type, replacements map[TypeVariable]Ty
return
}

enterLoggingContext()
var t_a, t_b Type
for i := 0; i < len(atypes); i++ {
t_a = atypes[i]
t_b = btypes[i]

logf("Unifying recursively %v and %v", t_a, t_b)
var t_a2, t_b2 Type
var r2 map[TypeVariable]Type
if t_a2, t_b2, r2, err = Unify(t_a, t_b); err != nil {
Expand All @@ -62,26 +81,32 @@ func Unify(t1, t2 Type) (retVal1, retVal2 Type, replacements map[TypeVariable]Ty
replacements[k] = v
}
}
logf("r: %v", replacements)

pt_a2 := Prune(t_a2)
pt_b2 := Prune(t_b2)

logf("Replacing %v with %v in %v", t_a, pt_a2, at)
logf("Replacing %v with %v in %v", t_b, pt_b2, bt)

at = at.Replace(t_a, pt_a2)
bt = bt.Replace(t_b, pt_b2)

logf("at: %v", at)
logf("bt: %v", bt)

if tv, ok := t_a.(TypeVariable); ok {
at = at.Replace(tv, Prune(t_a2))
if replacements == nil {
replacements = make(map[TypeVariable]Type)
}
replacements[tv] = t_a2
replacements[tv] = pt_a2
}

if tv, ok := t_b.(TypeVariable); ok {
bt = bt.Replace(tv, Prune(t_b2))
if replacements == nil {
replacements = make(map[TypeVariable]Type)
}
replacements[tv] = t_b2
replacements[tv] = pt_b2
}

atypes = at.Types()
btypes = bt.Types()
}
leaveLoggingContext()

retVal1 = at
retVal2 = bt
Expand Down
3 changes: 3 additions & 0 deletions typeSystem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ var unifyTests = []struct {
{"List a ~ List proton", list{NewTypeVar("a")}, list{proton}, list{proton}, list{proton}, false},
{"List a ~ GoateeList proton", list{NewTypeVar("a")}, mirrorUniverseList{list{proton}}, nil, nil, true},

// function types
{"List a → List a ~ List proton → List proton", NewFnType(list{NewTypeVar("a")}, list{NewTypeVar("a")}), NewFnType(list{proton}, list{proton}), NewFnType(list{proton}, list{proton}), NewFnType(list{proton}, list{proton}), false},

{"malformed ~ a", malformed{}, NewTypeVar("a"), nil, nil, true},
{"proton ~ malformed{}", proton, malformed{}, nil, nil, true},

Expand Down

0 comments on commit f6b4dcc

Please sign in to comment.