diff --git a/hm.go b/hm.go index 5d8cc8b..39098c9 100644 --- a/hm.go +++ b/hm.go @@ -334,14 +334,14 @@ func Unify(a, b Type) (sub Subs, err error) { switch at := a.(type) { case TypeVariable: - return bind(at, b) + return Bind(at, b) default: if a.Eq(b) { return nil, nil } if btv, ok := b.(TypeVariable); ok { - return bind(btv, a) + return Bind(btv, a) } atypes := a.Types() btypes := b.Types() @@ -385,7 +385,7 @@ func unifyMany(a, b Types) (sub Subs, err error) { if sub == nil { sub = s2 } else { - sub2 := compose(sub, s2) + sub2 := Compose(sub, s2) defer ReturnSubs(s2) if sub2 != sub { defer ReturnSubs(sub) @@ -396,11 +396,12 @@ func unifyMany(a, b Types) (sub Subs, err error) { return } -func bind(tv TypeVariable, t Type) (sub Subs, err error) { +// Bind binds a TypeVariable to a Type. It returns a substitution list. +func Bind(tv TypeVariable, t Type) (sub Subs, err error) { logf("Binding %v to %v", tv, t) switch { // case tv == t: - case occurs(tv, t): + case Occurs(tv, t): err = errors.Errorf("recursive unification") default: ssub := BorrowSSubs(1) @@ -411,7 +412,8 @@ func bind(tv TypeVariable, t Type) (sub Subs, err error) { return } -func occurs(tv TypeVariable, s Substitutable) bool { +// Occurs checks if a TypeVariable exists in any Substitutable (type, scheme, map etc). +func Occurs(tv TypeVariable, s Substitutable) bool { ftv := s.FreeTypeVar() defer ReturnTypeVarSet(ftv) diff --git a/solver.go b/solver.go index 80a1142..3c055f0 100644 --- a/solver.go +++ b/solver.go @@ -27,7 +27,7 @@ func (s *solver) solve(cs Constraints) { sub, s.err = Unify(c.a, c.b) defer ReturnSubs(s.sub) - s.sub = compose(sub, s.sub) + s.sub = Compose(sub, s.sub) cs = cs[1:].Apply(s.sub).(Constraints) s.solve(cs) diff --git a/substitutions.go b/substitutions.go index a21dd11..3a85e25 100644 --- a/substitutions.go +++ b/substitutions.go @@ -14,6 +14,15 @@ type Subs interface { Clone() Subs } +// MakeSubs is a utility function to help make substitution lists. +// This is useful for cases where there isn't a real need to implement Subs +func MakeSubs(n int) Subs { + if n >= 30 { + return make(mSubs) + } + return newSliceSubs(n) +} + // A Substitution is a tuple representing the TypeVariable and the replacement Type type Substitution struct { Tv TypeVariable @@ -116,7 +125,8 @@ func (s mSubs) Clone() Subs { return retVal } -func compose(a, b Subs) (retVal Subs) { +// Compose composes two substitution lists together. +func Compose(a, b Subs) (retVal Subs) { if b == nil { return a } diff --git a/substitutions_test.go b/substitutions_test.go index aabbd14..9e467a0 100644 --- a/substitutions_test.go +++ b/substitutions_test.go @@ -131,7 +131,7 @@ var composeTests = []struct { func TestCompose(t *testing.T) { for i, cts := range composeTests { - subs := compose(cts.a, cts.b) + subs := Compose(cts.a, cts.b) for _, v := range cts.expected.Iter() { if T, ok := subs.Get(v.Tv); !ok { diff --git a/type.go b/type.go index 6d2a1bc..384687e 100644 --- a/type.go +++ b/type.go @@ -7,10 +7,10 @@ import ( // Type represents all the possible type constructors. type Type interface { Substitutable - Name() string // Name is the name of the constructor - Normalize(TypeVarSet, TypeVarSet) (Type, error) // Normalize normalizes all the type variable names in the type - Types() Types // If the type is made up of smaller types, then it will return them - Eq(Type) bool // equality operation + Name() string // Name is the name of the constructor + Normalize(k TypeVarSet, v TypeVarSet) (Type, error) // Normalize normalizes all the type variable names in the type + Types() Types // If the type is made up of smaller types, then it will return them + Eq(Type) bool // equality operation fmt.Formatter fmt.Stringer diff --git a/typeVariable.go b/typeVariable.go index 2d33945..1fc383e 100644 --- a/typeVariable.go +++ b/typeVariable.go @@ -9,7 +9,7 @@ import ( // TypeVariable is a variable that ranges over the types - that is to say it can take any type. type TypeVariable rune -func (t TypeVariable) Name() string { return string(t) } +func (t TypeVariable) Name() string { return fmt.Sprintf("%v", t) } func (t TypeVariable) Apply(sub Subs) Substitutable { if sub == nil { return t @@ -30,7 +30,13 @@ func (t TypeVariable) Normalize(k, v TypeVarSet) (Type, error) { return nil, errors.Errorf("Type Variable %v not in signature", t) } -func (t TypeVariable) Types() Types { return nil } -func (t TypeVariable) String() string { return string(t) } -func (t TypeVariable) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%c", rune(t)) } -func (t TypeVariable) Eq(other Type) bool { return other == t } +func (t TypeVariable) Types() Types { return nil } +func (t TypeVariable) String() string { return fmt.Sprintf("%v", t) } +func (t TypeVariable) Format(s fmt.State, c rune) { + if t >= 'a' && t <= 'z' { + fmt.Fprintf(s, "%c", rune(t)) + return + } + fmt.Fprintf(s, "<%d>", rune(t)) +} +func (t TypeVariable) Eq(other Type) bool { return other == t } diff --git a/types/commonutils.go b/types/commonutils.go new file mode 100644 index 0000000..fd9a8d9 --- /dev/null +++ b/types/commonutils.go @@ -0,0 +1,90 @@ +package hmtypes + +import "github.com/chewxy/hm" + +// Pair is a convenient structural abstraction for types that are composed of two types. +// Depending on use cases, it may be useful to embed Pair, or define a new type base on *Pair. +// +// Pair partially implements hm.Type, as the intention is merely for syntactic abstraction +// +// It has very specific semantics - +// it's useful for a small subset of types like function types, or supertypes. +// See the documentation for Apply and FreeTypeVar. +type Pair struct { + A, B hm.Type +} + +// Apply applies a substitution on both the first and second types of the Pair. +func (t *Pair) Apply(sub hm.Subs) { + t.A = t.A.Apply(sub).(hm.Type) + t.B = t.B.Apply(sub).(hm.Type) +} + +// Types returns all the types of the Pair's constituents +func (t *Pair) Types() hm.Types { + retVal := hm.BorrowTypes(2) + retVal[0] = t.A + retVal[1] = t.B + return retVal +} + +// FreeTypeVar returns a set of free (unbound) type variables. +func (t *Pair) FreeTypeVar() hm.TypeVarSet { return t.A.FreeTypeVar().Union(t.B.FreeTypeVar()) } + +// Clone implements Cloner +func (t *Pair) Clone() *Pair { + retVal := borrowPair() + + if ac, ok := t.A.(Cloner); ok { + retVal.A = ac.Clone().(hm.Type) + } else { + retVal.A = t.A + } + + if bc, ok := t.B.(Cloner); ok { + retVal.B = bc.Clone().(hm.Type) + } else { + retVal.B = t.B + } + return retVal +} + +// Monuple is a convenient structural abstraction for types that are composed of one type. +// +// Monuple implements hm.Substitutable, but with very specific semantics - +// It's useful for singly polymorphic types like arrays, linear types, reference types, etc +type Monuple struct { + T hm.Type +} + +// Apply applies a substitution to the monuple type. +func (t Monuple) Apply(subs hm.Subs) Monuple { + t.T = t.T.Apply(subs).(hm.Type) + return t +} + +// FreeTypeVar returns the set of free type variables in the monuple. +func (t Monuple) FreeTypeVar() hm.TypeVarSet { return t.T.FreeTypeVar() } + +// Normalize is the method to normalize all type variables +func (t Monuple) Normalize(k, v hm.TypeVarSet) (Monuple, error) { + var t2 hm.Type + var err error + if t2, err = t.T.Normalize(k, v); err != nil { + return Monuple{}, err + } + t.T = t2 + return t, nil +} + +// Pairer is any type that can be represented by a Pair +type Pairer interface { + hm.Type + AsPair() *Pair +} + +// Monupler is any type that can be represented by a Monuple +type Monupler interface { + hm.Type + AsMonuple() Monuple +} diff --git a/types/function.go b/types/function.go new file mode 100644 index 0000000..a727dd3 --- /dev/null +++ b/types/function.go @@ -0,0 +1,105 @@ +package hmtypes + +import ( + "fmt" + + "github.com/chewxy/hm" +) + +// Function is a type constructor that builds function types. +type Function Pair + +// NewFunction creates a new FunctionType. Functions are by default right associative. This: +// NewFunction(a, a, a) +// is short hand for this: +// NewFunction(a, NewFunction(a, a)) +func NewFunction(ts ...hm.Type) *Function { + if len(ts) < 2 { + panic("Expected at least 2 input types") + } + + retVal := borrowFn() + retVal.A = ts[0] + + if len(ts) > 2 { + retVal.B = NewFunction(ts[1:]...) + } else { + retVal.B = ts[1] + } + return retVal +} + +func (t *Function) Name() string { return "→" } +func (t *Function) Apply(sub hm.Subs) hm.Substitutable { ((*Pair)(t)).Apply(sub); return t } +func (t *Function) FreeTypeVar() hm.TypeVarSet { return ((*Pair)(t)).FreeTypeVar() } +func (t *Function) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v → %v", t.A, t.B) } +func (t *Function) String() string { return fmt.Sprintf("%v", t) } +func (t *Function) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + var a, b hm.Type + var err error + if a, err = t.A.Normalize(k, v); err != nil { + return nil, err + } + + if b, err = t.B.Normalize(k, v); err != nil { + return nil, err + } + + return NewFunction(a, b), nil +} +func (t *Function) Types() hm.Types { return ((*Pair)(t)).Types() } + +func (t *Function) Eq(other hm.Type) bool { + if ot, ok := other.(*Function); ok { + return ot.A.Eq(t.A) && ot.B.Eq(t.B) + } + return false +} + +// Other methods (accessors mainly) + +// Arg returns the type of the function argument +func (t *Function) Arg() hm.Type { return t.A } + +// Ret returns the return type of a function. If recursive is true, it will get the final return type +func (t *Function) Ret(recursive bool) hm.Type { + if !recursive { + return t.B + } + + if fnt, ok := t.B.(*Function); ok { + return fnt.Ret(recursive) + } + + return t.B +} + +// FlatTypes returns the types in FunctionTypes as a flat slice of types. This allows for easier iteration in some applications +func (t *Function) FlatTypes() hm.Types { + retVal := hm.BorrowTypes(8) // start with 8. Can always grow + retVal = retVal[:0] + + if a, ok := t.A.(*Function); ok { + ft := a.FlatTypes() + retVal = append(retVal, ft...) + hm.ReturnTypes(ft) + } else { + retVal = append(retVal, t.A) + } + + if b, ok := t.B.(*Function); ok { + ft := b.FlatTypes() + retVal = append(retVal, ft...) + hm.ReturnTypes(ft) + } else { + retVal = append(retVal, t.B) + } + return retVal +} + +// Clone implenents cloner +func (t *Function) Clone() interface{} { + p := (*Pair)(t) + cloned := p.Clone() + return (*Function)(cloned) +} diff --git a/types/function_test.go b/types/function_test.go new file mode 100644 index 0000000..4e28fa4 --- /dev/null +++ b/types/function_test.go @@ -0,0 +1,99 @@ +package hmtypes + +import ( + "testing" + + "github.com/chewxy/hm" + "github.com/stretchr/testify/assert" +) + +func TestFunctionTypeBasics(t *testing.T) { + fnType := NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('a')) + if fnType.Name() != "→" { + t.Errorf("FunctionType should have \"→\" as a name. Got %q instead", fnType.Name()) + } + + if fnType.String() != "a → a → a" { + t.Errorf("Expected \"a → a → a\". Got %q instead", fnType.String()) + } + + if !fnType.Arg().Eq(hm.TypeVariable('a')) { + t.Error("Expected arg of function to be 'a'") + } + + if !fnType.Ret(false).Eq(NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a'))) { + t.Error("Expected ret(false) to be a → a") + } + + if !fnType.Ret(true).Eq(hm.TypeVariable('a')) { + t.Error("Expected final return type to be 'a'") + } + + // a very simple fn + fnType = NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a')) + if !fnType.Ret(true).Eq(hm.TypeVariable('a')) { + t.Error("Expected final return type to be 'a'") + } + + ftv := fnType.FreeTypeVar() + if len(ftv) != 1 { + t.Errorf("Expected only one free type var") + } + + for _, fas := range fnApplyTests { + fn := fas.fn.Apply(fas.sub).(*Function) + if !fn.Eq(fas.expected) { + t.Errorf("Expected %v. Got %v instead", fas.expected, fn) + } + } + + // bad shit + f := func() { + NewFunction(hm.TypeVariable('a')) + } + assert.Panics(t, f) +} + +var fnApplyTests = []struct { + fn *Function + sub hm.Subs + + expected *Function +}{ + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, proton)}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, neutron)}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'c': proton, 'd': neutron}, NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'))}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'a': proton, 'c': neutron}, NewFunction(proton, hm.TypeVariable('b'))}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'c': proton, 'b': neutron}, NewFunction(hm.TypeVariable('a'), neutron)}, + {NewFunction(electron, proton), mSubs{'a': proton, 'b': neutron}, NewFunction(electron, proton)}, + + // a -> (b -> c) + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, neutron, proton)}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('a'), hm.TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, proton, neutron)}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('c')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, neutron, hm.TypeVariable('c'))}, + {NewFunction(hm.TypeVariable('a'), hm.TypeVariable('c'), hm.TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFunction(proton, hm.TypeVariable('c'), neutron)}, + + // (a -> b) -> c + {NewFunction(NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), hm.TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFunction(NewFunction(proton, neutron), proton)}, +} + +func TestFunctionType_FlatTypes(t *testing.T) { + fnType := NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('c')) + ts := fnType.FlatTypes() + correct := hm.Types{hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('c')} + assert.Equal(t, ts, correct) + + fnType2 := NewFunction(fnType, hm.TypeVariable('d')) + correct = append(correct, hm.TypeVariable('d')) + ts = fnType2.FlatTypes() + assert.Equal(t, ts, correct) +} + +func TestFunctionType_Clone(t *testing.T) { + fnType := NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b'), hm.TypeVariable('c')) + assert.Equal(t, fnType.Clone(), fnType) + + rec := NewRecordType("", hm.TypeVariable('a'), NewFunction(hm.TypeVariable('a'), hm.TypeVariable('b')), hm.TypeVariable('c')) + fnType = NewFunction(rec, rec) + assert.Equal(t, fnType.Clone(), fnType) +} diff --git a/types/interfaces.go b/types/interfaces.go new file mode 100644 index 0000000..b0afe8a --- /dev/null +++ b/types/interfaces.go @@ -0,0 +1,5 @@ +package hmtypes + +type Cloner interface { + Clone() interface{} +} diff --git a/types/monuples.go b/types/monuples.go new file mode 100644 index 0000000..2faded1 --- /dev/null +++ b/types/monuples.go @@ -0,0 +1,88 @@ +package hmtypes + +import ( + "fmt" + + "github.com/chewxy/hm" +) + +// Slice is the type of a Slice/List +type Slice Monuple + +func (t Slice) Name() string { return "List" } +func (t Slice) Apply(subs hm.Subs) hm.Substitutable { return Slice(Monuple(t).Apply(subs)) } +func (t Slice) FreeTypeVar() hm.TypeVarSet { return Monuple(t).FreeTypeVar() } +func (t Slice) Format(s fmt.State, c rune) { fmt.Fprintf(s, "[]%v", t.T) } +func (t Slice) String() string { return fmt.Sprintf("%v", t) } +func (t Slice) Types() hm.Types { return hm.Types{t.T} } + +func (t Slice) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + t2, err := Monuple(t).Normalize(k, v) + if err != nil { + return nil, err + } + return Slice(t2), nil +} + +func (t Slice) Eq(other hm.Type) bool { + if ot, ok := other.(Slice); ok { + return ot.T.Eq(t.T) + } + return false +} + +func (t Slice) Monuple() Monuple { return Monuple(t) } + +// Linear is a linear type (i.e types that can only appear once) +type Linear Monuple + +func (t Linear) Name() string { return "Linear" } +func (t Linear) Apply(subs hm.Subs) hm.Substitutable { return Linear(Monuple(t).Apply(subs)) } +func (t Linear) FreeTypeVar() hm.TypeVarSet { return Monuple(t).FreeTypeVar() } +func (t Linear) Format(s fmt.State, c rune) { fmt.Fprintf(s, "Linear[%v]", t.T) } +func (t Linear) String() string { return fmt.Sprintf("%v", t) } +func (t Linear) Types() hm.Types { return hm.Types{t.T} } + +func (t Linear) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + t2, err := Monuple(t).Normalize(k, v) + if err != nil { + return nil, err + } + return Linear(t2), nil +} + +func (t Linear) Eq(other hm.Type) bool { + if ot, ok := other.(Linear); ok { + return ot.T.Eq(t.T) + } + return false +} + +func (t Linear) Monuple() Monuple { return Monuple(t) } + +// Ref is a reference type (think pointers) +type Ref Monuple + +func (t Ref) Name() string { return "Ref" } +func (t Ref) Apply(subs hm.Subs) hm.Substitutable { return Ref(Monuple(t).Apply(subs)) } +func (t Ref) FreeTypeVar() hm.TypeVarSet { return Monuple(t).FreeTypeVar() } +func (t Ref) Format(s fmt.State, c rune) { fmt.Fprintf(s, "*%v", t.T) } +func (t Ref) String() string { return fmt.Sprintf("%v", t) } +func (t Ref) Types() hm.Types { return hm.Types{t.T} } + +func (t Ref) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + t2, err := Monuple(t).Normalize(k, v) + if err != nil { + return nil, err + } + return Ref(t2), nil +} + +func (t Ref) Eq(other hm.Type) bool { + if ot, ok := other.(Ref); ok { + return ot.T.Eq(t.T) + } + return false +} + +func (t Ref) Monuple() Monuple { return Monuple(t) } diff --git a/types/pairs.go b/types/pairs.go new file mode 100644 index 0000000..13c94f8 --- /dev/null +++ b/types/pairs.go @@ -0,0 +1,119 @@ +package hmtypes + +import ( + "fmt" + + "github.com/chewxy/hm" +) + +var ( + _ hm.Type = &Choice{} + _ hm.Type = &Super{} + _ hm.Type = &Application{} +) + +// pair types + +// Choice is the type of choice of algorithm to use within a class method. +// +// Imagine how one would implement a class in an OOP language. +// Then imagine how one would implement method overloading for the class. +// The typical approach is name mangling followed by having a jump table. +// +// Now consider OOP classes and the ability to override methods, based on subclassing ability. +// The typical approach to this is to use a Vtable. +// +// Both overloading and overriding have a general notion: a jump table of sorts. +// How does one type such a table? +// +// By using Choice. +// +// The first type is the key of either the vtable or the name mangled table. +// The second type is the value of the table. +type Choice Pair + +func (t *Choice) Name() string { return ":" } +func (t *Choice) Apply(sub hm.Subs) hm.Substitutable { ((*Pair)(t)).Apply(sub); return t } +func (t *Choice) FreeTypeVar() hm.TypeVarSet { return ((*Pair)(t)).FreeTypeVar() } +func (t *Choice) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v : %v", t.A, t.B) } +func (t *Choice) String() string { return fmt.Sprintf("%v", t) } + +func (t *Choice) Normalize(k hm.TypeVarSet, v hm.TypeVarSet) (hm.Type, error) { + panic("not implemented") +} + +func (t *Choice) Types() hm.Types { return ((*Pair)(t)).Types() } + +func (t *Choice) Eq(other hm.Type) bool { + if ot, ok := other.(*Choice); ok { + return ot.A.Eq(t.A) && ot.B.Eq(t.B) + } + return false +} + +func (t *Choice) Clone() interface{} { return (*Choice)((*Pair)(t).Clone()) } + +func (t *Choice) Pair() *Pair { return (*Pair)(t) } + +// Super is the inverse of Choice. It allows for supertyping functions. +// +// Supertyping is typically implemented as a adding an entry to the vtable/mangled table. +// But there needs to be a separate accounting structure to keep account of the types. +// +// This is where Super comes in. +type Super Pair + +func (t *Super) Name() string { return "§" } +func (t *Super) Apply(sub hm.Subs) hm.Substitutable { ((*Pair)(t)).Apply(sub); return t } +func (t *Super) FreeTypeVar() hm.TypeVarSet { return ((*Pair)(t)).FreeTypeVar() } +func (t *Super) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v §: %v", t.A, t.B) } +func (t *Super) String() string { return fmt.Sprintf("%v", t) } + +func (t *Super) Normalize(k hm.TypeVarSet, v hm.TypeVarSet) (hm.Type, error) { + panic("not implemented") +} + +func (t *Super) Types() hm.Types { return ((*Pair)(t)).Types() } + +func (t *Super) Eq(other hm.Type) bool { + if ot, ok := other.(*Super); ok { + return ot.A.Eq(t.A) && ot.B.Eq(t.B) + } + return false +} + +func (t *Super) Clone() interface{} { return (*Super)((*Pair)(t).Clone()) } + +func (t *Super) Pair() *Pair { return (*Pair)(t) } + +// Application is the pre-unified type for a function application. +// In a simple HM system this would not be needed as the type of an +// application expression would be found during the unification phase of +// the expression. +// +// In advanced systems where unification may be done concurrently, this would +// be required, as a "thunk" of sorts for the type system. +type Application Pair + +func (t *Application) Name() string { return "•" } +func (t *Application) Apply(sub hm.Subs) hm.Substitutable { ((*Pair)(t)).Apply(sub); return t } +func (t *Application) FreeTypeVar() hm.TypeVarSet { return ((*Pair)(t)).FreeTypeVar() } +func (t *Application) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v • %v", t.A, t.B) } +func (t *Application) String() string { return fmt.Sprintf("%v", t) } + +func (t *Application) Normalize(k hm.TypeVarSet, v hm.TypeVarSet) (hm.Type, error) { + panic("not implemented") +} + +func (t *Application) Types() hm.Types { return ((*Pair)(t)).Types() } + +func (t *Application) Eq(other hm.Type) bool { + if ot, ok := other.(*Application); ok { + return ot.A.Eq(t.A) && ot.B.Eq(t.B) + } + return false +} + +func (t *Application) Clone() interface{} { return (*Application)((*Pair)(t).Clone()) } + +func (t *Application) Pair() *Pair { return (*Pair)(t) } diff --git a/types/perf.go b/types/perf.go new file mode 100644 index 0000000..e682236 --- /dev/null +++ b/types/perf.go @@ -0,0 +1,35 @@ +package hmtypes + +import ( + "sync" + "unsafe" +) + +var pairPool = &sync.Pool{ + New: func() interface{} { return new(Pair) }, +} + +func borrowPair() *Pair { + return pairPool.Get().(*Pair) +} + +func borrowFn() *Function { + got := pairPool.Get().(*Pair) + return (*Function)(unsafe.Pointer(got)) +} + +// ReturnFn returns a *FunctionType to the pool. NewFnType automatically borrows from the pool. USE WITH CAUTION +func ReturnFn(fnt *Function) { + if a, ok := fnt.A.(*Function); ok { + ReturnFn(a) + } + + if b, ok := fnt.B.(*Function); ok { + ReturnFn(b) + } + + fnt.A = nil + fnt.B = nil + p := (*Pair)(unsafe.Pointer(fnt)) + pairPool.Put(p) +} diff --git a/types/perf_test.go b/types/perf_test.go new file mode 100644 index 0000000..7d82844 --- /dev/null +++ b/types/perf_test.go @@ -0,0 +1,19 @@ +package hmtypes + +import "testing" + +func TestFnTypePool(t *testing.T) { + f := borrowFn() + f.A = NewFunction(proton, electron) + f.B = NewFunction(proton, neutron) + + ReturnFn(f) + f = borrowFn() + if f.A != nil { + t.Error("FunctionType not cleaned up: a is not nil") + } + if f.B != nil { + t.Error("FunctionType not cleaned up: b is not nil") + } + +} diff --git a/types/quantified.go b/types/quantified.go new file mode 100644 index 0000000..a54ae75 --- /dev/null +++ b/types/quantified.go @@ -0,0 +1,9 @@ +package hmtypes + +import "github.com/chewxy/hm" + +// Quantified is essentially a replacement scheme that is made into a Type +// TODO: implement hm.Type +type Quantified struct { + hm.Scheme +} diff --git a/types/record.go b/types/record.go new file mode 100644 index 0000000..6e19ab0 --- /dev/null +++ b/types/record.go @@ -0,0 +1,107 @@ +package hmtypes + +import ( + "fmt" + + "github.com/chewxy/hm" +) + +// Record is a basic record/tuple type. It takes an optional name. +type Record struct { + ts []hm.Type + name string +} + +// NewRecordType creates a new Record hm.Type +func NewRecordType(name string, ts ...hm.Type) *Record { + return &Record{ + ts: ts, + name: name, + } +} + +func (t *Record) Apply(subs hm.Subs) hm.Substitutable { + ts := make([]hm.Type, len(t.ts)) + for i, v := range t.ts { + ts[i] = v.Apply(subs).(hm.Type) + } + return NewRecordType(t.name, ts...) +} + +func (t *Record) FreeTypeVar() hm.TypeVarSet { + var tvs hm.TypeVarSet + for _, v := range t.ts { + tvs = v.FreeTypeVar().Union(tvs) + } + return tvs +} + +func (t *Record) Name() string { + if t.name != "" { + return t.name + } + return t.String() +} + +func (t *Record) Normalize(k, v hm.TypeVarSet) (hm.Type, error) { + ts := make([]hm.Type, len(t.ts)) + var err error + for i, tt := range t.ts { + if ts[i], err = tt.Normalize(k, v); err != nil { + return nil, err + } + } + return NewRecordType(t.name, ts...), nil +} + +func (t *Record) Types() hm.Types { + ts := hm.BorrowTypes(len(t.ts)) + copy(ts, t.ts) + return ts +} + +func (t *Record) Eq(other hm.Type) bool { + if ot, ok := other.(*Record); ok { + if len(ot.ts) != len(t.ts) { + return false + } + for i, v := range t.ts { + if !v.Eq(ot.ts[i]) { + return false + } + } + return true + } + return false +} + +func (t *Record) Format(f fmt.State, c rune) { + f.Write([]byte("(")) + for i, v := range t.ts { + if i < len(t.ts)-1 { + fmt.Fprintf(f, "%v, ", v) + } else { + fmt.Fprintf(f, "%v)", v) + } + } + +} + +func (t *Record) String() string { return fmt.Sprintf("%v", t) } + +// Clone implements Cloner +func (t *Record) Clone() interface{} { + retVal := new(Record) + ts := hm.BorrowTypes(len(t.ts)) + for i, tt := range t.ts { + if c, ok := tt.(Cloner); ok { + ts[i] = c.Clone().(hm.Type) + } else { + ts[i] = tt + } + } + retVal.ts = ts + retVal.name = t.name + + return retVal +} diff --git a/types/test_test.go b/types/test_test.go new file mode 100644 index 0000000..c3a67ba --- /dev/null +++ b/types/test_test.go @@ -0,0 +1,42 @@ +package hmtypes + +import "github.com/chewxy/hm" + +const ( + proton hm.TypeConst = "proton" + neutron hm.TypeConst = "neutron" + quark hm.TypeConst = "quark" + + electron hm.TypeConst = "electron" + positron hm.TypeConst = "positron" + muon hm.TypeConst = "muon" + + photon hm.TypeConst = "photon" + higgs hm.TypeConst = "higgs" +) + +// useful copy pasta from the hm package +type mSubs map[hm.TypeVariable]hm.Type + +func (s mSubs) Get(tv hm.TypeVariable) (hm.Type, bool) { retVal, ok := s[tv]; return retVal, ok } +func (s mSubs) Add(tv hm.TypeVariable, t hm.Type) hm.Subs { s[tv] = t; return s } +func (s mSubs) Remove(tv hm.TypeVariable) hm.Subs { delete(s, tv); return s } + +func (s mSubs) Iter() []hm.Substitution { + retVal := make([]hm.Substitution, len(s)) + var i int + for k, v := range s { + retVal[i] = hm.Substitution{k, v} + i++ + } + return retVal +} + +func (s mSubs) Size() int { return len(s) } +func (s mSubs) Clone() hm.Subs { + retVal := make(mSubs) + for k, v := range s { + retVal[k] = v + } + return retVal +}