Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions core/model/mcp_cache_redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package model

import (
"encoding"
"strconv"

"github.com/bytedance/sonic"
"github.com/labring/aiproxy/core/common/conv"
"github.com/redis/go-redis/v9"
)

var (
_ encoding.BinaryMarshaler = GroupMCPStatus(0)
_ redis.Scanner = (*GroupMCPStatus)(nil)
_ encoding.BinaryMarshaler = GroupMCPType("")
_ redis.Scanner = (*GroupMCPType)(nil)
_ encoding.BinaryMarshaler = PublicMCPStatus(0)
_ redis.Scanner = (*PublicMCPStatus)(nil)
_ encoding.BinaryMarshaler = PublicMCPType("")
_ redis.Scanner = (*PublicMCPType)(nil)
_ encoding.BinaryMarshaler = (*GroupMCPProxyConfig)(nil)
_ redis.Scanner = (*GroupMCPProxyConfig)(nil)
_ encoding.BinaryMarshaler = MCPPrice{}
_ redis.Scanner = (*MCPPrice)(nil)
_ encoding.BinaryMarshaler = (*PublicMCPProxyConfig)(nil)
_ redis.Scanner = (*PublicMCPProxyConfig)(nil)
_ encoding.BinaryMarshaler = (*MCPOpenAPIConfig)(nil)
_ redis.Scanner = (*MCPOpenAPIConfig)(nil)
_ encoding.BinaryMarshaler = (*MCPEmbeddingConfig)(nil)
_ redis.Scanner = (*MCPEmbeddingConfig)(nil)
)

func (s *GroupMCPStatus) ScanRedis(value string) error {
v, err := strconv.Atoi(value)
if err != nil {
return err
}

*s = GroupMCPStatus(v)

return nil
}

func (s GroupMCPStatus) MarshalBinary() ([]byte, error) {
return conv.StringToBytes(strconv.Itoa(int(s))), nil
}

func (t *GroupMCPType) ScanRedis(value string) error {
*t = GroupMCPType(value)
return nil
}

func (t GroupMCPType) MarshalBinary() ([]byte, error) {
return conv.StringToBytes(string(t)), nil
}

func (s *PublicMCPStatus) ScanRedis(value string) error {
v, err := strconv.Atoi(value)
if err != nil {
return err
}

*s = PublicMCPStatus(v)

return nil
}

func (s PublicMCPStatus) MarshalBinary() ([]byte, error) {
return conv.StringToBytes(strconv.Itoa(int(s))), nil
}

func (t *PublicMCPType) ScanRedis(value string) error {
*t = PublicMCPType(value)
return nil
}

func (t PublicMCPType) MarshalBinary() ([]byte, error) {
return conv.StringToBytes(string(t)), nil
}

func (c *GroupMCPProxyConfig) ScanRedis(value string) error {
return sonic.UnmarshalString(value, c)
}

func (c *GroupMCPProxyConfig) MarshalBinary() ([]byte, error) {
if c == nil {
return conv.StringToBytes("null"), nil
}

return sonic.Marshal(c)
}

func (p *MCPPrice) ScanRedis(value string) error {
return sonic.UnmarshalString(value, p)
}

func (p MCPPrice) MarshalBinary() ([]byte, error) {
return sonic.Marshal(p)
}

func (c *PublicMCPProxyConfig) ScanRedis(value string) error {
return sonic.UnmarshalString(value, c)
}

func (c *PublicMCPProxyConfig) MarshalBinary() ([]byte, error) {
if c == nil {
return conv.StringToBytes("null"), nil
}

return sonic.Marshal(c)
}

func (c *MCPOpenAPIConfig) ScanRedis(value string) error {
return sonic.UnmarshalString(value, c)
}

func (c *MCPOpenAPIConfig) MarshalBinary() ([]byte, error) {
if c == nil {
return conv.StringToBytes("null"), nil
}

return sonic.Marshal(c)
}

func (c *MCPEmbeddingConfig) ScanRedis(value string) error {
return sonic.UnmarshalString(value, c)
}

