Skip to content

Commit

Permalink
Use underlying enum type when writing binary parameters. Fixes #1421
Browse files Browse the repository at this point in the history
  • Loading branch information
bgrainger committed Dec 18, 2023
1 parent 8748a72 commit 2d44a00
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 42 deletions.
94 changes: 52 additions & 42 deletions src/MySqlConnector/MySqlParameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -637,116 +637,124 @@ internal void AppendBinary(ByteBufferWriter writer, StatementPreparerOptions opt
{
// stored in "null bitmap" only
}
else if (Value is string stringValue)
else
{
AppendBinary(writer, Value, options);
}
}

private void AppendBinary(ByteBufferWriter writer, object value, StatementPreparerOptions options)
{
if (value is string stringValue)
{
writer.WriteLengthEncodedString(stringValue);
}
else if (Value is char charValue)
else if (value is char charValue)
{
writer.WriteLengthEncodedString(charValue.ToString());
}
else if (Value is sbyte sbyteValue)
else if (value is sbyte sbyteValue)
{
writer.Write(unchecked((byte) sbyteValue));
}
else if (Value is byte byteValue)
else if (value is byte byteValue)
{
writer.Write(byteValue);
}
else if (Value is bool boolValue)
else if (value is bool boolValue)
{
writer.Write((byte) (boolValue ? 1 : 0));
}
else if (Value is short shortValue)
else if (value is short shortValue)
{
writer.Write(unchecked((ushort) shortValue));
}
else if (Value is ushort ushortValue)
else if (value is ushort ushortValue)
{
writer.Write(ushortValue);
}
else if (Value is int intValue)
else if (value is int intValue)
{
writer.Write(intValue);
}
else if (Value is uint uintValue)
else if (value is uint uintValue)
{
writer.Write(uintValue);
}
else if (Value is long longValue)
else if (value is long longValue)
{
writer.Write(unchecked((ulong) longValue));
}
else if (Value is ulong ulongValue)
else if (value is ulong ulongValue)
{
writer.Write(ulongValue);
}
else if (Value is byte[] byteArrayValue)
else if (value is byte[] byteArrayValue)
{
writer.WriteLengthEncodedInteger(unchecked((ulong) byteArrayValue.Length));
writer.Write(byteArrayValue);
}
else if (Value is ReadOnlyMemory<byte> readOnlyMemoryValue)
else if (value is ReadOnlyMemory<byte> readOnlyMemoryValue)
{
writer.WriteLengthEncodedInteger(unchecked((ulong) readOnlyMemoryValue.Length));
writer.Write(readOnlyMemoryValue.Span);
}
else if (Value is Memory<byte> memoryValue)
else if (value is Memory<byte> memoryValue)
{
writer.WriteLengthEncodedInteger(unchecked((ulong) memoryValue.Length));
writer.Write(memoryValue.Span);
}
else if (Value is ArraySegment<byte> arraySegmentValue)
else if (value is ArraySegment<byte> arraySegmentValue)
{
writer.WriteLengthEncodedInteger(unchecked((ulong) arraySegmentValue.Count));
writer.Write(arraySegmentValue);
}
else if (Value is MySqlGeometry geometry)
else if (value is MySqlGeometry geometry)
{
writer.WriteLengthEncodedInteger(unchecked((ulong) geometry.ValueSpan.Length));
writer.Write(geometry.ValueSpan);
}
else if (Value is MemoryStream memoryStream)
else if (value is MemoryStream memoryStream)
{
if (!memoryStream.TryGetBuffer(out var streamBuffer))
streamBuffer = new ArraySegment<byte>(memoryStream.ToArray());
writer.WriteLengthEncodedInteger(unchecked((ulong) streamBuffer.Count));
writer.Write(streamBuffer);
}
else if (Value is float floatValue)
else if (value is float floatValue)
{
writer.Write(BitConverter.GetBytes(floatValue));
}
else if (Value is double doubleValue)
else if (value is double doubleValue)
{
writer.Write(unchecked((ulong) BitConverter.DoubleToInt64Bits(doubleValue)));
}
else if (Value is decimal decimalValue)
else if (value is decimal decimalValue)
{
writer.WriteLengthEncodedAsciiString(decimalValue.ToString(CultureInfo.InvariantCulture));
}
else if (Value is BigInteger bigInteger)
else if (value is BigInteger bigInteger)
{
writer.WriteLengthEncodedAsciiString(bigInteger.ToString(CultureInfo.InvariantCulture));
}
else if (Value is MySqlDateTime mySqlDateTimeValue)
else if (value is MySqlDateTime mySqlDateTimeValue)
{
if (mySqlDateTimeValue.IsValidDateTime)
WriteDateTime(writer, mySqlDateTimeValue.GetDateTime());
else
writer.Write((byte) 0);
}
else if (Value is MySqlDecimal mySqlDecimal)
else if (value is MySqlDecimal mySqlDecimal)
{
writer.WriteLengthEncodedAsciiString(mySqlDecimal.ToString());
}
#if NET6_0_OR_GREATER
else if (Value is DateOnly dateOnlyValue)
else if (value is DateOnly dateOnlyValue)
{
WriteDateOnly(writer, dateOnlyValue);
}
#endif
else if (Value is DateTime dateTimeValue)
else if (value is DateTime dateTimeValue)
{
if ((options & StatementPreparerOptions.DateTimeUtc) != 0 && dateTimeValue.Kind == DateTimeKind.Local)
throw new MySqlException($"DateTime.Kind must not be Local when DateTimeKind setting is Utc (parameter name: {ParameterName})");
Expand All @@ -755,22 +763,22 @@ internal void AppendBinary(ByteBufferWriter writer, StatementPreparerOptions opt

WriteDateTime(writer, dateTimeValue);
}
else if (Value is DateTimeOffset dateTimeOffsetValue)
else if (value is DateTimeOffset dateTimeOffsetValue)
{
// store as UTC as it will be read as such when deserialized from a timespan column
WriteDateTime(writer, dateTimeOffsetValue.UtcDateTime);
}
#if NET6_0_OR_GREATER
else if (Value is TimeOnly timeOnlyValue)
else if (value is TimeOnly timeOnlyValue)
{
WriteTime(writer, timeOnlyValue.ToTimeSpan());
}
#endif
else if (Value is TimeSpan ts)
else if (value is TimeSpan ts)
{
WriteTime(writer, ts);
}
else if (Value is Guid guidValue)
else if (value is Guid guidValue)
{
StatementPreparerOptions guidOptions = options & StatementPreparerOptions.GuidFormatMask;
if (guidOptions is StatementPreparerOptions.GuidFormatBinary16 or StatementPreparerOptions.GuidFormatTimeSwapBinary16 or StatementPreparerOptions.GuidFormatLittleEndianBinary16)
Expand Down Expand Up @@ -817,53 +825,55 @@ internal void AppendBinary(ByteBufferWriter writer, StatementPreparerOptions opt
writer.Advance(guidLength);
}
}
else if (Value is ReadOnlyMemory<char> readOnlyMemoryChar)
else if (value is ReadOnlyMemory<char> readOnlyMemoryChar)
{
writer.WriteLengthEncodedString(readOnlyMemoryChar.Span);
}
else if (Value is Memory<char> memoryChar)
else if (value is Memory<char> memoryChar)
{
writer.WriteLengthEncodedString(memoryChar.Span);
}
else if (Value is StringBuilder stringBuilder)
else if (value is StringBuilder stringBuilder)
{
writer.WriteLengthEncodedString(stringBuilder);
}
else if ((MySqlDbType is MySqlDbType.String or MySqlDbType.VarChar) && HasSetDbType && Value is Enum stringEnumValue)
else if ((MySqlDbType is MySqlDbType.String or MySqlDbType.VarChar) && HasSetDbType && value is Enum stringEnumValue)
{
writer.WriteLengthEncodedString(stringEnumValue.ToString("G"));
}
else if (Value is Enum)
else if (value is Enum)
{
writer.Write(Convert.ToInt32(Value, CultureInfo.InvariantCulture));
// using the underlying type matches the log in TypeMapper.GetDbTypeMapping, which controls the column type value that was sent to the server
var underlyingValue = Convert.ChangeType(value, Enum.GetUnderlyingType(value.GetType()), CultureInfo.InvariantCulture);
AppendBinary(writer, underlyingValue, options);
}
else if (MySqlDbType == MySqlDbType.Int16)
{
writer.Write((ushort) (short) Value);
writer.Write((ushort) (short) value);
}
else if (MySqlDbType == MySqlDbType.UInt16)
{
writer.Write((ushort) Value);
writer.Write((ushort) value);
}
else if (MySqlDbType == MySqlDbType.Int32)
{
writer.Write((int) Value);
writer.Write((int) value);
}
else if (MySqlDbType == MySqlDbType.UInt32)
{
writer.Write((uint) Value);
writer.Write((uint) value);
}
else if (MySqlDbType == MySqlDbType.Int64)
{
writer.Write((ulong) (long) Value);
writer.Write((ulong) (long) value);
}
else if (MySqlDbType == MySqlDbType.UInt64)
{
writer.Write((ulong) Value);
writer.Write((ulong) value);
}
else
{
throw new NotSupportedException($"Parameter type {Value.GetType().Name} is not supported; see https://fl.vu/mysql-param-type. Value: {Value}");
throw new NotSupportedException($"Parameter type {value.GetType().Name} is not supported; see https://fl.vu/mysql-param-type. Value: {value}");
}
}

Expand Down
57 changes: 57 additions & 0 deletions tests/IntegrationTests/PreparedCommandTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,63 @@ public static IEnumerable<object[]> GetDifferentTypeInsertAndQueryData()
}
}

