Skip to content

Commit

Permalink
go/types, types2: preserve parent scope when substituting receivers
Browse files Browse the repository at this point in the history
Fixes #51920

Change-Id: I29e44a52dabee5c09e1761f9ec8fb2e8795f8901
Reviewed-on: https://go-review.googlesource.com/c/go/+/395538
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
Reviewed-by: Robert Griesemer <gri@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
  • Loading branch information
findleyr committed Mar 29, 2022
1 parent 1724077 commit 9b90838
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 14 deletions.
97 changes: 97 additions & 0 deletions src/cmd/compile/internal/types2/api_test.go
Expand Up @@ -2305,6 +2305,103 @@ func TestInstanceIdentity(t *testing.T) {
}
}

// TestInstantiatedObjects verifies properties of instantiated objects.
func TestInstantiatedObjects(t *testing.T) {
const src = `
package p
type T[P any] struct {
field P
}
func (recv *T[Q]) concreteMethod() {}
type FT[P any] func(ftp P) (ftrp P)
func F[P any](fp P) (frp P){ return }
type I[P any] interface {
interfaceMethod(P)
}
var (
t T[int]
ft FT[int]
f = F[int]
i I[int]
)
`
info := &Info{
Defs: make(map[*syntax.Name]Object),
}
f, err := parseSrc("p.go", src)
if err != nil {
t.Fatal(err)
}
conf := Config{}
pkg, err := conf.Check(f.PkgName.Value, []*syntax.File{f}, info)
if err != nil {
t.Fatal(err)
}

lookup := func(name string) Type { return pkg.Scope().Lookup(name).Type() }
tests := []struct {
ident string
obj Object
}{
{"field", lookup("t").Underlying().(*Struct).Field(0)},
{"concreteMethod", lookup("t").(*Named).Method(0)},
{"recv", lookup("t").(*Named).Method(0).Type().(*Signature).Recv()},
{"ftp", lookup("ft").Underlying().(*Signature).Params().At(0)},
{"ftrp", lookup("ft").Underlying().(*Signature).Results().At(0)},
{"fp", lookup("f").(*Signature).Params().At(0)},
{"frp", lookup("f").(*Signature).Results().At(0)},
{"interfaceMethod", lookup("i").Underlying().(*Interface).Method(0)},
}

// Collect all identifiers by name.
idents := make(map[string][]*syntax.Name)
syntax.Inspect(f, func(n syntax.Node) bool {
if id, ok := n.(*syntax.Name); ok {
idents[id.Value] = append(idents[id.Value], id)
}
return true
})

for _, test := range tests {
test := test
t.Run(test.ident, func(t *testing.T) {
if got := len(idents[test.ident]); got != 1 {
t.Fatalf("found %d identifiers named %s, want 1", got, test.ident)
}
ident := idents[test.ident][0]
def := info.Defs[ident]
if def == test.obj {
t.Fatalf("info.Defs[%s] contains the test object", test.ident)
}
if def.Pkg() != test.obj.Pkg() {
t.Errorf("Pkg() = %v, want %v", def.Pkg(), test.obj.Pkg())
}
if def.Name() != test.obj.Name() {
t.Errorf("Name() = %v, want %v", def.Name(), test.obj.Name())
}
if def.Pos() != test.obj.Pos() {
t.Errorf("Pos() = %v, want %v", def.Pos(), test.obj.Pos())
}
if def.Parent() != test.obj.Parent() {
t.Fatalf("Parent() = %v, want %v", def.Parent(), test.obj.Parent())
}
if def.Exported() != test.obj.Exported() {
t.Fatalf("Exported() = %v, want %v", def.Exported(), test.obj.Exported())
}
if def.Id() != test.obj.Id() {
t.Fatalf("Id() = %v, want %v", def.Id(), test.obj.Id())
}
// String and Type are expected to differ.
})
}
}

