From 12a7b6f66bb2b634a2147c2c90c6de59d2b49045 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Ib=C3=A1=C3=B1ez?= Date: Sat, 1 May 2021 01:36:14 +0200 Subject: [PATCH] [main] Add -- cmd flag separator to start different Gost instances --- cmd/gost/main.go | 192 ++++++++++++++++++++++++++++++----------------- 1 file changed, 125 insertions(+), 67 deletions(-) diff --git a/cmd/gost/main.go b/cmd/gost/main.go index f08f4132b..5371d7903 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -7,6 +7,8 @@ import ( "fmt" "net/http" "os" + "sync" + "strings" "runtime" _ "net/http/pprof" @@ -16,56 +18,148 @@ import ( ) var ( - configureFile string - baseCfg = &baseConfig{} - pprofAddr string pprofEnabled = os.Getenv("PROFILING") != "" ) func init() { gost.SetLogger(&gost.LogLogger{}) + // TODO - Generate different certificates for each worker + generateTlsCertificate() +} + +func main() { + var wg sync.WaitGroup + wg.Add(1) // Gost must exit if any of the workers exit + + // Split os.Args using -- and create a worker with each slice + args := strings.Split(" " + strings.Join(os.Args[1:], " ") + " ", " -- ") + if strings.Join(args, "") == "" { + // Fix to show gost help if the resulting array is empty + args[0] = " " + } + for wid, wargs := range args { + if wargs != "" { + go worker(wid, wargs, &wg) + } + } + wg.Wait() +} + +func worker(id int, args string, wg *sync.WaitGroup) { + defer wg.Done() + var ( - printVersion bool + configureFile string + baseCfg = &baseConfig{} + pprofAddr string ) - flag.Var(&baseCfg.route.ChainNodes, "F", "forward address, can make a forward chain") - flag.Var(&baseCfg.route.ServeNodes, "L", "listen address, can listen on multiple ports (required)") - flag.StringVar(&configureFile, "C", "", "configure file") - flag.BoolVar(&baseCfg.Debug, "D", false, "enable debug log") - flag.BoolVar(&printVersion, "V", false, "print version") - if pprofEnabled { - flag.StringVar(&pprofAddr, "P", ":6060", "profiling HTTP server address") - } - flag.Parse() + init := func () error { + var printVersion bool + + wf := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + wf.Var(&baseCfg.route.ChainNodes, "F", "forward address, can make a forward chain") + wf.Var(&baseCfg.route.ServeNodes, "L", "listen address, can listen on multiple ports (required)") + wf.StringVar(&configureFile, "C", "", "configure file") + wf.BoolVar(&baseCfg.Debug, "D", false, "enable debug log") + wf.BoolVar(&printVersion, "V", false, "print version") + + if pprofEnabled { + // Every worker uses a different profiling server by default + wf.StringVar(&pprofAddr, "P", fmt.Sprintf(":606%d", id), "profiling HTTP server address") + } + + wf.Parse(strings.Fields(args)) + + if printVersion { + fmt.Fprintf(os.Stdout, "gost %s (%s %s/%s)\n", gost.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH) + os.Exit(0) + } else if wf.NFlag() == 0 { + wf.Usage() + os.Exit(0) + } else if configureFile != "" { + err := parseBaseConfig(configureFile, baseCfg) + if err != nil { + return err + } + } - if printVersion { - fmt.Fprintf(os.Stdout, "gost %s (%s %s/%s)\n", - gost.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH) - os.Exit(0) + if baseCfg.route.ServeNodes.String() == "[]" { + configErrMsg := "" + if configureFile != "" { + configErrMsg = " or ServeNodes inside config file (-C)" + } + fmt.Fprintf(os.Stderr, "\n[!] Error: Missing -L flag%s\n\n", configErrMsg) + wf.Usage() + os.Exit(1) + } + + return nil } - if configureFile != "" { - _, err := parseBaseConfig(configureFile) + start := func () error { + // TODO - Make debug worker independent + if ! gost.Debug { + gost.Debug = baseCfg.Debug + } + + var routers []router + rts, err := baseCfg.route.GenRouters() if err != nil { - log.Log(err) - os.Exit(1) + return err + } + routers = append(routers, rts...) + + for _, route := range baseCfg.Routes { + rts, err := route.GenRouters() + if err != nil { + return err + } + routers = append(routers, rts...) + } + + if len(routers) == 0 { + return errors.New("invalid config") } + for i := range routers { + go routers[i].Serve() + } + + return nil } - if flag.NFlag() == 0 { - flag.PrintDefaults() - os.Exit(0) + + main := func () error { + if pprofEnabled { + go func() { + log.Log("profiling server on", pprofAddr) + log.Log(http.ListenAndServe(pprofAddr, nil)) + }() + } + + err := start() + return err } -} -func main() { - if pprofEnabled { - go func() { - log.Log("profiling server on", pprofAddr) - log.Log(http.ListenAndServe(pprofAddr, nil)) - }() + if err := init(); err != nil { + log.Log(err) + return + } + if err := main(); err != nil { + log.Log(err) + return } + // Allow local functions to be garbage-collected + init = nil + main = nil + start = nil + + select {} +} + +func generateTlsCertificate() { // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile, "") if err != nil { @@ -81,41 +175,5 @@ func main() { } else { log.Log("load TLS certificate files OK") } - gost.DefaultTLSConfig = tlsConfig - - if err := start(); err != nil { - log.Log(err) - os.Exit(1) - } - - select {} -} - -func start() error { - gost.Debug = baseCfg.Debug - - var routers []router - rts, err := baseCfg.route.GenRouters() - if err != nil { - return err - } - routers = append(routers, rts...) - - for _, route := range baseCfg.Routes { - rts, err := route.GenRouters() - if err != nil { - return err - } - routers = append(routers, rts...) - } - - if len(routers) == 0 { - return errors.New("invalid config") - } - for i := range routers { - go routers[i].Serve() - } - - return nil }