Skip to content

Commit

Permalink
skip caching datetime functions (#82)
Browse files Browse the repository at this point in the history
- added IsCacheNeeded function
- added test case for TestPluginDateFunctionInQuery
  • Loading branch information
sinadarbouy committed Jun 16, 2024
1 parent 5497d4c commit ceb0e9b
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 1 deletion.
30 changes: 29 additions & 1 deletion plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ type CachePlugin struct {
Impl Plugin
}

// Define a set for PostgreSQL date/time functions
// https://www.postgresql.org/docs/8.2/functions-datetime.html
var pgDateTimeFunctions = map[string]struct{}{
"AGE": {},
"CLOCK_TIMESTAMP": {},
"CURRENT_DATE": {},
"CURRENT_TIME": {},
"CURRENT_TIMESTAMP": {},
"LOCALTIME": {},
"LOCALTIMESTAMP": {},
"NOW": {},
"STATEMENT_TIMESTAMP": {},
"TIMEOFDAY": {},
"TRANSACTION_TIMESTAMP": {},
}

// NewCachePlugin returns a new instance of the CachePlugin.
func NewCachePlugin(impl Plugin) *CachePlugin {
return &CachePlugin{
Expand Down Expand Up @@ -164,6 +180,18 @@ func (p *Plugin) OnTrafficFromClient(
return req, nil
}

// IsCacheNeeded determines if caching is needed.
func IsCacheNeeded(upperQuery string) bool {
// Iterate over each function name in the set of PostgreSQL date/time functions.
for function := range pgDateTimeFunctions {
if strings.Contains(upperQuery, function) {
// If the query contains a date/time function, caching is not needed.
return false
}
}
return true
}

func (p *Plugin) UpdateCache(ctx context.Context) {
for {
serverResponse, ok := <-p.UpdateCacheChannel
Expand Down Expand Up @@ -219,7 +247,7 @@ func (p *Plugin) UpdateCache(ctx context.Context) {
}

cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":")
if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 {
if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 && IsCacheNeeded(cacheKey) {
// The request was successful and the response contains data. Cache the response.
if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil {
CacheMissesCounter.Inc()
Expand Down
89 changes: 89 additions & 0 deletions plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ func testQueryRequest() (string, []byte) {
return query, queryBytes
}

func testQueryRequestWithDateFucntion() (string, []byte) {
query := `SELECT
user_id,
username,
last_login,
NOW() AS current_time
FROM
users
WHERE
last_login >= CURRENT_DATE;`
queryMsg := pgproto3.Query{String: query}
// Encode the data to base64.
queryBytes, _ := queryMsg.Encode(nil)
return query, queryBytes
}

func testStartupRequest() []byte {
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: 196608,
Expand Down Expand Up @@ -180,3 +196,76 @@ func Test_Plugin(t *testing.T) {
assert.Equal(t, resultMap["response"], response)
assert.Contains(t, resultMap, sdkAct.Signals)
}

func TestPluginDateFunctionInQuery(t *testing.T) {
// Initialize a new mock Redis server.
mockRedisServer := miniredis.RunT(t)
redisURL := "redis://" + mockRedisServer.Addr() + "/0"
redisConfig, err := redis.ParseURL(redisURL)
redisClient := redis.NewClient(redisConfig)

cacheUpdateChannel := make(chan *v1.Struct, 10)

// Create and initialize a new plugin.
logger := hclog.New(&hclog.LoggerOptions{
Level: logging.GetLogLevel("error"),
Output: os.Stdout,
})
plugin := NewCachePlugin(Plugin{
Logger: logger,
RedisURL: redisURL,
RedisClient: redisClient,
UpdateCacheChannel: cacheUpdateChannel,
})

// Use a WaitGroup to wait for the goroutine to finish.
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
plugin.Impl.UpdateCache(context.Background())
}()

// Test the plugin's OnTrafficFromClient method with a StartupMessage.
clientArgs := map[string]interface{}{
"request": testStartupRequest(),
"client": map[string]interface{}{
"local": "localhost:15432",
"remote": "localhost:45320",
},
"server": map[string]interface{}{
"local": "localhost:54321",
"remote": "localhost:5432",
},
"error": "",
}
clientRequest, err := v1.NewStruct(clientArgs)
plugin.Impl.OnTrafficFromClient(context.Background(), clientRequest)

// Test the plugin's OnTrafficFromServer method with a query request.
_, queryRequest := testQueryRequestWithDateFucntion()
queryResponse, err := base64.StdEncoding.DecodeString("VAAAABsAAWlkAAAAQAQAAQAAABcABP////8AAEQAAAALAAEAAAABMUMAAAANU0VMRUNUIDEAWgAAAAVJ")
assert.Nil(t, err)
queryArgs := map[string]interface{}{
"request": queryRequest,
"response": queryResponse,
"client": map[string]interface{}{
"local": "localhost:15432",
"remote": "localhost:45320",
},
"server": map[string]interface{}{
"local": "localhost:54321",
"remote": "localhost:5432",
},
"error": "",
}
serverRequest, err := v1.NewStruct(queryArgs)
plugin.Impl.OnTrafficFromServer(context.Background(), serverRequest)

// Close the channel and wait for the cache updater to return gracefully.
close(cacheUpdateChannel)
wg.Wait()

keys, _ := redisClient.Keys(context.Background(), "*").Result()
assert.Equal(t, 1, len(keys)) // Only one key (representing the database name) should be present.
}

0 comments on commit ceb0e9b

Please sign in to comment.