[Theory, MemberData(nameof(GetDifferentWidthData))]
public void InsertDifferentWidthParameters(string dataType1, object value1, string dataType2, object value2)
{
using var connection = CreateConnection();
using var command = connection.CreateCommand();
command.CommandText = $"""
drop table if exists parameter_width;
create table parameter_width(value1 {dataType1} not null, value2 {dataType2} not null);
""";
command.ExecuteNonQuery();

command.CommandText = "insert into parameter_width(value1, value2) values(@value1, @value2);";
command.Parameters.AddWithValue("@value1", value1);
command.Parameters.AddWithValue("@value2", value2);
command.Prepare();
command.ExecuteNonQuery();

using var queryCommand = connection.CreateCommand();
queryCommand.CommandText = "select value1, value2 from parameter_width;";
using var reader = queryCommand.ExecuteReader();
Assert.True(reader.Read());
Assert.Equal(Convert.ToInt32(value1), reader.GetInt32(0));
Assert.Equal(Convert.ToInt32(value2), reader.GetInt32(1));
Assert.False(reader.Read());
}

public static IEnumerable<object[]> GetDifferentWidthData()
{
var dataTypes = new string[] { "TINYINT", "SMALLINT", "INT" };
var values = new object[] { (sbyte) 100, (short) 110, 120, OneByteEnum.Value, TwoByteEnum.Value };

foreach (var dataType1 in dataTypes)
{
foreach (var value1 in values)
{
foreach (var dataType2 in dataTypes)
{
foreach (var value2 in values)
{
yield return new object[] { dataType1, value1, dataType2, value2 };
}
}
}
}
}

private enum OneByteEnum : sbyte
{
Value = 101,
}

private enum TwoByteEnum : short
{
Value = 111,
}


private static MySqlConnection CreateConnection()
{
var connection = new MySqlConnection(AppConfig.ConnectionString);
Expand Down

0 comments on commit 2d44a00

Please sign in to comment.