Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 147 additions & 43 deletions core/Azure.Mcp.Core/src/Areas/Server/Commands/ServiceStartCommand.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.CommandLine.Parsing;
using System.Net;
using Azure.Mcp.Core.Areas.Server.Options;
using Azure.Mcp.Core.Commands;
Expand All @@ -21,7 +22,7 @@ namespace Azure.Mcp.Core.Areas.Server.Commands;
/// This command is hidden from the main command list.
/// </summary>
[HiddenCommand]
public sealed class ServiceStartCommand : BaseCommand
public sealed class ServiceStartCommand : BaseCommand<ServiceStartOptions>
{
private const string CommandTitle = "Start MCP Server";

Expand Down Expand Up @@ -63,6 +64,53 @@ protected override void RegisterOptions(Command command)
command.Options.Add(ServiceOptionDefinitions.InsecureDisableElicitation);
}

/// <summary>
/// Binds the parsed command line arguments to the ServiceStartOptions object.
/// </summary>
/// <param name="parseResult">The parsed command line arguments.</param>
/// <returns>A configured ServiceStartOptions instance.</returns>
protected override ServiceStartOptions BindOptions(ParseResult parseResult)
{
var options = new ServiceStartOptions
{
Transport = parseResult.GetValueOrDefault<string>(ServiceOptionDefinitions.Transport.Name) ?? TransportTypes.StdIo,
Namespace = parseResult.GetValueOrDefault<string[]?>(ServiceOptionDefinitions.Namespace.Name),
Mode = parseResult.GetValueOrDefault<string?>(ServiceOptionDefinitions.Mode.Name),
ReadOnly = parseResult.GetValueOrDefault<bool?>(ServiceOptionDefinitions.ReadOnly.Name),
Debug = parseResult.GetValueOrDefault<bool>(ServiceOptionDefinitions.Debug.Name),
EnableInsecureTransports = parseResult.GetValueOrDefault<bool>(ServiceOptionDefinitions.EnableInsecureTransports.Name),
InsecureDisableElicitation = parseResult.GetValueOrDefault<bool>(ServiceOptionDefinitions.InsecureDisableElicitation.Name)
};
return options;
}

/// <summary>
/// Validates the command options and arguments.
/// </summary>
/// <param name="commandResult">The command result to validate.</param>
/// <param name="commandResponse">Optional response object to set error details.</param>
/// <returns>A ValidationResult indicating whether the validation passed.</returns>
public override ValidationResult Validate(CommandResult commandResult, CommandResponse? commandResponse)
{
// First run the base validation for required options and parser errors
var baseResult = base.Validate(commandResult, commandResponse);
if (!baseResult.IsValid)
{
return baseResult;
}

// Get option values directly from commandResult
var mode = commandResult.GetValueOrDefault(ServiceOptionDefinitions.Mode);
var transport = commandResult.GetValueOrDefault(ServiceOptionDefinitions.Transport);
var enableInsecureTransports = commandResult.GetValueOrDefault(ServiceOptionDefinitions.EnableInsecureTransports);

// Validate and return early on any failures
return ValidateMode(mode, commandResponse) ??
ValidateTransport(transport, commandResponse) ??
ValidateInsecureTransportsConfiguration(enableInsecureTransports, commandResponse) ??
new ValidationResult { IsValid = true };
}

