Skip to content

Commit

Permalink
Microsoft.Data.Sqlite: Fix handling of queries with RETURNING clause (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bricelam committed Jul 6, 2023
1 parent 9e49303 commit 71fe707
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 15 deletions.
25 changes: 15 additions & 10 deletions src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public override bool NextResult()
// It's a SELECT statement
if (sqlite3_column_count(stmt) != 0)
{
_record = new SqliteDataRecord(stmt, rc != SQLITE_DONE, _command.Connection);
_record = new SqliteDataRecord(stmt, rc != SQLITE_DONE, _command.Connection, AddChanges);

return true;
}
Expand All @@ -191,14 +191,7 @@ public override bool NextResult()
sqlite3_reset(stmt);

var changes = sqlite3_changes(_command.Connection.Handle);
if (_recordsAffected == -1)
{
_recordsAffected = changes;
}
else
{
_recordsAffected += changes;
}
AddChanges(changes);
}
catch
{
Expand All @@ -219,6 +212,18 @@ private static bool IsBusy(int rc)
|| rc == SQLITE_BUSY
|| rc == SQLITE_LOCKED_SHAREDCACHE;

private void AddChanges(int changes)
{
if (_recordsAffected == -1)
{
_recordsAffected = changes;
}
else
{
_recordsAffected += changes;
}
}

/// <summary>
/// Closes the data reader.
/// </summary>
Expand All @@ -242,14 +247,14 @@ protected override void Dispose(bool disposing)
_command.DataReader = null;

_record?.Dispose();
_record = null;

if (_stmtEnumerator != null)
{
try
{
while (NextResult())
{
_record!.Dispose();
}
}
catch
Expand Down
54 changes: 49 additions & 5 deletions src/Microsoft.Data.Sqlite.Core/SqliteDataRecord.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ namespace Microsoft.Data.Sqlite
internal class SqliteDataRecord : SqliteValueReader, IDisposable
{
private readonly SqliteConnection _connection;
private readonly Action<int> _addChanges;
private byte[][]? _blobCache;
private int?[]? _typeCache;
private Dictionary<string, int>? _columnNameOrdinalCache;
private string[]? _columnNameCache;
private bool _stepped;
private int? _rowidOrdinal;
private bool _alreadyThrown;
private bool _alreadyAddedChanges;

public SqliteDataRecord(sqlite3_stmt stmt, bool hasRows, SqliteConnection connection)
public SqliteDataRecord(sqlite3_stmt stmt, bool hasRows, SqliteConnection connection, Action<int> addChanges)
{
Handle = stmt;
HasRows = hasRows;
_connection = connection;
_addChanges = addChanges;
}

public virtual object this[string name]
Expand Down Expand Up @@ -397,19 +401,59 @@ public bool Read()
return false;
}

var rc = sqlite3_step(Handle);
SqliteException.ThrowExceptionForRC(rc, _connection.Handle);
int rc;
try
{
rc = sqlite3_step(Handle);
SqliteException.ThrowExceptionForRC(rc, _connection.Handle);
}
catch
{
_alreadyThrown = true;

throw;
}

if (_blobCache != null)
{
Array.Clear(_blobCache, 0, _blobCache.Length);
}

return rc != SQLITE_DONE;
if (rc != SQLITE_DONE)
{
return true;
}

AddChanges();
_alreadyAddedChanges = true;

return false;
}

public void Dispose()
=> sqlite3_reset(Handle);
{
var rc = sqlite3_reset(Handle);
if (!_alreadyThrown)
{
SqliteException.ThrowExceptionForRC(rc, _connection.Handle);
}

if (!_alreadyAddedChanges)
{
AddChanges();
}
}

private void AddChanges()
{
if (sqlite3_stmt_readonly(Handle) != 0)
{
return;
}

var changes = sqlite3_changes(_connection.Handle);
_addChanges(changes);
}

private byte[] GetCachedBlob(int ordinal)
{
Expand Down
114 changes: 114 additions & 0 deletions test/Microsoft.Data.Sqlite.Tests/SqliteCommandTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,120 @@ public async Task ExecuteReader_retries_when_busy()
}
}

[Fact]
public Task ExecuteScalar_throws_when_busy_with_returning()
=> Execute_throws_when_busy_with_returning(command =>
{
var ex = Assert.Throws<SqliteException>(
() => command.ExecuteScalar());
Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode);
});

