From b6ee0c8c089c3886d7870d95fcf481527f061e95 Mon Sep 17 00:00:00 2001 From: Jean Luc Mongrain Date: Wed, 23 Apr 2025 17:22:45 -0400 Subject: [PATCH] add support for multiple db --- config/config.go | 48 +++++++++++++++++++++++++++++++++++++------- config/connect.go | 39 ++++++++++++++++++++++++++++++++++- main.go | 35 ++++++++++++++++++++------------ tools/execute.go | 24 +++++++++++++++++----- tools/query.go | 26 ++++++++++++++++++------ tools/schema.go | 22 ++++++++++++++++---- tools/transaction.go | 26 ++++++++++++++++++------ 7 files changed, 178 insertions(+), 42 deletions(-) diff --git a/config/config.go b/config/config.go index 911816e..be7332a 100644 --- a/config/config.go +++ b/config/config.go @@ -18,15 +18,49 @@ const ( // Config holds the database configuration. type Config struct { - Host string - Port string - Name string - User string - Password string - SSLMode string + Host string `json:"host"` + Port string `json:"port"` + Name string `json:"name"` + User string `json:"user"` + Password string `json:"password"` + SSLMode string `json:"sslmode"` } -// LoadConfig loads the configuration from environment variables. +// LoadConfigs loads multiple database configurations from a JSON string in environment variable. +func LoadConfigs(envVar string) (map[string]Config, error) { + jsonConfig := os.Getenv(envVar) + if jsonConfig == "" { + return nil, fmt.Errorf("environment variable %s is empty", envVar) + } + + var configs map[string]Config + err := json.Unmarshal([]byte(jsonConfig), &configs) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON config: %v", err) + } + + // Validate each config + for name, cfg := range configs { + if cfg.Name == "" || cfg.User == "" || cfg.Password == "" { + return nil, fmt.Errorf("missing required fields in config for database %s", name) + } + // Set defaults + if cfg.Host == "" { + cfg.Host = "localhost" + } + if cfg.Port == "" { + cfg.Port = "5432" + } + if cfg.SSLMode == "" { + cfg.SSLMode = "disable" + } + configs[name] = cfg + } + + return configs, nil +} + +// LoadConfig loads a single database configuration from environment variables. func LoadConfig() (*Config, error) { // Load .env file if it exists godotenv.Load() diff --git a/config/connect.go b/config/connect.go index a3ab016..b7654a5 100644 --- a/config/connect.go +++ b/config/connect.go @@ -7,7 +7,44 @@ import ( _ "github.com/lib/pq" ) -// ConnectDB establishes a connection to the PostgreSQL database. +// ConnectDBs establishes connections to multiple PostgreSQL databases. +func ConnectDBs(configs map[string]Config) (map[string]*sql.DB, error) { + dbs := make(map[string]*sql.DB) + var firstErr error + + for name, config := range configs { + connStr := fmt.Sprintf( + "host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", + config.Host, config.Port, config.User, config.Password, config.Name, config.SSLMode, + ) + + db, err := sql.Open("postgres", connStr) + if err != nil { + if firstErr == nil { + firstErr = fmt.Errorf("failed to connect to database %s: %v", name, err) + } + continue + } + + if err := db.Ping(); err != nil { + db.Close() + if firstErr == nil { + firstErr = fmt.Errorf("failed to ping database %s: %v", name, err) + } + continue + } + + dbs[name] = db + } + + if len(dbs) == 0 && firstErr != nil { + return nil, firstErr + } + + return dbs, firstErr +} + +// ConnectDB establishes a connection to a single PostgreSQL database. func ConnectDB(config *Config) (*sql.DB, error) { connStr := fmt.Sprintf( "host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", diff --git a/main.go b/main.go index 397a329..004a136 100644 --- a/main.go +++ b/main.go @@ -10,20 +10,29 @@ import ( ) func main() { - // Load configuration - cfg, err := config.LoadConfig() + // Load configurations + configs, err := config.LoadConfigs("POSTGRES_DBS") if err != nil { - log.Fatalf("Failed to load configuration: %v", err) + log.Fatalf("Failed to load configurations: %v", err) } - // Connect to the database - db, err := config.ConnectDB(cfg) + // Connect to all databases + dbs, err := config.ConnectDBs(configs) if err != nil { - log.Fatalf("Failed to connect to the database: %v", err) + log.Printf("Warning: some database connections failed: %v", err) + if len(dbs) == 0 { + log.Fatal("No database connections established") + } } - defer db.Close() - log.Println("Successfully connected to the PostgreSQL database!") + // Close all connections on exit + defer func() { + for _, db := range dbs { + db.Close() + } + }() + + log.Printf("Successfully connected to %d PostgreSQL database(s)!", len(dbs)) // Initialize the MCP server s := server.NewMCPServer( @@ -32,11 +41,11 @@ func main() { server.WithLogging(), ) - // Register the PostgreSQL tools - tools.RegisterExecuteTool(s, db) - tools.RegisterQueryTool(s, db) - tools.RegisterSchemaTool(s, db) - tools.RegisterTransactionTool(s, db) + // Register the PostgreSQL tools with all databases + tools.RegisterExecuteTool(s, dbs) + tools.RegisterQueryTool(s, dbs) + tools.RegisterSchemaTool(s, dbs) + tools.RegisterTransactionTool(s, dbs) // Start the server log.Println("Starting MCP server...") diff --git a/tools/execute.go b/tools/execute.go index 6d789ef..998f7f0 100644 --- a/tools/execute.go +++ b/tools/execute.go @@ -9,12 +9,16 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func RegisterExecuteTool(s *server.MCPServer, db *sql.DB) { +func RegisterExecuteTool(s *server.MCPServer, dbs map[string]*sql.DB) { executeTool := mcp.NewTool("execute_tool", - mcp.WithDescription("Execute statement"), + mcp.WithDescription("Execute SQL statement on specified database"), + mcp.WithString("database", + mcp.Required(), + mcp.Description("Name of the database to execute on"), + ), mcp.WithString("statement", mcp.Required(), - mcp.Description("Statement to be executed"), + mcp.Description("SQL statement(s) to execute. For multiple statements, separate them with semicolons (;)"), ), mcp.WithArray("arguments", mcp.Required(), @@ -23,11 +27,21 @@ func RegisterExecuteTool(s *server.MCPServer, db *sql.DB) { ) s.AddTool(executeTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return handleExecute(ctx, request, db) + return handleExecute(ctx, request, dbs) }) } -func handleExecute(ctx context.Context, req mcp.CallToolRequest, db *sql.DB) (*mcp.CallToolResult, error) { +func handleExecute(ctx context.Context, req mcp.CallToolRequest, dbs map[string]*sql.DB) (*mcp.CallToolResult, error) { + database, ok := req.Params.Arguments["database"].(string) + if !ok { + return nil, fmt.Errorf("database parameter should be a string") + } + + db, exists := dbs[database] + if !exists { + return nil, fmt.Errorf("database '%s' not found in available connections", database) + } + statement, ok := req.Params.Arguments["statement"].(string) if !ok { return nil, fmt.Errorf("statement should be a string") diff --git a/tools/query.go b/tools/query.go index fce6c62..5083068 100644 --- a/tools/query.go +++ b/tools/query.go @@ -9,25 +9,39 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func RegisterQueryTool(s *server.MCPServer, db *sql.DB) { +func RegisterQueryTool(s *server.MCPServer, dbs map[string]*sql.DB) { queryTool := mcp.NewTool("query_tool", - mcp.WithDescription("Make query"), + mcp.WithDescription("Execute SQL query on specified database"), + mcp.WithString("database", + mcp.Required(), + mcp.Description("Name of the database to query"), + ), mcp.WithString("statement", mcp.Required(), - mcp.Description("Statement to be executed"), + mcp.Description("SQL query to execute"), ), mcp.WithArray("arguments", mcp.Required(), - mcp.Description("Arguments for the statent provided"), + mcp.Description("Arguments for the query provided"), ), ) s.AddTool(queryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return handleQuery(ctx, request, db) + return handleQuery(ctx, request, dbs) }) } -func handleQuery(ctx context.Context, req mcp.CallToolRequest, db *sql.DB) (*mcp.CallToolResult, error) { +func handleQuery(ctx context.Context, req mcp.CallToolRequest, dbs map[string]*sql.DB) (*mcp.CallToolResult, error) { + database, ok := req.Params.Arguments["database"].(string) + if !ok { + return nil, fmt.Errorf("database parameter should be a string") + } + + db, exists := dbs[database] + if !exists { + return nil, fmt.Errorf("database '%s' not found in available connections", database) + } + statement, ok := req.Params.Arguments["statement"].(string) if !ok { return nil, fmt.Errorf("statement should be a string") diff --git a/tools/schema.go b/tools/schema.go index 1482e08..a8d6d26 100644 --- a/tools/schema.go +++ b/tools/schema.go @@ -9,21 +9,35 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func RegisterSchemaTool(s *server.MCPServer, db *sql.DB) { +func RegisterSchemaTool(s *server.MCPServer, dbs map[string]*sql.DB) { schemaTool := mcp.NewTool("schema_tool", - mcp.WithDescription("Inspect the table schema"), + mcp.WithDescription("Inspect table schema on specified database"), + mcp.WithString("database", + mcp.Required(), + mcp.Description("Name of the database to inspect"), + ), mcp.WithString("table_name", mcp.Required(), mcp.Description("The name of the table to inspect"), )) s.AddTool(schemaTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return handleSchema(ctx, request, db) + return handleSchema(ctx, request, dbs) }) } // Execute handles the schema request -func handleSchema(ctx context.Context, request mcp.CallToolRequest, db *sql.DB) (*mcp.CallToolResult, error) { +func handleSchema(ctx context.Context, request mcp.CallToolRequest, dbs map[string]*sql.DB) (*mcp.CallToolResult, error) { + database, ok := request.Params.Arguments["database"].(string) + if !ok { + return nil, fmt.Errorf("database parameter should be a string") + } + + db, exists := dbs[database] + if !exists { + return nil, fmt.Errorf("database '%s' not found in available connections", database) + } + tableName, ok := request.Params.Arguments["table_name"].(string) if !ok || tableName == "" { return nil, fmt.Errorf("table_name must be a non-empty string") diff --git a/tools/transaction.go b/tools/transaction.go index 9ddb883..ada842c 100644 --- a/tools/transaction.go +++ b/tools/transaction.go @@ -9,26 +9,40 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func RegisterTransactionTool(s *server.MCPServer, db *sql.DB) { +func RegisterTransactionTool(s *server.MCPServer, dbs map[string]*sql.DB) { transactionTool := mcp.NewTool("transaction_tool", - mcp.WithDescription("Make queries in a transaction"), + mcp.WithDescription("Execute queries in a transaction on specified database"), + mcp.WithString("database", + mcp.Required(), + mcp.Description("Name of the database for the transaction"), + ), mcp.WithArray("statements", mcp.Required(), mcp.Description("Statements to be executed in the transaction"), ), mcp.WithArray("arguments", mcp.Required(), - mcp.Description("Arguments for the statents provided"), + mcp.Description("Arguments for the statements provided"), ), ) s.AddTool(transactionTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return handleTransaction(ctx, request, db) + return handleTransaction(ctx, request, dbs) }) } -// Execute handles the schema request. -func handleTransaction(ctx context.Context, request mcp.CallToolRequest, db *sql.DB) (*mcp.CallToolResult, error) { +// Execute handles the transaction request. +func handleTransaction(ctx context.Context, request mcp.CallToolRequest, dbs map[string]*sql.DB) (*mcp.CallToolResult, error) { + database, ok := request.Params.Arguments["database"].(string) + if !ok { + return nil, fmt.Errorf("database parameter should be a string") + } + + db, exists := dbs[database] + if !exists { + return nil, fmt.Errorf("database '%s' not found in available connections", database) + } + statements, ok := request.Params.Arguments["statements"].([]string) if !ok { return nil, fmt.Errorf("statements should be an array of strings")