/// <summary>
/// Executes the service start command, creating and starting the MCP server.
/// </summary>
Expand All @@ -71,58 +119,123 @@ protected override void RegisterOptions(Command command)
/// <returns>A command response indicating the result of the operation.</returns>
public override async Task<CommandResponse> ExecuteAsync(CommandContext context, ParseResult parseResult)
{
string[]? namespaces = parseResult.GetValueOrDefault<string[]?>(ServiceOptionDefinitions.Namespace.Name);
string? mode = parseResult.GetValueOrDefault<string?>(ServiceOptionDefinitions.Mode.Name);
bool? readOnly = parseResult.GetValueOrDefault<bool?>(ServiceOptionDefinitions.ReadOnly.Name);
if (!Validate(parseResult.CommandResult, context.Response).IsValid)
{
return context.Response;
}

var debug = parseResult.GetValueOrDefault<bool>(ServiceOptionDefinitions.Debug.Name);
var options = BindOptions(parseResult);

if (!IsValidMode(mode))
try
{
throw new ArgumentException($"Invalid mode '{mode}'. Valid modes are: {ModeTypes.SingleToolProxy}, {ModeTypes.NamespaceProxy}, {ModeTypes.All}.");
}
using var host = CreateHost(options);
await host.StartAsync(CancellationToken.None);
await host.WaitForShutdownAsync(CancellationToken.None);

var enableInsecureTransports = parseResult.GetValueOrDefault<bool>(ServiceOptionDefinitions.EnableInsecureTransports.Name);
return context.Response;
}
catch (Exception ex)
{
HandleException(context, ex);
return context.Response;
}
}

if (enableInsecureTransports)
/// <summary>
/// Validates if the provided mode is a valid mode type.
/// </summary>
/// <param name="mode">The mode to validate.</param>
/// <param name="commandResponse">Optional command response to update on failure.</param>
/// <returns>ValidationResult with error details if invalid, null if valid.</returns>
private static ValidationResult? ValidateMode(string? mode, CommandResponse? commandResponse)
{
if (mode == ModeTypes.SingleToolProxy ||
mode == ModeTypes.NamespaceProxy ||
mode == ModeTypes.All)
{
var includeProdCreds = EnvironmentHelpers.GetEnvironmentVariableAsBool("AZURE_MCP_INCLUDE_PRODUCTION_CREDENTIALS");
if (!includeProdCreds)
{
throw new InvalidOperationException("Using --enable-insecure-transport requires the host to have either Managed Identity or Workload Identity enabled. Please refer to the troubleshooting guidelines here at https://aka.ms/azmcp/troubleshooting.");
}
return null; // Success
}

var serverOptions = new ServiceStartOptions
var result = new ValidationResult
{
Transport = parseResult.GetValueOrDefault<string>(ServiceOptionDefinitions.Transport.Name) ?? TransportTypes.StdIo,
Namespace = namespaces,
Mode = mode,
ReadOnly = readOnly,
Debug = debug,
EnableInsecureTransports = enableInsecureTransports,
InsecureDisableElicitation = parseResult.GetValueOrDefault<bool>(ServiceOptionDefinitions.InsecureDisableElicitation.Name),
IsValid = false,
ErrorMessage = $"Invalid mode '{mode}'. Valid modes are: {ModeTypes.SingleToolProxy}, {ModeTypes.NamespaceProxy}, {ModeTypes.All}."
};

using var host = CreateHost(serverOptions);
await host.StartAsync(CancellationToken.None);
await host.WaitForShutdownAsync(CancellationToken.None);
SetValidationError(commandResponse, result.ErrorMessage!, HttpStatusCode.BadRequest);
return result;
}

/// <summary>
/// Validates if the provided transport is valid.
/// </summary>
/// <param name="transport">The transport to validate.</param>
/// <param name="commandResponse">Optional command response to update on failure.</param>
/// <returns>ValidationResult with error details if invalid, null if valid.</returns>
private static ValidationResult? ValidateTransport(string? transport, CommandResponse? commandResponse)
{
if (transport is null || transport == TransportTypes.StdIo)
{
return null; // Success
}

return context.Response;
var result = new ValidationResult
{
IsValid = false,
ErrorMessage = $"Invalid transport '{transport}'. Valid transports are: {TransportTypes.StdIo}."
};

SetValidationError(commandResponse, result.ErrorMessage!, HttpStatusCode.BadRequest);
return result;
}

