From 41c6dfff73c23fb0024981fc300f8d4596433e3f Mon Sep 17 00:00:00 2001 From: Juancho Date: Fri, 31 Oct 2025 13:26:27 +0100 Subject: [PATCH 1/3] Return useful error message for AIs if we fail to deserialize custom data types --- .../src/Services/PostgresService.cs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs index 36fa045671..e2b834048e 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Net; using Azure.Core; +using Azure.Mcp.Core.Exceptions; using Azure.Mcp.Core.Services.Azure; using Azure.Mcp.Core.Services.Azure.ResourceGroup; using Azure.ResourceManager.PostgreSql.FlexibleServers; @@ -86,7 +88,20 @@ public async Task> ExecuteQueryAsync(string subscriptionId, string var row = new List(); for (int i = 0; i < reader.FieldCount; i++) { - row.Add(reader[i]?.ToString() ?? "NULL"); + try + { + row.Add(reader[i]?.ToString() ?? "NULL"); + } + catch (InvalidCastException) + { + throw new CommandValidationException($"E_QUERY_UNSUPPORTED_COMPLEX_TYPES. The PostgreSQL query failed because it returned one or more columns with non-standard data types (extension or user-defined) unsupported by the MCP agent.\nColumn that failed: '{columnNames[i]}'.\n" + + $"Action required:\n" + + $"1. Obtain the exact schema for all the tables involved in the query.\n" + + $"2. Identify which columns have non-standard data types.\n" + + $"3. Modify the query to convert them to a supported type (e.g. using CAST or converting to text, integer, or the appropriate standard type).\n" + + $"4. Re-execute the modified query.\n" + + $"Please perform steps 1-4 now and re-execute.", HttpStatusCode.BadRequest); + } } rows.Add(string.Join(", ", row)); } From 707454e84461e49a75ee9789673e9adb85d3abf0 Mon Sep 17 00:00:00 2001 From: Juancho Date: Fri, 7 Nov 2025 16:07:37 +0100 Subject: [PATCH 2/3] Added unit tests for ExecuteQueryAsync --- .../src/Azure.Mcp.Tools.Postgres.csproj | 1 - .../src/PostgresSetup.cs | 4 + .../src/Providers/DbProvider.cs | 23 +++ .../src/Providers/EntraTokenProvider.cs | 16 ++ .../src/Providers/IDbProvider.cs | 13 ++ .../src/Providers/IEntraTokenProvider.cs | 9 + .../src/Providers/IPostgresResource.cs | 9 + .../src/Providers/PostgresResource.cs | 31 +++ .../src/Services/PostgresService.cs | 67 ++---- .../PostgresServiceParameterizedQueryTests.cs | 21 +- .../Services/PostgresServiceTests.cs | 131 ++++++++++++ .../Services/Support/FakeDbDataReader.cs | 191 ++++++++++++++++++ .../Services/Support/InvalidCastItem.cs | 10 + 13 files changed, 479 insertions(+), 47 deletions(-) create mode 100644 tools/Azure.Mcp.Tools.Postgres/src/Providers/DbProvider.cs create mode 100644 tools/Azure.Mcp.Tools.Postgres/src/Providers/EntraTokenProvider.cs create mode 100644 tools/Azure.Mcp.Tools.Postgres/src/Providers/IDbProvider.cs create mode 100644 tools/Azure.Mcp.Tools.Postgres/src/Providers/IEntraTokenProvider.cs create mode 100644 tools/Azure.Mcp.Tools.Postgres/src/Providers/IPostgresResource.cs create mode 100644 tools/Azure.Mcp.Tools.Postgres/src/Providers/PostgresResource.cs create mode 100644 tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceTests.cs create mode 100644 tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/FakeDbDataReader.cs create mode 100644 tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/InvalidCastItem.cs diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Azure.Mcp.Tools.Postgres.csproj b/tools/Azure.Mcp.Tools.Postgres/src/Azure.Mcp.Tools.Postgres.csproj index 6dac8fbe5e..56f88555e8 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Azure.Mcp.Tools.Postgres.csproj +++ b/tools/Azure.Mcp.Tools.Postgres/src/Azure.Mcp.Tools.Postgres.csproj @@ -9,7 +9,6 @@ - diff --git a/tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs b/tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs index 547cbf6f8c..5b011fc0c3 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs @@ -3,9 +3,11 @@ using Azure.Mcp.Core.Areas; using Azure.Mcp.Core.Commands; +using Azure.Mcp.Tools.Postgres.Auth; using Azure.Mcp.Tools.Postgres.Commands.Database; using Azure.Mcp.Tools.Postgres.Commands.Server; using Azure.Mcp.Tools.Postgres.Commands.Table; +using Azure.Mcp.Tools.Postgres.Providers; using Azure.Mcp.Tools.Postgres.Services; using Microsoft.Extensions.DependencyInjection; @@ -19,6 +21,8 @@ public class PostgresSetup : IAreaSetup public void ConfigureServices(IServiceCollection services) { + services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/DbProvider.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/DbProvider.cs new file mode 100644 index 0000000000..5bcd3ba0e8 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/DbProvider.cs @@ -0,0 +1,23 @@ +using System.Data.Common; +using Npgsql; + +namespace Azure.Mcp.Tools.Postgres.Providers +{ + internal class DbProvider : IDbProvider + { + public async Task GetPostgresResource(string connectionString) + { + return await PostgresResource.CreateAsync(connectionString); + } + + public NpgsqlCommand GetCommand(string query, IPostgresResource postgresResource) + { + return new NpgsqlCommand(query, postgresResource.Connection); + } + + public async Task ExecuteReaderAsync(NpgsqlCommand command) + { + return await command.ExecuteReaderAsync(); + } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/EntraTokenProvider.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/EntraTokenProvider.cs new file mode 100644 index 0000000000..649cfdbb8a --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/EntraTokenProvider.cs @@ -0,0 +1,16 @@ +using Azure.Core; + +namespace Azure.Mcp.Tools.Postgres.Auth +{ + internal class EntraTokenProvider : IEntraTokenProvider + { + public async Task GetEntraToken(TokenCredential tokenCredential) + { + var tokenRequestContext = new TokenRequestContext(["https://ossrdbms-aad.database.windows.net/.default"]); + var accessToken = await tokenCredential + .GetTokenAsync(tokenRequestContext, CancellationToken.None) + .ConfigureAwait(false); + return accessToken; + } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/IDbProvider.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IDbProvider.cs new file mode 100644 index 0000000000..dacadc26fa --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IDbProvider.cs @@ -0,0 +1,13 @@ + +using System.Data.Common; +using Npgsql; + +namespace Azure.Mcp.Tools.Postgres.Providers +{ + public interface IDbProvider + { + Task GetPostgresResource(string connectionString); + NpgsqlCommand GetCommand(string query, IPostgresResource postgresResource); + Task ExecuteReaderAsync(NpgsqlCommand command); + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/IEntraTokenProvider.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IEntraTokenProvider.cs new file mode 100644 index 0000000000..8ca97b1769 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IEntraTokenProvider.cs @@ -0,0 +1,9 @@ +using Azure.Core; + +namespace Azure.Mcp.Tools.Postgres.Auth +{ + public interface IEntraTokenProvider + { + Task GetEntraToken(TokenCredential tokenCredential); + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/IPostgresResource.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IPostgresResource.cs new file mode 100644 index 0000000000..6703bcc484 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IPostgresResource.cs @@ -0,0 +1,9 @@ +using Npgsql; + +namespace Azure.Mcp.Tools.Postgres.Providers +{ + public interface IPostgresResource : IAsyncDisposable + { + NpgsqlConnection Connection { get; } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/PostgresResource.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/PostgresResource.cs new file mode 100644 index 0000000000..414b6110b7 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/PostgresResource.cs @@ -0,0 +1,31 @@ +using Npgsql; + +namespace Azure.Mcp.Tools.Postgres.Providers +{ + internal class PostgresResource : IPostgresResource + { + public NpgsqlConnection Connection { get; } + private readonly NpgsqlDataSource _dataSource; + + public static async Task CreateAsync(string connectionString) + { + var dataSource = new NpgsqlSlimDataSourceBuilder(connectionString) + .EnableTransportSecurity() + .Build(); + var connection = await dataSource.OpenConnectionAsync(); + return new PostgresResource(dataSource, connection); + } + + public async ValueTask DisposeAsync() + { + await Connection.DisposeAsync(); + await _dataSource.DisposeAsync(); + } + + private PostgresResource(NpgsqlDataSource dataSource, NpgsqlConnection connection) + { + _dataSource = dataSource; + Connection = connection; + } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs index e2b834048e..61a961849c 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs @@ -1,11 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Data; +using System.Data.Common; using System.Net; -using Azure.Core; using Azure.Mcp.Core.Exceptions; using Azure.Mcp.Core.Services.Azure; using Azure.Mcp.Core.Services.Azure.ResourceGroup; +using Azure.Mcp.Tools.Postgres.Auth; +using Azure.Mcp.Tools.Postgres.Providers; using Azure.ResourceManager.PostgreSql.FlexibleServers; using Npgsql; @@ -14,12 +17,16 @@ namespace Azure.Mcp.Tools.Postgres.Services; public class PostgresService : BaseAzureService, IPostgresService { private readonly IResourceGroupService _resourceGroupService; + private readonly IEntraTokenProvider _entraTokenAuth; + private readonly IDbProvider _dbProvider; private string? _cachedEntraIdAccessToken; private DateTime _tokenExpiryTime; - public PostgresService(IResourceGroupService resourceGroupService) + public PostgresService(IResourceGroupService resourceGroupService, IEntraTokenProvider entraTokenAuth, IDbProvider dbProvider) { _resourceGroupService = resourceGroupService ?? throw new ArgumentNullException(nameof(resourceGroupService)); + _entraTokenAuth = entraTokenAuth; + _dbProvider = dbProvider; } private async Task GetEntraIdAccessTokenAsync() @@ -29,11 +36,8 @@ private async Task GetEntraIdAccessTokenAsync() return _cachedEntraIdAccessToken; } - var tokenRequestContext = new TokenRequestContext(["https://ossrdbms-aad.database.windows.net/.default"]); var tokenCredential = await GetCredential(); - var accessToken = await tokenCredential - .GetTokenAsync(tokenRequestContext, CancellationToken.None) - .ConfigureAwait(false); + var accessToken = await _entraTokenAuth.GetEntraToken(tokenCredential); _cachedEntraIdAccessToken = accessToken.Token; _tokenExpiryTime = accessToken.ExpiresOn.UtcDateTime.AddSeconds(-60); // Subtract 60 seconds as a buffer. @@ -55,10 +59,10 @@ public async Task> ListDatabasesAsync(string subscriptionId, string var host = NormalizeServerName(server); var connectionString = $"Host={host};Database=postgres;Username={user};Password={entraIdAccessToken}"; - await using var resource = await PostgresResource.CreateAsync(connectionString); var query = "SELECT datname FROM pg_database WHERE datistemplate = false;"; - await using var command = new NpgsqlCommand(query, resource.Connection); - await using var reader = await command.ExecuteReaderAsync(); + await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString); + await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource); + await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command); var dbs = new List(); while (await reader.ReadAsync()) { @@ -73,9 +77,9 @@ public async Task> ExecuteQueryAsync(string subscriptionId, string var host = NormalizeServerName(server); var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}"; - await using var resource = await PostgresResource.CreateAsync(connectionString); - await using var command = new NpgsqlCommand(query, resource.Connection); - await using var reader = await command.ExecuteReaderAsync(); + await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString); + await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource); + await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command); var rows = new List(); @@ -114,10 +118,10 @@ public async Task> ListTablesAsync(string subscriptionId, string re var host = NormalizeServerName(server); var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}"; - await using var resource = await PostgresResource.CreateAsync(connectionString); var query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"; - await using var command = new NpgsqlCommand(query, resource.Connection); - await using var reader = await command.ExecuteReaderAsync(); + await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString); + await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource); + await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command); var tables = new List(); while (await reader.ReadAsync()) { @@ -132,10 +136,10 @@ public async Task> GetTableSchemaAsync(string subscriptionId, strin var host = NormalizeServerName(server); var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}"; - await using var resource = await PostgresResource.CreateAsync(connectionString); var query = $"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table}';"; - await using var command = new NpgsqlCommand(query, resource.Connection); - await using var reader = await command.ExecuteReaderAsync(); + await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString); + await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource); + await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command); var schema = new List(); while (await reader.ReadAsync()) { @@ -226,31 +230,4 @@ public async Task SetServerParameterAsync(string subscriptionId, string throw new Exception($"Failed to update parameter '{param}' to value '{value}'."); } } - - private sealed class PostgresResource : IAsyncDisposable - { - public NpgsqlConnection Connection { get; } - private readonly NpgsqlDataSource _dataSource; - - public static async Task CreateAsync(string connectionString) - { - var dataSource = new NpgsqlSlimDataSourceBuilder(connectionString) - .EnableTransportSecurity() - .Build(); - var connection = await dataSource.OpenConnectionAsync(); - return new PostgresResource(dataSource, connection); - } - - public async ValueTask DisposeAsync() - { - await Connection.DisposeAsync(); - await _dataSource.DisposeAsync(); - } - - private PostgresResource(NpgsqlDataSource dataSource, NpgsqlConnection connection) - { - _dataSource = dataSource; - Connection = connection; - } - } } diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs index dd73861352..113e5b6e46 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs @@ -1,8 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Data.Common; using Azure.Mcp.Core.Services.Azure.ResourceGroup; +using Azure.Mcp.Tools.Postgres.Auth; +using Azure.Mcp.Tools.Postgres.Providers; using Azure.Mcp.Tools.Postgres.Services; +using Npgsql; using NSubstitute; using Xunit; @@ -15,12 +19,27 @@ namespace Azure.Mcp.Tools.Postgres.UnitTests.Services; public class PostgresServiceParameterizedQueryTests { private readonly IResourceGroupService _resourceGroupService; + private readonly IEntraTokenProvider _entraTokenAuth; + private readonly IDbProvider _dbProvider; private readonly PostgresService _postgresService; public PostgresServiceParameterizedQueryTests() { _resourceGroupService = Substitute.For(); - _postgresService = new PostgresService(_resourceGroupService); + + _entraTokenAuth = Substitute.For(); + _entraTokenAuth.GetEntraToken(Arg.Any()) + .Returns(new Azure.Core.AccessToken("fake-token", DateTime.UtcNow.AddHours(1))); + + _dbProvider = Substitute.For(); + _dbProvider.GetPostgresResource(Arg.Any()) + .Returns(Substitute.For()); + _dbProvider.GetCommand(Arg.Any(), Arg.Any()) + .Returns(Substitute.For()); + _dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Substitute.For()); + + _postgresService = new PostgresService(_resourceGroupService, _entraTokenAuth, _dbProvider); } [Theory] diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceTests.cs new file mode 100644 index 0000000000..e0421bb36c --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceTests.cs @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Data.Common; +using Azure.Mcp.Core.Exceptions; +using Azure.Mcp.Core.Services.Azure.ResourceGroup; +using Azure.Mcp.Tools.Postgres.Auth; +using Azure.Mcp.Tools.Postgres.Providers; +using Azure.Mcp.Tools.Postgres.Services; +using Azure.Mcp.Tools.Postgres.UnitTests.Services.Support; +using Npgsql; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.Postgres.UnitTests.Services +{ + public class PostgresServiceTests + { + private readonly IResourceGroupService _resourceGroupService; + private readonly IEntraTokenProvider _entraTokenAuth; + private readonly IDbProvider _dbProvider; + private readonly PostgresService _postgresService; + + private string subscriptionId; + private string resourceGroup; + private string user; + private string server; + private string database; + private string query; + + public PostgresServiceTests() + { + _resourceGroupService = Substitute.For(); + + _entraTokenAuth = Substitute.For(); + _entraTokenAuth.GetEntraToken(Arg.Any()) + .Returns(new Azure.Core.AccessToken("fake-token", DateTime.UtcNow.AddHours(1))); + + _dbProvider = Substitute.For(); + _dbProvider.GetPostgresResource(Arg.Any()) + .Returns(Substitute.For()); + _dbProvider.GetCommand(Arg.Any(), Arg.Any()) + .Returns(Substitute.For()); + _dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Substitute.For()); + + _postgresService = new PostgresService(_resourceGroupService, _entraTokenAuth, _dbProvider); + + this.subscriptionId = "test-sub"; + this.resourceGroup = "test-rg"; + this.user = "test-user"; + this.server = "test-server"; + this.database = "test-db"; + this.query = "SELECT * FROM test-table;"; + } + + [Fact] + public async Task ExecuteQueryAsync_InvalidCastException_Test() + { + // This test verifies that queries that returns unsupported data types return an exception + // message that helps AI to understand the issue and fix the query. + + // Arrange + this._dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Task.FromResult(new FakeDbDataReader( + new object[][] { + new object[] { "row1", 1, new InvalidCastItem() }, + new object[] { "row2", 2, new InvalidCastItem() }, + new object[] { "row3", 3, new InvalidCastItem() } + }, + new[] { "string", "integer", "unsupported" }, + new[] { typeof(string), typeof(int), typeof(InvalidCastItem) }))); + + // Act + CommandValidationException exception = await Assert.ThrowsAsync(async () => + { + await _postgresService.ExecuteQueryAsync(subscriptionId, resourceGroup, user, server, database, query); + }); + + // Assert + Assert.Contains("The PostgreSQL query failed because it returned one or more columns with non-standard data types (extension or user-defined) unsupported by the MCP agent", exception.Message); + } + + [Fact] + public async Task ExecuteQueryAsync_MixedDataTypes_Test() + { + // This test verifies that queries that return supported data types work as expected. + + // Arrange + this._dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Task.FromResult(new FakeDbDataReader( + new object[][] { + new object[] { "row1", 1, }, + new object[] { "row2", 2, }, + new object[] { "row3", 3, } + }, + new[] { "string", "integer" }, + new[] { typeof(string), typeof(int), typeof(InvalidCastItem) }))); + + // Act + List rows = await _postgresService.ExecuteQueryAsync(subscriptionId, resourceGroup, user, server, database, query); + + // Assert + Assert.Equal(4, rows.Count); + Assert.Contains("string, integer", rows.ElementAt(0)); + Assert.Contains("row1, 1", rows.ElementAt(1)); + Assert.Contains("row2, 2", rows.ElementAt(2)); + Assert.Contains("row3, 3", rows.ElementAt(3)); + } + + [Fact] + public async Task ExecuteQueryAsync_NoRows_Test() + { + // This test verifies that if no elements are found, only the header row is returned. + + // Arrange + this._dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Task.FromResult(new FakeDbDataReader( + new object[][] { }, + new[] { "string", "integer" }, + new[] { typeof(string), typeof(int), typeof(InvalidCastItem) }))); + + // Act + List rows = await _postgresService.ExecuteQueryAsync(subscriptionId, resourceGroup, user, server, database, query); + + // Assert + Assert.Single(rows); + Assert.Contains("string, integer", rows.ElementAt(0)); + } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/FakeDbDataReader.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/FakeDbDataReader.cs new file mode 100644 index 0000000000..126b84be22 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/FakeDbDataReader.cs @@ -0,0 +1,191 @@ +using System.Collections; +using System.Data.Common; +using System.Globalization; + +namespace Azure.Mcp.Tools.Postgres.UnitTests.Services.Support; + +/// +/// In-memory for tests supporting heterogeneous column types. +/// +internal sealed class FakeDbDataReader(object[][] rows, + string[] columnNames, + Type[]? columnTypes = null, + string[]? dataTypeNames = null) + : DbDataReader +{ + private readonly object[][] _rows = rows; + private readonly string[] _columnNames = columnNames; + private readonly Type[] _columnTypes = columnTypes ?? Enumerable.Repeat(typeof(string), columnNames.Length).ToArray(); + private readonly string[] _dataTypeNames = dataTypeNames ?? + columnTypes?.Select(t => GetFriendlyTypeName(t)).ToArray() ?? + Enumerable.Repeat("text", columnNames.Length).ToArray(); + + private int _index = -1; + private bool _isClosed; + + /// + /// Backwards-compatible convenience ctor for all-string data. + /// + public FakeDbDataReader(string[][] stringRows, string[] columnNames) + : this(stringRows.Select(r => r.Cast().ToArray()).ToArray(), + columnNames, + Enumerable.Repeat(typeof(string), columnNames.Length).ToArray(), + Enumerable.Repeat("text", columnNames.Length).ToArray()) + { + } + + public override int FieldCount => _columnNames.Length; + public override bool HasRows => _rows.Length > 0; + public override bool IsClosed => _isClosed; + public override int RecordsAffected => 0; + public override int Depth => 0; + + public override object this[int ordinal] => GetValue(ordinal); + public override object this[string name] => GetValue(GetOrdinal(name)); + + public override string GetName(int ordinal) => _columnNames[ordinal]; + + public override int GetOrdinal(string name) + { + for (int i = 0; i < _columnNames.Length; i++) + { + if (string.Equals(_columnNames[i], name, StringComparison.Ordinal)) + { + return i; + } + } + throw new IndexOutOfRangeException($"Column '{name}' not found."); + } + + public override string GetDataTypeName(int ordinal) => _dataTypeNames[ordinal]; + public override Type GetFieldType(int ordinal) => _columnTypes[ordinal]; + + public override object GetValue(int ordinal) + { + EnsurePositioned(); + return _rows[_index][ordinal]!; + } + + public override int GetValues(object[] values) + { + int count = Math.Min(values.Length, FieldCount); + for (int i = 0; i < count; i++) + values[i] = GetValue(i)!; + return count; + } + + public override bool IsDBNull(int ordinal) => GetValue(ordinal) is null or DBNull; + + // Typed getters with safe conversion fallback + public override string GetString(int ordinal) => ConvertTo(ordinal); + public override bool GetBoolean(int ordinal) => ConvertTo(ordinal); + public override short GetInt16(int ordinal) => ConvertTo(ordinal); + public override int GetInt32(int ordinal) => ConvertTo(ordinal); + public override long GetInt64(int ordinal) => ConvertTo(ordinal); + public override float GetFloat(int ordinal) => ConvertTo(ordinal); + public override double GetDouble(int ordinal) => ConvertTo(ordinal); + public override decimal GetDecimal(int ordinal) => ConvertTo(ordinal); + public override DateTime GetDateTime(int ordinal) => ConvertTo(ordinal); + public override Guid GetGuid(int ordinal) + { + var v = GetValue(ordinal); + return v switch + { + Guid g => g, + string s when Guid.TryParse(s, out var g2) => g2, + _ => throw new InvalidCastException(GetInvalidCastMessage(ordinal, typeof(Guid), v)) + }; + } + + public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) => + throw new NotSupportedException("Binary data not supported in FakeDbDataReader."); + + public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) => + throw new NotSupportedException("Char streaming not supported in FakeDbDataReader."); + + public override char GetChar(int ordinal) => + throw new NotSupportedException("GetChar not implemented for FakeDbDataReader."); + + public override byte GetByte(int ordinal) => ConvertTo(ordinal); + + public override bool Read() + { + if (_index + 1 >= _rows.Length) + return false; + _index++; + return true; + } + + public override async Task ReadAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + await Task.Yield(); + return Read(); + } + + public override Task NextResultAsync(CancellationToken cancellationToken) => Task.FromResult(false); + public override bool NextResult() => false; + + public override IEnumerator GetEnumerator() => _rows.GetEnumerator(); + + public override void Close() => _isClosed = true; + protected override void Dispose(bool disposing) => _isClosed = true; + +#if NET8_0_OR_GREATER + public override ValueTask DisposeAsync() + { + _isClosed = true; + return ValueTask.CompletedTask; + } +#endif + + private void EnsurePositioned() + { + if (_index < 0 || _index >= _rows.Length) + { + throw new InvalidOperationException("The reader is not positioned on a valid row. Call Read() first."); + } + } + + private T ConvertTo(int ordinal) + { + var v = GetValue(ordinal); + if (v is null or DBNull) + { + throw new InvalidCastException(GetInvalidCastMessage(ordinal, typeof(T), v)); + } + + if (v is T tv) + return tv; + + try + { + // Handle string conversions explicitly for Guid, DateTime etc already handled above where needed. + if (typeof(T) == typeof(string)) + { + return (T)(object)v.ToString()!; + } + return (T)Convert.ChangeType(v, typeof(T), CultureInfo.InvariantCulture); + } + catch (Exception ex) + { + throw new InvalidCastException(GetInvalidCastMessage(ordinal, typeof(T), v), ex); + } + } + + private string GetInvalidCastMessage(int ordinal, Type target, object? value) => + $"Cannot convert column '{GetName(ordinal)}' (ordinal {ordinal}, type '{GetFieldType(ordinal).Name}') value '{value ?? "NULL"}' to {target.Name}."; + + private static string GetFriendlyTypeName(Type t) => + t == typeof(string) ? "text" : + t == typeof(int) ? "int4" : + t == typeof(long) ? "int8" : + t == typeof(short) ? "int2" : + t == typeof(bool) ? "bool" : + t == typeof(decimal) ? "numeric" : + t == typeof(double) ? "float8" : + t == typeof(float) ? "float4" : + t == typeof(DateTime) ? "timestamp" : + t == typeof(Guid) ? "uuid" : + t.Name.ToLowerInvariant(); +} diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/InvalidCastItem.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/InvalidCastItem.cs new file mode 100644 index 0000000000..c8b46eec0e --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/InvalidCastItem.cs @@ -0,0 +1,10 @@ +namespace Azure.Mcp.Tools.Postgres.UnitTests.Services.Support +{ + internal class InvalidCastItem + { + public override string ToString() + { + throw new InvalidCastException("This is an invalid cast item."); + } + } +} From d205ed0c61d960ac99bc1911c2ef1c00164fb504 Mon Sep 17 00:00:00 2001 From: Juancho Date: Tue, 11 Nov 2025 13:50:33 +0100 Subject: [PATCH 3/3] sorted usings --- .../Services/PostgresServiceParameterizedQueryTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs index 86db9d028f..436aac5b4e 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs @@ -4,9 +4,9 @@ using System.Data.Common; using Azure.Core; using Azure.Mcp.Core.Services.Azure.ResourceGroup; +using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Postgres.Auth; using Azure.Mcp.Tools.Postgres.Providers; -using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Tools.Postgres.Services; using Npgsql; using NSubstitute;