diff --git a/mksyscall_windows.go b/mksyscall_windows.go index a76bb4414c..82393ca367 100644 --- a/mksyscall_windows.go +++ b/mksyscall_windows.go @@ -57,6 +57,9 @@ import ( "io/ioutil" "log" "os" + "path/filepath" + "runtime" + "sort" "strconv" "strings" "text/template" @@ -65,6 +68,7 @@ import ( var ( filename = flag.String("output", "", "output file name (standard output if omitted)") printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") + systemDLL = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory") ) func trim(s string) string { @@ -277,7 +281,7 @@ func (r *Rets) SetReturnValuesCode() string { func (r *Rets) useLongHandleErrorCode(retvar string) string { const code = `if %s { if e1 != 0 { - err = error(e1) + err = errnoErr(e1) } else { err = %sEINVAL } @@ -607,7 +611,6 @@ func (f *Fn) IsNotDuplicate() bool { uniqDllFuncName[funcName] = true return true } - return false } @@ -621,8 +624,20 @@ func (f *Fn) HelperName() string { // Source files and functions. type Source struct { - Funcs []*Fn - Files []string + Funcs []*Fn + Files []string + StdLibImports []string + ExternalImports []string +} + +func (src *Source) Import(pkg string) { + src.StdLibImports = append(src.StdLibImports, pkg) + sort.Strings(src.StdLibImports) +} + +func (src *Source) ExternalImport(pkg string) { + src.ExternalImports = append(src.ExternalImports, pkg) + sort.Strings(src.ExternalImports) } // ParseFiles parses files listed in fs and extracts all syscall @@ -632,6 +647,10 @@ func ParseFiles(fs []string) (*Source, error) { src := &Source{ Funcs: make([]*Fn, 0), Files: make([]string, 0), + StdLibImports: []string{ + "unsafe", + }, + ExternalImports: make([]string, 0), } for _, file := range fs { if err := src.ParseFile(file); err != nil { @@ -702,14 +721,81 @@ func (src *Source) ParseFile(path string) error { return nil } +// IsStdRepo returns true if src is part of standard library. +func (src *Source) IsStdRepo() (bool, error) { + if len(src.Files) == 0 { + return false, errors.New("no input files provided") + } + abspath, err := filepath.Abs(src.Files[0]) + if err != nil { + return false, err + } + goroot := runtime.GOROOT() + if runtime.GOOS == "windows" { + abspath = strings.ToLower(abspath) + goroot = strings.ToLower(goroot) + } + sep := string(os.PathSeparator) + if !strings.HasSuffix(goroot, sep) { + goroot += sep + } + return strings.HasPrefix(abspath, goroot), nil +} + // Generate output source file from a source set src. func (src *Source) Generate(w io.Writer) error { + const ( + pkgStd = iota // any package in std library + pkgXSysWindows // x/sys/windows package + pkgOther + ) + isStdRepo, err := src.IsStdRepo() + if err != nil { + return err + } + var pkgtype int + switch { + case isStdRepo: + pkgtype = pkgStd + case packageName == "windows": + // TODO: this needs better logic than just using package name + pkgtype = pkgXSysWindows + default: + pkgtype = pkgOther + } + if *systemDLL { + switch pkgtype { + case pkgStd: + src.Import("internal/syscall/windows/sysdll") + case pkgXSysWindows: + default: + src.ExternalImport("golang.org/x/sys/windows") + } + } + src.ExternalImport("github.com/Microsoft/go-winio") + if packageName != "syscall" { + src.Import("syscall") + } funcMap := template.FuncMap{ "packagename": packagename, "syscalldot": syscalldot, + "newlazydll": func(dll string) string { + arg := "\"" + dll + ".dll\"" + if !*systemDLL { + return syscalldot() + "NewLazyDLL(" + arg + ")" + } + switch pkgtype { + case pkgStd: + return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))" + case pkgXSysWindows: + return "NewLazySystemDLL(" + arg + ")" + default: + return "windows.NewLazySystemDLL(" + arg + ")" + } + }, } t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate)) - err := t.Execute(w, src) + err = t.Execute(w, src) if err != nil { return errors.New("Failed to execute template: " + err.Error()) } @@ -761,12 +847,41 @@ const srcTemplate = ` package {{packagename}} -import "github.com/Microsoft/go-winio" -import "unsafe"{{if syscalldot}} -import "syscall"{{end}} +import ( +{{range .StdLibImports}}"{{.}}" +{{end}} + +{{range .ExternalImports}}"{{.}}" +{{end}} +) var _ unsafe.Pointer +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING) +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e {{syscalldot}}Errno) error { + switch e { + case 0: + return nil + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + var ( {{template "dlls" .}} {{template "funcnames" .}}) @@ -775,7 +890,7 @@ var ( {{/* help functions */}} -{{define "dlls"}}{{range .DLLs}} mod{{.}} = {{syscalldot}}NewLazyDLL("{{.}}.dll") +{{define "dlls"}}{{range .DLLs}} mod{{.}} = {{newlazydll .}} {{end}}{{end}} {{define "funcnames"}}{{range .Funcs}}{{if .IsNotDuplicate}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}"){{end}} @@ -802,12 +917,13 @@ func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{ {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}} +{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} + {{define "syscallcheck"}}{{if .ConfirmProc}}if {{.Rets.ErrorVarName}} = proc{{.DLLFuncName}}.Find(); {{.Rets.ErrorVarName}} != nil { return } {{end}}{{end}} -{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} {{end}}{{end}} diff --git a/zhcsshim.go b/zhcsshim.go index 7c3b7bdf5f..5d1a851ae8 100644 --- a/zhcsshim.go +++ b/zhcsshim.go @@ -2,16 +2,45 @@ package hcsshim -import "github.com/Microsoft/go-winio" -import "unsafe" -import "syscall" +import ( + "syscall" + "unsafe" + + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" +) var _ unsafe.Pointer +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return nil + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + var ( - modole32 = syscall.NewLazyDLL("ole32.dll") - modiphlpapi = syscall.NewLazyDLL("iphlpapi.dll") - modvmcompute = syscall.NewLazyDLL("vmcompute.dll") + modole32 = windows.NewLazySystemDLL("ole32.dll") + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + modvmcompute = windows.NewLazySystemDLL("vmcompute.dll") procCoTaskMemFree = modole32.NewProc("CoTaskMemFree") procSetCurrentThreadCompartmentId = modiphlpapi.NewProc("SetCurrentThreadCompartmentId")