/// <summary>
/// Validates if the provided mode is a valid mode type.
/// Validates if the insecure transport configuration is valid.
/// </summary>
/// <param name="mode">The mode to validate.</param>
/// <returns>True if the mode is valid, otherwise false.</returns>
private static bool IsValidMode(string? mode)
/// <param name="enableInsecureTransports">Whether insecure transports are enabled.</param>
/// <param name="commandResponse">Optional command response to update on failure.</param>
/// <returns>ValidationResult with error details if invalid, null if valid.</returns>
private static ValidationResult? ValidateInsecureTransportsConfiguration(bool enableInsecureTransports, CommandResponse? commandResponse)
{
return mode == ModeTypes.SingleToolProxy ||
mode == ModeTypes.NamespaceProxy ||
mode == ModeTypes.All;
// If insecure transports are not enabled, configuration is valid
if (!enableInsecureTransports)
{
return null; // Success
}

// If insecure transports are enabled, check if proper credentials are configured
var hasCredentials = EnvironmentHelpers.GetEnvironmentVariableAsBool("AZURE_MCP_INCLUDE_PRODUCTION_CREDENTIALS");
if (hasCredentials)
{
return null; // Success
}

var result = new ValidationResult
{
IsValid = false,
ErrorMessage = "Using --enable-insecure-transport requires the host to have either Managed Identity or Workload Identity enabled. Please refer to the troubleshooting guidelines here at https://aka.ms/azmcp/troubleshooting."
};

SetValidationError(commandResponse, result.ErrorMessage!, HttpStatusCode.InternalServerError);
return result;
}

/// <summary>
/// Provides custom error messages for specific exception types to improve user experience.
/// </summary>
/// <param name="ex">The exception to format an error message for.</param>
/// <returns>A user-friendly error message.</returns>
protected override string GetErrorMessage(Exception ex) => ex switch
{
ArgumentException argEx when argEx.Message.Contains("Invalid transport") =>
"Invalid transport option specified. Use --transport stdio for the supported transport mechanism.",
ArgumentException argEx when argEx.Message.Contains("Invalid mode") =>
"Invalid mode option specified. Use --mode single, namespace, or all for the supported modes.",
InvalidOperationException invOpEx when invOpEx.Message.Contains("Using --enable-insecure-transport") =>
"Insecure transport configuration error. Ensure proper authentication configured with Managed Identity or Workload Identity.",
_ => base.GetErrorMessage(ex)
};

/// <summary>
/// Creates the host for the MCP server with the specified options.
/// </summary>
Expand Down Expand Up @@ -284,13 +397,4 @@ private static string GetSafeAspNetCoreUrl()

return url;
}

