From 3b346d52ef4ce1375880609dfce8e3fc0fd746d3 Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Fri, 24 Apr 2026 16:42:09 +0800 Subject: [PATCH] fix: mcp redis cache --- core/model/mcp_cache_redis.go | 135 +++++++++++++++++++++++++++++ core/model/mcp_cache_redis_test.go | 131 ++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 core/model/mcp_cache_redis.go create mode 100644 core/model/mcp_cache_redis_test.go diff --git a/core/model/mcp_cache_redis.go b/core/model/mcp_cache_redis.go new file mode 100644 index 00000000..279af6db --- /dev/null +++ b/core/model/mcp_cache_redis.go @@ -0,0 +1,135 @@ +package model + +import ( + "encoding" + "strconv" + + "github.com/bytedance/sonic" + "github.com/labring/aiproxy/core/common/conv" + "github.com/redis/go-redis/v9" +) + +var ( + _ encoding.BinaryMarshaler = GroupMCPStatus(0) + _ redis.Scanner = (*GroupMCPStatus)(nil) + _ encoding.BinaryMarshaler = GroupMCPType("") + _ redis.Scanner = (*GroupMCPType)(nil) + _ encoding.BinaryMarshaler = PublicMCPStatus(0) + _ redis.Scanner = (*PublicMCPStatus)(nil) + _ encoding.BinaryMarshaler = PublicMCPType("") + _ redis.Scanner = (*PublicMCPType)(nil) + _ encoding.BinaryMarshaler = (*GroupMCPProxyConfig)(nil) + _ redis.Scanner = (*GroupMCPProxyConfig)(nil) + _ encoding.BinaryMarshaler = MCPPrice{} + _ redis.Scanner = (*MCPPrice)(nil) + _ encoding.BinaryMarshaler = (*PublicMCPProxyConfig)(nil) + _ redis.Scanner = (*PublicMCPProxyConfig)(nil) + _ encoding.BinaryMarshaler = (*MCPOpenAPIConfig)(nil) + _ redis.Scanner = (*MCPOpenAPIConfig)(nil) + _ encoding.BinaryMarshaler = (*MCPEmbeddingConfig)(nil) + _ redis.Scanner = (*MCPEmbeddingConfig)(nil) +) + +func (s *GroupMCPStatus) ScanRedis(value string) error { + v, err := strconv.Atoi(value) + if err != nil { + return err + } + + *s = GroupMCPStatus(v) + + return nil +} + +func (s GroupMCPStatus) MarshalBinary() ([]byte, error) { + return conv.StringToBytes(strconv.Itoa(int(s))), nil +} + +func (t *GroupMCPType) ScanRedis(value string) error { + *t = GroupMCPType(value) + return nil +} + +func (t GroupMCPType) MarshalBinary() ([]byte, error) { + return conv.StringToBytes(string(t)), nil +} + +func (s *PublicMCPStatus) ScanRedis(value string) error { + v, err := strconv.Atoi(value) + if err != nil { + return err + } + + *s = PublicMCPStatus(v) + + return nil +} + +func (s PublicMCPStatus) MarshalBinary() ([]byte, error) { + return conv.StringToBytes(strconv.Itoa(int(s))), nil +} + +func (t *PublicMCPType) ScanRedis(value string) error { + *t = PublicMCPType(value) + return nil +} + +func (t PublicMCPType) MarshalBinary() ([]byte, error) { + return conv.StringToBytes(string(t)), nil +} + +func (c *GroupMCPProxyConfig) ScanRedis(value string) error { + return sonic.UnmarshalString(value, c) +} + +func (c *GroupMCPProxyConfig) MarshalBinary() ([]byte, error) { + if c == nil { + return conv.StringToBytes("null"), nil + } + + return sonic.Marshal(c) +} + +func (p *MCPPrice) ScanRedis(value string) error { + return sonic.UnmarshalString(value, p) +} + +func (p MCPPrice) MarshalBinary() ([]byte, error) { + return sonic.Marshal(p) +} + +func (c *PublicMCPProxyConfig) ScanRedis(value string) error { + return sonic.UnmarshalString(value, c) +} + +func (c *PublicMCPProxyConfig) MarshalBinary() ([]byte, error) { + if c == nil { + return conv.StringToBytes("null"), nil + } + + return sonic.Marshal(c) +} + +func (c *MCPOpenAPIConfig) ScanRedis(value string) error { + return sonic.UnmarshalString(value, c) +} + +func (c *MCPOpenAPIConfig) MarshalBinary() ([]byte, error) { + if c == nil { + return conv.StringToBytes("null"), nil + } + + return sonic.Marshal(c) +} + +func (c *MCPEmbeddingConfig) ScanRedis(value string) error { + return sonic.UnmarshalString(value, c) +} + +func (c *MCPEmbeddingConfig) MarshalBinary() ([]byte, error) { + if c == nil { + return conv.StringToBytes("null"), nil + } + + return sonic.Marshal(c) +} diff --git a/core/model/mcp_cache_redis_test.go b/core/model/mcp_cache_redis_test.go new file mode 100644 index 00000000..018a8dcf --- /dev/null +++ b/core/model/mcp_cache_redis_test.go @@ -0,0 +1,131 @@ +//nolint:testpackage +package model + +import ( + "context" + "net" + "path/filepath" + "testing" + + "github.com/labring/aiproxy/core/common" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +func TestCacheSetAndGetPublicMCPViaRedis(t *testing.T) { + withTestMCPRedisEnv(t, func(ctx context.Context, client *redis.Client) { + cache := &PublicMCPCache{ + ID: "public-mcp-redis", + Status: PublicMCPStatusEnabled, + Type: PublicMCPTypeOpenAPI, + Price: MCPPrice{ + DefaultToolsCallPrice: 1.23, + ToolsCallPrices: map[string]float64{ + "tool-a": 2.34, + }, + }, + OpenAPIConfig: &MCPOpenAPIConfig{ + OpenAPISpec: "https://example.com/openapi.json", + }, + EmbedConfig: &MCPEmbeddingConfig{ + Init: map[string]string{ + "foo": "bar", + }, + }, + } + + require.NoError(t, CacheSetPublicMCP(cache)) + + modelLocalCache.Flush() + + got, err := CacheGetPublicMCP(cache.ID) + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, cache.ID, got.ID) + assert.Equal(t, PublicMCPStatusEnabled, got.Status) + assert.Equal(t, PublicMCPTypeOpenAPI, got.Type) + assert.Equal(t, cache.Price.DefaultToolsCallPrice, got.Price.DefaultToolsCallPrice) + assert.Equal(t, cache.Price.ToolsCallPrices, got.Price.ToolsCallPrices) + require.NotNil(t, got.OpenAPIConfig) + assert.Equal(t, cache.OpenAPIConfig.OpenAPISpec, got.OpenAPIConfig.OpenAPISpec) + require.NotNil(t, got.EmbedConfig) + assert.Equal(t, cache.EmbedConfig.Init, got.EmbedConfig.Init) + + exists, err := client.Exists(ctx, getPublicMCPCacheKey(cache.ID)).Result() + require.NoError(t, err) + assert.EqualValues(t, 1, exists) + }) +} + +func withTestMCPRedisEnv(t *testing.T, fn func(context.Context, *redis.Client)) { + t.Helper() + + ctx := context.Background() + + oldDB := DB + oldLogDB := LogDB + oldRDB := common.RDB + oldRedisEnabled := common.RedisEnabled + oldUsingSQLite := common.UsingSQLite + + db, err := OpenSQLite(filepath.Join(t.TempDir(), "mcp_cache_redis_test.db")) + require.NoError(t, err) + require.NoError(t, db.AutoMigrate(&PublicMCP{}, &PublicMCPReusingParam{}, &GroupMCP{})) + + req := testcontainers.ContainerRequest{ + Image: "redis:7-alpine", + ExposedPorts: []string{"6379/tcp"}, + WaitingFor: wait.ForListeningPort("6379/tcp"), + } + + container, err := testcontainers.GenericContainer( + ctx, + testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }, + ) + require.NoError(t, err) + + host, err := container.Host(ctx) + require.NoError(t, err) + + port, err := container.MappedPort(ctx, "6379") + require.NoError(t, err) + + client := redis.NewClient(&redis.Options{ + Addr: net.JoinHostPort(host, port.Port()), + DB: 0, + }) + require.NoError(t, client.Ping(ctx).Err()) + + DB = db + LogDB = db + common.RDB = client + common.RedisEnabled = true + common.UsingSQLite = true + + modelLocalCache.Flush() + + t.Cleanup(func() { + DB = oldDB + LogDB = oldLogDB + common.RDB = oldRDB + common.RedisEnabled = oldRedisEnabled + common.UsingSQLite = oldUsingSQLite + + modelLocalCache.Flush() + + _ = client.Close() + _ = container.Terminate(ctx) + + sqlDB, sqlErr := db.DB() + require.NoError(t, sqlErr) + require.NoError(t, sqlDB.Close()) + }) + + fn(ctx, client) +}