Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,25 @@ private static void AppendMethodDeclaration(
writer.WriteLine($"[return: Description(\"{EscapeString(xmlDocs.Returns)}\")]");
}

// Copy modifiers from original method syntax.
// Copy modifiers from original method syntax, excluding 'async' which is invalid on partial declarations (CS1994).
// Add return type (without nullable annotations).
// Add method name.
writer.Write(string.Join(" ", methodDeclaration.Modifiers.Select(m => m.Text)));
var modifiers = methodDeclaration.Modifiers
.Where(m => !m.IsKind(SyntaxKind.AsyncKeyword))
.Select(m => m.Text);
writer.Write(string.Join(" ", modifiers));
writer.Write(' ');
writer.Write(methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat));
writer.Write(' ');
writer.Write(methodSymbol.Name);

// Add parameters with their Description attributes.
writer.Write("(");
var parameterSyntaxList = methodDeclaration.ParameterList.Parameters;
for (int i = 0; i < methodSymbol.Parameters.Length; i++)
{
IParameterSymbol param = methodSymbol.Parameters[i];
ParameterSyntax? paramSyntax = i < parameterSyntaxList.Count ? parameterSyntaxList[i] : null;

if (i > 0)
{
Expand All @@ -352,6 +357,13 @@ private static void AppendMethodDeclaration(
writer.Write(param.Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat));
writer.Write(' ');
writer.Write(param.Name);

// Preserve default parameter values from the original syntax.
if (paramSyntax?.Default is { } defaultValue)
{
writer.Write(' ');
writer.Write(defaultValue.ToFullString().Trim());
}
}
writer.WriteLine(");");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,260 @@ partial class GlobalTools
AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString());
}

[Fact]
public void Generator_WithDefaultParameterValues_PreservesDefaults()
{
var result = RunGenerator("""
using ModelContextProtocol.Server;
using System.ComponentModel;

namespace Test;

[McpServerToolType]
public partial class TestTools
{
/// <summary>
/// Test tool with defaults
/// </summary>
/// <param name="project">The project name</param>
/// <param name="flag">Enable flag</param>
/// <param name="count">Item count</param>
[McpServerTool]
public static partial string TestMethod(
string? project = null,
bool flag = false,
int count = 42)
{
return project ?? "default";
}
}
""");

Assert.True(result.Success);
Assert.Single(result.GeneratedSources);

var expected = $$"""
// <auto-generated/>
// ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}}

#pragma warning disable

using System.ComponentModel;
using ModelContextProtocol.Server;

namespace Test
{
partial class TestTools
{
[Description("Test tool with defaults")]
public static partial string TestMethod([Description("The project name")] string? project = null, [Description("Enable flag")] bool flag = false, [Description("Item count")] int count = 42);
}
}
""";

AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString());
}

[Fact]
public void Generator_WithAsyncMethod_ExcludesAsyncModifier()
{
var result = RunGenerator("""
using ModelContextProtocol.Server;
using System.ComponentModel;
using System.Threading.Tasks;

namespace Test;

[McpServerToolType]
public partial class TestTools
{
/// <summary>
/// Async tool
/// </summary>
[McpServerTool]
public async partial Task<string> DoWorkAsync(string input)
{
await Task.Delay(100);
return input;
}
}
""");

Assert.True(result.Success);
Assert.Single(result.GeneratedSources);

var expected = $$"""
// <auto-generated/>
// ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}}

#pragma warning disable

using System.ComponentModel;
using ModelContextProtocol.Server;

namespace Test
{
partial class TestTools
{
[Description("Async tool")]
public partial Task<string> DoWorkAsync(string input);
}
}
""";

AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString());
}

[Fact]
public void Generator_WithAsyncStaticMethod_ExcludesAsyncModifier()
{
var result = RunGenerator("""
using ModelContextProtocol.Server;
using System.ComponentModel;
using System.Threading.Tasks;

namespace Test;

[McpServerToolType]
public partial class TestTools
{
/// <summary>
/// Static async tool
/// </summary>
[McpServerTool]
public static async partial Task<string> StaticAsyncMethod(string input)
{
await Task.Delay(100);
return input;
}
}
""");

Assert.True(result.Success);
Assert.Single(result.GeneratedSources);

var expected = $$"""
// <auto-generated/>
// ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}}

#pragma warning disable

using System.ComponentModel;
using ModelContextProtocol.Server;

namespace Test
{
partial class TestTools
{
[Description("Static async tool")]
public static partial Task<string> StaticAsyncMethod(string input);
}
}
""";

AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString());
}

[Fact]
public void Generator_WithDefaultParameterValuesAndAsync_HandlesBothCorrectly()
{
var result = RunGenerator("""
using ModelContextProtocol.Server;
using System.ComponentModel;
using System.Threading.Tasks;

namespace Test;

[McpServerToolType]
public partial class TestTools
{
/// <summary>
/// Async tool with defaults
/// </summary>
/// <param name="input">The input</param>
/// <param name="timeout">Timeout in ms</param>
[McpServerTool]
public static async partial Task<string> AsyncWithDefaults(string input, int timeout = 1000)
{
await Task.Delay(timeout);
return input;
}
}
""");

Assert.True(result.Success);
Assert.Single(result.GeneratedSources);

var expected = $$"""
// <auto-generated/>
// ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}}

#pragma warning disable

using System.ComponentModel;
using ModelContextProtocol.Server;

namespace Test
{
partial class TestTools
{
[Description("Async tool with defaults")]
public static partial Task<string> AsyncWithDefaults([Description("The input")] string input, [Description("Timeout in ms")] int timeout = 1000);
}
}
""";

AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString());
}

[Fact]
public void Generator_WithStringDefaultValue_PreservesQuotedDefault()
{
var result = RunGenerator("""
using ModelContextProtocol.Server;
using System.ComponentModel;

namespace Test;

[McpServerToolType]
public partial class TestTools
{
/// <summary>
/// Test tool with string default
/// </summary>
[McpServerTool]
public static partial string TestMethod(string name = "default value")
{
return name;
}
}
""");

Assert.True(result.Success);
Assert.Single(result.GeneratedSources);

var expected = $$"""
// <auto-generated/>
// ModelContextProtocol.Analyzers {{typeof(XmlToDescriptionGenerator).Assembly.GetName().Version}}

#pragma warning disable

using System.ComponentModel;
using ModelContextProtocol.Server;

namespace Test
{
partial class TestTools
{
[Description("Test tool with string default")]
public static partial string TestMethod(string name = "default value");
}
}
""";

AssertGeneratedSourceEquals(expected, result.GeneratedSources[0].SourceText.ToString());
}

private GeneratorRunResult RunGenerator([StringSyntax("C#-test")] string source)
{
var syntaxTree = CSharpSyntaxTree.ParseText(source);
Expand Down