Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,49 @@ const (

// Config holds the database configuration.
type Config struct {
Host string
Port string
Name string
User string
Password string
SSLMode string
Host string `json:"host"`
Port string `json:"port"`
Name string `json:"name"`
User string `json:"user"`
Password string `json:"password"`
SSLMode string `json:"sslmode"`
}

// LoadConfig loads the configuration from environment variables.
// LoadConfigs loads multiple database configurations from a JSON string in environment variable.
func LoadConfigs(envVar string) (map[string]Config, error) {
jsonConfig := os.Getenv(envVar)
if jsonConfig == "" {
return nil, fmt.Errorf("environment variable %s is empty", envVar)
}

var configs map[string]Config
err := json.Unmarshal([]byte(jsonConfig), &configs)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON config: %v", err)
}

// Validate each config
for name, cfg := range configs {
if cfg.Name == "" || cfg.User == "" || cfg.Password == "" {
return nil, fmt.Errorf("missing required fields in config for database %s", name)
}
// Set defaults
if cfg.Host == "" {
cfg.Host = "localhost"
}
if cfg.Port == "" {
cfg.Port = "5432"
}
if cfg.SSLMode == "" {
cfg.SSLMode = "disable"
}
configs[name] = cfg
}

return configs, nil
}

// LoadConfig loads a single database configuration from environment variables.
func LoadConfig() (*Config, error) {
// Load .env file if it exists
godotenv.Load()
Expand Down
39 changes: 38 additions & 1 deletion config/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,44 @@ import (
_ "github.com/lib/pq"
)

// ConnectDB establishes a connection to the PostgreSQL database.
// ConnectDBs establishes connections to multiple PostgreSQL databases.
func ConnectDBs(configs map[string]Config) (map[string]*sql.DB, error) {
dbs := make(map[string]*sql.DB)
var firstErr error
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest to have a bag of errors here (var errBag []error) to process a case when more than one DB is throwing an error.


for name, config := range configs {
connStr := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
config.Host, config.Port, config.User, config.Password, config.Name, config.SSLMode,
)

db, err := sql.Open("postgres", connStr)
if err != nil {
if firstErr == nil {
firstErr = fmt.Errorf("failed to connect to database %s: %v", name, err)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and errBag = append(errBag, fmt.Errorf(...)) here

}
continue
}

if err := db.Ping(); err != nil {
db.Close()
if firstErr == nil {
firstErr = fmt.Errorf("failed to ping database %s: %v", name, err)
}
continue
}

dbs[name] = db
}

if len(dbs) == 0 && firstErr != nil {
return nil, firstErr
}

return dbs, firstErr
}

