Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/mysql-mcp-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ A Python-based MCP (Model Context Protocol) server that provides a suite of tool
- `ragify_column`: Create/populate vector columns for embeddings
- `ask_ml_rag`: Retrieval-augmented generation from vector stores
- `heatwave_ask_help`: Answers questions about how to use HeatWave ML
- `ask_nl_sql`: Convert natural language questions into SQL queries and execute them automatically

- **Vector Store Management**
- List files in `secure_file_priv` (local mode)
Expand Down Expand Up @@ -213,6 +214,7 @@ python mysql_mcp_server.py
11. `list_all_compartments()`: List OCI compartments
12. `object_storage_list_buckets(compartment_name | compartment_id)`: List buckets in a compartment
13. `object_storage_list_objects(namespace, bucket_name)`: List objects in a bucket
14. `ask_nl_sql(connection_id, question)`: Convert natural language questions into SQL queries and execute them automatically

## Security

Expand All @@ -236,6 +238,7 @@ Here are example prompts you can use to interact with the MCP server, note that
```
"Generate a summary of error logs"
"Ask ml_rag: Show me refund policy from the vector store"
"What is the average delay incurred by flights?"
```

### 3. Object Storage
Expand Down
164 changes: 142 additions & 22 deletions src/mysql-mcp-server/mysql_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
from fastmcp import FastMCP
from mysql import connector
from mysql.connector.abstracts import MySQLConnectionAbstract

from utils import DatabaseConnectionError, get_ssh_command, load_mysql_config, Mode, OciInfo
from utils import (
DatabaseConnectionError,
Mode,
OciInfo,
get_ssh_command,
load_mysql_config,
)

MIN_CONTEXT_SIZE = 10
DEFAULT_CONTEXT_SIZE = 20
Expand All @@ -29,20 +34,26 @@
try:
config = load_mysql_config()
except Exception as e:
config_error_msg = json.dumps({
"error" : f"Error loading config. Fix configuration file and try restarting MCP server {str(e)}."
})
config_error_msg = json.dumps(
{
"error": f"Error loading config. Fix configuration file and try restarting MCP server {str(e)}."
}
)

# Setup oci connection if applicable
oci_info: Optional[OciInfo] = None # None if not available, otherwise OCI config info
oci_error_msg: Optional[str] = None # None if OCI available, otherwise a json formatted string
oci_error_msg: Optional[str] = (
None # None if OCI available, otherwise a json formatted string
)
try:
oci_info = OciInfo()
except Exception as e:
oci_error_msg = json.dumps({
"error" : "object store unavailable. If object store is required, the MCP server must be restarted with a valid"
f" OCI config. OCI connection attempt yielded error {str(e)}."
})
oci_error_msg = json.dumps(
{
"error": "object store unavailable. If object store is required, the MCP server must be restarted with a valid"
f" OCI config. OCI connection attempt yielded error {str(e)}."
}
)

# Create mcp server
mcp = FastMCP("MySQL")
Expand All @@ -51,6 +62,7 @@
# Finish setup
###############################################################


def _validate_name(name: str) -> str:
"""
Validate that the string is a legal SQL identifier (letters, digits, underscores).
Expand Down Expand Up @@ -81,9 +93,7 @@ def _get_mode(connection_id: str) -> Mode:
Returns:
Mode: The resolved provider mode.
"""
provider_result = _execute_sql_tool(
connection_id, "SELECT @@rapid_cloud_provider;"
)
provider_result = _execute_sql_tool(connection_id, "SELECT @@rapid_cloud_provider;")
if check_error(provider_result):
raise Exception(
f"Exception occurred while fetching cloud provider {str(provider_result)}"
Expand Down Expand Up @@ -230,7 +240,7 @@ def list_all_connections() -> str:
{
"key": connection_id,
"error": str(e),
"hint": f"Bastion/jump host may be down. Try starting it with {get_ssh_command(config)}"
"hint": f"Bastion/jump host may be down. Try starting it with {get_ssh_command(config)}",
}
)
return json.dumps({"valid keys": valid_keys, "invalid keys": invalid_keys})
Expand Down Expand Up @@ -258,6 +268,19 @@ def execute_sql_tool_by_connection_id(
return _execute_sql_tool(connection_id, sql_script, params=params)


from datetime import date, datetime
from decimal import Decimal


class CustomJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Decimal):
return str(o)
if isinstance(o, (date, datetime)):
return o.isoformat()
return super().default(o)