func (c *MCPEmbeddingConfig) MarshalBinary() ([]byte, error) {
if c == nil {
return conv.StringToBytes("null"), nil
}

return sonic.Marshal(c)
}
131 changes: 131 additions & 0 deletions core/model/mcp_cache_redis_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
//nolint:testpackage
package model

import (
"context"
"net"
"path/filepath"
"testing"

"github.com/labring/aiproxy/core/common"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)

func TestCacheSetAndGetPublicMCPViaRedis(t *testing.T) {
withTestMCPRedisEnv(t, func(ctx context.Context, client *redis.Client) {
cache := &PublicMCPCache{
ID: "public-mcp-redis",
Status: PublicMCPStatusEnabled,
Type: PublicMCPTypeOpenAPI,
Price: MCPPrice{
DefaultToolsCallPrice: 1.23,
ToolsCallPrices: map[string]float64{
"tool-a": 2.34,
},
},
OpenAPIConfig: &MCPOpenAPIConfig{
OpenAPISpec: "https://example.com/openapi.json",
},
EmbedConfig: &MCPEmbeddingConfig{
Init: map[string]string{
"foo": "bar",
},
},
}

require.NoError(t, CacheSetPublicMCP(cache))

modelLocalCache.Flush()

got, err := CacheGetPublicMCP(cache.ID)
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, cache.ID, got.ID)
assert.Equal(t, PublicMCPStatusEnabled, got.Status)
assert.Equal(t, PublicMCPTypeOpenAPI, got.Type)
assert.Equal(t, cache.Price.DefaultToolsCallPrice, got.Price.DefaultToolsCallPrice)
assert.Equal(t, cache.Price.ToolsCallPrices, got.Price.ToolsCallPrices)
require.NotNil(t, got.OpenAPIConfig)
assert.Equal(t, cache.OpenAPIConfig.OpenAPISpec, got.OpenAPIConfig.OpenAPISpec)
require.NotNil(t, got.EmbedConfig)
assert.Equal(t, cache.EmbedConfig.Init, got.EmbedConfig.Init)

exists, err := client.Exists(ctx, getPublicMCPCacheKey(cache.ID)).Result()
require.NoError(t, err)
assert.EqualValues(t, 1, exists)
})
}

func withTestMCPRedisEnv(t *testing.T, fn func(context.Context, *redis.Client)) {
t.Helper()

ctx := context.Background()

oldDB := DB
oldLogDB := LogDB
oldRDB := common.RDB
oldRedisEnabled := common.RedisEnabled
oldUsingSQLite := common.UsingSQLite

db, err := OpenSQLite(filepath.Join(t.TempDir(), "mcp_cache_redis_test.db"))
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&PublicMCP{}, &PublicMCPReusingParam{}, &GroupMCP{}))

req := testcontainers.ContainerRequest{
Image: "redis:7-alpine",
ExposedPorts: []string{"6379/tcp"},
WaitingFor: wait.ForListeningPort("6379/tcp"),
}

container, err := testcontainers.GenericContainer(
ctx,
testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
},
)
require.NoError(t, err)

host, err := container.Host(ctx)
require.NoError(t, err)

port, err := container.MappedPort(ctx, "6379")
require.NoError(t, err)

client := redis.NewClient(&redis.Options{
Addr: net.JoinHostPort(host, port.Port()),
DB: 0,
})
require.NoError(t, client.Ping(ctx).Err())

DB = db
LogDB = db
common.RDB = client
common.RedisEnabled = true
common.UsingSQLite = true

modelLocalCache.Flush()

t.Cleanup(func() {
DB = oldDB
LogDB = oldLogDB
common.RDB = oldRDB
common.RedisEnabled = oldRedisEnabled
common.UsingSQLite = oldUsingSQLite

modelLocalCache.Flush()

_ = client.Close()
_ = container.Terminate(ctx)

sqlDB, sqlErr := db.DB()
require.NoError(t, sqlErr)
require.NoError(t, sqlDB.Close())
})

fn(ctx, client)
}
Loading