From c94a03c703fcabc8c0b1fdbca26b557fbc3170a7 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 20:53:33 +0100 Subject: [PATCH 01/16] Rename hook.Config to hook.Registry --- cmd/run.go | 46 +++++++++++----------- network/proxy.go | 4 +- network/proxy_test.go | 4 +- network/server.go | 52 ++++++++++++------------- network/server_test.go | 10 ++--- plugin/hook/constants.go | 24 ++++++------ plugin/hook/hooks.go | 19 ++++----- plugin/hook/hooks_test.go | 20 +++++----- plugin/plugin_registry.go | 70 +++++++++++++++++----------------- plugin/plugin_registry_test.go | 8 ++-- 10 files changed, 129 insertions(+), 128 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index c8f230e9..58301938 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -24,14 +24,14 @@ import ( ) var ( - hooksConfig = hook.NewHookConfig() + hookRegistry = hook.NewRegistry() DefaultLogger = logging.NewLogger( logging.LoggerConfig{ Level: zerolog.InfoLevel, // Default log level NoColor: true, }, ) - pluginRegistry = plugin.NewRegistry(hooksConfig) + pluginRegistry = plugin.NewRegistry(hookRegistry) // Global koanf instance. Using "." as the key path delimiter. globalConfig = koanf.New(".") // Plugin koanf instance. Using "." as the key path delimiter. @@ -45,7 +45,7 @@ var runCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { // The plugins are loaded and hooks registered // before the configuration is loaded. - hooksConfig.Logger = DefaultLogger + hookRegistry.Logger = DefaultLogger // Load default plugin configuration. config.LoadPluginConfigDefaults(pluginConfig) @@ -90,7 +90,7 @@ var runCmd = &cobra.Command{ config.LoadEnvVars(globalConfig) // Get hooks signature verification policy. - hooksConfig.Verification = pConfig.GetVerificationPolicy() + hookRegistry.Verification = pConfig.GetVerificationPolicy() // Unmarshal the global configuration for easier access. var gConfig config.GlobalConfig @@ -102,11 +102,11 @@ var runCmd = &cobra.Command{ // The config will be passed to the plugins that register to the "OnConfigLoaded" hook. // The plugins can modify the config and return it. - updatedGlobalConfig, err := hooksConfig.Run( + updatedGlobalConfig, err := hookRegistry.Run( context.Background(), globalConfig.All(), hook.OnConfigLoaded, - hooksConfig.Verification) + hookRegistry.Verification) if err != nil { DefaultLogger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks") } @@ -139,7 +139,7 @@ var runCmd = &cobra.Command{ }) // Replace the default logger with the new one from the config. - hooksConfig.Logger = logger + hookRegistry.Logger = logger // This is a notification hook, so we don't care about the result. data := map[string]interface{}{ @@ -150,8 +150,8 @@ var runCmd = &cobra.Command{ "fileName": loggerCfg.FileName, } // TODO: Use a context with a timeout - _, err = hooksConfig.Run( - context.Background(), data, hook.OnNewLogger, hooksConfig.Verification) + _, err = hookRegistry.Run( + context.Background(), data, hook.OnNewLogger, hookRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks") } @@ -179,11 +179,11 @@ var runCmd = &cobra.Command{ "tcpKeepAlive": client.TCPKeepAlive, "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), } - _, err := hooksConfig.Run( + _, err := hookRegistry.Run( context.Background(), clientCfg, hook.OnNewClient, - hooksConfig.Verification) + hookRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") } @@ -207,11 +207,11 @@ var runCmd = &cobra.Command{ os.Exit(1) } - _, err = hooksConfig.Run( + _, err = hookRegistry.Run( context.Background(), map[string]interface{}{"size": poolSize}, hook.OnNewPool, - hooksConfig.Verification) + hookRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") } @@ -222,7 +222,7 @@ var runCmd = &cobra.Command{ healthCheckPeriod := gConfig.Proxy[config.Default].HealthCheckPeriod proxy := network.NewProxy( pool, - hooksConfig, + hookRegistry, elastic, reuseElasticClients, healthCheckPeriod, @@ -245,8 +245,8 @@ var runCmd = &cobra.Command{ "tcpKeepAlivePeriod": clientConfig.TCPKeepAlivePeriod.String(), }, } - _, err = hooksConfig.Run( - context.Background(), proxyCfg, hook.OnNewProxy, hooksConfig.Verification) + _, err = hookRegistry.Run( + context.Background(), proxyCfg, hook.OnNewProxy, hookRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") } @@ -285,7 +285,7 @@ var runCmd = &cobra.Command{ }, proxy, logger, - hooksConfig, + hookRegistry, ) serverCfg := map[string]interface{}{ @@ -307,8 +307,8 @@ var runCmd = &cobra.Command{ "tcpKeepAlive": gConfig.Server.TCPKeepAlive.String(), "tcpNoDelay": gConfig.Server.TCPNoDelay, } - _, err = hooksConfig.Run( - context.Background(), serverCfg, hook.OnNewServer, hooksConfig.Verification) + _, err = hookRegistry.Run( + context.Background(), serverCfg, hook.OnNewServer, hookRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewServer hooks") } @@ -326,16 +326,16 @@ var runCmd = &cobra.Command{ ) signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, signals...) - go func(hooksConfig *hook.Config) { + go func(hookRegistry *hook.Registry) { for sig := range signalsCh { for _, s := range signals { if sig != s { // Notify the hooks that the server is shutting down. - _, err := hooksConfig.Run( + _, err := hookRegistry.Run( context.Background(), map[string]interface{}{"signal": sig.String()}, hook.OnSignal, - hooksConfig.Verification, + hookRegistry.Verification, ) if err != nil { logger.Error().Err(err).Msg("Failed to run OnSignal hooks") @@ -347,7 +347,7 @@ var runCmd = &cobra.Command{ } } } - }(hooksConfig) + }(hookRegistry) // Run the server. if err := server.Run(); err != nil { diff --git a/network/proxy.go b/network/proxy.go index 896c1327..f6d7e9a7 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -26,7 +26,7 @@ type Proxy struct { availableConnections pool.IPool busyConnections pool.IPool logger zerolog.Logger - hookConfig *hook.Config + hookConfig *hook.Registry scheduler *gocron.Scheduler Elastic bool @@ -41,7 +41,7 @@ var _ IProxy = &Proxy{} // NewProxy creates a new proxy. func NewProxy( - connPool pool.IPool, hookConfig *hook.Config, + connPool pool.IPool, hookConfig *hook.Registry, elastic, reuseElasticClients bool, healthCheckPeriod time.Duration, clientConfig *config.Client, logger zerolog.Logger, diff --git a/network/proxy_test.go b/network/proxy_test.go index a4c198de..68fca5a3 100644 --- a/network/proxy_test.go +++ b/network/proxy_test.go @@ -54,7 +54,7 @@ func TestNewProxy(t *testing.T) { // Create a proxy with a fixed buffer pool proxy := NewProxy( - pool, hook.NewHookConfig(), false, false, config.DefaultHealthCheckPeriod, nil, logger) + pool, hook.NewRegistry(), false, false, config.DefaultHealthCheckPeriod, nil, logger) assert.NotNil(t, proxy) assert.Equal(t, 0, proxy.busyConnections.Size(), "Proxy should have no connected clients") @@ -83,7 +83,7 @@ func TestNewProxyElastic(t *testing.T) { pool := pool.NewPool(config.EmptyPoolCapacity) // Create a proxy with an elastic buffer pool - proxy := NewProxy(pool, hook.NewHookConfig(), true, false, config.DefaultHealthCheckPeriod, + proxy := NewProxy(pool, hook.NewRegistry(), true, false, config.DefaultHealthCheckPeriod, &config.Client{ Network: "tcp", Address: "localhost:5432", diff --git a/network/server.go b/network/server.go index fdbb1d37..215b383c 100644 --- a/network/server.go +++ b/network/server.go @@ -17,10 +17,10 @@ import ( type Server struct { gnet.BuiltinEventEngine - engine gnet.Engine - proxy IProxy - logger zerolog.Logger - hooksConfig *hook.Config + engine gnet.Engine + proxy IProxy + logger zerolog.Logger + hookRegistry *hook.Registry Network string // tcp/udp/unix Address string @@ -38,11 +38,11 @@ func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { s.logger.Debug().Msg("GatewayD is booting...") // Run the OnBooting hooks. - _, err := s.hooksConfig.Run( + _, err := s.hookRegistry.Run( context.Background(), map[string]interface{}{"status": fmt.Sprint(s.Status)}, hook.OnBooting, - s.hooksConfig.Verification) + s.hookRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnBooting hook") } @@ -53,11 +53,11 @@ func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { s.Status = config.Running // Run the OnBooted hooks. - _, err = s.hooksConfig.Run( + _, err = s.hookRegistry.Run( context.Background(), map[string]interface{}{"status": fmt.Sprint(s.Status)}, hook.OnBooted, - s.hooksConfig.Verification) + s.hookRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnBooted hook") } @@ -80,8 +80,8 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { "remote": gconn.RemoteAddr().String(), }, } - _, err := s.hooksConfig.Run( - context.Background(), onOpeningData, hook.OnOpening, s.hooksConfig.Verification) + _, err := s.hookRegistry.Run( + context.Background(), onOpeningData, hook.OnOpening, s.hookRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnOpening hook") } @@ -121,8 +121,8 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { "remote": gconn.RemoteAddr().String(), }, } - _, err = s.hooksConfig.Run( - context.Background(), onOpenedData, hook.OnOpened, s.hooksConfig.Verification) + _, err = s.hookRegistry.Run( + context.Background(), onOpenedData, hook.OnOpened, s.hookRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnOpened hook") } @@ -148,8 +148,8 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { if err != nil { data["error"] = err.Error() } - _, gatewaydErr := s.hooksConfig.Run( - context.Background(), data, hook.OnClosing, s.hooksConfig.Verification) + _, gatewaydErr := s.hookRegistry.Run( + context.Background(), data, hook.OnClosing, s.hookRegistry.Verification) if gatewaydErr != nil { s.logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosing hook") } @@ -179,8 +179,8 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { if err != nil { data["error"] = err.Error() } - _, gatewaydErr = s.hooksConfig.Run( - context.Background(), data, hook.OnClosed, s.hooksConfig.Verification) + _, gatewaydErr = s.hookRegistry.Run( + context.Background(), data, hook.OnClosed, s.hookRegistry.Verification) if gatewaydErr != nil { s.logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosed hook") } @@ -198,8 +198,8 @@ func (s *Server) OnTraffic(gconn gnet.Conn) gnet.Action { "remote": gconn.RemoteAddr().String(), }, } - _, err := s.hooksConfig.Run( - context.Background(), onTrafficData, hook.OnTraffic, s.hooksConfig.Verification) + _, err := s.hookRegistry.Run( + context.Background(), onTrafficData, hook.OnTraffic, s.hookRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnTraffic hook") } @@ -231,11 +231,11 @@ func (s *Server) OnShutdown(engine gnet.Engine) { s.logger.Debug().Msg("GatewayD is shutting down...") // Run the OnShutdown hooks. - _, err := s.hooksConfig.Run( + _, err := s.hookRegistry.Run( context.Background(), map[string]interface{}{"connections": s.engine.CountConnections()}, hook.OnShutdown, - s.hooksConfig.Verification) + s.hookRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnShutdown hook") } @@ -254,11 +254,11 @@ func (s *Server) OnTick() (time.Duration, gnet.Action) { "Active client connections") // Run the OnTick hooks. - _, err := s.hooksConfig.Run( + _, err := s.hookRegistry.Run( context.Background(), map[string]interface{}{"connections": s.engine.CountConnections()}, hook.OnTick, - s.hooksConfig.Verification) + s.hookRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnTick hook") } @@ -284,8 +284,8 @@ func (s *Server) Run() error { if err != nil && err.Unwrap() != nil { onRunData["error"] = err.OriginalError.Error() } - result, err := s.hooksConfig.Run( - context.Background(), onRunData, hook.OnRun, s.hooksConfig.Verification) + result, err := s.hookRegistry.Run( + context.Background(), onRunData, hook.OnRun, s.hookRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run the hook") } @@ -334,7 +334,7 @@ func NewServer( options []gnet.Option, proxy IProxy, logger zerolog.Logger, - hooksConfig *hook.Config, + hookRegistry *hook.Registry, ) *Server { // Create the server. server := Server{ @@ -390,7 +390,7 @@ func NewServer( server.proxy = proxy server.logger = logger - server.hooksConfig = hooksConfig + server.hookRegistry = hookRegistry return &server } diff --git a/network/server_test.go b/network/server_test.go index 2ca0d568..b3a054d8 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -39,7 +39,7 @@ func TestRunServer(t *testing.T) { logger := logging.NewLogger(cfg) - hooksConfig := hook.NewHookConfig() + hookRegistry := hook.NewRegistry() onTrafficFromClient := func( ctx context.Context, @@ -67,7 +67,7 @@ func TestRunServer(t *testing.T) { assert.Empty(t, paramsMap["error"]) return params, nil } - hooksConfig.Add(hook.OnTrafficFromClient, 1, onTrafficFromClient) + hookRegistry.Add(hook.OnTrafficFromClient, 1, onTrafficFromClient) onTrafficFromServer := func( ctx context.Context, @@ -92,7 +92,7 @@ func TestRunServer(t *testing.T) { assert.Empty(t, paramsMap["error"]) return params, nil } - hooksConfig.Add(hook.OnTrafficFromServer, 1, onTrafficFromServer) + hookRegistry.Add(hook.OnTrafficFromServer, 1, onTrafficFromServer) clientConfig := config.Client{ Network: "tcp", @@ -116,7 +116,7 @@ func TestRunServer(t *testing.T) { // Create a proxy with a fixed buffer pool. proxy := NewProxy( - pool, hooksConfig, false, false, config.DefaultHealthCheckPeriod, &clientConfig, logger) + pool, hookRegistry, false, false, config.DefaultHealthCheckPeriod, &clientConfig, logger) // Create a server. server := NewServer( @@ -132,7 +132,7 @@ func TestRunServer(t *testing.T) { }, proxy, logger, - hooksConfig, + hookRegistry, ) assert.NotNil(t, server) diff --git a/plugin/hook/constants.go b/plugin/hook/constants.go index b9c4e3e6..a3f8ccdf 100644 --- a/plugin/hook/constants.go +++ b/plugin/hook/constants.go @@ -5,24 +5,24 @@ const ( OnConfigLoaded Type = "onConfigLoaded" OnNewLogger Type = "onNewLogger" OnNewPool Type = "onNewPool" + OnNewClient Type = "onNewClient" OnNewProxy Type = "onNewProxy" OnNewServer Type = "onNewServer" OnSignal Type = "onSignal" // Server hooks (network/server.go). - OnRun Type = "onRun" - OnBooting Type = "onBooting" - OnBooted Type = "onBooted" - OnOpening Type = "onOpening" - OnOpened Type = "onOpened" - OnClosing Type = "onClosing" - OnClosed Type = "onClosed" - OnTraffic Type = "onTraffic" + OnRun Type = "onRun" + OnBooting Type = "onBooting" + OnBooted Type = "onBooted" + OnOpening Type = "onOpening" + OnOpened Type = "onOpened" + OnClosing Type = "onClosing" + OnClosed Type = "onClosed" + OnTraffic Type = "onTraffic" + OnShutdown Type = "onShutdown" + OnTick Type = "onTick" + // Proxy hooks (network/proxy.go). OnTrafficFromClient Type = "onTrafficFromClient" OnTrafficToServer Type = "onTrafficToServer" OnTrafficFromServer Type = "onTrafficFromServer" OnTrafficToClient Type = "onTrafficToClient" - OnShutdown Type = "onShutdown" - OnTick Type = "onTick" - // Pool hooks (network/pool.go). - OnNewClient Type = "onNewClient" ) diff --git a/plugin/hook/hooks.go b/plugin/hook/hooks.go index 20784cf8..9b2971d8 100644 --- a/plugin/hook/hooks.go +++ b/plugin/hook/hooks.go @@ -12,26 +12,27 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -type Config struct { - hooks map[Type]map[Priority]FunctionType +type Registry struct { + hooks map[Type]map[Priority]FunctionType + Logger zerolog.Logger Verification config.Policy } -// NewHookConfig returns a new Config. -func NewHookConfig() *Config { - return &Config{ +// NewRegistry returns a new Config. +func NewRegistry() *Registry { + return &Registry{ hooks: map[Type]map[Priority]FunctionType{}, } } // Hooks returns the hooks. -func (h *Config) Hooks() map[Type]map[Priority]FunctionType { +func (h *Registry) Hooks() map[Type]map[Priority]FunctionType { return h.hooks } // Add adds a hook with a priority to the hooks map. -func (h *Config) Add(hookType Type, prio Priority, hookFunc FunctionType) { +func (h *Registry) Add(hookType Type, prio Priority, hookFunc FunctionType) { if len(h.hooks[hookType]) == 0 { h.hooks[hookType] = map[Priority]FunctionType{prio: hookFunc} } else { @@ -48,7 +49,7 @@ func (h *Config) Add(hookType Type, prio Priority, hookFunc FunctionType) { } // Get returns the hooks of a specific type. -func (h *Config) Get(hookType Type) map[Priority]FunctionType { +func (h *Registry) Get(hookType Type) map[Priority]FunctionType { return h.hooks[hookType] } @@ -66,7 +67,7 @@ func (h *Config) Get(hookType Type) map[Priority]FunctionType { // The opts are passed to the hooks as well to allow them to use the grpc.CallOption. // //nolint:funlen -func (h *Config) Run( +func (h *Registry) Run( ctx context.Context, args map[string]interface{}, hookType Type, diff --git a/plugin/hook/hooks_test.go b/plugin/hook/hooks_test.go index 61ff4891..433385f2 100644 --- a/plugin/hook/hooks_test.go +++ b/plugin/hook/hooks_test.go @@ -12,13 +12,13 @@ import ( // Test_NewHookConfig tests the NewHookConfig function. func Test_NewHookConfig(t *testing.T) { - hc := NewHookConfig() + hc := NewRegistry() assert.NotNil(t, hc) } // Test_HookConfig_Add tests the Add function. func Test_HookConfig_Add(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() testFunc := func( ctx context.Context, args *structpb.Struct, @@ -33,7 +33,7 @@ func Test_HookConfig_Add(t *testing.T) { // Test_HookConfig_Add_Multiple_Hooks tests the Add function with multiple hooks. func Test_HookConfig_Add_Multiple_Hooks(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() hooks.Add(OnNewLogger, 0, func( ctx context.Context, args *structpb.Struct, @@ -54,7 +54,7 @@ func Test_HookConfig_Add_Multiple_Hooks(t *testing.T) { // Test_HookConfig_Get tests the Get function. func Test_HookConfig_Get(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() testFunc := func( ctx context.Context, args *structpb.Struct, @@ -70,7 +70,7 @@ func Test_HookConfig_Get(t *testing.T) { // Test_HookConfig_Run tests the Run function. func Test_HookConfig_Run(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() hooks.Add(OnNewLogger, 0, func( ctx context.Context, args *structpb.Struct, @@ -86,7 +86,7 @@ func Test_HookConfig_Run(t *testing.T) { // Test_HookConfig_Run_PassDown tests the Run function with the PassDown option. func Test_HookConfig_Run_PassDown(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() // The result of the hook will be nil and will be passed down to the next hooks.Add(OnNewLogger, 0, func( ctx context.Context, @@ -122,7 +122,7 @@ func Test_HookConfig_Run_PassDown(t *testing.T) { // Test_HookConfig_Run_PassDown_2 tests the Run function with the PassDown option. func Test_HookConfig_Run_PassDown_2(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() // The result of the hook will be nil and will be passed down to the next hooks.Add(OnNewLogger, 0, func( ctx context.Context, @@ -163,7 +163,7 @@ func Test_HookConfig_Run_PassDown_2(t *testing.T) { // Test_HookConfig_Run_Ignore tests the Run function with the Ignore option. func Test_HookConfig_Run_Ignore(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() // This should not run, because the return value is not the same as the params hooks.Add(OnNewLogger, 0, func( ctx context.Context, @@ -199,7 +199,7 @@ func Test_HookConfig_Run_Ignore(t *testing.T) { // Test_HookConfig_Run_Abort tests the Run function with the Abort option. func Test_HookConfig_Run_Abort(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() // This should not run, because the return value is not the same as the params hooks.Add(OnNewLogger, 0, func( ctx context.Context, @@ -229,7 +229,7 @@ func Test_HookConfig_Run_Abort(t *testing.T) { // Test_HookConfig_Run_Remove tests the Run function with the Remove option. func Test_HookConfig_Run_Remove(t *testing.T) { - hooks := NewHookConfig() + hooks := NewRegistry() // This should not run, because the return value is not the same as the params hooks.Add(OnNewLogger, 0, func( ctx context.Context, diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index 56b7af2b..b0c55ebe 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -29,17 +29,17 @@ type IPluginRegistry interface { type PluginRegistry struct { //nolint:golint,revive plugins pool.IPool - hooksConfig *hook.Config + hookRegistry *hook.Registry CompatPolicy config.CompatPolicy } var _ IPluginRegistry = &PluginRegistry{} // NewRegistry creates a new plugin registry. -func NewRegistry(hooksConfig *hook.Config) *PluginRegistry { +func NewRegistry(hookRegistry *hook.Registry) *PluginRegistry { return &PluginRegistry{ - plugins: pool.NewPool(config.EmptyPoolCapacity), - hooksConfig: hooksConfig, + plugins: pool.NewPool(config.EmptyPoolCapacity), + hookRegistry: hookRegistry, } } @@ -47,7 +47,7 @@ func NewRegistry(hooksConfig *hook.Config) *PluginRegistry { func (reg *PluginRegistry) Add(plugin *Plugin) bool { _, loaded, err := reg.plugins.GetOrPut(plugin.ID, plugin) if err != nil { - reg.hooksConfig.Logger.Error().Err(err).Msg("Failed to add plugin to registry") + reg.hookRegistry.Logger.Error().Err(err).Msg("Failed to add plugin to registry") return false } return loaded @@ -81,14 +81,14 @@ func (reg *PluginRegistry) Exists(name, version, remoteURL string) bool { // Parse the supplied version and the version in the registry. suppliedVer, err := semver.NewVersion(version) if err != nil { - reg.hooksConfig.Logger.Error().Err(err).Msg( + reg.hookRegistry.Logger.Error().Err(err).Msg( "Failed to parse supplied plugin version") return false } registryVer, err := semver.NewVersion(plugin.Version) if err != nil { - reg.hooksConfig.Logger.Error().Err(err).Msg( + reg.hookRegistry.Logger.Error().Err(err).Msg( "Failed to parse plugin version in registry") return false } @@ -99,7 +99,7 @@ func (reg *PluginRegistry) Exists(name, version, remoteURL string) bool { return true } - reg.hooksConfig.Logger.Debug().Str("name", name).Str("version", version).Msg( + reg.hookRegistry.Logger.Debug().Str("name", name).Str("version", version).Msg( "Supplied plugin version is greater than the version in registry") return false } @@ -141,7 +141,7 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { continue } - reg.hooksConfig.Logger.Debug().Str("name", pCfg.Name).Msg("Loading plugin") + reg.hookRegistry.Logger.Debug().Str("name", pCfg.Name).Msg("Loading plugin") plugin := &Plugin{ ID: Identifier{ Name: pCfg.Name, @@ -156,20 +156,20 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Is the plugin enabled? plugin.Enabled = pCfg.Enabled if !plugin.Enabled { - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin is disabled") + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin is disabled") continue } // File path of the plugin on disk. if plugin.LocalPath == "" { - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Msg( + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg( "Local file of the plugin doesn't exist or is not set") continue } // Checksum of the plugin. if plugin.ID.Checksum == "" { - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Msg( + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg( "Checksum of plugin doesn't exist or is not set") continue } @@ -177,11 +177,11 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Verify the checksum. // TODO: Load the plugin from a remote location if the checksum didn't match? if sum, err := utils.SHA256SUM(plugin.LocalPath); err != nil { - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( "Failed to calculate checksum") continue } else if sum != plugin.ID.Checksum { - reg.hooksConfig.Logger.Debug().Fields( + reg.hookRegistry.Logger.Debug().Fields( map[string]interface{}{ "calculated": sum, "expected": plugin.ID.Checksum, @@ -197,7 +197,7 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // have a priority of 1000 or greater. plugin.Priority = hook.Priority(config.PluginPriorityStart + uint(priority)) - logAdapter := logging.NewHcLogAdapter(®.hooksConfig.Logger, config.LoggerName) + logAdapter := logging.NewHcLogAdapter(®.hookRegistry.Logger, config.LoggerName) plugin.client = goplugin.NewClient( &goplugin.ClientConfig{ @@ -220,22 +220,22 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { }, ) - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin loaded") + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin loaded") if _, err := plugin.Start(); err != nil { - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( "Failed to start plugin") } // Load metadata from the plugin. var metadata *structpb.Struct if pluginV1, err := plugin.Dispense(); err != nil { - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( "Failed to dispense plugin") continue } else { if meta, origErr := pluginV1.GetPluginConfig( context.Background(), &structpb.Struct{}); err != nil { - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Err(origErr).Msg( + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Err(origErr).Msg( "Failed to get plugin metadata") continue } else { @@ -246,12 +246,12 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Retrieve plugin requirements. if err := mapstructure.Decode(metadata.Fields["requires"].GetListValue().AsSlice(), &plugin.Requires); err != nil { - reg.hooksConfig.Logger.Debug().Err(err).Msg("Failed to decode plugin requirements") + reg.hookRegistry.Logger.Debug().Err(err).Msg("Failed to decode plugin requirements") } // Too many requirements or not enough plugins loaded. if len(plugin.Requires) > reg.plugins.Size() { - reg.hooksConfig.Logger.Debug().Msg( + reg.hookRegistry.Logger.Debug().Msg( "The plugin has too many requirements, " + "and not enough of them exist in the registry, so it won't work properly") } @@ -259,19 +259,19 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Check if the plugin requirements are met. for _, req := range plugin.Requires { if !reg.Exists(req.Name, req.Version, req.RemoteURL) { - reg.hooksConfig.Logger.Debug().Fields( + reg.hookRegistry.Logger.Debug().Fields( map[string]interface{}{ "name": plugin.ID.Name, "requirement": req.Name, }, ).Msg("The plugin requirement is not met, so it won't work properly") if reg.CompatPolicy == config.Strict { - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Msg( + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg( "Registry is in strict compatibility mode, so the plugin won't be loaded") plugin.Stop() // Stop the plugin. continue } else { - reg.hooksConfig.Logger.Debug().Fields( + reg.hookRegistry.Logger.Debug().Fields( map[string]interface{}{ "name": plugin.ID.Name, "requirement": req.Name, @@ -290,12 +290,12 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Retrieve authors. if err := mapstructure.Decode(metadata.Fields["authors"].GetListValue().AsSlice(), &plugin.Authors); err != nil { - reg.hooksConfig.Logger.Debug().Err(err).Msg("Failed to decode plugin authors") + reg.hookRegistry.Logger.Debug().Err(err).Msg("Failed to decode plugin authors") } // Retrieve hooks. if err := mapstructure.Decode(metadata.Fields["hooks"].GetListValue().AsSlice(), &plugin.Hooks); err != nil { - reg.hooksConfig.Logger.Debug().Err(err).Msg("Failed to decode plugin hooks") + reg.hookRegistry.Logger.Debug().Err(err).Msg("Failed to decode plugin hooks") } // Retrieve plugin config. @@ -304,18 +304,18 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { if val, ok := value.(string); ok { plugin.Config[key] = val } else { - reg.hooksConfig.Logger.Debug().Str("key", key).Msg( + reg.hookRegistry.Logger.Debug().Str("key", key).Msg( "Failed to decode plugin config") } } - reg.hooksConfig.Logger.Trace().Msgf("Plugin metadata: %+v", plugin) + reg.hookRegistry.Logger.Trace().Msgf("Plugin metadata: %+v", plugin) reg.Add(plugin) - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin metadata loaded") + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin metadata loaded") reg.RegisterHooks(plugin.ID) - reg.hooksConfig.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin hooks registered") + reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin hooks registered") } } @@ -324,12 +324,12 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { //nolint:funlen func (reg *PluginRegistry) RegisterHooks(id Identifier) { pluginImpl := reg.Get(id) - reg.hooksConfig.Logger.Debug().Str("name", pluginImpl.ID.Name).Msg( + reg.hookRegistry.Logger.Debug().Str("name", pluginImpl.ID.Name).Msg( "Registering hooks for plugin") var pluginV1 pluginV1.GatewayDPluginServiceClient var err *gerr.GatewayDError if pluginV1, err = pluginImpl.Dispense(); err != nil { - reg.hooksConfig.Logger.Debug().Str("name", pluginImpl.ID.Name).Err(err).Msg( + reg.hookRegistry.Logger.Debug().Str("name", pluginImpl.ID.Name).Err(err).Msg( "Failed to dispense plugin") return } @@ -380,7 +380,7 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { case hook.OnNewClient: hookFunc = pluginV1.OnNewClient default: - reg.hooksConfig.Logger.Warn().Fields(map[string]interface{}{ + reg.hookRegistry.Logger.Warn().Fields(map[string]interface{}{ "hook": string(hookType), "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, @@ -388,11 +388,11 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { "Unknown hook, skipping") continue } - reg.hooksConfig.Logger.Debug().Fields(map[string]interface{}{ + reg.hookRegistry.Logger.Debug().Fields(map[string]interface{}{ "hook": string(hookType), "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, }).Msg("Registering hook") - reg.hooksConfig.Add(hookType, pluginImpl.Priority, hookFunc) + reg.hookRegistry.Add(hookType, pluginImpl.Priority, hookFunc) } } diff --git a/plugin/plugin_registry_test.go b/plugin/plugin_registry_test.go index c1aaa7b1..5697e291 100644 --- a/plugin/plugin_registry_test.go +++ b/plugin/plugin_registry_test.go @@ -9,12 +9,12 @@ import ( // TestPluginRegistry tests the PluginRegistry. func TestPluginRegistry(t *testing.T) { - hooksConfig := hook.NewHookConfig() - assert.NotNil(t, hooksConfig) - reg := NewRegistry(hooksConfig) + hookRegistry := hook.NewRegistry() + assert.NotNil(t, hookRegistry) + reg := NewRegistry(hookRegistry) assert.NotNil(t, reg) assert.NotNil(t, reg.plugins) - assert.NotNil(t, reg.hooksConfig) + assert.NotNil(t, reg.hookRegistry) assert.Equal(t, 0, len(reg.List())) ident := Identifier{ From 08b65d3d088ef70fc0324e1ae7694166624fc3b3 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 20:59:12 +0100 Subject: [PATCH 02/16] Rename files --- plugin/hook/{hooks.go => hook_registry.go} | 0 plugin/hook/{hooks_test.go => hook_registry_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename plugin/hook/{hooks.go => hook_registry.go} (100%) rename plugin/hook/{hooks_test.go => hook_registry_test.go} (100%) diff --git a/plugin/hook/hooks.go b/plugin/hook/hook_registry.go similarity index 100% rename from plugin/hook/hooks.go rename to plugin/hook/hook_registry.go diff --git a/plugin/hook/hooks_test.go b/plugin/hook/hook_registry_test.go similarity index 100% rename from plugin/hook/hooks_test.go rename to plugin/hook/hook_registry_test.go From 63198725fe6bb7950fbc2fb34b279a9b4e24e895 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:04:40 +0100 Subject: [PATCH 03/16] Rename FunctionType to Method --- plugin/hook/hook_registry.go | 12 ++++++------ plugin/hook/types.go | 6 +++--- plugin/plugin_registry.go | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/plugin/hook/hook_registry.go b/plugin/hook/hook_registry.go index 9b2971d8..9a580140 100644 --- a/plugin/hook/hook_registry.go +++ b/plugin/hook/hook_registry.go @@ -13,7 +13,7 @@ import ( ) type Registry struct { - hooks map[Type]map[Priority]FunctionType + hooks map[Type]map[Priority]Method Logger zerolog.Logger Verification config.Policy @@ -22,19 +22,19 @@ type Registry struct { // NewRegistry returns a new Config. func NewRegistry() *Registry { return &Registry{ - hooks: map[Type]map[Priority]FunctionType{}, + hooks: map[Type]map[Priority]Method{}, } } // Hooks returns the hooks. -func (h *Registry) Hooks() map[Type]map[Priority]FunctionType { +func (h *Registry) Hooks() map[Type]map[Priority]Method { return h.hooks } // Add adds a hook with a priority to the hooks map. -func (h *Registry) Add(hookType Type, prio Priority, hookFunc FunctionType) { +func (h *Registry) Add(hookType Type, prio Priority, hookFunc Method) { if len(h.hooks[hookType]) == 0 { - h.hooks[hookType] = map[Priority]FunctionType{prio: hookFunc} + h.hooks[hookType] = map[Priority]Method{prio: hookFunc} } else { if _, ok := h.hooks[hookType][prio]; ok { h.Logger.Warn().Fields( @@ -49,7 +49,7 @@ func (h *Registry) Add(hookType Type, prio Priority, hookFunc FunctionType) { } // Get returns the hooks of a specific type. -func (h *Registry) Get(hookType Type) map[Priority]FunctionType { +func (h *Registry) Get(hookType Type) map[Priority]Method { return h.hooks[hookType] } diff --git a/plugin/hook/types.go b/plugin/hook/types.go index c32e0467..d0135470 100644 --- a/plugin/hook/types.go +++ b/plugin/hook/types.go @@ -10,8 +10,8 @@ import ( type ( // Priority is the priority of a hook. // Smaller values are executed first (higher priority). - Priority uint - Type string - FunctionType func( + Priority uint + Type string + Method func( context.Context, *structpb.Struct, ...grpc.CallOption) (*structpb.Struct, error) ) diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index b0c55ebe..ea5424b5 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -335,7 +335,7 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { } for _, hookType := range pluginImpl.Hooks { - var hookFunc hook.FunctionType + var hookFunc hook.Method switch hookType { case hook.OnConfigLoaded: hookFunc = pluginV1.OnConfigLoaded From 5b7288842527aa1751922e9f9e8c5800fad7e908 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:06:28 +0100 Subject: [PATCH 04/16] Add a new interface for hook registry --- plugin/hook/hook_registry.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/plugin/hook/hook_registry.go b/plugin/hook/hook_registry.go index 9a580140..f61a3cfc 100644 --- a/plugin/hook/hook_registry.go +++ b/plugin/hook/hook_registry.go @@ -12,6 +12,19 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) +type IRegistry interface { + Hooks() map[Type]map[Priority]Method + Add(hookType Type, prio Priority, hookFunc Method) + Get(hookType Type) map[Priority]Method + Run( + ctx context.Context, + args map[string]interface{}, + hookType Type, + verification config.Policy, + opts ...grpc.CallOption, + ) (map[string]interface{}, *gerr.GatewayDError) +} + type Registry struct { hooks map[Type]map[Priority]Method @@ -19,6 +32,8 @@ type Registry struct { Verification config.Policy } +var _ IRegistry = &Registry{} + // NewRegistry returns a new Config. func NewRegistry() *Registry { return &Registry{ From 30599895c223d3b6610c2451f698a8825e3f4e3e Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:11:08 +0100 Subject: [PATCH 05/16] Remove hook.Type and replace it with plain string --- plugin/hook/constants.go | 42 ++++++++++++++++++------------------ plugin/hook/hook_registry.go | 20 ++++++++--------- plugin/hook/types.go | 1 - plugin/plugin.go | 2 +- 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/plugin/hook/constants.go b/plugin/hook/constants.go index a3f8ccdf..8f22ff85 100644 --- a/plugin/hook/constants.go +++ b/plugin/hook/constants.go @@ -2,27 +2,27 @@ package hook const ( // Run command hooks (cmd/run.go). - OnConfigLoaded Type = "onConfigLoaded" - OnNewLogger Type = "onNewLogger" - OnNewPool Type = "onNewPool" - OnNewClient Type = "onNewClient" - OnNewProxy Type = "onNewProxy" - OnNewServer Type = "onNewServer" - OnSignal Type = "onSignal" + OnConfigLoaded string = "onConfigLoaded" + OnNewLogger string = "onNewLogger" + OnNewPool string = "onNewPool" + OnNewClient string = "onNewClient" + OnNewProxy string = "onNewProxy" + OnNewServer string = "onNewServer" + OnSignal string = "onSignal" // Server hooks (network/server.go). - OnRun Type = "onRun" - OnBooting Type = "onBooting" - OnBooted Type = "onBooted" - OnOpening Type = "onOpening" - OnOpened Type = "onOpened" - OnClosing Type = "onClosing" - OnClosed Type = "onClosed" - OnTraffic Type = "onTraffic" - OnShutdown Type = "onShutdown" - OnTick Type = "onTick" + OnRun string = "onRun" + OnBooting string = "onBooting" + OnBooted string = "onBooted" + OnOpening string = "onOpening" + OnOpened string = "onOpened" + OnClosing string = "onClosing" + OnClosed string = "onClosed" + OnTraffic string = "onTraffic" + OnShutdown string = "onShutdown" + OnTick string = "onTick" // Proxy hooks (network/proxy.go). - OnTrafficFromClient Type = "onTrafficFromClient" - OnTrafficToServer Type = "onTrafficToServer" - OnTrafficFromServer Type = "onTrafficFromServer" - OnTrafficToClient Type = "onTrafficToClient" + OnTrafficFromClient string = "onTrafficFromClient" + OnTrafficToServer string = "onTrafficToServer" + OnTrafficFromServer string = "onTrafficFromServer" + OnTrafficToClient string = "onTrafficToClient" ) diff --git a/plugin/hook/hook_registry.go b/plugin/hook/hook_registry.go index f61a3cfc..42c715f1 100644 --- a/plugin/hook/hook_registry.go +++ b/plugin/hook/hook_registry.go @@ -13,20 +13,20 @@ import ( ) type IRegistry interface { - Hooks() map[Type]map[Priority]Method - Add(hookType Type, prio Priority, hookFunc Method) - Get(hookType Type) map[Priority]Method + Hooks() map[string]map[Priority]Method + Add(hookType string, prio Priority, hookFunc Method) + Get(hookType string) map[Priority]Method Run( ctx context.Context, args map[string]interface{}, - hookType Type, + hookType string, verification config.Policy, opts ...grpc.CallOption, ) (map[string]interface{}, *gerr.GatewayDError) } type Registry struct { - hooks map[Type]map[Priority]Method + hooks map[string]map[Priority]Method Logger zerolog.Logger Verification config.Policy @@ -37,17 +37,17 @@ var _ IRegistry = &Registry{} // NewRegistry returns a new Config. func NewRegistry() *Registry { return &Registry{ - hooks: map[Type]map[Priority]Method{}, + hooks: map[string]map[Priority]Method{}, } } // Hooks returns the hooks. -func (h *Registry) Hooks() map[Type]map[Priority]Method { +func (h *Registry) Hooks() map[string]map[Priority]Method { return h.hooks } // Add adds a hook with a priority to the hooks map. -func (h *Registry) Add(hookType Type, prio Priority, hookFunc Method) { +func (h *Registry) Add(hookType string, prio Priority, hookFunc Method) { if len(h.hooks[hookType]) == 0 { h.hooks[hookType] = map[Priority]Method{prio: hookFunc} } else { @@ -64,7 +64,7 @@ func (h *Registry) Add(hookType Type, prio Priority, hookFunc Method) { } // Get returns the hooks of a specific type. -func (h *Registry) Get(hookType Type) map[Priority]Method { +func (h *Registry) Get(hookType string) map[Priority]Method { return h.hooks[hookType] } @@ -85,7 +85,7 @@ func (h *Registry) Get(hookType Type) map[Priority]Method { func (h *Registry) Run( ctx context.Context, args map[string]interface{}, - hookType Type, + hookType string, verification config.Policy, opts ...grpc.CallOption, ) (map[string]interface{}, *gerr.GatewayDError) { diff --git a/plugin/hook/types.go b/plugin/hook/types.go index d0135470..d3c98bb6 100644 --- a/plugin/hook/types.go +++ b/plugin/hook/types.go @@ -11,7 +11,6 @@ type ( // Priority is the priority of a hook. // Smaller values are executed first (higher priority). Priority uint - Type string Method func( context.Context, *structpb.Struct, ...grpc.CallOption) (*structpb.Struct, error) ) diff --git a/plugin/plugin.go b/plugin/plugin.go index 2ae76915..31caefdf 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -40,7 +40,7 @@ type Plugin struct { // internal and external config options Config map[string]string // hooks it attaches to - Hooks []hook.Type + Hooks []string Priority hook.Priority // required plugins to be loaded before this one // Built-in plugins are always loaded first From 4d7eeb1359a009b9cff92c78a1d72777dca0c0f4 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:13:35 +0100 Subject: [PATCH 06/16] Remove unnecessary conversion --- plugin/plugin_registry.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index ea5424b5..77b10a72 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -381,7 +381,7 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { hookFunc = pluginV1.OnNewClient default: reg.hookRegistry.Logger.Warn().Fields(map[string]interface{}{ - "hook": string(hookType), + "hook": hookType, "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, }).Msg( @@ -389,7 +389,7 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { continue } reg.hookRegistry.Logger.Debug().Fields(map[string]interface{}{ - "hook": string(hookType), + "hook": hookType, "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, }).Msg("Registering hook") From d027aafc1d3e20322d8f866942a2d3f54722285a Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:14:52 +0100 Subject: [PATCH 07/16] Rename hookType to hookName --- plugin/hook/hook_registry.go | 40 ++++++++++++++++++------------------ plugin/plugin_registry.go | 10 ++++----- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/plugin/hook/hook_registry.go b/plugin/hook/hook_registry.go index 42c715f1..fbe0cac5 100644 --- a/plugin/hook/hook_registry.go +++ b/plugin/hook/hook_registry.go @@ -14,12 +14,12 @@ import ( type IRegistry interface { Hooks() map[string]map[Priority]Method - Add(hookType string, prio Priority, hookFunc Method) - Get(hookType string) map[Priority]Method + Add(hookName string, prio Priority, hookFunc Method) + Get(hookName string) map[Priority]Method Run( ctx context.Context, args map[string]interface{}, - hookType string, + hookName string, verification config.Policy, opts ...grpc.CallOption, ) (map[string]interface{}, *gerr.GatewayDError) @@ -47,25 +47,25 @@ func (h *Registry) Hooks() map[string]map[Priority]Method { } // Add adds a hook with a priority to the hooks map. -func (h *Registry) Add(hookType string, prio Priority, hookFunc Method) { - if len(h.hooks[hookType]) == 0 { - h.hooks[hookType] = map[Priority]Method{prio: hookFunc} +func (h *Registry) Add(hookName string, prio Priority, hookFunc Method) { + if len(h.hooks[hookName]) == 0 { + h.hooks[hookName] = map[Priority]Method{prio: hookFunc} } else { - if _, ok := h.hooks[hookType][prio]; ok { + if _, ok := h.hooks[hookName][prio]; ok { h.Logger.Warn().Fields( map[string]interface{}{ - "hookType": hookType, + "hookName": hookName, "priority": prio, }, ).Msg("Hook is replaced") } - h.hooks[hookType][prio] = hookFunc + h.hooks[hookName][prio] = hookFunc } } // Get returns the hooks of a specific type. -func (h *Registry) Get(hookType string) map[Priority]Method { - return h.hooks[hookType] +func (h *Registry) Get(hookName string) map[Priority]Method { + return h.hooks[hookName] } // Run runs the hooks of a specific type. The result of the previous hook is passed @@ -85,7 +85,7 @@ func (h *Registry) Get(hookType string) map[Priority]Method { func (h *Registry) Run( ctx context.Context, args map[string]interface{}, - hookType string, + hookName string, verification config.Policy, opts ...grpc.CallOption, ) (map[string]interface{}, *gerr.GatewayDError) { @@ -111,8 +111,8 @@ func (h *Registry) Run( } // Sort hooks by priority. - priorities := make([]Priority, 0, len(h.hooks[hookType])) - for prio := range h.hooks[hookType] { + priorities := make([]Priority, 0, len(h.hooks[hookName])) + for prio := range h.hooks[hookName] { priorities = append(priorities, prio) } sort.SliceStable(priorities, func(i, j int) bool { @@ -127,9 +127,9 @@ func (h *Registry) Run( var result *structpb.Struct var err error if idx == 0 { - result, err = h.hooks[hookType][prio](inheritedCtx, params, opts...) + result, err = h.hooks[hookName][prio](inheritedCtx, params, opts...) } else { - result, err = h.hooks[hookType][prio](inheritedCtx, returnVal, opts...) + result, err = h.hooks[hookName][prio](inheritedCtx, returnVal, opts...) } // This is done to ensure that the return value of the hook is always valid, @@ -149,7 +149,7 @@ func (h *Registry) Run( case config.Ignore: h.Logger.Error().Err(err).Fields( map[string]interface{}{ - "hookType": hookType, + "hookName": hookName, "priority": prio, }, ).Msg("Hook returned invalid value, ignoring") @@ -160,7 +160,7 @@ func (h *Registry) Run( case config.Abort: h.Logger.Error().Err(err).Fields( map[string]interface{}{ - "hookType": hookType, + "hookName": hookName, "priority": prio, }, ).Msg("Hook returned invalid value, aborting") @@ -172,7 +172,7 @@ func (h *Registry) Run( case config.Remove: h.Logger.Error().Err(err).Fields( map[string]interface{}{ - "hookType": hookType, + "hookName": hookName, "priority": prio, }, ).Msg("Hook returned invalid value, removing") @@ -188,7 +188,7 @@ func (h *Registry) Run( // Remove hooks that failed verification. for _, prio := range removeList { - delete(h.hooks[hookType], prio) + delete(h.hooks[hookName], prio) } return returnVal.AsMap(), nil diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index 77b10a72..aacd75bd 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -334,9 +334,9 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { return } - for _, hookType := range pluginImpl.Hooks { + for _, hookName := range pluginImpl.Hooks { var hookFunc hook.Method - switch hookType { + switch hookName { case hook.OnConfigLoaded: hookFunc = pluginV1.OnConfigLoaded case hook.OnNewLogger: @@ -381,7 +381,7 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { hookFunc = pluginV1.OnNewClient default: reg.hookRegistry.Logger.Warn().Fields(map[string]interface{}{ - "hook": hookType, + "hook": hookName, "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, }).Msg( @@ -389,10 +389,10 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { continue } reg.hookRegistry.Logger.Debug().Fields(map[string]interface{}{ - "hook": hookType, + "hook": hookName, "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, }).Msg("Registering hook") - reg.hookRegistry.Add(hookType, pluginImpl.Priority, hookFunc) + reg.hookRegistry.Add(hookName, pluginImpl.Priority, hookFunc) } } From 733a019157b1281abb2e75ad45d5b2fd8b746ab4 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:15:24 +0100 Subject: [PATCH 08/16] Rename prio to priority --- plugin/hook/hook_registry.go | 34 +++++++++++++++---------------- plugin/hook/hook_registry_test.go | 10 ++++----- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/plugin/hook/hook_registry.go b/plugin/hook/hook_registry.go index fbe0cac5..bff75014 100644 --- a/plugin/hook/hook_registry.go +++ b/plugin/hook/hook_registry.go @@ -14,7 +14,7 @@ import ( type IRegistry interface { Hooks() map[string]map[Priority]Method - Add(hookName string, prio Priority, hookFunc Method) + Add(hookName string, priority Priority, hookFunc Method) Get(hookName string) map[Priority]Method Run( ctx context.Context, @@ -47,19 +47,19 @@ func (h *Registry) Hooks() map[string]map[Priority]Method { } // Add adds a hook with a priority to the hooks map. -func (h *Registry) Add(hookName string, prio Priority, hookFunc Method) { +func (h *Registry) Add(hookName string, priority Priority, hookFunc Method) { if len(h.hooks[hookName]) == 0 { - h.hooks[hookName] = map[Priority]Method{prio: hookFunc} + h.hooks[hookName] = map[Priority]Method{priority: hookFunc} } else { - if _, ok := h.hooks[hookName][prio]; ok { + if _, ok := h.hooks[hookName][priority]; ok { h.Logger.Warn().Fields( map[string]interface{}{ "hookName": hookName, - "priority": prio, + "priority": priority, }, ).Msg("Hook is replaced") } - h.hooks[hookName][prio] = hookFunc + h.hooks[hookName][priority] = hookFunc } } @@ -112,8 +112,8 @@ func (h *Registry) Run( // Sort hooks by priority. priorities := make([]Priority, 0, len(h.hooks[hookName])) - for prio := range h.hooks[hookName] { - priorities = append(priorities, prio) + for priority := range h.hooks[hookName] { + priorities = append(priorities, priority) } sort.SliceStable(priorities, func(i, j int) bool { return priorities[i] < priorities[j] @@ -123,13 +123,13 @@ func (h *Registry) Run( returnVal := &structpb.Struct{} var removeList []Priority // The signature of parameters and args MUST be the same for this to work. - for idx, prio := range priorities { + for idx, priority := range priorities { var result *structpb.Struct var err error if idx == 0 { - result, err = h.hooks[hookName][prio](inheritedCtx, params, opts...) + result, err = h.hooks[hookName][priority](inheritedCtx, params, opts...) } else { - result, err = h.hooks[hookName][prio](inheritedCtx, returnVal, opts...) + result, err = h.hooks[hookName][priority](inheritedCtx, returnVal, opts...) } // This is done to ensure that the return value of the hook is always valid, @@ -150,7 +150,7 @@ func (h *Registry) Run( h.Logger.Error().Err(err).Fields( map[string]interface{}{ "hookName": hookName, - "priority": prio, + "priority": priority, }, ).Msg("Hook returned invalid value, ignoring") if idx == 0 { @@ -161,7 +161,7 @@ func (h *Registry) Run( h.Logger.Error().Err(err).Fields( map[string]interface{}{ "hookName": hookName, - "priority": prio, + "priority": priority, }, ).Msg("Hook returned invalid value, aborting") if idx == 0 { @@ -173,10 +173,10 @@ func (h *Registry) Run( h.Logger.Error().Err(err).Fields( map[string]interface{}{ "hookName": hookName, - "priority": prio, + "priority": priority, }, ).Msg("Hook returned invalid value, removing") - removeList = append(removeList, prio) + removeList = append(removeList, priority) if idx == 0 { returnVal = params } @@ -187,8 +187,8 @@ func (h *Registry) Run( } // Remove hooks that failed verification. - for _, prio := range removeList { - delete(h.hooks[hookName], prio) + for _, priority := range removeList { + delete(h.hooks[hookName], priority) } return returnVal.AsMap(), nil diff --git a/plugin/hook/hook_registry_test.go b/plugin/hook/hook_registry_test.go index 433385f2..0e250ec0 100644 --- a/plugin/hook/hook_registry_test.go +++ b/plugin/hook/hook_registry_test.go @@ -62,10 +62,10 @@ func Test_HookConfig_Get(t *testing.T) { ) (*structpb.Struct, error) { return args, nil } - prio := Priority(0) - hooks.Add(OnNewLogger, prio, testFunc) + priority := Priority(0) + hooks.Add(OnNewLogger, priority, testFunc) assert.NotNil(t, hooks.Get(OnNewLogger)) - assert.ObjectsAreEqual(testFunc, hooks.Get(OnNewLogger)[prio]) + assert.ObjectsAreEqual(testFunc, hooks.Get(OnNewLogger)[priority]) } // Test_HookConfig_Run tests the Run function. @@ -109,7 +109,7 @@ func Test_HookConfig_Run_PassDown(t *testing.T) { }) // Although the first hook returns nil, and its signature doesn't match the params, - // so its result (nil) is passed down to the next hook in chain (prio 2). + // so its result (nil) is passed down to the next hook in chain (priority 2). // Then the second hook runs and returns a signature with a "test" key and value. result, err := hooks.Run( context.Background(), @@ -150,7 +150,7 @@ func Test_HookConfig_Run_PassDown_2(t *testing.T) { return args, nil }) // Although the first hook returns nil, and its signature doesn't match the params, - // so its result (nil) is passed down to the next hook in chain (prio 2). + // so its result (nil) is passed down to the next hook in chain (priority 2). // Then the second hook runs and returns a signature with a "test1" and "test2" key and value. result, err := hooks.Run( context.Background(), From 342e8ce88c86ce9290fb90e477a2ac18e8f9b5d2 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:16:20 +0100 Subject: [PATCH 09/16] Rename hookFunc to hookMethod --- plugin/hook/hook_registry.go | 8 +++---- plugin/plugin_registry.go | 46 ++++++++++++++++++------------------ 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/plugin/hook/hook_registry.go b/plugin/hook/hook_registry.go index bff75014..dda99991 100644 --- a/plugin/hook/hook_registry.go +++ b/plugin/hook/hook_registry.go @@ -14,7 +14,7 @@ import ( type IRegistry interface { Hooks() map[string]map[Priority]Method - Add(hookName string, priority Priority, hookFunc Method) + Add(hookName string, priority Priority, hookMethod Method) Get(hookName string) map[Priority]Method Run( ctx context.Context, @@ -47,9 +47,9 @@ func (h *Registry) Hooks() map[string]map[Priority]Method { } // Add adds a hook with a priority to the hooks map. -func (h *Registry) Add(hookName string, priority Priority, hookFunc Method) { +func (h *Registry) Add(hookName string, priority Priority, hookMethod Method) { if len(h.hooks[hookName]) == 0 { - h.hooks[hookName] = map[Priority]Method{priority: hookFunc} + h.hooks[hookName] = map[Priority]Method{priority: hookMethod} } else { if _, ok := h.hooks[hookName][priority]; ok { h.Logger.Warn().Fields( @@ -59,7 +59,7 @@ func (h *Registry) Add(hookName string, priority Priority, hookFunc Method) { }, ).Msg("Hook is replaced") } - h.hooks[hookName][priority] = hookFunc + h.hooks[hookName][priority] = hookMethod } } diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index aacd75bd..cf97bdd5 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -335,50 +335,50 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { } for _, hookName := range pluginImpl.Hooks { - var hookFunc hook.Method + var hookMethod hook.Method switch hookName { case hook.OnConfigLoaded: - hookFunc = pluginV1.OnConfigLoaded + hookMethod = pluginV1.OnConfigLoaded case hook.OnNewLogger: - hookFunc = pluginV1.OnNewLogger + hookMethod = pluginV1.OnNewLogger case hook.OnNewPool: - hookFunc = pluginV1.OnNewPool + hookMethod = pluginV1.OnNewPool case hook.OnNewProxy: - hookFunc = pluginV1.OnNewProxy + hookMethod = pluginV1.OnNewProxy case hook.OnNewServer: - hookFunc = pluginV1.OnNewServer + hookMethod = pluginV1.OnNewServer case hook.OnSignal: - hookFunc = pluginV1.OnSignal + hookMethod = pluginV1.OnSignal case hook.OnRun: - hookFunc = pluginV1.OnRun + hookMethod = pluginV1.OnRun case hook.OnBooting: - hookFunc = pluginV1.OnBooting + hookMethod = pluginV1.OnBooting case hook.OnBooted: - hookFunc = pluginV1.OnBooted + hookMethod = pluginV1.OnBooted case hook.OnOpening: - hookFunc = pluginV1.OnOpening + hookMethod = pluginV1.OnOpening case hook.OnOpened: - hookFunc = pluginV1.OnOpened + hookMethod = pluginV1.OnOpened case hook.OnClosing: - hookFunc = pluginV1.OnClosing + hookMethod = pluginV1.OnClosing case hook.OnClosed: - hookFunc = pluginV1.OnClosed + hookMethod = pluginV1.OnClosed case hook.OnTraffic: - hookFunc = pluginV1.OnTraffic + hookMethod = pluginV1.OnTraffic case hook.OnTrafficFromClient: - hookFunc = pluginV1.OnTrafficFromClient + hookMethod = pluginV1.OnTrafficFromClient case hook.OnTrafficToServer: - hookFunc = pluginV1.OnTrafficToServer + hookMethod = pluginV1.OnTrafficToServer case hook.OnTrafficFromServer: - hookFunc = pluginV1.OnTrafficFromServer + hookMethod = pluginV1.OnTrafficFromServer case hook.OnTrafficToClient: - hookFunc = pluginV1.OnTrafficToClient + hookMethod = pluginV1.OnTrafficToClient case hook.OnShutdown: - hookFunc = pluginV1.OnShutdown + hookMethod = pluginV1.OnShutdown case hook.OnTick: - hookFunc = pluginV1.OnTick + hookMethod = pluginV1.OnTick case hook.OnNewClient: - hookFunc = pluginV1.OnNewClient + hookMethod = pluginV1.OnNewClient default: reg.hookRegistry.Logger.Warn().Fields(map[string]interface{}{ "hook": hookName, @@ -393,6 +393,6 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, }).Msg("Registering hook") - reg.hookRegistry.Add(hookName, pluginImpl.Priority, hookFunc) + reg.hookRegistry.Add(hookName, pluginImpl.Priority, hookMethod) } } From 644c857c593eb2bd5c0ad45ab5eb4467eca6f591 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:21:48 +0100 Subject: [PATCH 10/16] Use a constant for exit code --- cmd/run.go | 2 +- errors/errors.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/run.go b/cmd/run.go index 58301938..505abde0 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -204,7 +204,7 @@ var runCmd = &cobra.Command{ "the clients cannot connect due to no network connectivity " + "or the server is not running. exiting...") pluginRegistry.Shutdown() - os.Exit(1) + os.Exit(gerr.FailedToInitializePool) } _, err = hookRegistry.Run( diff --git a/errors/errors.go b/errors/errors.go index a84fdb67..2fdf3a8d 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -92,4 +92,5 @@ var ( const ( FailedToLoadPluginConfig = 1 FailedToLoadGlobalConfig = 2 + FailedToInitializePool = 3 ) From 7464754f0758e8c7bbc9daa51baf3675c80f1327 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Fri, 13 Jan 2023 23:29:09 +0100 Subject: [PATCH 11/16] Rename hookconfig to hookregistry --- network/proxy.go | 22 ++++++++--------- plugin/hook/hook_registry_test.go | 40 +++++++++++++++---------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/network/proxy.go b/network/proxy.go index f6d7e9a7..bc89c5cb 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -26,7 +26,7 @@ type Proxy struct { availableConnections pool.IPool busyConnections pool.IPool logger zerolog.Logger - hookConfig *hook.Registry + hookRegistry *hook.Registry scheduler *gocron.Scheduler Elastic bool @@ -41,7 +41,7 @@ var _ IProxy = &Proxy{} // NewProxy creates a new proxy. func NewProxy( - connPool pool.IPool, hookConfig *hook.Registry, + connPool pool.IPool, hookRegistry *hook.Registry, elastic, reuseElasticClients bool, healthCheckPeriod time.Duration, clientConfig *config.Client, logger zerolog.Logger, @@ -50,7 +50,7 @@ func NewProxy( availableConnections: connPool, busyConnections: pool.NewPool(config.EmptyPoolCapacity), logger: logger, - hookConfig: hookConfig, + hookRegistry: hookRegistry, scheduler: gocron.NewScheduler(time.UTC), Elastic: elastic, ReuseElasticClients: reuseElasticClients, @@ -363,7 +363,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { request, origErr := receiveTrafficFromClient() // Run the OnTrafficFromClient hooks. - result, err := pr.hookConfig.Run( + result, err := pr.hookRegistry.Run( context.Background(), trafficData( gconn, @@ -376,7 +376,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { }, origErr), hook.OnTrafficFromClient, - pr.hookConfig.Verification) + pr.hookRegistry.Verification) if err != nil { pr.logger.Error().Err(err).Msg("Error running hook") } @@ -396,7 +396,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { _, err = sendTrafficToServer(request) // Run the OnTrafficToServer hooks. - _, err = pr.hookConfig.Run( + _, err = pr.hookRegistry.Run( context.Background(), trafficData( gconn, @@ -409,7 +409,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { }, err), hook.OnTrafficToServer, - pr.hookConfig.Verification) + pr.hookRegistry.Verification) if err != nil { pr.logger.Error().Err(err).Msg("Error running hook") } @@ -448,7 +448,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } // Run the OnTrafficFromServer hooks. - result, err = pr.hookConfig.Run( + result, err = pr.hookRegistry.Run( context.Background(), trafficData( gconn, @@ -465,7 +465,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { }, err), hook.OnTrafficFromServer, - pr.hookConfig.Verification) + pr.hookRegistry.Verification) if err != nil { pr.logger.Error().Err(err).Msg("Error running hook") } @@ -481,7 +481,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { errVerdict := sendTrafficToClient(response, received) // Run the OnTrafficToClient hooks. - _, err = pr.hookConfig.Run( + _, err = pr.hookRegistry.Run( context.Background(), trafficData( gconn, @@ -499,7 +499,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { err, ), hook.OnTrafficToClient, - pr.hookConfig.Verification) + pr.hookRegistry.Verification) if err != nil { pr.logger.Error().Err(err).Msg("Error running hook") } diff --git a/plugin/hook/hook_registry_test.go b/plugin/hook/hook_registry_test.go index 0e250ec0..429a5f92 100644 --- a/plugin/hook/hook_registry_test.go +++ b/plugin/hook/hook_registry_test.go @@ -10,14 +10,14 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -// Test_NewHookConfig tests the NewHookConfig function. -func Test_NewHookConfig(t *testing.T) { +// Test_NewHookRegistry tests the NewHookRegistry function. +func Test_NewHookRegistry(t *testing.T) { hc := NewRegistry() assert.NotNil(t, hc) } -// Test_HookConfig_Add tests the Add function. -func Test_HookConfig_Add(t *testing.T) { +// Test_HookRegistry_Add tests the Add function. +func Test_HookRegistry_Add(t *testing.T) { hooks := NewRegistry() testFunc := func( ctx context.Context, @@ -31,8 +31,8 @@ func Test_HookConfig_Add(t *testing.T) { assert.ObjectsAreEqual(testFunc, hooks.Hooks()[OnNewLogger][0]) } -// Test_HookConfig_Add_Multiple_Hooks tests the Add function with multiple hooks. -func Test_HookConfig_Add_Multiple_Hooks(t *testing.T) { +// Test_HookRegistry_Add_Multiple_Hooks tests the Add function with multiple hooks. +func Test_HookRegistry_Add_Multiple_Hooks(t *testing.T) { hooks := NewRegistry() hooks.Add(OnNewLogger, 0, func( ctx context.Context, @@ -52,8 +52,8 @@ func Test_HookConfig_Add_Multiple_Hooks(t *testing.T) { assert.NotNil(t, hooks.Hooks()[OnNewLogger][1]) } -// Test_HookConfig_Get tests the Get function. -func Test_HookConfig_Get(t *testing.T) { +// Test_HookRegistry_Get tests the Get function. +func Test_HookRegistry_Get(t *testing.T) { hooks := NewRegistry() testFunc := func( ctx context.Context, @@ -68,8 +68,8 @@ func Test_HookConfig_Get(t *testing.T) { assert.ObjectsAreEqual(testFunc, hooks.Get(OnNewLogger)[priority]) } -// Test_HookConfig_Run tests the Run function. -func Test_HookConfig_Run(t *testing.T) { +// Test_HookRegistry_Run tests the Run function. +func Test_HookRegistry_Run(t *testing.T) { hooks := NewRegistry() hooks.Add(OnNewLogger, 0, func( ctx context.Context, @@ -84,8 +84,8 @@ func Test_HookConfig_Run(t *testing.T) { assert.Nil(t, err) } -// Test_HookConfig_Run_PassDown tests the Run function with the PassDown option. -func Test_HookConfig_Run_PassDown(t *testing.T) { +// Test_HookRegistry_Run_PassDown tests the Run function with the PassDown option. +func Test_HookRegistry_Run_PassDown(t *testing.T) { hooks := NewRegistry() // The result of the hook will be nil and will be passed down to the next hooks.Add(OnNewLogger, 0, func( @@ -120,8 +120,8 @@ func Test_HookConfig_Run_PassDown(t *testing.T) { assert.NotNil(t, result) } -// Test_HookConfig_Run_PassDown_2 tests the Run function with the PassDown option. -func Test_HookConfig_Run_PassDown_2(t *testing.T) { +// Test_HookRegistry_Run_PassDown_2 tests the Run function with the PassDown option. +func Test_HookRegistry_Run_PassDown_2(t *testing.T) { hooks := NewRegistry() // The result of the hook will be nil and will be passed down to the next hooks.Add(OnNewLogger, 0, func( @@ -161,8 +161,8 @@ func Test_HookConfig_Run_PassDown_2(t *testing.T) { assert.NotNil(t, result) } -// Test_HookConfig_Run_Ignore tests the Run function with the Ignore option. -func Test_HookConfig_Run_Ignore(t *testing.T) { +// Test_HookRegistry_Run_Ignore tests the Run function with the Ignore option. +func Test_HookRegistry_Run_Ignore(t *testing.T) { hooks := NewRegistry() // This should not run, because the return value is not the same as the params hooks.Add(OnNewLogger, 0, func( @@ -197,8 +197,8 @@ func Test_HookConfig_Run_Ignore(t *testing.T) { assert.NotNil(t, result) } -// Test_HookConfig_Run_Abort tests the Run function with the Abort option. -func Test_HookConfig_Run_Abort(t *testing.T) { +// Test_HookRegistry_Run_Abort tests the Run function with the Abort option. +func Test_HookRegistry_Run_Abort(t *testing.T) { hooks := NewRegistry() // This should not run, because the return value is not the same as the params hooks.Add(OnNewLogger, 0, func( @@ -227,8 +227,8 @@ func Test_HookConfig_Run_Abort(t *testing.T) { assert.Equal(t, map[string]interface{}{}, result) } -// Test_HookConfig_Run_Remove tests the Run function with the Remove option. -func Test_HookConfig_Run_Remove(t *testing.T) { +// Test_HookRegistry_Run_Remove tests the Run function with the Remove option. +func Test_HookRegistry_Run_Remove(t *testing.T) { hooks := NewRegistry() // This should not run, because the return value is not the same as the params hooks.Add(OnNewLogger, 0, func( From 7d2a5a6e4d4ca322a3ee04b52e7681839f75f64d Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Sat, 14 Jan 2023 00:17:16 +0100 Subject: [PATCH 12/16] Merge the hook package with the plugin package --- cmd/run.go | 59 +++--- network/proxy.go | 32 ++-- network/proxy_test.go | 20 ++- network/server.go | 62 +++---- network/server_test.go | 12 +- plugin/{hook => }/constants.go | 2 +- plugin/hook/hook_registry.go | 195 -------------------- plugin/hook/hook_registry_test.go | 261 --------------------------- plugin/plugin.go | 3 +- plugin/plugin_registry.go | 289 ++++++++++++++++++++++++------ plugin/plugin_registry_test.go | 256 +++++++++++++++++++++++++- plugin/{hook => }/types.go | 2 +- 12 files changed, 579 insertions(+), 614 deletions(-) rename plugin/{hook => }/constants.go (98%) delete mode 100644 plugin/hook/hook_registry.go delete mode 100644 plugin/hook/hook_registry_test.go rename plugin/{hook => }/types.go (95%) diff --git a/cmd/run.go b/cmd/run.go index 505abde0..e43913e1 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -12,7 +12,6 @@ import ( "github.com/gatewayd-io/gatewayd/logging" "github.com/gatewayd-io/gatewayd/network" "github.com/gatewayd-io/gatewayd/plugin" - "github.com/gatewayd-io/gatewayd/plugin/hook" "github.com/gatewayd-io/gatewayd/pool" "github.com/knadh/koanf" "github.com/knadh/koanf/parsers/yaml" @@ -24,14 +23,14 @@ import ( ) var ( - hookRegistry = hook.NewRegistry() DefaultLogger = logging.NewLogger( logging.LoggerConfig{ Level: zerolog.InfoLevel, // Default log level NoColor: true, }, ) - pluginRegistry = plugin.NewRegistry(hookRegistry) + // The plugins are loaded and hooks registered before the configuration is loaded. + pluginRegistry = plugin.NewRegistry(config.Loose, config.PassDown, DefaultLogger) // Global koanf instance. Using "." as the key path delimiter. globalConfig = koanf.New(".") // Plugin koanf instance. Using "." as the key path delimiter. @@ -43,10 +42,6 @@ var runCmd = &cobra.Command{ Use: "run", Short: "Run a gatewayd instance", Run: func(cmd *cobra.Command, args []string) { - // The plugins are loaded and hooks registered - // before the configuration is loaded. - hookRegistry.Logger = DefaultLogger - // Load default plugin configuration. config.LoadPluginConfigDefaults(pluginConfig) @@ -90,7 +85,7 @@ var runCmd = &cobra.Command{ config.LoadEnvVars(globalConfig) // Get hooks signature verification policy. - hookRegistry.Verification = pConfig.GetVerificationPolicy() + pluginRegistry.Verification = pConfig.GetVerificationPolicy() // Unmarshal the global configuration for easier access. var gConfig config.GlobalConfig @@ -100,13 +95,13 @@ var runCmd = &cobra.Command{ os.Exit(gerr.FailedToLoadGlobalConfig) } - // The config will be passed to the plugins that register to the "OnConfigLoaded" hook. + // The config will be passed to the plugins that register to the "OnConfigLoaded" plugin. // The plugins can modify the config and return it. - updatedGlobalConfig, err := hookRegistry.Run( + updatedGlobalConfig, err := pluginRegistry.Run( context.Background(), globalConfig.All(), - hook.OnConfigLoaded, - hookRegistry.Verification) + plugin.OnConfigLoaded, + pluginRegistry.Verification) if err != nil { DefaultLogger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks") } @@ -139,7 +134,7 @@ var runCmd = &cobra.Command{ }) // Replace the default logger with the new one from the config. - hookRegistry.Logger = logger + pluginRegistry.Logger = logger // This is a notification hook, so we don't care about the result. data := map[string]interface{}{ @@ -150,8 +145,8 @@ var runCmd = &cobra.Command{ "fileName": loggerCfg.FileName, } // TODO: Use a context with a timeout - _, err = hookRegistry.Run( - context.Background(), data, hook.OnNewLogger, hookRegistry.Verification) + _, err = pluginRegistry.Run( + context.Background(), data, plugin.OnNewLogger, pluginRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks") } @@ -179,11 +174,11 @@ var runCmd = &cobra.Command{ "tcpKeepAlive": client.TCPKeepAlive, "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), } - _, err := hookRegistry.Run( + _, err := pluginRegistry.Run( context.Background(), clientCfg, - hook.OnNewClient, - hookRegistry.Verification) + plugin.OnNewClient, + pluginRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") } @@ -207,11 +202,11 @@ var runCmd = &cobra.Command{ os.Exit(gerr.FailedToInitializePool) } - _, err = hookRegistry.Run( + _, err = pluginRegistry.Run( context.Background(), map[string]interface{}{"size": poolSize}, - hook.OnNewPool, - hookRegistry.Verification) + plugin.OnNewPool, + pluginRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") } @@ -222,7 +217,7 @@ var runCmd = &cobra.Command{ healthCheckPeriod := gConfig.Proxy[config.Default].HealthCheckPeriod proxy := network.NewProxy( pool, - hookRegistry, + pluginRegistry, elastic, reuseElasticClients, healthCheckPeriod, @@ -245,8 +240,8 @@ var runCmd = &cobra.Command{ "tcpKeepAlivePeriod": clientConfig.TCPKeepAlivePeriod.String(), }, } - _, err = hookRegistry.Run( - context.Background(), proxyCfg, hook.OnNewProxy, hookRegistry.Verification) + _, err = pluginRegistry.Run( + context.Background(), proxyCfg, plugin.OnNewProxy, pluginRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") } @@ -285,7 +280,7 @@ var runCmd = &cobra.Command{ }, proxy, logger, - hookRegistry, + pluginRegistry, ) serverCfg := map[string]interface{}{ @@ -307,8 +302,8 @@ var runCmd = &cobra.Command{ "tcpKeepAlive": gConfig.Server.TCPKeepAlive.String(), "tcpNoDelay": gConfig.Server.TCPNoDelay, } - _, err = hookRegistry.Run( - context.Background(), serverCfg, hook.OnNewServer, hookRegistry.Verification) + _, err = pluginRegistry.Run( + context.Background(), serverCfg, plugin.OnNewServer, pluginRegistry.Verification) if err != nil { logger.Error().Err(err).Msg("Failed to run OnNewServer hooks") } @@ -326,16 +321,16 @@ var runCmd = &cobra.Command{ ) signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, signals...) - go func(hookRegistry *hook.Registry) { + go func(pluginRegistry *plugin.PluginRegistry) { for sig := range signalsCh { for _, s := range signals { if sig != s { // Notify the hooks that the server is shutting down. - _, err := hookRegistry.Run( + _, err := pluginRegistry.Run( context.Background(), map[string]interface{}{"signal": sig.String()}, - hook.OnSignal, - hookRegistry.Verification, + plugin.OnSignal, + pluginRegistry.Verification, ) if err != nil { logger.Error().Err(err).Msg("Failed to run OnSignal hooks") @@ -347,7 +342,7 @@ var runCmd = &cobra.Command{ } } } - }(hookRegistry) + }(pluginRegistry) // Run the server. if err := server.Run(); err != nil { diff --git a/network/proxy.go b/network/proxy.go index bc89c5cb..22945987 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -6,7 +6,7 @@ import ( "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/gatewayd-io/gatewayd/plugin/hook" + "github.com/gatewayd-io/gatewayd/plugin" "github.com/gatewayd-io/gatewayd/pool" "github.com/go-co-op/gocron" "github.com/panjf2000/gnet/v2" @@ -26,7 +26,7 @@ type Proxy struct { availableConnections pool.IPool busyConnections pool.IPool logger zerolog.Logger - hookRegistry *hook.Registry + pluginRegistry *plugin.PluginRegistry scheduler *gocron.Scheduler Elastic bool @@ -41,7 +41,7 @@ var _ IProxy = &Proxy{} // NewProxy creates a new proxy. func NewProxy( - connPool pool.IPool, hookRegistry *hook.Registry, + connPool pool.IPool, pluginRegistry *plugin.PluginRegistry, elastic, reuseElasticClients bool, healthCheckPeriod time.Duration, clientConfig *config.Client, logger zerolog.Logger, @@ -50,7 +50,7 @@ func NewProxy( availableConnections: connPool, busyConnections: pool.NewPool(config.EmptyPoolCapacity), logger: logger, - hookRegistry: hookRegistry, + pluginRegistry: pluginRegistry, scheduler: gocron.NewScheduler(time.UTC), Elastic: elastic, ReuseElasticClients: reuseElasticClients, @@ -363,7 +363,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { request, origErr := receiveTrafficFromClient() // Run the OnTrafficFromClient hooks. - result, err := pr.hookRegistry.Run( + result, err := pr.pluginRegistry.Run( context.Background(), trafficData( gconn, @@ -375,8 +375,8 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { }, }, origErr), - hook.OnTrafficFromClient, - pr.hookRegistry.Verification) + plugin.OnTrafficFromClient, + pr.pluginRegistry.Verification) if err != nil { pr.logger.Error().Err(err).Msg("Error running hook") } @@ -396,7 +396,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { _, err = sendTrafficToServer(request) // Run the OnTrafficToServer hooks. - _, err = pr.hookRegistry.Run( + _, err = pr.pluginRegistry.Run( context.Background(), trafficData( gconn, @@ -408,8 +408,8 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { }, }, err), - hook.OnTrafficToServer, - pr.hookRegistry.Verification) + plugin.OnTrafficToServer, + pr.pluginRegistry.Verification) if err != nil { pr.logger.Error().Err(err).Msg("Error running hook") } @@ -448,7 +448,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { } // Run the OnTrafficFromServer hooks. - result, err = pr.hookRegistry.Run( + result, err = pr.pluginRegistry.Run( context.Background(), trafficData( gconn, @@ -464,8 +464,8 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { }, }, err), - hook.OnTrafficFromServer, - pr.hookRegistry.Verification) + plugin.OnTrafficFromServer, + pr.pluginRegistry.Verification) if err != nil { pr.logger.Error().Err(err).Msg("Error running hook") } @@ -481,7 +481,7 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { errVerdict := sendTrafficToClient(response, received) // Run the OnTrafficToClient hooks. - _, err = pr.hookRegistry.Run( + _, err = pr.pluginRegistry.Run( context.Background(), trafficData( gconn, @@ -498,8 +498,8 @@ func (pr *Proxy) PassThrough(gconn gnet.Conn) *gerr.GatewayDError { }, err, ), - hook.OnTrafficToClient, - pr.hookRegistry.Verification) + plugin.OnTrafficToClient, + pr.pluginRegistry.Verification) if err != nil { pr.logger.Error().Err(err).Msg("Error running hook") } diff --git a/network/proxy_test.go b/network/proxy_test.go index 68fca5a3..4d5b4ccc 100644 --- a/network/proxy_test.go +++ b/network/proxy_test.go @@ -6,7 +6,7 @@ import ( embeddedpostgres "github.com/fergusstrange/embedded-postgres" "github.com/gatewayd-io/gatewayd/config" "github.com/gatewayd-io/gatewayd/logging" - "github.com/gatewayd-io/gatewayd/plugin/hook" + "github.com/gatewayd-io/gatewayd/plugin" "github.com/gatewayd-io/gatewayd/pool" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -53,8 +53,13 @@ func TestNewProxy(t *testing.T) { assert.Nil(t, err) // Create a proxy with a fixed buffer pool - proxy := NewProxy( - pool, hook.NewRegistry(), false, false, config.DefaultHealthCheckPeriod, nil, logger) + proxy := NewProxy(pool, + plugin.NewRegistry(config.Loose, config.PassDown, logger), + false, + false, + config.DefaultHealthCheckPeriod, + nil, + logger) assert.NotNil(t, proxy) assert.Equal(t, 0, proxy.busyConnections.Size(), "Proxy should have no connected clients") @@ -83,7 +88,11 @@ func TestNewProxyElastic(t *testing.T) { pool := pool.NewPool(config.EmptyPoolCapacity) // Create a proxy with an elastic buffer pool - proxy := NewProxy(pool, hook.NewRegistry(), true, false, config.DefaultHealthCheckPeriod, + proxy := NewProxy(pool, + plugin.NewRegistry(config.Loose, config.PassDown, logger), + true, + false, + config.DefaultHealthCheckPeriod, &config.Client{ Network: "tcp", Address: "localhost:5432", @@ -93,7 +102,8 @@ func TestNewProxyElastic(t *testing.T) { SendDeadline: config.DefaultSendDeadline, TCPKeepAlive: false, TCPKeepAlivePeriod: config.DefaultTCPKeepAlivePeriod, - }, logger) + }, + logger) assert.NotNil(t, proxy) assert.Equal(t, 0, proxy.busyConnections.Size()) diff --git a/network/server.go b/network/server.go index 215b383c..77de3e71 100644 --- a/network/server.go +++ b/network/server.go @@ -10,17 +10,17 @@ import ( "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/gatewayd-io/gatewayd/plugin/hook" + "github.com/gatewayd-io/gatewayd/plugin" "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" ) type Server struct { gnet.BuiltinEventEngine - engine gnet.Engine - proxy IProxy - logger zerolog.Logger - hookRegistry *hook.Registry + engine gnet.Engine + proxy IProxy + logger zerolog.Logger + pluginRegistry *plugin.PluginRegistry Network string // tcp/udp/unix Address string @@ -38,11 +38,11 @@ func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { s.logger.Debug().Msg("GatewayD is booting...") // Run the OnBooting hooks. - _, err := s.hookRegistry.Run( + _, err := s.pluginRegistry.Run( context.Background(), map[string]interface{}{"status": fmt.Sprint(s.Status)}, - hook.OnBooting, - s.hookRegistry.Verification) + plugin.OnBooting, + s.pluginRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnBooting hook") } @@ -53,11 +53,11 @@ func (s *Server) OnBoot(engine gnet.Engine) gnet.Action { s.Status = config.Running // Run the OnBooted hooks. - _, err = s.hookRegistry.Run( + _, err = s.pluginRegistry.Run( context.Background(), map[string]interface{}{"status": fmt.Sprint(s.Status)}, - hook.OnBooted, - s.hookRegistry.Verification) + plugin.OnBooted, + s.pluginRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnBooted hook") } @@ -80,8 +80,8 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { "remote": gconn.RemoteAddr().String(), }, } - _, err := s.hookRegistry.Run( - context.Background(), onOpeningData, hook.OnOpening, s.hookRegistry.Verification) + _, err := s.pluginRegistry.Run( + context.Background(), onOpeningData, plugin.OnOpening, s.pluginRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnOpening hook") } @@ -121,8 +121,8 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { "remote": gconn.RemoteAddr().String(), }, } - _, err = s.hookRegistry.Run( - context.Background(), onOpenedData, hook.OnOpened, s.hookRegistry.Verification) + _, err = s.pluginRegistry.Run( + context.Background(), onOpenedData, plugin.OnOpened, s.pluginRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnOpened hook") } @@ -148,8 +148,8 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { if err != nil { data["error"] = err.Error() } - _, gatewaydErr := s.hookRegistry.Run( - context.Background(), data, hook.OnClosing, s.hookRegistry.Verification) + _, gatewaydErr := s.pluginRegistry.Run( + context.Background(), data, plugin.OnClosing, s.pluginRegistry.Verification) if gatewaydErr != nil { s.logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosing hook") } @@ -179,8 +179,8 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { if err != nil { data["error"] = err.Error() } - _, gatewaydErr = s.hookRegistry.Run( - context.Background(), data, hook.OnClosed, s.hookRegistry.Verification) + _, gatewaydErr = s.pluginRegistry.Run( + context.Background(), data, plugin.OnClosed, s.pluginRegistry.Verification) if gatewaydErr != nil { s.logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosed hook") } @@ -198,8 +198,8 @@ func (s *Server) OnTraffic(gconn gnet.Conn) gnet.Action { "remote": gconn.RemoteAddr().String(), }, } - _, err := s.hookRegistry.Run( - context.Background(), onTrafficData, hook.OnTraffic, s.hookRegistry.Verification) + _, err := s.pluginRegistry.Run( + context.Background(), onTrafficData, plugin.OnTraffic, s.pluginRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnTraffic hook") } @@ -231,11 +231,11 @@ func (s *Server) OnShutdown(engine gnet.Engine) { s.logger.Debug().Msg("GatewayD is shutting down...") // Run the OnShutdown hooks. - _, err := s.hookRegistry.Run( + _, err := s.pluginRegistry.Run( context.Background(), map[string]interface{}{"connections": s.engine.CountConnections()}, - hook.OnShutdown, - s.hookRegistry.Verification) + plugin.OnShutdown, + s.pluginRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnShutdown hook") } @@ -254,11 +254,11 @@ func (s *Server) OnTick() (time.Duration, gnet.Action) { "Active client connections") // Run the OnTick hooks. - _, err := s.hookRegistry.Run( + _, err := s.pluginRegistry.Run( context.Background(), map[string]interface{}{"connections": s.engine.CountConnections()}, - hook.OnTick, - s.hookRegistry.Verification) + plugin.OnTick, + s.pluginRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run OnTick hook") } @@ -284,8 +284,8 @@ func (s *Server) Run() error { if err != nil && err.Unwrap() != nil { onRunData["error"] = err.OriginalError.Error() } - result, err := s.hookRegistry.Run( - context.Background(), onRunData, hook.OnRun, s.hookRegistry.Verification) + result, err := s.pluginRegistry.Run( + context.Background(), onRunData, plugin.OnRun, s.pluginRegistry.Verification) if err != nil { s.logger.Error().Err(err).Msg("Failed to run the hook") } @@ -334,7 +334,7 @@ func NewServer( options []gnet.Option, proxy IProxy, logger zerolog.Logger, - hookRegistry *hook.Registry, + pluginRegistry *plugin.PluginRegistry, ) *Server { // Create the server. server := Server{ @@ -390,7 +390,7 @@ func NewServer( server.proxy = proxy server.logger = logger - server.hookRegistry = hookRegistry + server.pluginRegistry = pluginRegistry return &server } diff --git a/network/server_test.go b/network/server_test.go index b3a054d8..4193f45f 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -9,7 +9,7 @@ import ( embeddedpostgres "github.com/fergusstrange/embedded-postgres" "github.com/gatewayd-io/gatewayd/config" "github.com/gatewayd-io/gatewayd/logging" - "github.com/gatewayd-io/gatewayd/plugin/hook" + "github.com/gatewayd-io/gatewayd/plugin" "github.com/gatewayd-io/gatewayd/pool" "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" @@ -39,7 +39,7 @@ func TestRunServer(t *testing.T) { logger := logging.NewLogger(cfg) - hookRegistry := hook.NewRegistry() + pluginRegistry := plugin.NewRegistry(config.Loose, config.PassDown, logger) onTrafficFromClient := func( ctx context.Context, @@ -67,7 +67,7 @@ func TestRunServer(t *testing.T) { assert.Empty(t, paramsMap["error"]) return params, nil } - hookRegistry.Add(hook.OnTrafficFromClient, 1, onTrafficFromClient) + pluginRegistry.AddHook(plugin.OnTrafficFromClient, 1, onTrafficFromClient) onTrafficFromServer := func( ctx context.Context, @@ -92,7 +92,7 @@ func TestRunServer(t *testing.T) { assert.Empty(t, paramsMap["error"]) return params, nil } - hookRegistry.Add(hook.OnTrafficFromServer, 1, onTrafficFromServer) + pluginRegistry.AddHook(plugin.OnTrafficFromServer, 1, onTrafficFromServer) clientConfig := config.Client{ Network: "tcp", @@ -116,7 +116,7 @@ func TestRunServer(t *testing.T) { // Create a proxy with a fixed buffer pool. proxy := NewProxy( - pool, hookRegistry, false, false, config.DefaultHealthCheckPeriod, &clientConfig, logger) + pool, pluginRegistry, false, false, config.DefaultHealthCheckPeriod, &clientConfig, logger) // Create a server. server := NewServer( @@ -132,7 +132,7 @@ func TestRunServer(t *testing.T) { }, proxy, logger, - hookRegistry, + pluginRegistry, ) assert.NotNil(t, server) diff --git a/plugin/hook/constants.go b/plugin/constants.go similarity index 98% rename from plugin/hook/constants.go rename to plugin/constants.go index 8f22ff85..122ed96f 100644 --- a/plugin/hook/constants.go +++ b/plugin/constants.go @@ -1,4 +1,4 @@ -package hook +package plugin const ( // Run command hooks (cmd/run.go). diff --git a/plugin/hook/hook_registry.go b/plugin/hook/hook_registry.go deleted file mode 100644 index dda99991..00000000 --- a/plugin/hook/hook_registry.go +++ /dev/null @@ -1,195 +0,0 @@ -package hook - -import ( - "context" - "sort" - - "github.com/gatewayd-io/gatewayd/config" - gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/gatewayd-io/gatewayd/plugin/utils" - "github.com/rs/zerolog" - "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/structpb" -) - -type IRegistry interface { - Hooks() map[string]map[Priority]Method - Add(hookName string, priority Priority, hookMethod Method) - Get(hookName string) map[Priority]Method - Run( - ctx context.Context, - args map[string]interface{}, - hookName string, - verification config.Policy, - opts ...grpc.CallOption, - ) (map[string]interface{}, *gerr.GatewayDError) -} - -type Registry struct { - hooks map[string]map[Priority]Method - - Logger zerolog.Logger - Verification config.Policy -} - -var _ IRegistry = &Registry{} - -// NewRegistry returns a new Config. -func NewRegistry() *Registry { - return &Registry{ - hooks: map[string]map[Priority]Method{}, - } -} - -// Hooks returns the hooks. -func (h *Registry) Hooks() map[string]map[Priority]Method { - return h.hooks -} - -// Add adds a hook with a priority to the hooks map. -func (h *Registry) Add(hookName string, priority Priority, hookMethod Method) { - if len(h.hooks[hookName]) == 0 { - h.hooks[hookName] = map[Priority]Method{priority: hookMethod} - } else { - if _, ok := h.hooks[hookName][priority]; ok { - h.Logger.Warn().Fields( - map[string]interface{}{ - "hookName": hookName, - "priority": priority, - }, - ).Msg("Hook is replaced") - } - h.hooks[hookName][priority] = hookMethod - } -} - -// Get returns the hooks of a specific type. -func (h *Registry) Get(hookName string) map[Priority]Method { - return h.hooks[hookName] -} - -// Run runs the hooks of a specific type. The result of the previous hook is passed -// to the next hook as the argument, aka. chained. The context is passed to the -// hooks as well to allow them to cancel the execution. The args are passed to the -// first hook as the argument. The result of the first hook is passed to the second -// hook, and so on. The result of the last hook is eventually returned. The verification -// mode is used to determine how to handle errors. If the verification mode is set to -// Abort, the execution is aborted on the first error. If the verification mode is set -// to Remove, the hook is removed from the list of hooks on the first error. If the -// verification mode is set to Ignore, the error is ignored and the execution continues. -// If the verification mode is set to PassDown, the extra keys/values in the result -// are passed down to the next The verification mode is set to PassDown by default. -// The opts are passed to the hooks as well to allow them to use the grpc.CallOption. -// -//nolint:funlen -func (h *Registry) Run( - ctx context.Context, - args map[string]interface{}, - hookName string, - verification config.Policy, - opts ...grpc.CallOption, -) (map[string]interface{}, *gerr.GatewayDError) { - if ctx == nil { - return nil, gerr.ErrNilContext - } - - // Inherit context. - inheritedCtx, cancel := context.WithCancel(ctx) - defer cancel() - - // Cast custom fields to their primitive types, like time.Duration to float64. - args = utils.CastToPrimitiveTypes(args) - - // Create structpb.Struct from args. - var params *structpb.Struct - if len(args) == 0 { - params = &structpb.Struct{} - } else if casted, err := structpb.NewStruct(args); err == nil { - params = casted - } else { - return nil, gerr.ErrCastFailed.Wrap(err) - } - - // Sort hooks by priority. - priorities := make([]Priority, 0, len(h.hooks[hookName])) - for priority := range h.hooks[hookName] { - priorities = append(priorities, priority) - } - sort.SliceStable(priorities, func(i, j int) bool { - return priorities[i] < priorities[j] - }) - - // Run hooks, passing the result of the previous hook to the next one. - returnVal := &structpb.Struct{} - var removeList []Priority - // The signature of parameters and args MUST be the same for this to work. - for idx, priority := range priorities { - var result *structpb.Struct - var err error - if idx == 0 { - result, err = h.hooks[hookName][priority](inheritedCtx, params, opts...) - } else { - result, err = h.hooks[hookName][priority](inheritedCtx, returnVal, opts...) - } - - // This is done to ensure that the return value of the hook is always valid, - // and that the hook does not return any unexpected values. - // If the verification mode is non-strict (permissive), let the plugin pass - // extra keys/values to the next plugin in chain. - if utils.Verify(params, result) || verification == config.PassDown { - // Update the last return value with the current result - returnVal = result - continue - } - - // At this point, the hook returned an invalid value, so we need to handle it. - // The result of the current hook will be ignored, regardless of the policy. - switch verification { - // Ignore the result of this plugin, log an error and execute the next - case config.Ignore: - h.Logger.Error().Err(err).Fields( - map[string]interface{}{ - "hookName": hookName, - "priority": priority, - }, - ).Msg("Hook returned invalid value, ignoring") - if idx == 0 { - returnVal = params - } - // Abort execution of the plugins, log the error and return the result of the last - case config.Abort: - h.Logger.Error().Err(err).Fields( - map[string]interface{}{ - "hookName": hookName, - "priority": priority, - }, - ).Msg("Hook returned invalid value, aborting") - if idx == 0 { - return args, nil - } - return returnVal.AsMap(), nil - // Remove the hook from the registry, log the error and execute the next - case config.Remove: - h.Logger.Error().Err(err).Fields( - map[string]interface{}{ - "hookName": hookName, - "priority": priority, - }, - ).Msg("Hook returned invalid value, removing") - removeList = append(removeList, priority) - if idx == 0 { - returnVal = params - } - case config.PassDown: - default: - returnVal = result - } - } - - // Remove hooks that failed verification. - for _, priority := range removeList { - delete(h.hooks[hookName], priority) - } - - return returnVal.AsMap(), nil -} diff --git a/plugin/hook/hook_registry_test.go b/plugin/hook/hook_registry_test.go deleted file mode 100644 index 429a5f92..00000000 --- a/plugin/hook/hook_registry_test.go +++ /dev/null @@ -1,261 +0,0 @@ -package hook - -import ( - "context" - "testing" - - "github.com/gatewayd-io/gatewayd/config" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/structpb" -) - -// Test_NewHookRegistry tests the NewHookRegistry function. -func Test_NewHookRegistry(t *testing.T) { - hc := NewRegistry() - assert.NotNil(t, hc) -} - -// Test_HookRegistry_Add tests the Add function. -func Test_HookRegistry_Add(t *testing.T) { - hooks := NewRegistry() - testFunc := func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return args, nil - } - hooks.Add(OnNewLogger, 0, testFunc) - assert.NotNil(t, hooks.Hooks()[OnNewLogger][0]) - assert.ObjectsAreEqual(testFunc, hooks.Hooks()[OnNewLogger][0]) -} - -// Test_HookRegistry_Add_Multiple_Hooks tests the Add function with multiple hooks. -func Test_HookRegistry_Add_Multiple_Hooks(t *testing.T) { - hooks := NewRegistry() - hooks.Add(OnNewLogger, 0, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return args, nil - }) - hooks.Add(OnNewLogger, 1, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return args, nil - }) - assert.NotNil(t, hooks.Hooks()[OnNewLogger][0]) - assert.NotNil(t, hooks.Hooks()[OnNewLogger][1]) -} - -// Test_HookRegistry_Get tests the Get function. -func Test_HookRegistry_Get(t *testing.T) { - hooks := NewRegistry() - testFunc := func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return args, nil - } - priority := Priority(0) - hooks.Add(OnNewLogger, priority, testFunc) - assert.NotNil(t, hooks.Get(OnNewLogger)) - assert.ObjectsAreEqual(testFunc, hooks.Get(OnNewLogger)[priority]) -} - -// Test_HookRegistry_Run tests the Run function. -func Test_HookRegistry_Run(t *testing.T) { - hooks := NewRegistry() - hooks.Add(OnNewLogger, 0, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return args, nil - }) - result, err := hooks.Run( - context.Background(), map[string]interface{}{}, OnNewLogger, config.Ignore) - assert.NotNil(t, result) - assert.Nil(t, err) -} - -// Test_HookRegistry_Run_PassDown tests the Run function with the PassDown option. -func Test_HookRegistry_Run_PassDown(t *testing.T) { - hooks := NewRegistry() - // The result of the hook will be nil and will be passed down to the next - hooks.Add(OnNewLogger, 0, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return nil, nil //nolint:nilnil - }) - // The consolidated result should be {"test": "test"}. - hooks.Add(OnNewLogger, 1, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - output, err := structpb.NewStruct(map[string]interface{}{ - "test": "test", - }) - assert.Nil(t, err) - return output, nil - }) - - // Although the first hook returns nil, and its signature doesn't match the params, - // so its result (nil) is passed down to the next hook in chain (priority 2). - // Then the second hook runs and returns a signature with a "test" key and value. - result, err := hooks.Run( - context.Background(), - map[string]interface{}{"test": "test"}, - OnNewLogger, - config.PassDown) - assert.Nil(t, err) - assert.NotNil(t, result) -} - -// Test_HookRegistry_Run_PassDown_2 tests the Run function with the PassDown option. -func Test_HookRegistry_Run_PassDown_2(t *testing.T) { - hooks := NewRegistry() - // The result of the hook will be nil and will be passed down to the next - hooks.Add(OnNewLogger, 0, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - args.Fields["test1"] = &structpb.Value{ - Kind: &structpb.Value_StringValue{ //nolint:nosnakecase - StringValue: "test1", - }, - } - return args, nil - }) - // The consolidated result should be {"test1": "test1", "test2": "test2"}. - hooks.Add(OnNewLogger, 1, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - args.Fields["test2"] = &structpb.Value{ - Kind: &structpb.Value_StringValue{ //nolint:nosnakecase - StringValue: "test2", - }, - } - return args, nil - }) - // Although the first hook returns nil, and its signature doesn't match the params, - // so its result (nil) is passed down to the next hook in chain (priority 2). - // Then the second hook runs and returns a signature with a "test1" and "test2" key and value. - result, err := hooks.Run( - context.Background(), - map[string]interface{}{"test": "test"}, - OnNewLogger, - config.PassDown) - assert.Nil(t, err) - assert.NotNil(t, result) -} - -// Test_HookRegistry_Run_Ignore tests the Run function with the Ignore option. -func Test_HookRegistry_Run_Ignore(t *testing.T) { - hooks := NewRegistry() - // This should not run, because the return value is not the same as the params - hooks.Add(OnNewLogger, 0, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return nil, nil //nolint:nilnil - }) - // This should run, because the return value is the same as the params - hooks.Add(OnNewLogger, 1, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - args.Fields["test"] = &structpb.Value{ - Kind: &structpb.Value_StringValue{ //nolint:nosnakecase - StringValue: "test", - }, - } - return args, nil - }) - // The first hook returns nil, and its signature doesn't match the params, - // so its result is ignored. - // Then the second hook runs and returns a signature with a "test" key and value. - result, err := hooks.Run( - context.Background(), - map[string]interface{}{"test": "test"}, - OnNewLogger, - config.Ignore) - assert.Nil(t, err) - assert.NotNil(t, result) -} - -// Test_HookRegistry_Run_Abort tests the Run function with the Abort option. -func Test_HookRegistry_Run_Abort(t *testing.T) { - hooks := NewRegistry() - // This should not run, because the return value is not the same as the params - hooks.Add(OnNewLogger, 0, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return nil, nil //nolint:nilnil - }) - // This should not run, because the first hook returns nil, and its result is ignored. - hooks.Add(OnNewLogger, 1, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - output, err := structpb.NewStruct(map[string]interface{}{ - "test": "test", - }) - assert.Nil(t, err) - return output, nil - }) - // The first hook returns nil, and it aborts the execution of the rest of the - result, err := hooks.Run( - context.Background(), map[string]interface{}{}, OnNewLogger, config.Abort) - assert.Nil(t, err) - assert.Equal(t, map[string]interface{}{}, result) -} - -// Test_HookRegistry_Run_Remove tests the Run function with the Remove option. -func Test_HookRegistry_Run_Remove(t *testing.T) { - hooks := NewRegistry() - // This should not run, because the return value is not the same as the params - hooks.Add(OnNewLogger, 0, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - return nil, nil //nolint:nilnil - }) - // This should not run, because the first hook returns nil, and its result is ignored. - hooks.Add(OnNewLogger, 1, func( - ctx context.Context, - args *structpb.Struct, - opts ...grpc.CallOption, - ) (*structpb.Struct, error) { - output, err := structpb.NewStruct(map[string]interface{}{ - "test": "test", - }) - assert.Nil(t, err) - return output, nil - }) - // The first hook returns nil, and its signature doesn't match the params, - // so its result is ignored. The failing hook is removed from the list and - // the execution continues with the next hook in the list. - result, err := hooks.Run( - context.Background(), map[string]interface{}{}, OnNewLogger, config.Remove) - assert.Nil(t, err) - assert.Equal(t, map[string]interface{}{}, result) - assert.Equal(t, 1, len(hooks.Hooks()[OnNewLogger])) -} diff --git a/plugin/plugin.go b/plugin/plugin.go index 31caefdf..27aadea3 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -4,7 +4,6 @@ import ( "net" gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/gatewayd-io/gatewayd/plugin/hook" pluginV1 "github.com/gatewayd-io/gatewayd/plugin/v1" goplugin "github.com/hashicorp/go-plugin" ) @@ -41,7 +40,7 @@ type Plugin struct { Config map[string]string // hooks it attaches to Hooks []string - Priority hook.Priority + Priority Priority // required plugins to be loaded before this one // Built-in plugins are always loaded first Requires []Identifier diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index cf97bdd5..52a167f7 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -2,17 +2,19 @@ package plugin import ( "context" + "sort" semver "github.com/Masterminds/semver/v3" "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/logging" - "github.com/gatewayd-io/gatewayd/plugin/hook" "github.com/gatewayd-io/gatewayd/plugin/utils" pluginV1 "github.com/gatewayd-io/gatewayd/plugin/v1" "github.com/gatewayd-io/gatewayd/pool" goplugin "github.com/hashicorp/go-plugin" "github.com/mitchellh/mapstructure" + "github.com/rs/zerolog" + "google.golang.org/grpc" "google.golang.org/protobuf/types/known/structpb" ) @@ -23,23 +25,44 @@ type IPluginRegistry interface { Exists(name, version, remoteURL string) bool Remove(id Identifier) Shutdown() + + AddHook(hookName string, priority Priority, hookMethod Method) + Hooks() map[string]map[Priority]Method + Run( + ctx context.Context, + args map[string]interface{}, + hookName string, + verification config.Policy, + opts ...grpc.CallOption, + ) (map[string]interface{}, *gerr.GatewayDError) + LoadPlugins(plugins []config.Plugin) RegisterHooks(id Identifier) } type PluginRegistry struct { //nolint:golint,revive - plugins pool.IPool - hookRegistry *hook.Registry + plugins pool.IPool + hooks map[string]map[Priority]Method + + Logger zerolog.Logger CompatPolicy config.CompatPolicy + Verification config.Policy } var _ IPluginRegistry = &PluginRegistry{} // NewRegistry creates a new plugin registry. -func NewRegistry(hookRegistry *hook.Registry) *PluginRegistry { +func NewRegistry( + compatPolicy config.CompatPolicy, + verification config.Policy, + logger zerolog.Logger, +) *PluginRegistry { return &PluginRegistry{ plugins: pool.NewPool(config.EmptyPoolCapacity), - hookRegistry: hookRegistry, + hooks: map[string]map[Priority]Method{}, + Logger: logger, + CompatPolicy: compatPolicy, + Verification: verification, } } @@ -47,7 +70,7 @@ func NewRegistry(hookRegistry *hook.Registry) *PluginRegistry { func (reg *PluginRegistry) Add(plugin *Plugin) bool { _, loaded, err := reg.plugins.GetOrPut(plugin.ID, plugin) if err != nil { - reg.hookRegistry.Logger.Error().Err(err).Msg("Failed to add plugin to registry") + reg.Logger.Error().Err(err).Msg("Failed to add plugin to registry") return false } return loaded @@ -81,14 +104,14 @@ func (reg *PluginRegistry) Exists(name, version, remoteURL string) bool { // Parse the supplied version and the version in the registry. suppliedVer, err := semver.NewVersion(version) if err != nil { - reg.hookRegistry.Logger.Error().Err(err).Msg( + reg.Logger.Error().Err(err).Msg( "Failed to parse supplied plugin version") return false } registryVer, err := semver.NewVersion(plugin.Version) if err != nil { - reg.hookRegistry.Logger.Error().Err(err).Msg( + reg.Logger.Error().Err(err).Msg( "Failed to parse plugin version in registry") return false } @@ -99,7 +122,7 @@ func (reg *PluginRegistry) Exists(name, version, remoteURL string) bool { return true } - reg.hookRegistry.Logger.Debug().Str("name", name).Str("version", version).Msg( + reg.Logger.Debug().Str("name", name).Str("version", version).Msg( "Supplied plugin version is greater than the version in registry") return false } @@ -127,6 +150,154 @@ func (reg *PluginRegistry) Shutdown() { goplugin.CleanupClients() } +// Hooks returns the hooks map. +func (reg *PluginRegistry) Hooks() map[string]map[Priority]Method { + return reg.hooks +} + +// Add adds a hook with a priority to the hooks map. +func (reg *PluginRegistry) AddHook(hookName string, priority Priority, hookMethod Method) { + if len(reg.hooks[hookName]) == 0 { + reg.hooks[hookName] = map[Priority]Method{priority: hookMethod} + } else { + if _, ok := reg.hooks[hookName][priority]; ok { + reg.Logger.Warn().Fields( + map[string]interface{}{ + "hookName": hookName, + "priority": priority, + }, + ).Msg("Hook is replaced") + } + reg.hooks[hookName][priority] = hookMethod + } +} + +// Run runs the hooks of a specific type. The result of the previous hook is passed +// to the next hook as the argument, aka. chained. The context is passed to the +// hooks as well to allow them to cancel the execution. The args are passed to the +// first hook as the argument. The result of the first hook is passed to the second +// hook, and so on. The result of the last hook is eventually returned. The verification +// mode is used to determine how to handle errors. If the verification mode is set to +// Abort, the execution is aborted on the first error. If the verification mode is set +// to Remove, the hook is removed from the list of hooks on the first error. If the +// verification mode is set to Ignore, the error is ignored and the execution continues. +// If the verification mode is set to PassDown, the extra keys/values in the result +// are passed down to the next The verification mode is set to PassDown by default. +// The opts are passed to the hooks as well to allow them to use the grpc.CallOption. +// +//nolint:funlen +func (reg *PluginRegistry) Run( + ctx context.Context, + args map[string]interface{}, + hookName string, + verification config.Policy, + opts ...grpc.CallOption, +) (map[string]interface{}, *gerr.GatewayDError) { + if ctx == nil { + return nil, gerr.ErrNilContext + } + + // Inherit context. + inheritedCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Cast custom fields to their primitive types, like time.Duration to float64. + args = utils.CastToPrimitiveTypes(args) + + // Create structpb.Struct from args. + var params *structpb.Struct + if len(args) == 0 { + params = &structpb.Struct{} + } else if casted, err := structpb.NewStruct(args); err == nil { + params = casted + } else { + return nil, gerr.ErrCastFailed.Wrap(err) + } + + // Sort hooks by priority. + priorities := make([]Priority, 0, len(reg.hooks[hookName])) + for priority := range reg.hooks[hookName] { + priorities = append(priorities, priority) + } + sort.SliceStable(priorities, func(i, j int) bool { + return priorities[i] < priorities[j] + }) + + // Run hooks, passing the result of the previous hook to the next one. + returnVal := &structpb.Struct{} + var removeList []Priority + // The signature of parameters and args MUST be the same for this to work. + for idx, priority := range priorities { + var result *structpb.Struct + var err error + if idx == 0 { + result, err = reg.hooks[hookName][priority](inheritedCtx, params, opts...) + } else { + result, err = reg.hooks[hookName][priority](inheritedCtx, returnVal, opts...) + } + + // This is done to ensure that the return value of the hook is always valid, + // and that the hook does not return any unexpected values. + // If the verification mode is non-strict (permissive), let the plugin pass + // extra keys/values to the next plugin in chain. + if utils.Verify(params, result) || verification == config.PassDown { + // Update the last return value with the current result + returnVal = result + continue + } + + // At this point, the hook returned an invalid value, so we need to handle it. + // The result of the current hook will be ignored, regardless of the policy. + switch verification { + // Ignore the result of this plugin, log an error and execute the next + case config.Ignore: + reg.Logger.Error().Err(err).Fields( + map[string]interface{}{ + "hookName": hookName, + "priority": priority, + }, + ).Msg("Hook returned invalid value, ignoring") + if idx == 0 { + returnVal = params + } + // Abort execution of the plugins, log the error and return the result of the last + case config.Abort: + reg.Logger.Error().Err(err).Fields( + map[string]interface{}{ + "hookName": hookName, + "priority": priority, + }, + ).Msg("Hook returned invalid value, aborting") + if idx == 0 { + return args, nil + } + return returnVal.AsMap(), nil + // Remove the hook from the registry, log the error and execute the next + case config.Remove: + reg.Logger.Error().Err(err).Fields( + map[string]interface{}{ + "hookName": hookName, + "priority": priority, + }, + ).Msg("Hook returned invalid value, removing") + removeList = append(removeList, priority) + if idx == 0 { + returnVal = params + } + case config.PassDown: + default: + returnVal = result + } + } + + // Remove hooks that failed verification. + for _, priority := range removeList { + delete(reg.hooks[hookName], priority) + } + + return returnVal.AsMap(), nil +} + // LoadPlugins loads plugins from the config file. // //nolint:funlen @@ -141,7 +312,7 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { continue } - reg.hookRegistry.Logger.Debug().Str("name", pCfg.Name).Msg("Loading plugin") + reg.Logger.Debug().Str("name", pCfg.Name).Msg("Loading plugin") plugin := &Plugin{ ID: Identifier{ Name: pCfg.Name, @@ -156,20 +327,20 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Is the plugin enabled? plugin.Enabled = pCfg.Enabled if !plugin.Enabled { - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin is disabled") + reg.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin is disabled") continue } // File path of the plugin on disk. if plugin.LocalPath == "" { - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg( + reg.Logger.Debug().Str("name", plugin.ID.Name).Msg( "Local file of the plugin doesn't exist or is not set") continue } // Checksum of the plugin. if plugin.ID.Checksum == "" { - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg( + reg.Logger.Debug().Str("name", plugin.ID.Name).Msg( "Checksum of plugin doesn't exist or is not set") continue } @@ -177,11 +348,11 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Verify the checksum. // TODO: Load the plugin from a remote location if the checksum didn't match? if sum, err := utils.SHA256SUM(plugin.LocalPath); err != nil { - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( + reg.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( "Failed to calculate checksum") continue } else if sum != plugin.ID.Checksum { - reg.hookRegistry.Logger.Debug().Fields( + reg.Logger.Debug().Fields( map[string]interface{}{ "calculated": sum, "expected": plugin.ID.Checksum, @@ -195,9 +366,9 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // in the config file. Built-in plugins are loaded first, followed by user-defined // plugins. Built-in plugins have a priority of 0 to 999, and user-defined plugins // have a priority of 1000 or greater. - plugin.Priority = hook.Priority(config.PluginPriorityStart + uint(priority)) + plugin.Priority = Priority(config.PluginPriorityStart + uint(priority)) - logAdapter := logging.NewHcLogAdapter(®.hookRegistry.Logger, config.LoggerName) + logAdapter := logging.NewHcLogAdapter(®.Logger, config.LoggerName) plugin.client = goplugin.NewClient( &goplugin.ClientConfig{ @@ -220,22 +391,22 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { }, ) - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin loaded") + reg.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin loaded") if _, err := plugin.Start(); err != nil { - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( + reg.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( "Failed to start plugin") } // Load metadata from the plugin. var metadata *structpb.Struct if pluginV1, err := plugin.Dispense(); err != nil { - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( + reg.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( "Failed to dispense plugin") continue } else { if meta, origErr := pluginV1.GetPluginConfig( context.Background(), &structpb.Struct{}); err != nil { - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Err(origErr).Msg( + reg.Logger.Debug().Str("name", plugin.ID.Name).Err(origErr).Msg( "Failed to get plugin metadata") continue } else { @@ -246,12 +417,12 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Retrieve plugin requirements. if err := mapstructure.Decode(metadata.Fields["requires"].GetListValue().AsSlice(), &plugin.Requires); err != nil { - reg.hookRegistry.Logger.Debug().Err(err).Msg("Failed to decode plugin requirements") + reg.Logger.Debug().Err(err).Msg("Failed to decode plugin requirements") } // Too many requirements or not enough plugins loaded. if len(plugin.Requires) > reg.plugins.Size() { - reg.hookRegistry.Logger.Debug().Msg( + reg.Logger.Debug().Msg( "The plugin has too many requirements, " + "and not enough of them exist in the registry, so it won't work properly") } @@ -259,19 +430,19 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Check if the plugin requirements are met. for _, req := range plugin.Requires { if !reg.Exists(req.Name, req.Version, req.RemoteURL) { - reg.hookRegistry.Logger.Debug().Fields( + reg.Logger.Debug().Fields( map[string]interface{}{ "name": plugin.ID.Name, "requirement": req.Name, }, ).Msg("The plugin requirement is not met, so it won't work properly") if reg.CompatPolicy == config.Strict { - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg( + reg.Logger.Debug().Str("name", plugin.ID.Name).Msg( "Registry is in strict compatibility mode, so the plugin won't be loaded") plugin.Stop() // Stop the plugin. continue } else { - reg.hookRegistry.Logger.Debug().Fields( + reg.Logger.Debug().Fields( map[string]interface{}{ "name": plugin.ID.Name, "requirement": req.Name, @@ -290,12 +461,12 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // Retrieve authors. if err := mapstructure.Decode(metadata.Fields["authors"].GetListValue().AsSlice(), &plugin.Authors); err != nil { - reg.hookRegistry.Logger.Debug().Err(err).Msg("Failed to decode plugin authors") + reg.Logger.Debug().Err(err).Msg("Failed to decode plugin authors") } // Retrieve hooks. if err := mapstructure.Decode(metadata.Fields["hooks"].GetListValue().AsSlice(), &plugin.Hooks); err != nil { - reg.hookRegistry.Logger.Debug().Err(err).Msg("Failed to decode plugin hooks") + reg.Logger.Debug().Err(err).Msg("Failed to decode plugin hooks") } // Retrieve plugin config. @@ -304,18 +475,18 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { if val, ok := value.(string); ok { plugin.Config[key] = val } else { - reg.hookRegistry.Logger.Debug().Str("key", key).Msg( + reg.Logger.Debug().Str("key", key).Msg( "Failed to decode plugin config") } } - reg.hookRegistry.Logger.Trace().Msgf("Plugin metadata: %+v", plugin) + reg.Logger.Trace().Msgf("Plugin metadata: %+v", plugin) reg.Add(plugin) - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin metadata loaded") + reg.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin metadata loaded") reg.RegisterHooks(plugin.ID) - reg.hookRegistry.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin hooks registered") + reg.Logger.Debug().Str("name", plugin.ID.Name).Msg("Plugin hooks registered") } } @@ -324,63 +495,63 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { //nolint:funlen func (reg *PluginRegistry) RegisterHooks(id Identifier) { pluginImpl := reg.Get(id) - reg.hookRegistry.Logger.Debug().Str("name", pluginImpl.ID.Name).Msg( + reg.Logger.Debug().Str("name", pluginImpl.ID.Name).Msg( "Registering hooks for plugin") var pluginV1 pluginV1.GatewayDPluginServiceClient var err *gerr.GatewayDError if pluginV1, err = pluginImpl.Dispense(); err != nil { - reg.hookRegistry.Logger.Debug().Str("name", pluginImpl.ID.Name).Err(err).Msg( + reg.Logger.Debug().Str("name", pluginImpl.ID.Name).Err(err).Msg( "Failed to dispense plugin") return } for _, hookName := range pluginImpl.Hooks { - var hookMethod hook.Method + var hookMethod Method switch hookName { - case hook.OnConfigLoaded: + case OnConfigLoaded: hookMethod = pluginV1.OnConfigLoaded - case hook.OnNewLogger: + case OnNewLogger: hookMethod = pluginV1.OnNewLogger - case hook.OnNewPool: + case OnNewPool: hookMethod = pluginV1.OnNewPool - case hook.OnNewProxy: + case OnNewProxy: hookMethod = pluginV1.OnNewProxy - case hook.OnNewServer: + case OnNewServer: hookMethod = pluginV1.OnNewServer - case hook.OnSignal: + case OnSignal: hookMethod = pluginV1.OnSignal - case hook.OnRun: + case OnRun: hookMethod = pluginV1.OnRun - case hook.OnBooting: + case OnBooting: hookMethod = pluginV1.OnBooting - case hook.OnBooted: + case OnBooted: hookMethod = pluginV1.OnBooted - case hook.OnOpening: + case OnOpening: hookMethod = pluginV1.OnOpening - case hook.OnOpened: + case OnOpened: hookMethod = pluginV1.OnOpened - case hook.OnClosing: + case OnClosing: hookMethod = pluginV1.OnClosing - case hook.OnClosed: + case OnClosed: hookMethod = pluginV1.OnClosed - case hook.OnTraffic: + case OnTraffic: hookMethod = pluginV1.OnTraffic - case hook.OnTrafficFromClient: + case OnTrafficFromClient: hookMethod = pluginV1.OnTrafficFromClient - case hook.OnTrafficToServer: + case OnTrafficToServer: hookMethod = pluginV1.OnTrafficToServer - case hook.OnTrafficFromServer: + case OnTrafficFromServer: hookMethod = pluginV1.OnTrafficFromServer - case hook.OnTrafficToClient: + case OnTrafficToClient: hookMethod = pluginV1.OnTrafficToClient - case hook.OnShutdown: + case OnShutdown: hookMethod = pluginV1.OnShutdown - case hook.OnTick: + case OnTick: hookMethod = pluginV1.OnTick - case hook.OnNewClient: + case OnNewClient: hookMethod = pluginV1.OnNewClient default: - reg.hookRegistry.Logger.Warn().Fields(map[string]interface{}{ + reg.Logger.Warn().Fields(map[string]interface{}{ "hook": hookName, "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, @@ -388,11 +559,11 @@ func (reg *PluginRegistry) RegisterHooks(id Identifier) { "Unknown hook, skipping") continue } - reg.hookRegistry.Logger.Debug().Fields(map[string]interface{}{ + reg.Logger.Debug().Fields(map[string]interface{}{ "hook": hookName, "priority": pluginImpl.Priority, "name": pluginImpl.ID.Name, }).Msg("Registering hook") - reg.hookRegistry.Add(hookName, pluginImpl.Priority, hookMethod) + reg.AddHook(hookName, pluginImpl.Priority, hookMethod) } } diff --git a/plugin/plugin_registry_test.go b/plugin/plugin_registry_test.go index 5697e291..ad3bcd23 100644 --- a/plugin/plugin_registry_test.go +++ b/plugin/plugin_registry_test.go @@ -1,20 +1,37 @@ package plugin import ( + "context" "testing" - "github.com/gatewayd-io/gatewayd/plugin/hook" + "github.com/gatewayd-io/gatewayd/config" + "github.com/gatewayd-io/gatewayd/logging" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/structpb" ) +func NewPluginRegistry(t *testing.T) *PluginRegistry { + t.Helper() + + cfg := logging.LoggerConfig{ + Output: config.Console, + TimeFormat: zerolog.TimeFormatUnix, + Level: zerolog.DebugLevel, + NoColor: true, + } + logger := logging.NewLogger(cfg) + reg := NewRegistry(config.Loose, config.PassDown, logger) + return reg +} + // TestPluginRegistry tests the PluginRegistry. func TestPluginRegistry(t *testing.T) { - hookRegistry := hook.NewRegistry() - assert.NotNil(t, hookRegistry) - reg := NewRegistry(hookRegistry) + reg := NewPluginRegistry(t) assert.NotNil(t, reg) assert.NotNil(t, reg.plugins) - assert.NotNil(t, reg.hookRegistry) + assert.NotNil(t, reg.hooks) assert.Equal(t, 0, len(reg.List())) ident := Identifier{ @@ -36,3 +53,232 @@ func TestPluginRegistry(t *testing.T) { reg.Shutdown() } + +// Test_HookRegistry_Add tests the Add function. +func Test_PluginRegistry_AddHook(t *testing.T) { + testFunc := func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + return args, nil + } + + reg := NewPluginRegistry(t) + reg.AddHook(OnNewLogger, 0, testFunc) + assert.NotNil(t, reg.Hooks()[OnNewLogger][0]) + assert.ObjectsAreEqual(testFunc, reg.Hooks()[OnNewLogger][0]) +} + +// Test_HookRegistry_Add_Multiple_Hooks tests the Add function with multiple hooks. +func Test_PluginRegistry_AddHook_Multiple(t *testing.T) { + reg := NewPluginRegistry(t) + reg.AddHook(OnNewLogger, 0, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + return args, nil + }) + reg.AddHook(OnNewLogger, 1, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + return args, nil + }) + assert.NotNil(t, reg.Hooks()[OnNewLogger][0]) + assert.NotNil(t, reg.Hooks()[OnNewLogger][1]) +} + +// Test_HookRegistry_Run tests the Run function. +func Test_PluginRegistry_Run(t *testing.T) { + reg := NewPluginRegistry(t) + reg.AddHook(OnNewLogger, 0, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + return args, nil + }) + result, err := reg.Run( + context.Background(), map[string]interface{}{}, OnNewLogger, config.Ignore) + assert.NotNil(t, result) + assert.Nil(t, err) +} + +// Test_HookRegistry_Run_PassDown tests the Run function with the PassDown option. +func Test_PluginRegistry_Run_PassDown(t *testing.T) { + reg := NewPluginRegistry(t) + // The result of the hook will be nil and will be passed down to the next + reg.AddHook(OnNewLogger, 0, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + return nil, nil //nolint:nilnil + }) + // The consolidated result should be {"test": "test"}. + reg.AddHook(OnNewLogger, 1, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + output, err := structpb.NewStruct(map[string]interface{}{ + "test": "test", + }) + assert.Nil(t, err) + return output, nil + }) + + // Although the first hook returns nil, and its signature doesn't match the params, + // so its result (nil) is passed down to the next hook in chain (priority 2). + // Then the second hook runs and returns a signature with a "test" key and value. + result, err := reg.Run( + context.Background(), + map[string]interface{}{"test": "test"}, + OnNewLogger, + config.PassDown) + assert.Nil(t, err) + assert.NotNil(t, result) +} + +// Test_HookRegistry_Run_PassDown_2 tests the Run function with the PassDown option. +func Test_HookRegistry_Run_PassDown_2(t *testing.T) { + reg := NewPluginRegistry(t) + // The result of the hook will be nil and will be passed down to the next + reg.AddHook(OnNewLogger, 0, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + args.Fields["test1"] = &structpb.Value{ + Kind: &structpb.Value_StringValue{ //nolint:nosnakecase + StringValue: "test1", + }, + } + return args, nil + }) + // The consolidated result should be {"test1": "test1", "test2": "test2"}. + reg.AddHook(OnNewLogger, 1, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + args.Fields["test2"] = &structpb.Value{ + Kind: &structpb.Value_StringValue{ //nolint:nosnakecase + StringValue: "test2", + }, + } + return args, nil + }) + // Although the first hook returns nil, and its signature doesn't match the params, + // so its result (nil) is passed down to the next hook in chain (priority 2). + // Then the second hook runs and returns a signature with a "test1" and "test2" key and value. + result, err := reg.Run( + context.Background(), + map[string]interface{}{"test": "test"}, + OnNewLogger, + config.PassDown) + assert.Nil(t, err) + assert.NotNil(t, result) +} + +// Test_HookRegistry_Run_Ignore tests the Run function with the Ignore option. +func Test_HookRegistry_Run_Ignore(t *testing.T) { + reg := NewPluginRegistry(t) + // This should not run, because the return value is not the same as the params + reg.AddHook(OnNewLogger, 0, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + return nil, nil //nolint:nilnil + }) + // This should run, because the return value is the same as the params + reg.AddHook(OnNewLogger, 1, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + args.Fields["test"] = &structpb.Value{ + Kind: &structpb.Value_StringValue{ //nolint:nosnakecase + StringValue: "test", + }, + } + return args, nil + }) + // The first hook returns nil, and its signature doesn't match the params, + // so its result is ignored. + // Then the second hook runs and returns a signature with a "test" key and value. + result, err := reg.Run( + context.Background(), + map[string]interface{}{"test": "test"}, + OnNewLogger, + config.Ignore) + assert.Nil(t, err) + assert.NotNil(t, result) +} + +// Test_HookRegistry_Run_Abort tests the Run function with the Abort option. +func Test_HookRegistry_Run_Abort(t *testing.T) { + reg := NewPluginRegistry(t) + // This should not run, because the return value is not the same as the params + reg.AddHook(OnNewLogger, 0, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + return nil, nil //nolint:nilnil + }) + // This should not run, because the first hook returns nil, and its result is ignored. + reg.AddHook(OnNewLogger, 1, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + output, err := structpb.NewStruct(map[string]interface{}{ + "test": "test", + }) + assert.Nil(t, err) + return output, nil + }) + // The first hook returns nil, and it aborts the execution of the rest of the + result, err := reg.Run( + context.Background(), map[string]interface{}{}, OnNewLogger, config.Abort) + assert.Nil(t, err) + assert.Equal(t, map[string]interface{}{}, result) +} + +// Test_HookRegistry_Run_Remove tests the Run function with the Remove option. +func Test_HookRegistry_Run_Remove(t *testing.T) { + reg := NewPluginRegistry(t) + // This should not run, because the return value is not the same as the params + reg.AddHook(OnNewLogger, 0, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + return nil, nil //nolint:nilnil + }) + // This should not run, because the first hook returns nil, and its result is ignored. + reg.AddHook(OnNewLogger, 1, func( + ctx context.Context, + args *structpb.Struct, + opts ...grpc.CallOption, + ) (*structpb.Struct, error) { + output, err := structpb.NewStruct(map[string]interface{}{ + "test": "test", + }) + assert.Nil(t, err) + return output, nil + }) + // The first hook returns nil, and its signature doesn't match the params, + // so its result is ignored. The failing hook is removed from the list and + // the execution continues with the next hook in the list. + result, err := reg.Run( + context.Background(), map[string]interface{}{}, OnNewLogger, config.Remove) + assert.Nil(t, err) + assert.Equal(t, map[string]interface{}{}, result) + assert.Equal(t, 1, len(reg.Hooks()[OnNewLogger])) +} diff --git a/plugin/hook/types.go b/plugin/types.go similarity index 95% rename from plugin/hook/types.go rename to plugin/types.go index d3c98bb6..f88c26ea 100644 --- a/plugin/hook/types.go +++ b/plugin/types.go @@ -1,4 +1,4 @@ -package hook +package plugin import ( "context" From aa673480834d3a11fe02684a2ceed71ec46eae7c Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Sat, 14 Jan 2023 00:19:01 +0100 Subject: [PATCH 13/16] Rename PluginRegistry and its interface to Registry --- cmd/run.go | 2 +- network/proxy.go | 4 ++-- network/server.go | 4 ++-- plugin/plugin_registry.go | 32 ++++++++++++++++---------------- plugin/plugin_registry_test.go | 2 +- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/cmd/run.go b/cmd/run.go index e43913e1..790cbc07 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -321,7 +321,7 @@ var runCmd = &cobra.Command{ ) signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, signals...) - go func(pluginRegistry *plugin.PluginRegistry) { + go func(pluginRegistry *plugin.Registry) { for sig := range signalsCh { for _, s := range signals { if sig != s { diff --git a/network/proxy.go b/network/proxy.go index 22945987..cb3ec7e1 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -26,7 +26,7 @@ type Proxy struct { availableConnections pool.IPool busyConnections pool.IPool logger zerolog.Logger - pluginRegistry *plugin.PluginRegistry + pluginRegistry *plugin.Registry scheduler *gocron.Scheduler Elastic bool @@ -41,7 +41,7 @@ var _ IProxy = &Proxy{} // NewProxy creates a new proxy. func NewProxy( - connPool pool.IPool, pluginRegistry *plugin.PluginRegistry, + connPool pool.IPool, pluginRegistry *plugin.Registry, elastic, reuseElasticClients bool, healthCheckPeriod time.Duration, clientConfig *config.Client, logger zerolog.Logger, diff --git a/network/server.go b/network/server.go index 77de3e71..46215ef8 100644 --- a/network/server.go +++ b/network/server.go @@ -20,7 +20,7 @@ type Server struct { engine gnet.Engine proxy IProxy logger zerolog.Logger - pluginRegistry *plugin.PluginRegistry + pluginRegistry *plugin.Registry Network string // tcp/udp/unix Address string @@ -334,7 +334,7 @@ func NewServer( options []gnet.Option, proxy IProxy, logger zerolog.Logger, - pluginRegistry *plugin.PluginRegistry, + pluginRegistry *plugin.Registry, ) *Server { // Create the server. server := Server{ diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index 52a167f7..e0ef7daa 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -18,7 +18,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -type IPluginRegistry interface { +type IRegistry interface { Add(plugin *Plugin) bool Get(id Identifier) *Plugin List() []Identifier @@ -40,7 +40,7 @@ type IPluginRegistry interface { RegisterHooks(id Identifier) } -type PluginRegistry struct { //nolint:golint,revive +type Registry struct { //nolint:golint,revive plugins pool.IPool hooks map[string]map[Priority]Method @@ -49,15 +49,15 @@ type PluginRegistry struct { //nolint:golint,revive Verification config.Policy } -var _ IPluginRegistry = &PluginRegistry{} +var _ IRegistry = &Registry{} // NewRegistry creates a new plugin registry. func NewRegistry( compatPolicy config.CompatPolicy, verification config.Policy, logger zerolog.Logger, -) *PluginRegistry { - return &PluginRegistry{ +) *Registry { + return &Registry{ plugins: pool.NewPool(config.EmptyPoolCapacity), hooks: map[string]map[Priority]Method{}, Logger: logger, @@ -67,7 +67,7 @@ func NewRegistry( } // Add adds a plugin to the registry. -func (reg *PluginRegistry) Add(plugin *Plugin) bool { +func (reg *Registry) Add(plugin *Plugin) bool { _, loaded, err := reg.plugins.GetOrPut(plugin.ID, plugin) if err != nil { reg.Logger.Error().Err(err).Msg("Failed to add plugin to registry") @@ -77,7 +77,7 @@ func (reg *PluginRegistry) Add(plugin *Plugin) bool { } // Get returns a plugin from the registry. -func (reg *PluginRegistry) Get(id Identifier) *Plugin { +func (reg *Registry) Get(id Identifier) *Plugin { if plugin, ok := reg.plugins.Get(id).(*Plugin); ok { return plugin } @@ -86,7 +86,7 @@ func (reg *PluginRegistry) Get(id Identifier) *Plugin { } // List returns a list of all plugins in the registry. -func (reg *PluginRegistry) List() []Identifier { +func (reg *Registry) List() []Identifier { var plugins []Identifier reg.plugins.ForEach(func(key, _ interface{}) bool { if id, ok := key.(Identifier); ok { @@ -98,7 +98,7 @@ func (reg *PluginRegistry) List() []Identifier { } // Exists checks if a plugin exists in the registry. -func (reg *PluginRegistry) Exists(name, version, remoteURL string) bool { +func (reg *Registry) Exists(name, version, remoteURL string) bool { for _, plugin := range reg.List() { if plugin.Name == name && plugin.RemoteURL == remoteURL { // Parse the supplied version and the version in the registry. @@ -132,12 +132,12 @@ func (reg *PluginRegistry) Exists(name, version, remoteURL string) bool { } // Remove removes a plugin from the registry. -func (reg *PluginRegistry) Remove(id Identifier) { +func (reg *Registry) Remove(id Identifier) { reg.plugins.Remove(id) } // Shutdown shuts down all plugins in the registry. -func (reg *PluginRegistry) Shutdown() { +func (reg *Registry) Shutdown() { reg.plugins.ForEach(func(key, value interface{}) bool { if id, ok := key.(Identifier); ok { if plugin, ok := value.(*Plugin); ok { @@ -151,12 +151,12 @@ func (reg *PluginRegistry) Shutdown() { } // Hooks returns the hooks map. -func (reg *PluginRegistry) Hooks() map[string]map[Priority]Method { +func (reg *Registry) Hooks() map[string]map[Priority]Method { return reg.hooks } // Add adds a hook with a priority to the hooks map. -func (reg *PluginRegistry) AddHook(hookName string, priority Priority, hookMethod Method) { +func (reg *Registry) AddHook(hookName string, priority Priority, hookMethod Method) { if len(reg.hooks[hookName]) == 0 { reg.hooks[hookName] = map[Priority]Method{priority: hookMethod} } else { @@ -186,7 +186,7 @@ func (reg *PluginRegistry) AddHook(hookName string, priority Priority, hookMetho // The opts are passed to the hooks as well to allow them to use the grpc.CallOption. // //nolint:funlen -func (reg *PluginRegistry) Run( +func (reg *Registry) Run( ctx context.Context, args map[string]interface{}, hookName string, @@ -301,7 +301,7 @@ func (reg *PluginRegistry) Run( // LoadPlugins loads plugins from the config file. // //nolint:funlen -func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { +func (reg *Registry) LoadPlugins(plugins []config.Plugin) { // TODO: Append built-in plugins to the list of plugins // Built-in plugins are plugins that are compiled and shipped with the gatewayd binary. @@ -493,7 +493,7 @@ func (reg *PluginRegistry) LoadPlugins(plugins []config.Plugin) { // RegisterHooks registers the hooks for the given plugin. // //nolint:funlen -func (reg *PluginRegistry) RegisterHooks(id Identifier) { +func (reg *Registry) RegisterHooks(id Identifier) { pluginImpl := reg.Get(id) reg.Logger.Debug().Str("name", pluginImpl.ID.Name).Msg( "Registering hooks for plugin") diff --git a/plugin/plugin_registry_test.go b/plugin/plugin_registry_test.go index ad3bcd23..48f91ad9 100644 --- a/plugin/plugin_registry_test.go +++ b/plugin/plugin_registry_test.go @@ -12,7 +12,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -func NewPluginRegistry(t *testing.T) *PluginRegistry { +func NewPluginRegistry(t *testing.T) *Registry { t.Helper() cfg := logging.LoggerConfig{ From 24b7392d4b4e64d0547bcadd4e0011f721f378a2 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Sat, 14 Jan 2023 00:22:03 +0100 Subject: [PATCH 14/16] Move plugin/utils package into plugin --- plugin/plugin_registry.go | 9 ++-- plugin/utils/functions.go | 82 ------------------------------ plugin/utils/functions_test.go | 91 ---------------------------------- 3 files changed, 4 insertions(+), 178 deletions(-) delete mode 100644 plugin/utils/functions.go delete mode 100644 plugin/utils/functions_test.go diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index e0ef7daa..8ed3412e 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -8,7 +8,6 @@ import ( "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/logging" - "github.com/gatewayd-io/gatewayd/plugin/utils" pluginV1 "github.com/gatewayd-io/gatewayd/plugin/v1" "github.com/gatewayd-io/gatewayd/pool" goplugin "github.com/hashicorp/go-plugin" @@ -202,7 +201,7 @@ func (reg *Registry) Run( defer cancel() // Cast custom fields to their primitive types, like time.Duration to float64. - args = utils.CastToPrimitiveTypes(args) + args = CastToPrimitiveTypes(args) // Create structpb.Struct from args. var params *structpb.Struct @@ -240,7 +239,7 @@ func (reg *Registry) Run( // and that the hook does not return any unexpected values. // If the verification mode is non-strict (permissive), let the plugin pass // extra keys/values to the next plugin in chain. - if utils.Verify(params, result) || verification == config.PassDown { + if Verify(params, result) || verification == config.PassDown { // Update the last return value with the current result returnVal = result continue @@ -347,7 +346,7 @@ func (reg *Registry) LoadPlugins(plugins []config.Plugin) { // Verify the checksum. // TODO: Load the plugin from a remote location if the checksum didn't match? - if sum, err := utils.SHA256SUM(plugin.LocalPath); err != nil { + if sum, err := SHA256SUM(plugin.LocalPath); err != nil { reg.Logger.Debug().Str("name", plugin.ID.Name).Err(err).Msg( "Failed to calculate checksum") continue @@ -374,7 +373,7 @@ func (reg *Registry) LoadPlugins(plugins []config.Plugin) { &goplugin.ClientConfig{ HandshakeConfig: pluginV1.Handshake, Plugins: pluginV1.GetPluginMap(plugin.ID.Name), - Cmd: utils.NewCommand(plugin.LocalPath, plugin.Args, plugin.Env), + Cmd: NewCommand(plugin.LocalPath, plugin.Args, plugin.Env), AllowedProtocols: []goplugin.Protocol{ goplugin.ProtocolGRPC, }, diff --git a/plugin/utils/functions.go b/plugin/utils/functions.go deleted file mode 100644 index c9ecf5f5..00000000 --- a/plugin/utils/functions.go +++ /dev/null @@ -1,82 +0,0 @@ -package utils - -import ( - "bufio" - "crypto/sha256" - "errors" - "fmt" - "io" - "os" - "os/exec" - "time" - - "github.com/gatewayd-io/gatewayd/config" - gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/protobuf/types/known/structpb" -) - -// SHA256SUM returns the sha256 checksum of a file. -// Ref: https://github.com/codingsince1985/checksum -// A little copying is better than a little dependency. -func SHA256SUM(filename string) (string, *gerr.GatewayDError) { - if info, err := os.Stat(filename); err != nil || info.IsDir() { - return "", gerr.ErrFileNotFound.Wrap(err) - } - - file, err := os.Open(filename) - if err != nil { - return "", gerr.ErrFileOpenFailed.Wrap(err) - } - defer func() { _ = file.Close() }() - - hashAlgorithm := sha256.New() - - buf := make([]byte, config.ChecksumBufferSize) - for { - n, err := bufio.NewReader(file).Read(buf) - //nolint:gocritic - if err == nil { - hashAlgorithm.Write(buf[:n]) - } else if errors.Is(err, io.EOF) { - return fmt.Sprintf("%x", hashAlgorithm.Sum(nil)), nil - } else { - return "", gerr.ErrFileReadFailed.Wrap(err) - } - } -} - -// Verify compares two structs and returns true if they are equal. -func Verify(params, returnVal *structpb.Struct) bool { - return cmp.Equal(params.AsMap(), returnVal.AsMap(), cmp.Options{ - cmpopts.SortMaps(func(a, b string) bool { - return a < b - }), - cmpopts.EquateEmpty(), - }) -} - -// NewCommand returns a command with the given arguments and environment variables. -func NewCommand(cmd string, args []string, env []string) *exec.Cmd { - command := exec.Command(cmd, args...) - if env != nil { - command.Env = append(command.Env, env...) - } - return command -} - -// CastToPrimitiveTypes casts the values of a map to its primitive type -// (e.g. time.Duration to float64) to prevent structpb invalid type(s) errors. -func CastToPrimitiveTypes(args map[string]interface{}) map[string]interface{} { - for key, value := range args { - switch value := value.(type) { - case time.Duration: - args[key] = value.String() - // TODO: Add more types here as needed. - default: - args[key] = value - } - } - return args -} diff --git a/plugin/utils/functions_test.go b/plugin/utils/functions_test.go deleted file mode 100644 index 6d9578f1..00000000 --- a/plugin/utils/functions_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package utils - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/structpb" -) - -// Test_sha256sum tests the sha256sum function. -func Test_sha256sum(t *testing.T) { - checksum, err := SHA256SUM("../../LICENSE") - assert.Nil(t, err) - assert.Equal(t, - "8486a10c4393cee1c25392769ddd3b2d6c242d6ec7928e1414efff7dfb2f07ef", - checksum, - ) -} - -// Test_sha256sum_fail tests the sha256sum function with a file that does not exist. -func Test_sha256sum_fail(t *testing.T) { - _, err := SHA256SUM("not_a_file") - assert.NotNil(t, err) -} - -// Test_Verify tests the Verify function. -func Test_Verify(t *testing.T) { - params, err := structpb.NewStruct( - map[string]interface{}{ - "test": "test", - }, - ) - assert.Nil(t, err) - - returnVal, err := structpb.NewStruct( - map[string]interface{}{ - "test": "test", - }, - ) - assert.Nil(t, err) - - assert.True(t, Verify(params, returnVal)) -} - -// Test_Verify_fail tests the Verify function with different parameters to -// ensure it returns false on verification errors. -func Test_Verify_fail(t *testing.T) { - data := [][]map[string]interface{}{ - { - { - "test": "test", - }, - { - "test": "test", - "test2": "test2", - }, - }, - { - { - "test": "test", - "test2": "test2", - }, - { - "test": "test", - }, - }, - { - { - "test": "test", - "test2": "test2", - }, - { - "test": "test", - "test3": "test3", - }, - }, - } - - for _, d := range data { - params, err := structpb.NewStruct(d[0]) - assert.Nil(t, err) - returnVal, err := structpb.NewStruct(d[1]) - assert.Nil(t, err) - assert.False(t, Verify(params, returnVal)) - } -} - -// Test_Verify_nil tests the Verify function with nil parameters. -func Test_Verify_nil(t *testing.T) { - assert.True(t, Verify(nil, nil)) -} From 9fb83b3da57673cf909a2263a5ed36e4fbf56df1 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Sat, 14 Jan 2023 00:22:38 +0100 Subject: [PATCH 15/16] Rename functions to utils --- plugin/utils.go | 82 +++++++++++++++++++++++++++++++++++++++ plugin/utils_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 plugin/utils.go create mode 100644 plugin/utils_test.go diff --git a/plugin/utils.go b/plugin/utils.go new file mode 100644 index 00000000..e8c0539c --- /dev/null +++ b/plugin/utils.go @@ -0,0 +1,82 @@ +package plugin + +import ( + "bufio" + "crypto/sha256" + "errors" + "fmt" + "io" + "os" + "os/exec" + "time" + + "github.com/gatewayd-io/gatewayd/config" + gerr "github.com/gatewayd-io/gatewayd/errors" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/types/known/structpb" +) + +// SHA256SUM returns the sha256 checksum of a file. +// Ref: https://github.com/codingsince1985/checksum +// A little copying is better than a little dependency. +func SHA256SUM(filename string) (string, *gerr.GatewayDError) { + if info, err := os.Stat(filename); err != nil || info.IsDir() { + return "", gerr.ErrFileNotFound.Wrap(err) + } + + file, err := os.Open(filename) + if err != nil { + return "", gerr.ErrFileOpenFailed.Wrap(err) + } + defer func() { _ = file.Close() }() + + hashAlgorithm := sha256.New() + + buf := make([]byte, config.ChecksumBufferSize) + for { + n, err := bufio.NewReader(file).Read(buf) + //nolint:gocritic + if err == nil { + hashAlgorithm.Write(buf[:n]) + } else if errors.Is(err, io.EOF) { + return fmt.Sprintf("%x", hashAlgorithm.Sum(nil)), nil + } else { + return "", gerr.ErrFileReadFailed.Wrap(err) + } + } +} + +// Verify compares two structs and returns true if they are equal. +func Verify(params, returnVal *structpb.Struct) bool { + return cmp.Equal(params.AsMap(), returnVal.AsMap(), cmp.Options{ + cmpopts.SortMaps(func(a, b string) bool { + return a < b + }), + cmpopts.EquateEmpty(), + }) +} + +// NewCommand returns a command with the given arguments and environment variables. +func NewCommand(cmd string, args []string, env []string) *exec.Cmd { + command := exec.Command(cmd, args...) + if env != nil { + command.Env = append(command.Env, env...) + } + return command +} + +// CastToPrimitiveTypes casts the values of a map to its primitive type +// (e.g. time.Duration to float64) to prevent structpb invalid type(s) errors. +func CastToPrimitiveTypes(args map[string]interface{}) map[string]interface{} { + for key, value := range args { + switch value := value.(type) { + case time.Duration: + args[key] = value.String() + // TODO: Add more types here as needed. + default: + args[key] = value + } + } + return args +} diff --git a/plugin/utils_test.go b/plugin/utils_test.go new file mode 100644 index 00000000..b42f4f8c --- /dev/null +++ b/plugin/utils_test.go @@ -0,0 +1,91 @@ +package plugin + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/structpb" +) + +// Test_sha256sum tests the sha256sum function. +func Test_sha256sum(t *testing.T) { + checksum, err := SHA256SUM("../LICENSE") + assert.Nil(t, err) + assert.Equal(t, + "8486a10c4393cee1c25392769ddd3b2d6c242d6ec7928e1414efff7dfb2f07ef", + checksum, + ) +} + +// Test_sha256sum_fail tests the sha256sum function with a file that does not exist. +func Test_sha256sum_fail(t *testing.T) { + _, err := SHA256SUM("not_a_file") + assert.NotNil(t, err) +} + +// Test_Verify tests the Verify function. +func Test_Verify(t *testing.T) { + params, err := structpb.NewStruct( + map[string]interface{}{ + "test": "test", + }, + ) + assert.Nil(t, err) + + returnVal, err := structpb.NewStruct( + map[string]interface{}{ + "test": "test", + }, + ) + assert.Nil(t, err) + + assert.True(t, Verify(params, returnVal)) +} + +// Test_Verify_fail tests the Verify function with different parameters to +// ensure it returns false on verification errors. +func Test_Verify_fail(t *testing.T) { + data := [][]map[string]interface{}{ + { + { + "test": "test", + }, + { + "test": "test", + "test2": "test2", + }, + }, + { + { + "test": "test", + "test2": "test2", + }, + { + "test": "test", + }, + }, + { + { + "test": "test", + "test2": "test2", + }, + { + "test": "test", + "test3": "test3", + }, + }, + } + + for _, d := range data { + params, err := structpb.NewStruct(d[0]) + assert.Nil(t, err) + returnVal, err := structpb.NewStruct(d[1]) + assert.Nil(t, err) + assert.False(t, Verify(params, returnVal)) + } +} + +// Test_Verify_nil tests the Verify function with nil parameters. +func Test_Verify_nil(t *testing.T) { + assert.True(t, Verify(nil, nil)) +} From 381424a5567be73ab46fdace3fe8ba56ce82e01e Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Sat, 14 Jan 2023 00:27:44 +0100 Subject: [PATCH 16/16] Fix linter errors --- plugin/plugin_registry.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index 8ed3412e..57c2b03d 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -17,14 +17,19 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) +//nolint:interfacebloat type IRegistry interface { + // Plugin management Add(plugin *Plugin) bool Get(id Identifier) *Plugin List() []Identifier Exists(name, version, remoteURL string) bool Remove(id Identifier) Shutdown() + LoadPlugins(plugins []config.Plugin) + RegisterHooks(id Identifier) + // Hook management AddHook(hookName string, priority Priority, hookMethod Method) Hooks() map[string]map[Priority]Method Run( @@ -34,12 +39,9 @@ type IRegistry interface { verification config.Policy, opts ...grpc.CallOption, ) (map[string]interface{}, *gerr.GatewayDError) - - LoadPlugins(plugins []config.Plugin) - RegisterHooks(id Identifier) } -type Registry struct { //nolint:golint,revive +type Registry struct { plugins pool.IPool hooks map[string]map[Priority]Method