diff --git a/cmd/run.go b/cmd/run.go index 0b91ebd2..49edee4f 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -60,6 +60,11 @@ var runCmd = &cobra.Command{ os.Exit(gerr.FailedToLoadPluginConfig) } } + + // Load environment variables for the global configuration. + config.LoadEnvVars(pluginConfig) + + // Unmarshal the plugin configuration for easier access. var pConfig config.PluginConfig if err := pluginConfig.Unmarshal("", &pConfig); err != nil { DefaultLogger.Fatal().Err(err).Msg("Failed to unmarshal plugin configuration") @@ -84,9 +89,13 @@ var runCmd = &cobra.Command{ } } + // Load environment variables for the global configuration. + config.LoadEnvVars(globalConfig) + // Get hooks signature verification policy. hooksConfig.Verification = pConfig.GetVerificationPolicy() + // Unmarshal the global configuration for easier access. var gConfig config.GlobalConfig if err := globalConfig.Unmarshal("", &gConfig); err != nil { DefaultLogger.Fatal().Err(err).Msg("Failed to unmarshal global configuration") @@ -94,8 +103,8 @@ var runCmd = &cobra.Command{ os.Exit(gerr.FailedToLoadGlobalConfig) } - // The config will be passed to the hooks, and in turn to the plugins that - // register to this hook. + // The config will be passed to the plugins that register to the "OnConfigLoaded" hook. + // The plugins can modify the config and return it. updatedGlobalConfig, err := hooksConfig.Run( context.Background(), globalConfig.All(), @@ -105,6 +114,9 @@ var runCmd = &cobra.Command{ DefaultLogger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks") } + // If the config was modified by the plugins, merge it with the one loaded from the file. + // Only global configuration is merged, which means that plugins cannot modify the plugin + // configurations. if updatedGlobalConfig != nil { // Merge the config with the one loaded from the file (in memory). // The changes won't be persisted to disk. diff --git a/config/constants.go b/config/constants.go index 20232a99..12c6da78 100644 --- a/config/constants.go +++ b/config/constants.go @@ -11,11 +11,13 @@ type ( LogOutput uint ) +// Status is the status of the server. const ( Running Status = iota Stopped ) +// Policy is the policy for hook verification. const ( // Non-strict (permissive) mode. PassDown Policy = iota // Pass down the extra keys/values in result to the next plugins @@ -25,11 +27,13 @@ const ( Remove // Remove the hook from the list on error and continue ) +// CompatPolicy is the compatibility policy for plugins. const ( Strict CompatPolicy = iota Loose ) +// LogOutput is the output type for the logger. const ( Console LogOutput = iota Stdout @@ -40,7 +44,8 @@ const ( const ( // Config constants. - Default = "default" + Default = "default" + EnvPrefix = "GATEWAYD_" // Logger constants. DefaultLogFileName = "gatewayd.log" diff --git a/config/getters.go b/config/getters.go index 59c4fc3b..4ab949f0 100644 --- a/config/getters.go +++ b/config/getters.go @@ -5,52 +5,47 @@ import ( "github.com/rs/zerolog" ) -// verificationPolicy returns the hook verification policy from plugin config file. +// GetVerificationPolicy returns the hook verification policy from plugin config file. func (p PluginConfig) GetVerificationPolicy() Policy { - // vPolicy := pluginConfig.String("plugins.verificationPolicy") - verificationPolicy := PassDown // default switch p.VerificationPolicy { case "ignore": - verificationPolicy = Ignore + return Ignore case "abort": - verificationPolicy = Abort + return Abort case "remove": - verificationPolicy = Remove + return Remove + default: + return PassDown } - - return verificationPolicy } -// pluginCompatPolicy returns the plugin compatibility policy from plugin config file. +// GetPluginCompatPolicy returns the plugin compatibility policy from plugin config file. func (p PluginConfig) GetPluginCompatPolicy() CompatPolicy { - // vPolicy := pluginConfig.String("plugins.compatibilityPolicy") - compatPolicy := Strict // default switch p.CompatibilityPolicy { case "strict": - compatPolicy = Strict + return Strict case "loose": - compatPolicy = Loose + return Loose + default: + return Strict } - - return compatPolicy } -// loadBalancer returns the load balancing algorithm to use. +// GetLoadBalancer returns the load balancing algorithm to use. func (s Server) GetLoadBalancer() gnet.LoadBalancing { - loadBalancer := map[string]gnet.LoadBalancing{ - "roundrobin": gnet.RoundRobin, - "leastconnections": gnet.LeastConnections, - "sourceaddrhash": gnet.SourceAddrHash, - } - - if lb, ok := loadBalancer[s.LoadBalancer]; ok { - return lb + switch s.LoadBalancer { + case "roundrobin": + return gnet.RoundRobin + case "leastconnections": + return gnet.LeastConnections + case "sourceaddrhash": + return gnet.SourceAddrHash + default: + return gnet.RoundRobin } - - return gnet.RoundRobin } -// tcpNoDelay returns the TCP no delay option from config file. +// GetTCPNoDelay returns the TCP no delay option from config file. func (s Server) GetTCPNoDelay() gnet.TCPSocketOpt { if s.TCPNoDelay { return gnet.TCPNoDelay @@ -73,7 +68,7 @@ func (p Pool) GetSize() int { return p.Size } -// output returns the logger output from config file. +// GetOutput returns the logger output from config file. func (l Logger) GetOutput() LogOutput { switch l.Output { case "file": @@ -87,7 +82,7 @@ func (l Logger) GetOutput() LogOutput { } } -// timeFormat returns the logger time format from config file. +// GetTimeFormat returns the logger time format from config file. func (l Logger) GetTimeFormat() string { switch l.TimeFormat { case "unixms": @@ -103,7 +98,7 @@ func (l Logger) GetTimeFormat() string { } } -// level returns the logger level from config file. +// GetLevel returns the logger level from config file. func (l Logger) GetLevel() zerolog.Level { switch l.Level { case "debug": diff --git a/config/types.go b/config/types.go index 976fbc63..7da7650c 100644 --- a/config/types.go +++ b/config/types.go @@ -2,10 +2,12 @@ package config import ( "fmt" + "strings" "time" "github.com/knadh/koanf" "github.com/knadh/koanf/providers/confmap" + "github.com/knadh/koanf/providers/env" ) // // getPath returns the path to the referenced config value. @@ -143,6 +145,8 @@ func LoadGlobalConfigDefaults(cfg *koanf.Koanf) { } } +// LoadPluginConfigDefaults loads the default plugin configuration +// before loading the plugin config file. func LoadPluginConfigDefaults(cfg *koanf.Koanf) { defaultValues := confmap.Provider(map[string]interface{}{ "plugins": map[string]interface{}{ @@ -155,3 +159,13 @@ func LoadPluginConfigDefaults(cfg *koanf.Koanf) { panic(fmt.Errorf("failed to load default plugin configuration: %w", err)) } } + +// LoadEnvVars loads the environment variables into the configuration with the +// given prefix, "GATEWAYD_". +func LoadEnvVars(cfg *koanf.Koanf) { + if err := cfg.Load(env.Provider(EnvPrefix, ".", func(env string) string { + return strings.ReplaceAll(strings.ToLower(strings.TrimPrefix(env, EnvPrefix)), "_", ".") + }), nil); err != nil { + panic(fmt.Errorf("failed to load environment variables: %w", err)) + } +} diff --git a/network/server.go b/network/server.go index 37843de3..f12089ad 100644 --- a/network/server.go +++ b/network/server.go @@ -15,13 +15,6 @@ import ( "github.com/rs/zerolog" ) -type Status string - -const ( - Running Status = "running" - Stopped Status = "stopped" -) - type Server struct { gnet.BuiltinEventEngine engine gnet.Engine @@ -34,7 +27,7 @@ type Server struct { Options []gnet.Option SoftLimit uint64 HardLimit uint64 - Status Status + Status config.Status TickInterval time.Duration } @@ -47,7 +40,7 @@ func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { // Run the OnBooting hooks. _, err := s.hooksConfig.Run( context.Background(), - map[string]interface{}{"status": string(s.Status)}, + map[string]interface{}{"status": fmt.Sprint(s.Status)}, plugin.OnBooting, s.hooksConfig.Verification) if err != nil { @@ -57,12 +50,12 @@ func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { s.engine = engine // Set the server status to running. - s.Status = Running + s.Status = config.Running // Run the OnBooted hooks. _, err = s.hooksConfig.Run( context.Background(), - map[string]interface{}{"status": string(s.Status)}, + map[string]interface{}{"status": fmt.Sprint(s.Status)}, plugin.OnBooted, s.hooksConfig.Verification) if err != nil { @@ -163,7 +156,7 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { // Shutdown the server if there are no more connections and the server is stopped. // This is used to shutdown the server gracefully. - if uint64(s.engine.CountConnections()) == 0 && s.Status == Stopped { + if uint64(s.engine.CountConnections()) == 0 && s.Status == config.Stopped { return gnet.Shutdown } @@ -250,7 +243,7 @@ func (s *Server) OnShutdown(engine gnet.Engine) { s.proxy.Shutdown() // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. - s.Status = Stopped + s.Status = config.Stopped } // OnTick is called every TickInterval. It calls the OnTick hooks. @@ -322,12 +315,12 @@ func (s *Server) Shutdown() { s.proxy.Shutdown() // Set the server status to stopped. This is used to shutdown the server gracefully in OnClose. - s.Status = Stopped + s.Status = config.Stopped } // IsRunning returns true if the server is running. func (s *Server) IsRunning() bool { - return s.Status == Running + return s.Status == config.Running } // NewServer creates a new server. @@ -348,7 +341,7 @@ func NewServer( Address: address, Options: options, TickInterval: tickInterval, - Status: Stopped, + Status: config.Stopped, } // Try to resolve the address and log an error if it can't be resolved.