[Fact]
public Task ExecuteNonQuery_throws_when_busy_with_returning()
=> Execute_throws_when_busy_with_returning(command =>
{
var ex = Assert.Throws<SqliteException>(
() => command.ExecuteNonQuery());
Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode);
});

[Fact]
public Task ExecuteReader_throws_when_busy_with_returning()
=> Execute_throws_when_busy_with_returning(command =>
{
var reader = command.ExecuteReader();
try
{
Assert.True(reader.Read());
Assert.Equal(2L, reader.GetInt64(0));
}
finally
{
var ex = Assert.Throws<SqliteException>(
() => reader.Dispose());
Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode);
}
});

[Fact]
public Task ExecuteReader_throws_when_busy_with_returning_while_draining()
=> Execute_throws_when_busy_with_returning(command =>
{
using var reader = command.ExecuteReader();
Assert.True(reader.Read());
Assert.Equal(2L, reader.GetInt64(0));
Assert.True(reader.Read());
Assert.Equal(3L, reader.GetInt64(0));
var ex = Assert.Throws<SqliteException>(
() => reader.Read());
Assert.Equal(SQLITE_BUSY, ex.SqliteErrorCode);
});

private static async Task Execute_throws_when_busy_with_returning(Action<SqliteCommand> action)
{
const string connectionString = "Data Source=returning.db";

var selectedSignal = new AutoResetEvent(initialState: false);

try
{
using var connection1 = new SqliteConnection(connectionString);

if (new Version(connection1.ServerVersion) < new Version(3, 35, 0))
{
// Skip. RETURNING clause not supported
return;
}

connection1.Open();

connection1.ExecuteNonQuery(
"CREATE TABLE Data (Value); INSERT INTO Data VALUES (0);");

await Task.WhenAll(
Task.Run(
async () =>
{
using var connection = new SqliteConnection(connectionString);
connection.Open();
using (connection.ExecuteReader("SELECT * FROM Data;"))
{
selectedSignal.Set();
await Task.Delay(1000);
}
}),
Task.Run(
() =>
{
using var connection = new SqliteConnection(connectionString);
connection.Open();
selectedSignal.WaitOne();
var command = connection.CreateCommand();
command.CommandText = "INSERT INTO Data VALUES (1),(2) RETURNING rowid;";
action(command);
}));

var count = connection1.ExecuteScalar<long>("SELECT COUNT(*) FROM Data;");
Assert.Equal(1L, count);
}
finally
{
SqliteConnection.ClearPool(new SqliteConnection(connectionString));
File.Delete("returning.db");
}
}

[Fact]
public void ExecuteReader_honors_CommandTimeout()
{
Expand Down
67 changes: 67 additions & 0 deletions test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1851,6 +1851,73 @@ public void RecordsAffected_works_during_enumeration()
}
}

[Fact]
public void RecordsAffected_works_with_returning()
{
using (var connection = new SqliteConnection("Data Source=:memory:"))
{
if (new Version(connection.ServerVersion) < new Version(3, 35, 0))
{
// Skip. RETURNING clause not supported
return;
}

connection.Open();
connection.ExecuteNonQuery("CREATE TABLE Test(Value);");

var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1) RETURNING rowid;");
((IDisposable)reader).Dispose();

Assert.Equal(1, reader.RecordsAffected);
}
}

[Fact]
public void RecordsAffected_works_with_returning_before_dispose_after_draining()
{
using (var connection = new SqliteConnection("Data Source=:memory:"))
{
if (new Version(connection.ServerVersion) < new Version(3, 35, 0))
{
// Skip. RETURNING clause not supported
return;
}

connection.Open();
connection.ExecuteNonQuery("CREATE TABLE Test(Value);");

using (var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1) RETURNING rowid;"))
{
while (reader.Read())
{
}

Assert.Equal(1, reader.RecordsAffected);
}
}
}

[Fact]
public void RecordsAffected_works_with_returning_multiple()
{
using (var connection = new SqliteConnection("Data Source=:memory:"))
{
if (new Version(connection.ServerVersion) < new Version(3, 35, 0))
{
// Skip. RETURNING clause not supported
return;
}

connection.Open();
connection.ExecuteNonQuery("CREATE TABLE Test(Value);");

var reader = connection.ExecuteReader("INSERT INTO Test VALUES(1),(2) RETURNING rowid;");
((IDisposable)reader).Dispose();

Assert.Equal(2, reader.RecordsAffected);
}
}

[Fact]
public void GetSchemaTable_works()
{
Expand Down

0 comments on commit 71fe707

Please sign in to comment.