func TestImplements(t *testing.T) {
const src = `
package p
Expand Down
2 changes: 1 addition & 1 deletion src/cmd/compile/internal/types2/named.go
Expand Up @@ -191,7 +191,7 @@ func (t *Named) instantiateMethod(i int) *Func {
rtyp = t
}

sig.recv = NewParam(origSig.recv.pos, origSig.recv.pkg, origSig.recv.name, rtyp)
sig.recv = substVar(origSig.recv, rtyp)
return NewFunc(origm.pos, origm.pkg, origm.name, sig)
}

Expand Down
15 changes: 9 additions & 6 deletions src/cmd/compile/internal/types2/subst.go
Expand Up @@ -290,14 +290,18 @@ func (subst *subster) typOrNil(typ Type) Type {
func (subst *subster) var_(v *Var) *Var {
if v != nil {
if typ := subst.typ(v.typ); typ != v.typ {
copy := *v
copy.typ = typ
return &copy
return substVar(v, typ)
}
}
return v
}

func substVar(v *Var, typ Type) *Var {
copy := *v
copy.typ = typ
return &copy
}

func (subst *subster) tuple(t *Tuple) *Tuple {
if t != nil {
if vars, copied := subst.varList(t.vars); copied {
Expand Down Expand Up @@ -410,9 +414,8 @@ func replaceRecvType(in []*Func, old, new Type) (out []*Func, copied bool) {
copied = true
}
newsig := *sig
sig = &newsig
sig.recv = NewVar(sig.recv.pos, sig.recv.pkg, "", new)
out[i] = NewFunc(method.pos, method.pkg, method.name, sig)
newsig.recv = substVar(sig.recv, new)
out[i] = NewFunc(method.pos, method.pkg, method.name, &newsig)
}
}
return
Expand Down
98 changes: 98 additions & 0 deletions src/go/types/api_test.go
Expand Up @@ -2305,6 +2305,104 @@ func TestInstanceIdentity(t *testing.T) {
}
}

// TestInstantiatedObjects verifies properties of instantiated objects.
func TestInstantiatedObjects(t *testing.T) {
const src = `
package p
type T[P any] struct {
field P
}
func (recv *T[Q]) concreteMethod() {}
type FT[P any] func(ftp P) (ftrp P)
func F[P any](fp P) (frp P){ return }
type I[P any] interface {
interfaceMethod(P)
}
var (
t T[int]
ft FT[int]
f = F[int]
i I[int]
)
`
info := &Info{
Defs: make(map[*ast.Ident]Object),
}
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "p.go", src, 0)
if err != nil {
t.Fatal(err)
}
conf := Config{}
pkg, err := conf.Check(f.Name.Name, fset, []*ast.File{f}, info)
if err != nil {
t.Fatal(err)
}

lookup := func(name string) Type { return pkg.Scope().Lookup(name).Type() }
tests := []struct {
name string
obj Object
}{
{"field", lookup("t").Underlying().(*Struct).Field(0)},
{"concreteMethod", lookup("t").(*Named).Method(0)},
{"recv", lookup("t").(*Named).Method(0).Type().(*Signature).Recv()},
{"ftp", lookup("ft").Underlying().(*Signature).Params().At(0)},
{"ftrp", lookup("ft").Underlying().(*Signature).Results().At(0)},
{"fp", lookup("f").(*Signature).Params().At(0)},
{"frp", lookup("f").(*Signature).Results().At(0)},
{"interfaceMethod", lookup("i").Underlying().(*Interface).Method(0)},
}

// Collect all identifiers by name.
idents := make(map[string][]*ast.Ident)
ast.Inspect(f, func(n ast.Node) bool {
if id, ok := n.(*ast.Ident); ok {
idents[id.Name] = append(idents[id.Name], id)
}
return true
})

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
if got := len(idents[test.name]); got != 1 {
t.Fatalf("found %d identifiers named %s, want 1", got, test.name)
}
ident := idents[test.name][0]
def := info.Defs[ident]
if def == test.obj {
t.Fatalf("info.Defs[%s] contains the test object", test.name)
}
if def.Pkg() != test.obj.Pkg() {
t.Errorf("Pkg() = %v, want %v", def.Pkg(), test.obj.Pkg())
}
if def.Name() != test.obj.Name() {
t.Errorf("Name() = %v, want %v", def.Name(), test.obj.Name())
}
if def.Pos() != test.obj.Pos() {
t.Errorf("Pos() = %v, want %v", def.Pos(), test.obj.Pos())
}
if def.Parent() != test.obj.Parent() {
t.Fatalf("Parent() = %v, want %v", def.Parent(), test.obj.Parent())
}
if def.Exported() != test.obj.Exported() {
t.Fatalf("Exported() = %v, want %v", def.Exported(), test.obj.Exported())
}
if def.Id() != test.obj.Id() {
t.Fatalf("Id() = %v, want %v", def.Id(), test.obj.Id())
}
// String and Type are expected to differ.
})
}
}

func TestImplements(t *testing.T) {
const src = `
package p
Expand Down
2 changes: 1 addition & 1 deletion src/go/types/named.go
Expand Up @@ -193,7 +193,7 @@ func (t *Named) instantiateMethod(i int) *Func {
rtyp = t
}

sig.recv = NewParam(origSig.recv.pos, origSig.recv.pkg, origSig.recv.name, rtyp)
sig.recv = substVar(origSig.recv, rtyp)
return NewFunc(origm.pos, origm.pkg, origm.name, sig)
}

Expand Down
15 changes: 9 additions & 6 deletions src/go/types/subst.go
Expand Up @@ -290,14 +290,18 @@ func (subst *subster) typOrNil(typ Type) Type {
func (subst *subster) var_(v *Var) *Var {
if v != nil {
if typ := subst.typ(v.typ); typ != v.typ {
copy := *v
copy.typ = typ
return &copy
return substVar(v, typ)
}
}
return v
}

func substVar(v *Var, typ Type) *Var {
copy := *v
copy.typ = typ
return &copy
}

func (subst *subster) tuple(t *Tuple) *Tuple {
if t != nil {
if vars, copied := subst.varList(t.vars); copied {
Expand Down Expand Up @@ -410,9 +414,8 @@ func replaceRecvType(in []*Func, old, new Type) (out []*Func, copied bool) {
copied = true
}
newsig := *sig
sig = &newsig
sig.recv = NewVar(sig.recv.pos, sig.recv.pkg, "", new)
out[i] = NewFunc(method.pos, method.pkg, method.name, sig)
newsig.recv = substVar(sig.recv, new)
out[i] = NewFunc(method.pos, method.pkg, method.name, &newsig)
}
}
return
Expand Down

0 comments on commit 9b90838

Please sign in to comment.