/// <summary>
/// Hosted service for running the MCP server using standard input/output.
/// </summary>
private sealed class StdioMcpServerHostedService(IMcpServer session) : BackgroundService
{
/// <inheritdoc />
protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Diagnostics;
using System.Net;
using System.Text.Json.Nodes;
using Azure.Mcp.Core.Areas.Server.Models;
using Azure.Mcp.Core.Commands;
Expand Down Expand Up @@ -196,7 +197,7 @@ public async ValueTask<CallToolResult> CallToolHandler(RequestContext<CallToolRe
{
var commandResponse = await command.ExecuteAsync(commandContext, commandOptions);
var jsonResponse = JsonSerializer.Serialize(commandResponse, ModelsJsonContext.Default.CommandResponse);
var isError = commandResponse.Status < 200 || commandResponse.Status >= 300;
var isError = commandResponse.Status < HttpStatusCode.OK || commandResponse.Status >= HttpStatusCode.Ambiguous;

return new CallToolResult
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace Azure.Mcp.Core.Areas.Tools.Commands;

[HiddenCommand]
public sealed class ToolsListCommand(ILogger<ToolsListCommand> logger) : BaseCommand()
public sealed class ToolsListCommand(ILogger<ToolsListCommand> logger) : BaseCommand<EmptyOptions>
{
private const string CommandTitle = "List Available Tools";

Expand All @@ -33,6 +33,8 @@ arguments. Use this to explore the CLI's functionality or to build interactive c
Secret = false
};

protected override EmptyOptions BindOptions(ParseResult parseResult) => new();

public override async Task<CommandResponse> ExecuteAsync(CommandContext context, ParseResult parseResult)
{
try
Expand Down
50 changes: 39 additions & 11 deletions core/Azure.Mcp.Core/src/Commands/BaseCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

using System.CommandLine.Parsing;
using System.Diagnostics;
using System.Net;
using Azure.Mcp.Core.Exceptions;
using Azure.Mcp.Core.Helpers;
using static Azure.Mcp.Core.Services.Telemetry.TelemetryConstants;

namespace Azure.Mcp.Core.Commands;

public abstract class BaseCommand : IBaseCommand
public abstract class BaseCommand<TOptions> : IBaseCommand where TOptions : class, new()
Comment thread
jongio marked this conversation as resolved.
{
private const string MissingRequiredOptionsPrefix = "Missing Required options: ";
private const int ValidationErrorStatusCode = 400;
private const string TroubleshootingUrl = "https://aka.ms/azmcp/troubleshooting";

private readonly Command _command;
Expand All @@ -34,6 +34,14 @@ protected virtual void RegisterOptions(Command command)
{
}

/// <summary>
/// Binds the parsed command line arguments to a strongly-typed options object.
/// Implement this method in derived classes to provide option binding logic.
/// </summary>
/// <param name="parseResult">The parsed command line arguments.</param>
/// <returns>An options object containing the bound options.</returns>
protected abstract TOptions BindOptions(ParseResult parseResult);

public abstract Task<CommandResponse> ExecuteAsync(CommandContext context, ParseResult parseResult);

protected virtual void HandleException(CommandContext context, Exception ex)
Expand Down Expand Up @@ -73,19 +81,20 @@ protected virtual void HandleException(CommandContext context, Exception ex)
response.Results = ResponseResult.Create(result, JsonSourceGenerationContext.Default.ExceptionResult);
Comment thread
jongio marked this conversation as resolved.
}

internal record ExceptionResult(
string Message,
string? StackTrace,
string Type);

protected virtual string GetErrorMessage(Exception ex) => ex.Message;

protected virtual int GetStatusCode(Exception ex) => 500;
protected virtual HttpStatusCode GetStatusCode(Exception ex) => ex switch
{
ArgumentException => HttpStatusCode.BadRequest, // Bad Request for invalid arguments
InvalidOperationException => HttpStatusCode.UnprocessableEntity, // Unprocessable Entity for configuration errors
_ => HttpStatusCode.InternalServerError // Internal Server Error for unexpected errors
};

public virtual ValidationResult Validate(CommandResult commandResult, CommandResponse? commandResponse = null)
{
var result = new ValidationResult { IsValid = true };

// First, check for missing required options
var missingOptions = commandResult.Command.Options
.Where(o => o.Required && !o.HasDefaultValue && !commandResult.HasOptionResult(o))
.Select(o => $"--{NameNormalization.NormalizeOptionName(o.Name)}")
Expand All @@ -101,8 +110,7 @@ public virtual ValidationResult Validate(CommandResult commandResult, CommandRes
return result;
}

// If no missing required options, propagate parser/validator errors as-is.
// Commands can throw CommandValidationException for structured handling.
// Check for parser/validator errors
if (commandResult.Errors != null && commandResult.Errors.Any())
{
result.IsValid = false;
Expand All @@ -118,9 +126,29 @@ static void SetValidationError(CommandResponse? response, string errorMessage)
{
if (response != null)
{
response.Status = ValidationErrorStatusCode;
response.Status = HttpStatusCode.BadRequest;
response.Message = errorMessage;
}
}
}

/// <summary>
/// Sets validation error details on the command response with a custom status code.
/// </summary>
/// <param name="response">The command response to update.</param>
/// <param name="errorMessage">The error message.</param>
/// <param name="statusCode">The HTTP status code (defaults to ValidationErrorStatusCode).</param>
protected static void SetValidationError(CommandResponse? response, string errorMessage, HttpStatusCode statusCode)
{
if (response != null)
{
response.Status = statusCode;
response.Message = errorMessage;
}
}
}

internal record ExceptionResult(
string Message,
string? StackTrace,
string Type);
Loading