Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize load types #21

Merged
merged 6 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type Context struct {
External types.Importer // external import
Sizes types.Sizes
DebugFunc func(*DebugInfo)
Types map[types.Type]bool
}

func NewContext(mode Mode) *Context {
Expand All @@ -62,6 +63,7 @@ func NewContext(mode Mode) *Context {
Mode: mode,
ParserMode: parser.AllErrors,
BuilderMode: 0, //ssa.SanityCheckFunctions,
Types: make(map[types.Type]bool),
}
return ctx
}
Expand Down Expand Up @@ -162,7 +164,10 @@ func (c *Context) RunFunc(mainPkg *ssa.Package, fnname string, args ...Value) (r

func (c *Context) NewInterp(mainPkg *ssa.Package) (*Interp, error) {
r, err := NewInterp(c.Loader, mainPkg, c.Mode)
r.setDebug(c.DebugFunc)
if err == nil {
r.PreloadTypes(c.Types)
r.setDebug(c.DebugFunc)
}
return r, err
}

Expand Down Expand Up @@ -244,6 +249,16 @@ func (c *Context) RunTest(path string, args []string) error {
return c.TestPkg(pkgs, path, args)
}

func (ctx *Context) saveType(t types.Type) {
if tuple, ok := t.(*types.Tuple); ok {
for i := 0; i < tuple.Len(); i++ {
ctx.saveType(tuple.At(i).Type())
}
return
}
ctx.Types[t] = true
}

func (ctx *Context) BuildPackage(fset *token.FileSet, pkg *types.Package, files []*ast.File) (*ssa.Package, *types.Info, error) {
if fset == nil {
panic("no token.FileSet")
Expand All @@ -268,6 +283,9 @@ func (ctx *Context) BuildPackage(fset *token.FileSet, pkg *types.Package, files
if err := types.NewChecker(tc, fset, pkg, info).Files(files); err != nil {
return nil, nil, err
}
for _, v := range info.Types {
ctx.saveType(v.Type)
}
prog := ssa.NewProgram(fset, ctx.BuilderMode)

// Create SSA packages for all imports.
Expand Down
119 changes: 71 additions & 48 deletions interp.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,20 @@ func (e runtimeError) Error() string {

// State shared between all interpreted goroutines.
type Interp struct {
fset *token.FileSet
prog *ssa.Program // the SSA program
mainpkg *ssa.Package // the SSA main package
globals map[ssa.Value]value // addresses of global variables (immutable)
mode Mode // interpreter options
sizes types.Sizes // the effective type-sizing function
goroutines int32 // atomically updated
types map[types.Type]reflect.Type
caller *frame
loader Loader
record *TypesRecord
typesMutex sync.RWMutex
callerMutex sync.RWMutex
fnDebug func(*DebugInfo)
fset *token.FileSet
prog *ssa.Program // the SSA program
mainpkg *ssa.Package // the SSA main package
globals map[ssa.Value]value // addresses of global variables (immutable)
mode Mode // interpreter options
sizes types.Sizes // the effective type-sizing function
goroutines int32 // atomically updated
types map[types.Type]reflect.Type
preloadTypes map[types.Type]reflect.Type
caller *frame
loader Loader
record *TypesRecord
typesMutex sync.RWMutex
fnDebug func(*DebugInfo)
}

func (i *Interp) setDebug(fn func(*DebugInfo)) {
Expand Down Expand Up @@ -142,10 +142,7 @@ func (i *Interp) FindMethod(mtyp reflect.Type, fn *types.Func) func([]reflect.Va
for i := 0; i < len(args); i++ {
iargs[i] = args[i].Interface()
}
i.callerMutex.RLock()
caller := i.caller
i.callerMutex.RUnlock()
r := call(i, caller, token.NoPos, f, iargs, nil)
r := i.call(nil, token.NoPos, f, iargs, nil)
switch mtyp.NumOut() {
case 0:
return nil
Expand Down Expand Up @@ -190,10 +187,27 @@ func (i *Interp) FindMethod(mtyp reflect.Type, fn *types.Func) func([]reflect.Va
return nil
}

func (i *Interp) preToType(typ types.Type) reflect.Type {
if t, ok := i.preloadTypes[typ]; ok {
return t
}
t := i.record.ToType(typ)
i.preloadTypes[typ] = t
return t
}

func (i *Interp) toType(typ types.Type) reflect.Type {
if t, ok := i.preloadTypes[typ]; ok {
return t
}
i.typesMutex.Lock()
defer i.typesMutex.Unlock()
return i.record.ToType(typ)
if t, ok := i.types[typ]; ok {
return t
}
t := i.record.ToType(typ)
i.types[typ] = t
return t
}

func (i *Interp) toFunc(fr *frame, typ reflect.Type, fn value) reflect.Value {
Expand All @@ -202,7 +216,7 @@ func (i *Interp) toFunc(fr *frame, typ reflect.Type, fn value) reflect.Value {
for i := 0; i < len(args); i++ {
iargs[i] = args[i].Interface()
}
r := call(i, fr, token.NoPos, fn, iargs, nil)
r := i.call(fr, token.NoPos, fn, iargs, nil)
if v, ok := r.(tuple); ok {
res := make([]reflect.Value, len(v))
for i := 0; i < len(v); i++ {
Expand Down Expand Up @@ -324,7 +338,7 @@ func (fr *frame) runDefer(d *deferred) {
fr.panic = recover()
}
}()
call(fr.i, fr, d.instr.Pos(), d.fn, d.args, d.ssaArgs)
fr.i.call(fr, d.instr.Pos(), d.fn, d.args, d.ssaArgs)
ok = true
}

Expand Down Expand Up @@ -449,7 +463,7 @@ func (i *Interp) visitInstr(fr *frame, instr ssa.Instruction) (func(), continuat
case *ssa.Call:
return func() {
fn, args := i.prepareCall(fr, &instr.Call)
fr.env[instr] = call(i, fr, instr.Pos(), fn, args, instr.Call.Args)
fr.env[instr] = i.call(fr, instr.Pos(), fn, args, instr.Call.Args)
}, kNext

case *ssa.ChangeInterface:
Expand Down Expand Up @@ -588,7 +602,7 @@ func (i *Interp) visitInstr(fr *frame, instr ssa.Instruction) (func(), continuat
fn, args := i.prepareCall(fr, &instr.Call)
atomic.AddInt32(&i.goroutines, 1)
go func() {
call(i, nil, instr.Pos(), fn, args, instr.Call.Args)
i.call(nil, instr.Pos(), fn, args, instr.Call.Args)
atomic.AddInt32(&i.goroutines, -1)
}()

Expand Down Expand Up @@ -913,26 +927,28 @@ func (i *Interp) prepareCall(fr *frame, call *ssa.CallCommon) (fn value, args []
// fn with arguments args, returning its result.
// callpos is the position of the callsite.
//
func call(i *Interp, caller *frame, callpos token.Pos, fn value, args []value, ssaArgs []ssa.Value) value {
i.callerMutex.Lock()
i.caller = caller
i.callerMutex.Unlock()
func (i *Interp) call(caller *frame, callpos token.Pos, fn value, args []value, ssaArgs []ssa.Value) value {
if caller == nil {
caller = i.caller
} else {
i.caller = caller
}
switch fn := fn.(type) {
case *ssa.Function:
if fn == nil {
panic("call of nil function") // nil of func type
}
return callSSA(i, caller, callpos, fn, args, nil)
return i.callSSA(caller, callpos, fn, args, nil)
case *closure:
if fn.Fn == nil {
panic("call of nil closure function") // nil of func type
}
return callSSA(i, caller, callpos, fn.Fn, args, fn.Env)
return i.callSSA(caller, callpos, fn.Fn, args, fn.Env)
case *ssa.Builtin:
return callBuiltin(i, caller, callpos, fn, args, ssaArgs)
return i.callBuiltin(caller, callpos, fn, args, ssaArgs)
default:
if f := reflect.ValueOf(fn); f.Kind() == reflect.Func {
return callReflect(i, caller, callpos, f, args, nil)
return i.callReflect(caller, callpos, f, args, nil)
}
}
panic(fmt.Sprintf("cannot call %T %v", fn, reflect.ValueOf(fn).Kind()))
Expand All @@ -949,7 +965,7 @@ func loc(fset *token.FileSet, pos token.Pos) string {
// and lexical environment env, returning its result.
// callpos is the position of the callsite.
//
func callSSA(i *Interp, caller *frame, callpos token.Pos, fn *ssa.Function, args []value, env []value) value {
func (i *Interp) callSSA(caller *frame, callpos token.Pos, fn *ssa.Function, args []value, env []value) value {
if i.mode&EnableTracing != 0 {
fset := fn.Prog.Fset
// TODO(adonovan): fix: loc() lies for external functions.
Expand All @@ -972,7 +988,7 @@ func callSSA(i *Interp, caller *frame, callpos token.Pos, fn *ssa.Function, args
if i.mode&EnableTracing != 0 {
log.Println("\t(external)")
}
return callReflect(i, caller, callpos, ext, args, nil)
return i.callReflect(caller, callpos, ext, args, nil)
}
if fn.Pkg != nil {
pkgPath := fn.Pkg.Pkg.Path()
Expand All @@ -982,21 +998,21 @@ func callSSA(i *Interp, caller *frame, callpos token.Pos, fn *ssa.Function, args
if i.mode&EnableTracing != 0 {
log.Println("\t(external func)")
}
return callReflect(i, caller, callpos, ext, args, nil)
return i.callReflect(caller, callpos, ext, args, nil)
}
} else if typ, ok := i.loader.LookupReflect(recv.Type()); ok {
//TODO maybe make full name for search
if m, ok := typ.MethodByName(fn.Name()); ok {
if i.mode&EnableTracing != 0 {
log.Println("\t(external reflect method)")
}
return callReflect(i, caller, callpos, m.Func, args, nil)
return i.callReflect(caller, callpos, m.Func, args, nil)
}
// if ext, ok := pkg.Methods[fullName]; ok {
// if i.mode&EnableTracing != 0 {
// log.Println("\t(external method)")
// }
// return callReflect(i, caller, callpos, ext, args, nil)
// return i.callReflect(caller, callpos, ext, args, nil)
// }
}
}
Expand All @@ -1006,7 +1022,7 @@ func callSSA(i *Interp, caller *frame, callpos token.Pos, fn *ssa.Function, args
if fn.Signature.Recv() != nil {
v := reflect.ValueOf(args[0])
if f, ok := v.Type().MethodByName(fn.Name()); ok {
return callReflect(i, caller, callpos, f.Func, args, nil)
return i.callReflect(caller, callpos, f.Func, args, nil)
}
}
if fn.Name() == "init" && fn.Type().String() == "func()" {
Expand Down Expand Up @@ -1040,7 +1056,7 @@ func callSSA(i *Interp, caller *frame, callpos token.Pos, fn *ssa.Function, args
return fr.result
}

func callReflect(i *Interp, caller *frame, callpos token.Pos, fn reflect.Value, args []value, env []value) value {
func (i *Interp) callReflect(caller *frame, callpos token.Pos, fn reflect.Value, args []value, env []value) value {
var ins []reflect.Value
typ := fn.Type()
isVariadic := fn.Type().IsVariadic()
Expand Down Expand Up @@ -1203,13 +1219,14 @@ func setGlobal(i *Interp, pkg *ssa.Package, name string, v value) {

func NewInterp(loader Loader, mainpkg *ssa.Package, mode Mode) (*Interp, error) {
i := &Interp{
fset: mainpkg.Prog.Fset,
prog: mainpkg.Prog,
mainpkg: mainpkg,
globals: make(map[ssa.Value]value),
mode: mode,
goroutines: 1,
types: make(map[types.Type]reflect.Type),
fset: mainpkg.Prog.Fset,
prog: mainpkg.Prog,
mainpkg: mainpkg,
globals: make(map[ssa.Value]value),
mode: mode,
goroutines: 1,
types: make(map[types.Type]reflect.Type),
preloadTypes: make(map[types.Type]reflect.Type),
}
i.loader = loader
i.record = NewTypesRecord(i.loader, i)
Expand All @@ -1219,7 +1236,7 @@ func NewInterp(loader Loader, mainpkg *ssa.Package, mode Mode) (*Interp, error)
for _, m := range pkg.Members {
switch v := m.(type) {
case *ssa.Global:
typ := i.toType(deref(v.Type()))
typ := i.preToType(deref(v.Type()))
i.globals[v] = reflect.New(typ).Interface()
}
}
Expand Down Expand Up @@ -1254,7 +1271,7 @@ func (i *Interp) RunFunc(name string, args ...Value) (r Value, err error) {
}
}()
if fn := i.mainpkg.Func(name); fn != nil {
r = call(i, nil, token.NoPos, fn, args, nil)
r = i.call(nil, token.NoPos, fn, args, nil)
} else {
err = fmt.Errorf("no function %v", name)
}
Expand Down Expand Up @@ -1286,7 +1303,7 @@ func (i *Interp) Run(entry string) (exitCode int, err error) {
}
}()
if mainFn := i.mainpkg.Func(entry); mainFn != nil {
call(i, nil, token.NoPos, mainFn, nil, nil)
i.call(nil, token.NoPos, mainFn, nil, nil)
exitCode = 0
} else {
err = fmt.Errorf("no function %v", entry)
Expand Down Expand Up @@ -1344,6 +1361,12 @@ func (i *Interp) GetType(key string) (reflect.Type, bool) {
return i.toType(t.Type()), true
}

func (i *Interp) PreloadTypes(ts map[types.Type]bool) {
for t := range ts {
i.preToType(t)
}
}

// deref returns a pointer's element type; otherwise it returns typ.
// TODO(adonovan): Import from ssa?
func deref(typ types.Type) types.Type {
Expand Down
Loading