// ConnectDB establishes a connection to a single PostgreSQL database.
func ConnectDB(config *Config) (*sql.DB, error) {
connStr := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
Expand Down
35 changes: 22 additions & 13 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,29 @@ import (
)

func main() {
// Load configuration
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be a file if the environment variable is undefined

Copy link
Owner

@habuvo habuvo Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

`Yes, it could be a file, but in general it supposed to work with environmental variables provided by the tool using the server (Roo Code, Cline etc)
But let's have it implemented this way for the future extension 👍

cfg, err := config.LoadConfig()
// Load configurations
configs, err := config.LoadConfigs("POSTGRES_DBS")
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
log.Fatalf("Failed to load configurations: %v", err)
}

// Connect to the database
db, err := config.ConnectDB(cfg)
// Connect to all databases
dbs, err := config.ConnectDBs(configs)
if err != nil {
log.Fatalf("Failed to connect to the database: %v", err)
log.Printf("Warning: some database connections failed: %v", err)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe %+v here for better representation for bag of error
or even helper function to iterate over a bug and construct pseudo table or something like this, wdyt?

if len(dbs) == 0 {
log.Fatal("No database connections established")
}
}
defer db.Close()

log.Println("Successfully connected to the PostgreSQL database!")
// Close all connections on exit
defer func() {
for _, db := range dbs {
db.Close()
}
}()

log.Printf("Successfully connected to %d PostgreSQL database(s)!", len(dbs))

// Initialize the MCP server
s := server.NewMCPServer(
Expand All @@ -32,11 +41,11 @@ func main() {
server.WithLogging(),
)

// Register the PostgreSQL tools
tools.RegisterExecuteTool(s, db)
tools.RegisterQueryTool(s, db)
tools.RegisterSchemaTool(s, db)
tools.RegisterTransactionTool(s, db)
// Register the PostgreSQL tools with all databases
tools.RegisterExecuteTool(s, dbs)
tools.RegisterQueryTool(s, dbs)
tools.RegisterSchemaTool(s, dbs)
tools.RegisterTransactionTool(s, dbs)

// Start the server
log.Println("Starting MCP server...")
Expand Down
24 changes: 19 additions & 5 deletions tools/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ import (
"github.com/mark3labs/mcp-go/server"
)

func RegisterExecuteTool(s *server.MCPServer, db *sql.DB) {
func RegisterExecuteTool(s *server.MCPServer, dbs map[string]*sql.DB) {
executeTool := mcp.NewTool("execute_tool",
mcp.WithDescription("Execute statement"),
mcp.WithDescription("Execute SQL statement on specified database"),
mcp.WithString("database",
mcp.Required(),
mcp.Description("Name of the database to execute on"),
),
mcp.WithString("statement",
mcp.Required(),
mcp.Description("Statement to be executed"),
mcp.Description("SQL statement(s) to execute. For multiple statements, separate them with semicolons (;)"),
),
mcp.WithArray("arguments",
mcp.Required(),
Expand All @@ -23,11 +27,21 @@ func RegisterExecuteTool(s *server.MCPServer, db *sql.DB) {
)

s.AddTool(executeTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return handleExecute(ctx, request, db)
return handleExecute(ctx, request, dbs)
})
}

func handleExecute(ctx context.Context, req mcp.CallToolRequest, db *sql.DB) (*mcp.CallToolResult, error) {
func handleExecute(ctx context.Context, req mcp.CallToolRequest, dbs map[string]*sql.DB) (*mcp.CallToolResult, error) {
database, ok := req.Params.Arguments["database"].(string)
if !ok {
return nil, fmt.Errorf("database parameter should be a string")
}

db, exists := dbs[database]
if !exists {
return nil, fmt.Errorf("database '%s' not found in available connections", database)
}

statement, ok := req.Params.Arguments["statement"].(string)
if !ok {
return nil, fmt.Errorf("statement should be a string")
Expand Down
26 changes: 20 additions & 6 deletions tools/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,39 @@ import (
"github.com/mark3labs/mcp-go/server"
)

func RegisterQueryTool(s *server.MCPServer, db *sql.DB) {
func RegisterQueryTool(s *server.MCPServer, dbs map[string]*sql.DB) {
queryTool := mcp.NewTool("query_tool",
mcp.WithDescription("Make query"),
mcp.WithDescription("Execute SQL query on specified database"),
mcp.WithString("database",
mcp.Required(),
mcp.Description("Name of the database to query"),
),
mcp.WithString("statement",
mcp.Required(),
mcp.Description("Statement to be executed"),
mcp.Description("SQL query to execute"),
),
mcp.WithArray("arguments",
mcp.Required(),
mcp.Description("Arguments for the statement provided"),
mcp.Description("Arguments for the query provided"),
),
)

s.AddTool(queryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return handleQuery(ctx, request, db)
return handleQuery(ctx, request, dbs)
})
}

func handleQuery(ctx context.Context, req mcp.CallToolRequest, db *sql.DB) (*mcp.CallToolResult, error) {
func handleQuery(ctx context.Context, req mcp.CallToolRequest, dbs map[string]*sql.DB) (*mcp.CallToolResult, error) {
database, ok := req.Params.Arguments["database"].(string)
if !ok {
return nil, fmt.Errorf("database parameter should be a string")
}

db, exists := dbs[database]
if !exists {
return nil, fmt.Errorf("database '%s' not found in available connections", database)
}

statement, ok := req.Params.Arguments["statement"].(string)
if !ok {
return nil, fmt.Errorf("statement should be a string")
Expand Down
22 changes: 18 additions & 4 deletions tools/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,35 @@ import (
"github.com/mark3labs/mcp-go/server"
)

func RegisterSchemaTool(s *server.MCPServer, db *sql.DB) {
func RegisterSchemaTool(s *server.MCPServer, dbs map[string]*sql.DB) {
schemaTool := mcp.NewTool("schema_tool",
mcp.WithDescription("Inspect the table schema"),
mcp.WithDescription("Inspect table schema on specified database"),
mcp.WithString("database",
mcp.Required(),
mcp.Description("Name of the database to inspect"),
),
mcp.WithString("table_name",
mcp.Required(),
mcp.Description("The name of the table to inspect"),
))

s.AddTool(schemaTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return handleSchema(ctx, request, db)
return handleSchema(ctx, request, dbs)
})
}

// Execute handles the schema request
func handleSchema(ctx context.Context, request mcp.CallToolRequest, db *sql.DB) (*mcp.CallToolResult, error) {
func handleSchema(ctx context.Context, request mcp.CallToolRequest, dbs map[string]*sql.DB) (*mcp.CallToolResult, error) {
database, ok := request.Params.Arguments["database"].(string)
if !ok {
return nil, fmt.Errorf("database parameter should be a string")
}

db, exists := dbs[database]
if !exists {
return nil, fmt.Errorf("database '%s' not found in available connections", database)
}

tableName, ok := request.Params.Arguments["table_name"].(string)
if !ok || tableName == "" {
return nil, fmt.Errorf("table_name must be a non-empty string")
Expand Down
26 changes: 20 additions & 6 deletions tools/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,40 @@ import (
"github.com/mark3labs/mcp-go/server"
)

func RegisterTransactionTool(s *server.MCPServer, db *sql.DB) {
func RegisterTransactionTool(s *server.MCPServer, dbs map[string]*sql.DB) {
transactionTool := mcp.NewTool("transaction_tool",
mcp.WithDescription("Make queries in a transaction"),
mcp.WithDescription("Execute queries in a transaction on specified database"),
mcp.WithString("database",
mcp.Required(),
mcp.Description("Name of the database for the transaction"),
),
mcp.WithArray("statements",
mcp.Required(),
mcp.Description("Statements to be executed in the transaction"),
),
mcp.WithArray("arguments",
mcp.Required(),
mcp.Description("Arguments for the statents provided"),
mcp.Description("Arguments for the statements provided"),
),
)

s.AddTool(transactionTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return handleTransaction(ctx, request, db)
return handleTransaction(ctx, request, dbs)
})
}

// Execute handles the schema request.
func handleTransaction(ctx context.Context, request mcp.CallToolRequest, db *sql.DB) (*mcp.CallToolResult, error) {
// Execute handles the transaction request.
func handleTransaction(ctx context.Context, request mcp.CallToolRequest, dbs map[string]*sql.DB) (*mcp.CallToolResult, error) {
database, ok := request.Params.Arguments["database"].(string)
if !ok {
return nil, fmt.Errorf("database parameter should be a string")
}

db, exists := dbs[database]
if !exists {
return nil, fmt.Errorf("database '%s' not found in available connections", database)
}

statements, ok := request.Params.Arguments["statements"].([]string)
if !ok {
return nil, fmt.Errorf("statements should be an array of strings")
Expand Down