def _execute_sql_tool(
connection: Union[str, MySQLConnectionAbstract],
sql_script: str,
Expand Down Expand Up @@ -309,7 +332,7 @@ def _execute_sql_tool(

db_connection.commit()

return json.dumps(results)
return json.dumps(results, cls=CustomJSONEncoder)

except Exception as e:
return json.dumps(
Expand Down Expand Up @@ -565,7 +588,9 @@ def load_vector_store_oci(


@mcp.tool()
def ask_ml_rag_vector_store(connection_id: str, question: str, context_size: int = DEFAULT_CONTEXT_SIZE) -> str:
def ask_ml_rag_vector_store(
connection_id: str, question: str, context_size: int = DEFAULT_CONTEXT_SIZE
) -> str:
"""
[MCP Tool] Retrieve segments from the default vector store (skip_generate=true).

Expand All @@ -586,16 +611,26 @@ def ask_ml_rag_vector_store(connection_id: str, question: str, context_size: int
arguments: {"connection_id": "example_local_server", "question": "Find information about refunds."}
"""
if context_size < MIN_CONTEXT_SIZE or MAX_CONTEXT_SIZE < context_size:
return json.dumps({"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"})
return json.dumps(
{
"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"
}
)

return _ask_ml_rag_helper(
connection_id, question, f"JSON_OBJECT('skip_generate', true, 'n_citations', {context_size})"
connection_id,
question,
f"JSON_OBJECT('skip_generate', true, 'n_citations', {context_size})",
)


@mcp.tool()
def ask_ml_rag_innodb(
connection_id: str, question: str, segment_col: str, embedding_col: str, context_size: int = DEFAULT_CONTEXT_SIZE
connection_id: str,
question: str,
segment_col: str,
embedding_col: str,
context_size: int = DEFAULT_CONTEXT_SIZE,
) -> str:
"""
[MCP Tool] Retrieve segments from InnoDB tables using specified segment and embedding columns.
Expand Down Expand Up @@ -626,7 +661,11 @@ def ask_ml_rag_innodb(
arguments: {"connection_id": "example_local_server", "question": "Search product docs", "segment_col": "body", "embedding_col": "embedding"}
"""
if context_size < MIN_CONTEXT_SIZE or MAX_CONTEXT_SIZE < context_size:
return json.dumps({"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"})
return json.dumps(
{
"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"
}
)

try:
# prevent possible injection
Expand Down Expand Up @@ -732,6 +771,84 @@ def heatwave_ask_help(connection_id: str, question: str) -> str:
return json.dumps({"error": f"Error with NL2ML: {str(e)}"})


@mcp.tool()
def ask_nl_sql(connection_id: str, question: str) -> str:
"""
[MCP Tool] Convert natural language questions into SQL queries and execute them automatically.

This tool is ideal for database exploration using plain English questions like:
- "What tables are available?"
- "Show me the average price by category"
- "How many users registered last month?"
- "What are the column names in the customers table?"

Args:
connection_id (str): MySQL connection key.
question (str): Natural language query.

Returns:
JSON object containing:

sql_response(str): The response from executing the generated SQL query.
sql_query(str): The generated SQL query
schemas(json): The schemas where metadata was retrieved
tables(json): The tables where metadata was retrieved
is_sql_valid(bool): Whether the generated SQL statement is valid
model_id(str): The LLM used for generation


MCP usage example:
- name: ask_nl_sql
arguments: {"connection_id": "example_local_server", "question": "How many singers are there?"}

Here is the what part of the return JSON looks like;
{
"tables": [
"singer.singer",
"singer.song",
"concert_singer.singer",
"concert_singer.stadium",
"music_2.Songs",
"music_2.Instruments",
"music_2.Band",
"music_2.Vocals",
"music_2.Tracklists"
],
"schemas": [
"concert_singer",
"music_2",
"singer"
],
"sql_query": "SELECT COUNT(`Singer_ID`) FROM `concert_singer`.`singer`;",
"is_sql_valid": 1
}
"""
with _get_database_connection_cm(connection_id) as db_connection:
# Execute the heatwave chat query
set_response = _execute_sql_tool(db_connection, "SET @response = NULL;")
if check_error(set_response):
return json.dumps({"error": f"Error with NL_SQL: {set_response}"})

nl2sql_response = _execute_sql_tool(
db_connection,
f"CALL sys.NL_SQL(%s, @response, NULL)",
params=[question],
)
if check_error(nl2sql_response):
return json.dumps({"error": f"Error with NL_SQL: {nl2sql_response}"})

fetch_response = _execute_sql_tool(db_connection, "SELECT @response;")
if check_error(fetch_response):
return json.dumps({"error": f"Error with ML_RAG: {fetch_response}"})

try:
response = json.loads(fetch_one(fetch_response))
response["sql_response"] = nl2sql_response
return json.dumps(response)
except:
return json.dumps({"error": "Unexpected response format from NL_SQL"})


"""
Object store
"""
Expand All @@ -745,7 +862,7 @@ def verify_compartment_access(compartments):
"compartment_id": compartment.id,
"object_storage": False,
"databases": False,
"errors": []
"errors": [],
}

# Test Object Storage
Expand All @@ -756,10 +873,13 @@ def verify_compartment_access(compartments):
)
access_report[compartment.name]["object_storage"] = True
except Exception as e:
access_report[compartment.name]["errors"].append(f"Object Storage: {str(e)}")
access_report[compartment.name]["errors"].append(
f"Object Storage: {str(e)}"
)

return access_report


@mcp.tool()
def list_all_compartments() -> str:
"""
Expand Down