Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,14 @@ public class AndroidxSqliteDriver(
): QueryResult<Long> {
createOrMigrateIfNeeded()

fun SQLiteConnection.getTotalChangedRows() =
prepare("SELECT changes()").use { statement ->
when {
statement.step() -> statement.getLong(0)
else -> 0
}
}

val transaction = currentTransaction()
if(transaction == null) {
val writerConnection = connectionPool.acquireWriterConnection()
Expand All @@ -362,7 +370,10 @@ public class AndroidxSqliteDriver(
)
},
binders = binders,
result = { execute() },
result = {
execute()
writerConnection.getTotalChangedRows()
},
)
} finally {
connectionPool.releaseWriterConnection()
Expand All @@ -379,7 +390,10 @@ public class AndroidxSqliteDriver(
)
},
binders = binders,
result = { execute() },
result = {
execute()
connection.getTotalChangedRows()
},
)
}
}
Expand Down Expand Up @@ -564,7 +578,7 @@ private fun SQLiteConnection.reportForeignKeyViolations(
}

internal interface AndroidxStatement : SqlPreparedStatement {
fun execute(): Long
fun execute()
fun <R> executeQuery(mapper: (SqlCursor) -> QueryResult<R>): R
fun reset()
fun close()
Expand Down Expand Up @@ -601,12 +615,11 @@ private class AndroidxPreparedStatement(
override fun <R> executeQuery(mapper: (SqlCursor) -> QueryResult<R>): R =
throw UnsupportedOperationException()

override fun execute(): Long {
override fun execute() {
var cont = true
while(cont) {
cont = statement.step()
}
return statement.getColumnCount().toLong()
}

override fun toString() = sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,4 +282,50 @@ abstract class AndroidxSqliteDriverTest {
}
}
}

@Test
fun `row count is correctly returned after an insert`() {
val rowCount = driver.execute(null, "INSERT INTO test VALUES (?, ?)", 2) {
bindLong(0, 1)
bindString(1, "42")
}.value

assertEquals(1, rowCount)
}

@Test
fun `row count is correctly returned after an update`() {
val rowCount = driver.execute(null, "UPDATE test SET value = ?", 1) {
bindString(0, "42")
}.value

assertEquals(0, rowCount)

driver.execute(null, "INSERT INTO test VALUES (?, ?)", 2) {
bindLong(0, 1)
bindString(1, "41")
}

val rowCount2 = driver.execute(null, "UPDATE test SET value = ?", 1) {
bindString(0, "42")
}.value

assertEquals(1, rowCount2)
}

@Test
fun `row count is correctly returned after a delete`() {
val rowCount = driver.execute(null, "DELETE FROM test", 0).value

assertEquals(0, rowCount)

driver.execute(null, "INSERT INTO test VALUES (?, ?)", 2) {
bindLong(0, 1)
bindString(1, "41")
}

val rowCount2 = driver.execute(null, "DELETE FROM test", 0).value

assertEquals(1, rowCount2)
}
}