diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 87f75e80..cc380338 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -92,13 +92,10 @@ jobs: run: | sudo apt-get update sudo apt-get install --yes --no-install-recommends postgresql-client - psql -h ${PGHOST} -p ${PGPORT} -U ${PGUSER} -c "CREATE DATABASE gatewayd_test;" - psql -h ${PGHOST} -p ${PGPORT} -U ${PGUSER} -d ${DBNAME} -c "CREATE TABLE test_table (id serial PRIMARY KEY, name varchar(255));" - psql -h ${PGHOST} -p ${PGPORT} -U ${PGUSER} -d ${DBNAME} -c "INSERT INTO test_table (name) VALUES ('test');" - psql -h ${PGHOST} -p ${PGPORT} -U ${PGUSER} -d ${DBNAME} -c "SELECT * FROM test_table;" | grep test + psql ${PGURL1} -c "CREATE DATABASE gatewayd_test;" + psql ${PGURL2} -c "CREATE TABLE test_table (id serial PRIMARY KEY, name varchar(255));" + psql ${PGURL2} -c "INSERT INTO test_table (name) VALUES ('test');" + psql ${PGURL2} -c "SELECT * FROM test_table;" | grep test || exit 1 env: - DBNAME: gatewayd_test - PGUSER: postgres - PGPASSWORD: postgres - PGHOST: localhost - PGPORT: 15432 + PGURL1: postgresql://postgres:postgres@localhost:15432/postgres + PGURL2: postgresql://postgres:postgres@localhost:15432/gatewayd_test diff --git a/client_test.py b/client_test.py deleted file mode 100644 index 9ffae34b..00000000 --- a/client_test.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -from pprint import pprint -from concurrent.futures import ThreadPoolExecutor -import psycopg - - -def create_table(): - conn = None - try: - conn = psycopg.connect(host="localhost", port=15432, dbname="test", - user="postgres", password="postgres", sslmode="disable") - - conn.execute( - "CREATE TABLE IF NOT EXISTS test (id serial PRIMARY KEY, num integer, data varchar);") - conn.close() - except KeyboardInterrupt: - if conn: - conn.close() - os._exit(0) - except Exception as e: - print("Worker %s: %s" % (id, e)) - - return - - -def writer(id): - conn = None - try: - conn = psycopg.connect(host="localhost", port=15432, dbname="test", - user="postgres", password="postgres", sslmode="disable") - - conn.execute("INSERT INTO test (num, data) VALUES (%s, %s)", (id, "abc'def")) - - conn.close() - except KeyboardInterrupt: - if conn: - conn.close() - os._exit(0) - except Exception as e: - print("Worker %s: %s" % (id, e)) - - return - - -def reader(): - conn = None - try: - conn = psycopg.connect(host="localhost", port=15432, dbname="test", - user="postgres", password="postgres", sslmode="disable") - - for row in conn.execute("SELECT * FROM test;"): - print("ID=%s, NUM=%s, DATA=%s" % row) - # conn.execute("DROP TABLE test;") - conn.close() - except KeyboardInterrupt: - if conn: - conn.close() - os._exit(0) - except Exception as e: - print("Worker %s: %s" % (id, e)) - - return - - -if __name__ == '__main__': - with ThreadPoolExecutor(max_workers=10) as executor: - # Create 11 connections to the server and run queries in parallel - # This will cause the server to crash - executor.submit(create_table) - - for i in range(10): - executor.submit(writer, i) - - # Wait for all threads to finish - executor.submit(reader) - executor.shutdown(wait=True) diff --git a/cmd/config_parser.go b/cmd/config_parser.go index fdd29585..98126cc4 100644 --- a/cmd/config_parser.go +++ b/cmd/config_parser.go @@ -131,6 +131,10 @@ func proxyConfig() (bool, bool, *network.Client) { address := globalConfig.String(ref + ".address") receiveBufferSize := globalConfig.Int(ref + ".receiveBufferSize") + if receiveBufferSize <= 0 { + receiveBufferSize = network.DefaultBufferSize + } + return elastic, reuseElasticClients, &network.Client{ Network: net, Address: address, @@ -181,6 +185,26 @@ func getTCPNoDelay() gnet.TCPSocketOpt { } func serverConfig() *ServerConfig { + readBufferCap := globalConfig.Int("server.readBufferCap") + if readBufferCap <= 0 { + readBufferCap = network.DefaultBufferSize + } + + writeBufferCap := globalConfig.Int("server.writeBufferCap") + if writeBufferCap <= 0 { + writeBufferCap = network.DefaultBufferSize + } + + socketRecvBuffer := globalConfig.Int("server.socketRecvBuffer") + if socketRecvBuffer <= 0 { + socketRecvBuffer = network.DefaultBufferSize + } + + socketSendBuffer := globalConfig.Int("server.socketSendBuffer") + if socketSendBuffer <= 0 { + socketSendBuffer = network.DefaultBufferSize + } + return &ServerConfig{ Network: globalConfig.String("server.network"), Address: globalConfig.String("server.address"), @@ -191,10 +215,10 @@ func serverConfig() *ServerConfig { MultiCore: globalConfig.Bool("server.multiCore"), LockOSThread: globalConfig.Bool("server.lockOSThread"), LoadBalancer: getLoadBalancer(globalConfig.String("server.loadBalancer")), - ReadBufferCap: globalConfig.Int("server.readBufferCap"), - WriteBufferCap: globalConfig.Int("server.writeBufferCap"), - SocketRecvBuffer: globalConfig.Int("server.socketRecvBuffer"), - SocketSendBuffer: globalConfig.Int("server.socketSendBuffer"), + ReadBufferCap: readBufferCap, + WriteBufferCap: writeBufferCap, + SocketRecvBuffer: socketRecvBuffer, + SocketSendBuffer: socketSendBuffer, ReuseAddress: globalConfig.Bool("server.reuseAddress"), ReusePort: globalConfig.Bool("server.reusePort"), TCPKeepAlive: globalConfig.Duration("server.tcpKeepAlive"), diff --git a/cmd/run.go b/cmd/run.go index 3dc3fe15..5e917ece 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -114,8 +114,8 @@ var runCmd = &cobra.Command{ } // Create and initialize a pool of connections - pool := pool.NewPool() poolSize, clientConfig := poolConfig() + pool := pool.NewPool(poolSize) // Add clients to the pool for i := 0; i < poolSize; i++ { @@ -146,7 +146,10 @@ var runCmd = &cobra.Command{ } } - pool.Put(client.ID, client) + err = pool.Put(client.ID, client) + if err != nil { + logger.Error().Err(err).Msg("Failed to add client to the pool") + } } } diff --git a/errors/errors.go b/errors/errors.go index eb541e56..63b75c21 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -10,6 +10,13 @@ var ( ErrPluginNotFound = errors.New("plugin not found") ErrPluginNotReady = errors.New("plugin is not ready") + + ErrClientReceiveFailed = errors.New("couldn't receive data from the server") + ErrClientSendFailed = errors.New("couldn't send data to the server") + + ErrPutFailed = errors.New("failed to put in pool") + + ErrCastFailed = errors.New("failed to cast") ) const ( diff --git a/gatewayd.yaml b/gatewayd.yaml index 1463b0c6..2de46d2f 100644 --- a/gatewayd.yaml +++ b/gatewayd.yaml @@ -15,14 +15,14 @@ clients: client1: network: tcp address: localhost:5432 - receiveBufferSize: 4096 + # receiveBufferSize: 16777216 # Pool config pool: # Use the logger config passed here # i.e. don't assume it's the same as the logger config above logger: loggers.logger - size: 2 + size: 10 # Database configs for the connection pool client: clients.client1 @@ -52,10 +52,10 @@ server: multiCore: True lockOSThread: False loadBalancer: roundrobin - readBufferCap: 4096 - writeBufferCap: 4096 - socketRecvBuffer: 4096 - socketSendBuffer: 4096 + # readBufferCap: 16777216 + # writeBufferCap: 16777216 + # socketRecvBuffer: 16777216 + # socketSendBuffer: 16777216 reuseAddress: True reusePort: True tcpKeepAlive: 3s # seconds diff --git a/go.mod b/go.mod index d8bd6539..e81c0ed6 100644 --- a/go.mod +++ b/go.mod @@ -9,11 +9,11 @@ require ( github.com/hashicorp/go-plugin v1.4.8 github.com/knadh/koanf v1.4.4 github.com/mitchellh/mapstructure v1.5.0 - github.com/panjf2000/gnet/v2 v2.2.0 + github.com/panjf2000/gnet/v2 v2.2.1 github.com/rs/zerolog v1.28.0 github.com/spf13/cobra v1.6.1 github.com/stretchr/testify v1.8.1 - golang.org/x/exp v0.0.0-20221212164502-fae10dda9338 + golang.org/x/exp v0.0.0-20221217163422-3c43f8badb15 google.golang.org/grpc v1.51.0 google.golang.org/protobuf v1.28.1 ) @@ -21,7 +21,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.13.0 // indirect - github.com/fsnotify/fsnotify v1.4.9 // indirect + github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/hashicorp/yamux v0.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -36,7 +36,7 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect go.uber.org/atomic v1.10.0 // indirect - go.uber.org/multierr v1.8.0 // indirect + go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.24.0 // indirect golang.org/x/net v0.4.0 // indirect golang.org/x/sys v0.3.0 // indirect diff --git a/go.sum b/go.sum index 2da9cb9b..09fffbda 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/fergusstrange/embedded-postgres v1.19.0 h1:NqDufJHeA03U7biULlPHZ0pZ10 github.com/fergusstrange/embedded-postgres v1.19.0/go.mod h1:0B+3bPsMvcNgR9nN+bdM2x9YaNYDnf3ksUqYp1OAub0= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= @@ -222,8 +224,11 @@ github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= github.com/panjf2000/ants/v2 v2.4.8 h1:JgTbolX6K6RreZ4+bfctI0Ifs+3mrE5BIHudQxUDQ9k= github.com/panjf2000/ants/v2 v2.4.8/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= +github.com/panjf2000/ants/v2 v2.7.0 h1:Y3Bgpfo9HDkBoHNVFbMfY5mAvi5TAA17y3HbzQ74p5Y= github.com/panjf2000/gnet/v2 v2.2.0 h1:+6itXhRlHJpv5UGAyN1DebHzK1l0GbZMOsg2Spb1VS0= github.com/panjf2000/gnet/v2 v2.2.0/go.mod h1:unWr2B4jF0DQPJH3GsXBGQiDcAamM6+Pf5FiK705kc4= +github.com/panjf2000/gnet/v2 v2.2.1 h1:HJVK3vmD6rBgOeTnYkG4czW6jphVHygxLLWTEBU3nqU= +github.com/panjf2000/gnet/v2 v2.2.1/go.mod h1:y8xWR1EEK6pGDuAQ6XULY/WWmPv0Pgbsq2Q4lbXJ6JA= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml v1.7.0 h1:7utD74fnzVc/cpcyy8sjrlFr5vYpypUixARcHIMIGuI= @@ -304,6 +309,8 @@ go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9i go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= @@ -316,6 +323,10 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20221212164502-fae10dda9338 h1:OvjRkcNHnf6/W5FZXSxODbxwD+X7fspczG7Jn/xQVD4= golang.org/x/exp v0.0.0-20221212164502-fae10dda9338/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20221215174704-0915cd710c24 h1:6w3iSY8IIkp5OQtbYj8NeuKG1jS9d+kYaubXqsoOiQ8= +golang.org/x/exp v0.0.0-20221215174704-0915cd710c24/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20221217163422-3c43f8badb15 h1:5oN1Pz/eDhCpbMbLstvIPa0b/BEQo6g6nwV3pLjfM6w= +golang.org/x/exp v0.0.0-20221217163422-3c43f8badb15/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -391,6 +402,7 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220224120231-95c6836cb0e7/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/network/client.go b/network/client.go index 1ac2200d..435e223f 100644 --- a/network/client.go +++ b/network/client.go @@ -1,9 +1,11 @@ package network import ( - "fmt" + "errors" + "io" "net" + gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/rs/zerolog" ) @@ -58,33 +60,37 @@ func NewClient(network, address string, receiveBufferSize int, logger zerolog.Lo } client.Conn = conn - if client.ReceiveBufferSize == 0 { + if receiveBufferSize <= 0 { client.ReceiveBufferSize = DefaultBufferSize + } else { + client.ReceiveBufferSize = receiveBufferSize } + logger.Debug().Msgf("New client created: %s", client.Address) client.ID = GetID(conn.LocalAddr().Network(), conn.LocalAddr().String(), DefaultSeed, logger) return &client } -func (c *Client) Send(data []byte) error { - if _, err := c.Write(data); err != nil { +func (c *Client) Send(data []byte) (int, error) { + sent, err := c.Conn.Write(data) + if err != nil { c.logger.Error().Err(err).Msgf("Couldn't send data to the server: %s", err) - return fmt.Errorf("couldn't send data to the server: %w", err) + // TODO: Wrap the original error + return 0, gerr.ErrClientSendFailed } c.logger.Debug().Msgf("Sent %d bytes to %s", len(data), c.Address) - return nil + return sent, nil } func (c *Client) Receive() (int, []byte, error) { buf := make([]byte, c.ReceiveBufferSize) - read, err := c.Read(buf) - if err != nil { - c.logger.Error().Err(err).Msgf("Couldn't receive data from the server: %s", err) - return 0, nil, fmt.Errorf("couldn't receive data from the server: %w", err) + received, err := c.Conn.Read(buf) + if err != nil && errors.Is(err, io.EOF) { + c.logger.Error().Err(err).Msg("Couldn't receive data from the server") + return 0, nil, err //nolint:wrapcheck } - c.logger.Debug().Msgf("Received %d bytes from %s", read, c.Address) - return read, buf, nil + return received, buf, err //nolint:wrapcheck } func (c *Client) Close() { @@ -96,5 +102,26 @@ func (c *Client) Close() { c.Conn = nil c.Address = "" c.Network = "" - c.ReceiveBufferSize = 0 +} + +// Go returns io.EOF when the server closes the connection. +// So, if I read 0 bytes and the error is io.EOF or net.ErrClosed, I should reconnect. +func (c *Client) IsConnected() bool { + if c == nil { + return false + } + + if c != nil && c.Conn == nil || c.ID == "" { + c.Close() + return false + } + + buf := make([]byte, 0) + if _, err := c.Read(buf); errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { + c.logger.Debug().Msgf("Connection to %s is closed", c.Address) + c.Close() + return false + } + + return true } diff --git a/network/client_test.go b/network/client_test.go index 3ae2b697..d30e1a09 100644 --- a/network/client_test.go +++ b/network/client_test.go @@ -65,8 +65,10 @@ func TestSend(t *testing.T) { defer client.Close() assert.NotNil(t, client) - err := client.Send(CreatePostgreSQLPacket('Q', []byte("select 1;"))) + packet := CreatePostgreSQLPacket('Q', []byte("select 1;")) + sent, err := client.Send(packet) assert.Nil(t, err) + assert.Equal(t, len(packet), sent) } func TestReceive(t *testing.T) { @@ -94,8 +96,10 @@ func TestReceive(t *testing.T) { defer client.Close() assert.NotNil(t, client) - err := client.Send(CreatePgStartupPacket()) + packet := CreatePgStartupPacket() + sent, err := client.Send(packet) assert.Nil(t, err) + assert.Equal(t, len(packet), sent) size, data, err := client.Receive() msg := "\x00\x00\x00\x03" @@ -136,5 +140,5 @@ func TestClose(t *testing.T) { assert.Equal(t, "", client.Network) assert.Equal(t, "", client.Address) assert.Nil(t, client.Conn) - assert.Equal(t, 0, client.ReceiveBufferSize) + assert.Equal(t, DefaultBufferSize, client.ReceiveBufferSize) } diff --git a/network/proxy.go b/network/proxy.go index b0506136..e0738d98 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -2,9 +2,6 @@ package network import ( "context" - "errors" - "fmt" - "io" gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/plugin" @@ -15,12 +12,17 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) +const ( + EmptyPoolCapacity int = 0 +) + type Proxy interface { Connect(gconn gnet.Conn) error Disconnect(gconn gnet.Conn) error PassThrough(gconn gnet.Conn) error - Reconnect(cl *Client) *Client + TryReconnect(cl *Client) (*Client, error) Shutdown() + IsExhausted() bool } type ProxyImpl struct { @@ -45,7 +47,7 @@ func NewProxy( ) *ProxyImpl { return &ProxyImpl{ availableConnections: p, - busyConnections: pool.NewPool(), + busyConnections: pool.NewPool(EmptyPoolCapacity), logger: logger, hookConfig: hookConfig, Elastic: elastic, @@ -66,8 +68,8 @@ func (pr *ProxyImpl) Connect(gconn gnet.Conn) error { }) var client *Client - if pr.availableConnections.Size() == 0 { - // Pool is exhausted + if pr.IsExhausted() { + // Pool is exhausted or is elastic if pr.Elastic { // Create a new client client = NewClient( @@ -81,45 +83,65 @@ func (pr *ProxyImpl) Connect(gconn gnet.Conn) error { return gerr.ErrPoolExhausted } } else { - // Get a client from the pool - pr.logger.Debug().Msgf("Available clients: %v", pr.availableConnections.Size()) + // Get the client from the pool with the given clientID if cl, ok := pr.availableConnections.Pop(clientID).(*Client); ok { client = cl } } - if clientID != "" || client.ID != "" { - pr.busyConnections.Put(gconn, client) - pr.logger.Debug().Msgf("Client %s has been assigned to %s", client.ID, gconn.RemoteAddr().String()) - } else { - return gerr.ErrClientNotConnected + client, err := pr.TryReconnect(client) + if err != nil { + pr.logger.Error().Err(err).Msgf("Failed to connect to the client") + } + + if err := pr.busyConnections.Put(gconn, client); err != nil { + // This should never happen + return gerr.ErrPutFailed } + pr.logger.Debug().Msgf( + "Client %s has been assigned to %s", client.ID, gconn.RemoteAddr().String()) - pr.logger.Debug().Msgf("[C] There are %d clients in the pool", pr.availableConnections.Size()) - pr.logger.Debug().Msgf("[C] There are %d clients in use", pr.busyConnections.Size()) + pr.logger.Debug().Str("function", "Proxy.Connect").Msgf( + "There are %d available clients", pr.availableConnections.Size()) + pr.logger.Debug().Str("function", "Proxy.Connect").Msgf( + "There are %d busy clients", pr.busyConnections.Size()) return nil } func (pr *ProxyImpl) Disconnect(gconn gnet.Conn) error { - var client *Client - if cl, ok := pr.busyConnections.Pop(gconn).(*Client); !ok { - client = cl - } - - // TODO: The connection is unstable when I put the client back in the pool - // If the client is not in the pool, put it back - if pr.Elastic && pr.ReuseElasticClients || !pr.Elastic { - client = pr.Reconnect(client) - if client != nil && client.ID != "" { - pr.availableConnections.Put(client.ID, client) + client := pr.busyConnections.Pop(gconn) + //nolint:nestif + if client != nil { + if client, ok := client.(*Client); ok { + if (pr.Elastic && pr.ReuseElasticClients) || !pr.Elastic { + if !client.IsConnected() { + _, err := pr.TryReconnect(client) + if err != nil { + pr.logger.Error().Err(err).Msgf("Failed to reconnect to the client") + } + } + // If the client is not in the pool, put it back + err := pr.availableConnections.Put(client.ID, client) + if err != nil { + pr.logger.Error().Err(err).Msgf("Failed to put the client back in the pool") + } + } else { + return gerr.ErrClientNotConnected + } + } else { + // This should never happen, but if it does, + // then there are some serious issues with the pool + return gerr.ErrCastFailed } } else { - client.Close() + return gerr.ErrClientNotFound } - pr.logger.Debug().Msgf("[D] There are %d clients in the pool", pr.availableConnections.Size()) - pr.logger.Debug().Msgf("[D] There are %d clients in use", pr.busyConnections.Size()) + pr.logger.Debug().Str("function", "Proxy.Disconnect").Msgf( + "There are %d available clients", pr.availableConnections.Size()) + pr.logger.Debug().Str("function", "Proxy.Disconnect").Msgf( + "There are %d busy clients", pr.busyConnections.Size()) return nil } @@ -133,10 +155,14 @@ func (pr *ProxyImpl) PassThrough(gconn gnet.Conn) error { // that listens for data from the server and sends it to the client var client *Client + if pr.busyConnections.Get(gconn) == nil { + return gerr.ErrClientNotFound + } + if cl, ok := pr.busyConnections.Get(gconn).(*Client); ok { client = cl } else { - return gerr.ErrClientNotFound + return gerr.ErrCastFailed } // buf contains the data from the client (, length, query) @@ -144,6 +170,13 @@ func (pr *ProxyImpl) PassThrough(gconn gnet.Conn) error { if err != nil { pr.logger.Error().Err(err).Msgf("Error reading from client: %v", err) } + pr.logger.Debug().Fields( + map[string]interface{}{ + "length": len(buf), + "local": gconn.LocalAddr().String(), + "remote": gconn.RemoteAddr().String(), + }, + ).Msg("Received data from client") addresses := map[string]interface{}{ "client": map[string]interface{}{ @@ -188,21 +221,33 @@ func (pr *ProxyImpl) PassThrough(gconn gnet.Conn) error { } } - // TODO: This is a very basic implementation of the gateway - // and it is synchronous. I should make it asynchronous. - pr.logger.Debug().Msgf("Received %d bytes from %s", len(buf), gconn.RemoteAddr().String()) - // Send the query to the server - err = client.Send(buf) + sent, err := client.Send(buf) if err != nil { - return err + pr.logger.Error().Err(err).Msgf("Error sending data to database") } + pr.logger.Debug().Fields( + map[string]interface{}{ + "function": "Proxy.PassThrough", + "length": sent, + "local": client.Conn.LocalAddr().String(), + "remote": client.Conn.RemoteAddr().String(), + }, + ).Msg("Sent data to database") // Receive the response from the server - size, response, err := client.Receive() + received, response, err := client.Receive() + pr.logger.Debug().Fields( + map[string]interface{}{ + "function": "Proxy.PassThrough", + "length": received, + "local": client.Conn.LocalAddr().String(), + "remote": client.Conn.RemoteAddr().String(), + }, + ).Msg("Received data from database") egress := map[string]interface{}{ - "response": response[:size], // Will be converted to base64-encoded string + "response": response[:received], // Will be converted to base64-encoded string "error": "", } if err != nil { @@ -233,46 +278,38 @@ func (pr *ProxyImpl) PassThrough(gconn gnet.Conn) error { } } - //nolint:gocritic - if err != nil && errors.Is(err, io.EOF) { - // The server has closed the connection - pr.logger.Error().Err(err).Msg("The client is not connected to the server anymore") - // Either the client is not connected to the server anymore or - // server forceful closed the connection - // Reconnect the client - client = pr.Reconnect(client) - // Put the client in the busy connections pool, effectively replacing the old one - pr.busyConnections.Put(gconn, client) + err = gconn.AsyncWrite(response[:received], func(gconn gnet.Conn, err error) error { + pr.logger.Debug().Fields( + map[string]interface{}{ + "function": "Proxy.PassThrough", + "length": received, + "local": gconn.LocalAddr().String(), + "remote": gconn.RemoteAddr().String(), + }, + ).Msg("Sent data to client") return err - } else if err != nil { - // Write the error to the client - _, err := gconn.Write(response[:size]) - if err != nil { - pr.logger.Error().Err(err).Msgf("Error writing the error to client: %v", err) - } - return fmt.Errorf("error receiving data from server: %w", err) - } else { - // Write the response to the incoming connection - _, err = gconn.Write(response[:size]) - if err != nil { - pr.logger.Error().Err(err).Msgf("Error writing to client: %v", err) - } + }) + if err != nil { + pr.logger.Error().Err(err).Msgf("Error writing to client") + return err //nolint:wrapcheck } return nil } -func (pr *ProxyImpl) Reconnect(cl *Client) *Client { - // Close the client - if cl != nil && cl.ID != "" { - cl.Close() +func (pr *ProxyImpl) TryReconnect(client *Client) (*Client, error) { + // TODO: try retriable connection? + + if pr.IsExhausted() { + pr.logger.Error().Msg("No more available connections :: TryReconnect") + return client, gerr.ErrPoolExhausted + } + + if !client.IsConnected() { + pr.logger.Error().Msg("Client is disconnected") } - return NewClient( - pr.ClientConfig.Network, - pr.ClientConfig.Address, - pr.ClientConfig.ReceiveBufferSize, - pr.logger, - ) + + return client, nil } func (pr *ProxyImpl) Shutdown() { @@ -297,3 +334,11 @@ func (pr *ProxyImpl) Shutdown() { pr.busyConnections.Clear() pr.logger.Debug().Msg("All busy connections have been closed") } + +func (pr *ProxyImpl) IsExhausted() bool { + if pr.Elastic { + return false + } + + return pr.availableConnections.Size() == 0 && pr.availableConnections.Cap() > 0 +} diff --git a/network/proxy_test.go b/network/proxy_test.go index 86ad32ed..e6ff4484 100644 --- a/network/proxy_test.go +++ b/network/proxy_test.go @@ -33,9 +33,10 @@ func TestNewProxy(t *testing.T) { logger := logging.NewLogger(cfg) // Create a connection pool - pool := pool.NewPool() + pool := pool.NewPool(EmptyPoolCapacity) client := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger) - pool.Put(client.ID, client) + err := pool.Put(client.ID, client) + assert.Nil(t, err) // Create a proxy with a fixed buffer pool proxy := NewProxy(pool, plugin.NewHookConfig(), false, false, nil, logger) @@ -63,7 +64,7 @@ func TestNewProxyElastic(t *testing.T) { logger := logging.NewLogger(cfg) // Create a connection pool - pool := pool.NewPool() + pool := pool.NewPool(EmptyPoolCapacity) // Create a proxy with an elastic buffer pool proxy := NewProxy(pool, plugin.NewHookConfig(), true, false, &Client{ diff --git a/network/server.go b/network/server.go index c91fa310..bc4e8244 100644 --- a/network/server.go +++ b/network/server.go @@ -2,10 +2,13 @@ package network import ( "context" + "errors" "fmt" + "io" "os" "time" + gerr "github.com/gatewayd-io/gatewayd/errors" "github.com/gatewayd-io/gatewayd/plugin" "github.com/panjf2000/gnet/v2" "github.com/rs/zerolog" @@ -21,7 +24,7 @@ const ( DefaultTickInterval = 5 * time.Second DefaultPoolSize = 10 MinimumPoolSize = 2 - DefaultBufferSize = 4096 + DefaultBufferSize = 1 << 24 // 16777216 bytes ) type Server struct { @@ -113,7 +116,14 @@ func (s *Server) OnOpen(gconn gnet.Conn) ([]byte, gnet.Action) { } if err := s.proxy.Connect(gconn); err != nil { - return nil, gnet.Close + if errors.Is(err, gerr.ErrPoolExhausted) { + return nil, gnet.Close + } + + // This should never happen + // TODO: Send error to client or retry connection + s.logger.Error().Err(err).Msg("Failed to connect to proxy") + return nil, gnet.None } onOpenedData, err := structpb.NewStruct(map[string]interface{}{ @@ -160,14 +170,16 @@ func (s *Server) OnClose(gconn gnet.Conn, err error) gnet.Action { } } - if err := s.proxy.Disconnect(gconn); err != nil { - s.logger.Error().Err(err).Msg("Failed to disconnect from the client") - } - + // Shutdown the server if there are no more connections and the server is stopped if uint64(s.engine.CountConnections()) == 0 && s.Status == Stopped { return gnet.Shutdown } + if err := s.proxy.Disconnect(gconn); err != nil { + s.logger.Error().Err(err).Msg("Failed to disconnect the server connection") + return gnet.Close + } + data = map[string]interface{}{ "client": map[string]interface{}{ "local": gconn.LocalAddr().String(), @@ -213,8 +225,19 @@ func (s *Server) OnTraffic(gconn gnet.Conn) gnet.Action { if err := s.proxy.PassThrough(gconn); err != nil { s.logger.Error().Err(err).Msg("Failed to pass through traffic") // TODO: Close the connection *gracefully* - return gnet.Close + switch { + case errors.Is(err, gerr.ErrPoolExhausted): + case errors.Is(err, gerr.ErrCastFailed): + case errors.Is(err, gerr.ErrClientNotFound): + case errors.Is(err, gerr.ErrClientNotConnected): + case errors.Is(err, gerr.ErrClientSendFailed): + case errors.Is(err, gerr.ErrClientReceiveFailed): + case errors.Is(err, io.EOF): + return gnet.Close + } } + // Flush the connection to make sure all data is sent + gconn.Flush() return gnet.None } diff --git a/network/server_test.go b/network/server_test.go index b5f08bb8..e5bc6b7c 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -90,11 +90,13 @@ func TestRunServer(t *testing.T) { hooksConfig.Add(plugin.OnEgressTraffic, 1, onEgressTraffic) // Create a connection pool - pool := pool.NewPool() + pool := pool.NewPool(2) client1 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger) - pool.Put(client1.ID, client1) + err := pool.Put(client1.ID, client1) + assert.Nil(t, err) client2 := NewClient("tcp", "localhost:5432", DefaultBufferSize, logger) - pool.Put(client2.ID, client2) + err = pool.Put(client2.ID, client2) + assert.Nil(t, err) // Create a proxy with a fixed buffer pool proxy := NewProxy(pool, hooksConfig, false, false, &Client{ @@ -140,8 +142,9 @@ func TestRunServer(t *testing.T) { defer client.Close() assert.NotNil(t, client) - err := client.Send(CreatePgStartupPacket()) + sent, err := client.Send(CreatePgStartupPacket()) assert.Nil(t, err) + assert.Equal(t, len(CreatePgStartupPacket()), sent) // The server should respond with a 'R' packet size, data, err := client.Receive() diff --git a/plugin/hooks.go b/plugin/hooks.go index 079ace8d..f870c095 100644 --- a/plugin/hooks.go +++ b/plugin/hooks.go @@ -93,6 +93,7 @@ func (h *HookConfig) Run( verification Policy, opts ...grpc.CallOption, ) (*structpb.Struct, error) { + // TODO: accept args as map[string]interface{} and convert to structpb.Struct if ctx == nil { ctx = context.Background() } diff --git a/plugin/registry.go b/plugin/registry.go index 6311b2b8..d2349d78 100644 --- a/plugin/registry.go +++ b/plugin/registry.go @@ -17,6 +17,7 @@ const ( DefaultMinPort uint = 50000 DefaultMaxPort uint = 60000 PluginPriorityStart uint = 1000 + EmptyPoolCapacity int = 0 LoggerName string = "plugin" ) @@ -44,11 +45,15 @@ type RegistryImpl struct { var _ Registry = &RegistryImpl{} func NewRegistry(hooksConfig *HookConfig) *RegistryImpl { - return &RegistryImpl{plugins: pool.NewPool(), hooksConfig: hooksConfig} + return &RegistryImpl{plugins: pool.NewPool(EmptyPoolCapacity), hooksConfig: hooksConfig} } func (reg *RegistryImpl) Add(plugin *Impl) bool { - _, loaded := reg.plugins.GetOrPut(plugin.ID, plugin) + _, loaded, err := reg.plugins.GetOrPut(plugin.ID, plugin) + if err != nil { + reg.hooksConfig.Logger.Error().Err(err).Msg("Failed to add plugin to registry") + return false + } return loaded } diff --git a/pool/pool.go b/pool/pool.go index a823ccfa..202a4392 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -2,6 +2,12 @@ package pool import ( "sync" + + gerr "github.com/gatewayd-io/gatewayd/errors" +) + +const ( + EmptyPoolCapacity = 0 ) type Callback func(key, value interface{}) bool @@ -9,17 +15,19 @@ type Callback func(key, value interface{}) bool type Pool interface { ForEach(Callback) Pool() *sync.Map - Put(key, value interface{}) + Put(key, value interface{}) error Get(key interface{}) interface{} - GetOrPut(key, value interface{}) (interface{}, bool) + GetOrPut(key, value interface{}) (interface{}, bool, error) Pop(key interface{}) interface{} Remove(key interface{}) Size() int Clear() + Cap() int } type Impl struct { pool sync.Map + cap int } var _ Pool = &Impl{} @@ -32,32 +40,46 @@ func (p *Impl) Pool() *sync.Map { return &p.pool } -func (p *Impl) Put(key, value interface{}) { +func (p *Impl) Put(key, value interface{}) error { + if p.cap > 0 && p.Size() >= p.cap { + return gerr.ErrPoolExhausted + } p.pool.Store(key, value) + return nil } func (p *Impl) Get(key interface{}) interface{} { if value, ok := p.pool.Load(key); ok { return value } - return nil } -func (p *Impl) GetOrPut(key, value interface{}) (interface{}, bool) { - return p.pool.LoadOrStore(key, value) +func (p *Impl) GetOrPut(key, value interface{}) (interface{}, bool, error) { + if p.cap > 0 && p.Size() >= p.cap { + return nil, false, gerr.ErrPoolExhausted + } + val, loaded := p.pool.LoadOrStore(key, value) + return val, loaded, nil } func (p *Impl) Pop(key interface{}) interface{} { + if p.Size() == 0 { + return nil + } if value, ok := p.pool.LoadAndDelete(key); ok { return value } - return nil } func (p *Impl) Remove(key interface{}) { - p.pool.Delete(key) + if p.Size() == 0 { + return + } + if _, ok := p.pool.Load(key); ok { + p.pool.Delete(key) + } } func (p *Impl) Size() int { @@ -74,6 +96,11 @@ func (p *Impl) Clear() { p.pool = sync.Map{} } -func NewPool() *Impl { - return &Impl{pool: sync.Map{}} +func (p *Impl) Cap() int { + return p.cap +} + +//nolint:predeclared +func NewPool(cap int) *Impl { + return &Impl{pool: sync.Map{}, cap: cap} } diff --git a/pool/pool_test.go b/pool/pool_test.go index 34a8ac0e..834ad69c 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -7,7 +7,7 @@ import ( ) func TestNewPool(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) @@ -15,27 +15,31 @@ func TestNewPool(t *testing.T) { } func TestPool_Put(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) assert.Equal(t, 0, pool.Size()) - pool.Put("client1.ID", "client1") + err := pool.Put("client1.ID", "client1") + assert.Nil(t, err) assert.Equal(t, 1, pool.Size()) - pool.Put("client2.ID", "client2") + err = pool.Put("client2.ID", "client2") assert.Equal(t, 2, pool.Size()) + assert.Nil(t, err) } //nolint:dupl func TestPool_Pop(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) assert.Equal(t, 0, pool.Size()) - pool.Put("client1.ID", "client1") + err := pool.Put("client1.ID", "client1") + assert.Nil(t, err) assert.Equal(t, 1, pool.Size()) - pool.Put("client2.ID", "client2") + err = pool.Put("client2.ID", "client2") + assert.Nil(t, err) assert.Equal(t, 2, pool.Size()) if c1, ok := pool.Pop("client1.ID").(string); !ok { assert.Equal(t, c1, "client1") @@ -52,28 +56,32 @@ func TestPool_Pop(t *testing.T) { } func TestPool_Clear(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) assert.Equal(t, 0, pool.Size()) - pool.Put("client1.ID", "client1") + err := pool.Put("client1.ID", "client1") + assert.Nil(t, err) assert.Equal(t, 1, pool.Size()) - pool.Put("client2.ID", "client2") + err = pool.Put("client2.ID", "client2") + assert.Nil(t, err) assert.Equal(t, 2, pool.Size()) pool.Clear() assert.Equal(t, 0, pool.Size()) } func TestPool_ForEach(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) assert.Equal(t, 0, pool.Size()) - pool.Put("client1.ID", "client1") + err := pool.Put("client1.ID", "client1") + assert.Nil(t, err) assert.Equal(t, 1, pool.Size()) - pool.Put("client2.ID", "client2") + err = pool.Put("client2.ID", "client2") + assert.Nil(t, err) assert.Equal(t, 2, pool.Size()) pool.ForEach(func(key, value interface{}) bool { if c, ok := value.(string); ok { @@ -85,14 +93,16 @@ func TestPool_ForEach(t *testing.T) { //nolint:dupl func TestPool_Get(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) assert.Equal(t, 0, pool.Size()) - pool.Put("client1.ID", "client1") + err := pool.Put("client1.ID", "client1") + assert.Nil(t, err) assert.Equal(t, 1, pool.Size()) - pool.Put("client2.ID", "client2") + err = pool.Put("client2.ID", "client2") + assert.Nil(t, err) assert.Equal(t, 2, pool.Size()) if c1, ok := pool.Get("client1.ID").(string); !ok { assert.Equal(t, c1, "client1") @@ -109,16 +119,18 @@ func TestPool_Get(t *testing.T) { } func TestPool_GetOrPut(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) assert.Equal(t, 0, pool.Size()) - pool.Put("client1.ID", "client1") + err := pool.Put("client1.ID", "client1") + assert.Nil(t, err) assert.Equal(t, 1, pool.Size()) - pool.Put("client2.ID", "client2") + err = pool.Put("client2.ID", "client2") + assert.Nil(t, err) assert.Equal(t, 2, pool.Size()) - c1, loaded := pool.GetOrPut("client1.ID", "client1") + c1, loaded, err := pool.GetOrPut("client1.ID", "client1") assert.True(t, loaded) if c1, ok := c1.(string); !ok { assert.Equal(t, c1, "client1") @@ -126,7 +138,8 @@ func TestPool_GetOrPut(t *testing.T) { assert.Equal(t, "client1", c1) assert.Equal(t, 2, pool.Size()) } - c2, loaded := pool.GetOrPut("client2.ID", "client2") + assert.Nil(t, err) + c2, loaded, err := pool.GetOrPut("client2.ID", "client2") assert.True(t, loaded) if c2, ok := c2.(string); !ok { assert.Equal(t, c2, "client2") @@ -134,17 +147,20 @@ func TestPool_GetOrPut(t *testing.T) { assert.Equal(t, "client2", c2) assert.Equal(t, 2, pool.Size()) } + assert.Nil(t, err) } func TestPool_Remove(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) assert.Equal(t, 0, pool.Size()) - pool.Put("client1.ID", "client1") + err := pool.Put("client1.ID", "client1") + assert.Nil(t, err) assert.Equal(t, 1, pool.Size()) - pool.Put("client2.ID", "client2") + err = pool.Put("client2.ID", "client2") + assert.Nil(t, err) assert.Equal(t, 2, pool.Size()) pool.Remove("client1.ID") assert.Equal(t, 1, pool.Size()) @@ -153,14 +169,16 @@ func TestPool_Remove(t *testing.T) { } func TestPool_GetClientIDs(t *testing.T) { - pool := NewPool() + pool := NewPool(EmptyPoolCapacity) defer pool.Clear() assert.NotNil(t, pool) assert.NotNil(t, pool.Pool()) assert.Equal(t, 0, pool.Size()) - pool.Put("client1.ID", "client1") + err := pool.Put("client1.ID", "client1") + assert.Nil(t, err) assert.Equal(t, 1, pool.Size()) - pool.Put("client2.ID", "client2") + err = pool.Put("client2.ID", "client2") + assert.Nil(t, err) assert.Equal(t, 2, pool.Size()) var ids []string