diff --git a/doc/samples/AzureKeyVaultProviderExample_2_0.cs b/doc/samples/AzureKeyVaultProviderExample_2_0.cs index 95733310a8..f241966458 100644 --- a/doc/samples/AzureKeyVaultProviderExample_2_0.cs +++ b/doc/samples/AzureKeyVaultProviderExample_2_0.cs @@ -242,7 +242,5 @@ public CustomerRecord(int id, string fName, string lName) } } } - - // - } +// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 0b6cc4abf8..16d707f3fd 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -5837,11 +5837,22 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto { Debug.Assert(CommandType == CommandType.StoredProcedure, "BuildStoredProcedureStatementForColumnEncryption() should only be called for stored procedures"); Debug.Assert(!string.IsNullOrWhiteSpace(storedProcedureName), "storedProcedureName cannot be null or empty in BuildStoredProcedureStatementForColumnEncryption"); - Debug.Assert(parameters != null, "parameters cannot be null in BuildStoredProcedureStatementForColumnEncryption"); StringBuilder execStatement = new StringBuilder(); execStatement.Append(@"EXEC "); + if (parameters is null) + { + execStatement.Append(ParseAndQuoteIdentifier(storedProcedureName, false)); + return new SqlParameter( + null, + ((execStatement.Length << 1) <= TdsEnums.TYPE_SIZE_LIMIT) ? SqlDbType.NVarChar : SqlDbType.NText, + execStatement.Length) + { + Value = execStatement.ToString() + }; + } + // Find the return value parameter (if any). SqlParameter returnValueParameter = null; foreach (SqlParameter parameter in parameters) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index 0d773dfd82..a4de7ead2b 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -6786,11 +6786,22 @@ private SqlParameter BuildStoredProcedureStatementForColumnEncryption(string sto { Debug.Assert(CommandType == CommandType.StoredProcedure, "BuildStoredProcedureStatementForColumnEncryption() should only be called for stored procedures"); Debug.Assert(!string.IsNullOrWhiteSpace(storedProcedureName), "storedProcedureName cannot be null or empty in BuildStoredProcedureStatementForColumnEncryption"); - Debug.Assert(parameters != null, "parameters cannot be null in BuildStoredProcedureStatementForColumnEncryption"); StringBuilder execStatement = new StringBuilder(); execStatement.Append(@"EXEC "); + if (parameters is null) + { + execStatement.Append(ParseAndQuoteIdentifier(storedProcedureName, false)); + return new SqlParameter( + null, + ((execStatement.Length << 1) <= TdsEnums.TYPE_SIZE_LIMIT) ? SqlDbType.NVarChar : SqlDbType.NText, + execStatement.Length) + { + Value = execStatement.ToString() + }; + } + // Find the return value parameter (if any). SqlParameter returnValueParameter = null; foreach (SqlParameter parameter in parameters) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs index b92a642ae8..d27b985864 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs @@ -878,6 +878,53 @@ public void TestExecuteReaderWithCommandBehavior(string connection, CommandBehav }); } + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringSetupForAE))] + [ClassData(typeof(AEConnectionStringProvider))] + public void TestEnclaveStoredProceduresWithAndWithoutParameters(string connectionString) + { + using SqlConnection sqlConnection = new(connectionString); + sqlConnection.Open(); + + using SqlCommand sqlCommand = new("", sqlConnection, transaction: null, + columnEncryptionSetting: SqlCommandColumnEncryptionSetting.Enabled); + + string procWithoutParams = DataTestUtility.GetUniqueName("EnclaveWithoutParams", withBracket: false); + string procWithParam = DataTestUtility.GetUniqueName("EnclaveWithParams", withBracket: false); + + try + { + sqlCommand.CommandText = $"CREATE PROCEDURE {procWithoutParams} AS SELECT FirstName, LastName FROM [{_tableName}];"; + sqlCommand.ExecuteNonQuery(); + sqlCommand.CommandText = $"CREATE PROCEDURE {procWithParam} @id INT AS SELECT FirstName, LastName FROM [{_tableName}] WHERE CustomerId = @id"; + sqlCommand.ExecuteNonQuery(); + int expectedFields = 2; + + sqlCommand.CommandText = procWithoutParams; + sqlCommand.CommandType = CommandType.StoredProcedure; + using (SqlDataReader reader = sqlCommand.ExecuteReader()) + { + Assert.Equal(expectedFields, reader.VisibleFieldCount); + } + + sqlCommand.CommandText = procWithParam; + sqlCommand.CommandType = CommandType.StoredProcedure; + Exception ex = Assert.Throws(() => sqlCommand.ExecuteReader()); + string expectedMsg = $"Procedure or function '{procWithParam}' expects parameter '@id', which was not supplied."; + + Assert.Equal(expectedMsg, ex.Message); + + sqlCommand.Parameters.AddWithValue("@id", 0); + using (SqlDataReader reader = sqlCommand.ExecuteReader()) + { + Assert.Equal(expectedFields, reader.VisibleFieldCount); + } + } + finally + { + DropHelperProcedures(new[] { procWithoutParams, procWithParam }, connectionString); + } + } + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringSetupForAE))] [ClassData(typeof(AEConnectionStringProvider))] public void TestPrepareWithExecuteNonQuery(string connection) @@ -2262,7 +2309,8 @@ public void TestSystemProvidersHavePrecedenceOverInstanceLevelProviders(string c connection.Open(); using SqlCommand command = CreateCommandThatRequiresSystemKeyStoreProvider(connection); connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customKeyStoreProviders); - command.ExecuteReader(); + SqlDataReader reader = command.ExecuteReader(); + Assert.Equal(3, reader.VisibleFieldCount); } using (SqlConnection connection = new(connectionString)) @@ -2270,7 +2318,8 @@ public void TestSystemProvidersHavePrecedenceOverInstanceLevelProviders(string c connection.Open(); using SqlCommand command = CreateCommandThatRequiresSystemKeyStoreProvider(connection); command.RegisterColumnEncryptionKeyStoreProvidersOnCommand(customKeyStoreProviders); - command.ExecuteReader(); + SqlDataReader reader = command.ExecuteReader(); + Assert.Equal(3, reader.VisibleFieldCount); } }