Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion beam-sqlite/ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
# Unreleased

## Added features

* `runInsertReturningList` now uses SQLite's relatively new `RETURNING` clause.

## Bux fixes

* Fixed an issue where values inserted with conflicts did not return then when using `runInsertReturningList` (#774)

## Updated dependencies

* Updated the upper bound on `time` to include `time-1.14`
* Updated the lower bound of `direct-sqlite` to `2.3.27`.
* Updated the upper bound on `time` to include `time-1.14`.

# 0.5.4.1

Expand Down
131 changes: 67 additions & 64 deletions beam-sqlite/Database/Beam/Sqlite/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ module Database.Beam.Sqlite.Connection

, runBeamSqlite, runBeamSqliteDebug

-- * Emulated @INSERT RETURNING@ support
, insertReturning, runInsertReturningList
) where

Expand Down Expand Up @@ -39,22 +38,21 @@ import Database.Beam.Schema.Tables ( Beamable
, DatabaseEntityDescriptor(..)
, TableEntity
, TableField(..)
, allBeamValues
, changeBeamRep )
import Database.Beam.Sqlite.Syntax

import Database.SQLite.Simple ( Connection, ToRow(..), FromRow(..)
, Query(..), SQLData(..), field
, execute, execute_
, SQLData(..), field
, execute
, withStatement, bind, nextRow
, query_, open, close )
, open, close )
import Database.SQLite.Simple.FromField ( FromField(..), ResultError(..)
, returnError, fieldData)
import Database.SQLite.Simple.Internal (RowParser(RP), unRP)
import Database.SQLite.Simple.Ok (Ok(..))
import Database.SQLite.Simple.Types (Null)

import Control.Exception (SomeException(..), bracket_, onException, mask)
import Control.Exception (SomeException(..))
import Control.Monad (forM_)
import Control.Monad.Base (MonadBase)
import Control.Monad.Fail (MonadFail(..))
Expand All @@ -71,20 +69,19 @@ import Data.ByteString.Builder (toLazyByteString)
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy.Char8 as BL
import qualified Data.DList as D
import Data.Hashable (hash)
import Data.Int
import Data.IORef (newIORef, atomicModifyIORef')
import Data.Maybe (mapMaybe)
import Data.Proxy (Proxy(..))
import Data.Scientific (Scientific)
import Data.String (fromString)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T (decodeUtf8)
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TL (decodeUtf8)
import Data.Time ( LocalTime, UTCTime, Day
, ZonedTime, utc, utcToLocalTime, getCurrentTime )
, ZonedTime, utc, utcToLocalTime )
import Data.Typeable (cast)
import Data.Word
import GHC.IORef (atomicModifyIORef'_)
import GHC.TypeLits

import Network.URI
Expand Down Expand Up @@ -335,11 +332,64 @@ instance MonadBeam Sqlite SqliteM where
Nothing -> pure Nothing
Just (BeamSqliteRow row) -> pure row
runReaderT (runSqliteM (action nextRow')) (logger, conn)
runReturningMany SqliteCommandInsert {} _ =
fail . mconcat $
[ "runReturningMany{Sqlite}: sqlite does not support returning "
, "rows from an insert, use Database.Beam.Sqlite.insertReturning "
, "for emulation" ]
runReturningMany (SqliteCommandInsert (SqliteInsertSyntax tbl fields vs onConflict)) action
| SqliteInsertExpressions es <- vs, any (any (== SqliteExpressionDefault)) es =
-- SQLite's handling of default values differs from other DBMses because
-- it lacks support for DEFAULT. In order to insert a default value in a column,
-- the column's name should be omitted from the INSERT statement.
--
-- This is problematic if you insert multiple rows, some of which have defaults;
-- you must use multiple INSERT statements. This is what we do below.
--
-- However, to respect the 'runReturningMany' interface, be must accumulate the
-- results of all those inserts into an 'IORef [a]', and then feed the results
-- incrementally to 'action'.
SqliteM $ do
(logger, conn) <- ask
resultsRef <- liftIO (newIORef [])
forM_ es $ \row -> do
-- RETURNING is only supported by SQLite 3.35+, which requires direct-sqlite 2.3.27+
let returningClause = emit " RETURNING " <> commas (map quotedIdentifier fields)
(insertFields, insertRow) = unzip $ filter ((/= SqliteExpressionDefault) . snd) $ zip fields row
SqliteSyntax cmd vals = formatSqliteInsertOnConflict tbl insertFields (SqliteInsertExpressions [ insertRow ]) onConflict <> returningClause
cmdString = BL.unpack (toLazyByteString (withPlaceholders cmd))
Comment on lines +350 to +355
Copy link

@sheaf sheaf Oct 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only a minor comment (that probably should be punted to a separate ticket), but: would it be possible to instead classify the rows by which fields have DEFAULT values? For example, if all the rows to be inserted have DEFAULT in the same position, then we can still insert all the rows at once (although we still need to filter out the DEFAULT fields).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's very true.
I considered doing something like it, but I started with this simpler implementation instead because the order of insertion is preserved this way. If you grouped rows by the positions of their default values, then you would have to reconstruct the return order.
This isn't impossible, but this is significantly more complex. As you mention, this could be punted to a new ticket.

Actually, I see you have made a prototype implementation for this below; would you like to submit a PR?


liftIO $ do
logger (cmdString ++ ";\n-- With values: " ++ show (D.toList vals))
withStatement conn (fromString cmdString) $ \stmt ->
do bind stmt (BeamSqliteParams (D.toList vals))
unfoldM (nextRow stmt) >>= \new -> atomicModifyIORef'_ resultsRef (new ++)

-- We must reverse the list in the IORef because it has been constructed in reverse
-- order. We construct the list in reverse because it's faster to prepend to
-- a linked list
_ <- liftIO (atomicModifyIORef'_ resultsRef reverse)
let nextRow' = liftIO $ do
atomicModifyIORef' resultsRef $ \results -> case results of
(BeamSqliteRow h:rest) -> (rest, Just h)
[] -> ([], Nothing)
runSqliteM (action nextRow')
| otherwise =
SqliteM $ do
(logger, conn) <- ask
let returningClause = emit " RETURNING " <> commas (map quotedIdentifier fields)
SqliteSyntax cmd vals = formatSqliteInsertOnConflict tbl fields vs onConflict <> returningClause
cmdString = BL.unpack (toLazyByteString (withPlaceholders cmd))
liftIO $ do
logger (cmdString ++ ";\n-- With values: " ++ show (D.toList vals))
withStatement conn (fromString cmdString) $ \stmt ->
do bind stmt (BeamSqliteParams (D.toList vals))
let nextRow' = liftIO (nextRow stmt) >>= \x ->
case x of
Nothing -> pure Nothing
Just (BeamSqliteRow row) -> pure row
runReaderT (runSqliteM (action nextRow')) (logger, conn)


unfoldM :: Monad m => m (Maybe a) -> m [a]
unfoldM f = go []
where
go acc = f >>= maybe (pure acc) (\x -> go (x : acc))

instance Beam.MonadBeamInsertReturning Sqlite SqliteM where
runInsertReturningList = runInsertReturningList
Expand All @@ -361,7 +411,7 @@ runSqliteInsert logger conn (SqliteInsertSyntax tbl fields vs onConflict)
logger (cmdString ++ ";\n-- With values: " ++ show (D.toList vals))
execute conn (fromString cmdString) (D.toList vals)

-- * emulated INSERT returning support
-- * INSERT returning support

-- | Build a 'SqliteInsertReturning' representing inserting the given values
-- into the given table. Use 'runInsertReturningList'
Expand All @@ -377,54 +427,7 @@ runInsertReturningList :: (Beamable table, FromBackendRow Sqlite (table Identity
=> SqlInsert Sqlite table
-> SqliteM [ table Identity ]
runInsertReturningList SqlInsertNoRows = pure []
runInsertReturningList (SqlInsert tblSettings insertStmt_@(SqliteInsertSyntax nm _ _ _)) =
do (logger, conn) <- SqliteM ask
SqliteM . liftIO $ do

-- We create a pseudo-random savepoint identification that can be referenced
-- throughout this operation. -- This used to be based on the process ID
-- (e.g. `System.Posix.Process.getProcessID` for UNIX),
-- but using timestamps is more portable; see #738
--
-- Note that `hash` can return negative numbers, hence the use of `abs`.
savepointId <- fromString . show . abs . hash <$> getCurrentTime

let tableNameTxt = T.decodeUtf8 (BL.toStrict (sqliteRenderSyntaxScript (fromSqliteTableName nm)))

startSavepoint =
execute_ conn (Query ("SAVEPOINT insert_savepoint_" <> savepointId))
rollbackToSavepoint =
execute_ conn (Query ("ROLLBACK TRANSACTION TO SAVEPOINT insert_savepoint_" <> savepointId))
releaseSavepoint =
execute_ conn (Query ("RELEASE SAVEPOINT insert_savepoint_" <> savepointId))

createInsertedValuesTable =
execute_ conn (Query ("CREATE TEMPORARY TABLE inserted_values_" <> savepointId <> " AS SELECT * FROM " <> tableNameTxt <> " LIMIT 0"))
dropInsertedValuesTable =
execute_ conn (Query ("DROP TABLE inserted_values_" <> savepointId))

createInsertTrigger =
execute_ conn (Query ("CREATE TEMPORARY TRIGGER insert_trigger_" <> savepointId <> " AFTER INSERT ON " <> tableNameTxt <> " BEGIN " <>
"INSERT INTO inserted_values_" <> savepointId <> " SELECT * FROM " <> tableNameTxt <> " WHERE ROWID=last_insert_rowid(); END" ))
dropInsertTrigger =
execute_ conn (Query ("DROP TRIGGER insert_trigger_" <> savepointId))


mask $ \restore -> do
startSavepoint
flip onException rollbackToSavepoint . restore $ do
x <- bracket_ createInsertedValuesTable dropInsertedValuesTable $
bracket_ createInsertTrigger dropInsertTrigger $ do
runSqliteInsert logger conn insertStmt_

let columns = TL.toStrict $ TL.decodeUtf8 $
sqliteRenderSyntaxScript $ commas $
allBeamValues (\(Columnar' projField) -> quotedIdentifier (_fieldName projField)) $
tblSettings

fmap (\(BeamSqliteRow r) -> r) <$> query_ conn (Query ("SELECT " <> columns <> " FROM inserted_values_" <> savepointId))
releaseSavepoint
return x
runInsertReturningList (SqlInsert _ insertCommand) = runReturningList $ SqliteCommandInsert insertCommand

instance Beam.BeamHasInsertOnConflict Sqlite where
newtype SqlConflictTarget Sqlite table = SqliteConflictTarget
Expand Down
5 changes: 4 additions & 1 deletion beam-sqlite/beam-sqlite.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ library
aeson >=0.11 && <2.3,
attoparsec >=0.13 && <0.15,
transformers-base >=0.4 && <0.5,
direct-sqlite >=2.3.24
-- Minimum version of direct-sqlite that includes a version
-- of SQLite that supports the RETURNING construct
direct-sqlite >=2.3.27
default-language: Haskell2010
default-extensions: ScopedTypeVariables, OverloadedStrings, MultiParamTypeClasses, RankNTypes, FlexibleInstances,
DeriveDataTypeable, DeriveGeneric, StandaloneDeriving, TypeFamilies, GADTs, OverloadedStrings,
Expand Down Expand Up @@ -72,6 +74,7 @@ test-suite beam-sqlite-tests
other-modules:
Database.Beam.Sqlite.Test
Database.Beam.Sqlite.Test.Insert
Database.Beam.Sqlite.Test.InsertOnConflictReturning
Database.Beam.Sqlite.Test.Migrate
Database.Beam.Sqlite.Test.Select
default-language: Haskell2010
Expand Down
2 changes: 1 addition & 1 deletion beam-sqlite/test/Database/Beam/Sqlite/Test/Insert.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ testInsertReturningColumnOrder = testCase "runInsertReturningList with mismatchi
, TestTable 1 "sally" "apple" ((val_ 56 + val_ 109) `div_` 5) currentTimestamp_ (val_ oneUtcTime)
, TestTable 4 "blah" "blah" (-1) currentTimestamp_ (val_ now) ]

let dateJoined = ttDateJoined (head inserted)
let dateJoined : _ = ttDateJoined <$> inserted

expected = [ TestTable 0 "jim" "smith" 19 dateJoined zeroUtcTime
, TestTable 1 "sally" "apple" 33 dateJoined oneUtcTime
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
{-# LANGUAGE DerivingStrategies #-}

{-
This module is based on the bug reproducer from issue #773
https://github.com/haskell-beam/beam/issues/773
-}
module Database.Beam.Sqlite.Test.InsertOnConflictReturning (tests) where

import Data.Int (Int32)
import Data.Text (Text)
import Database.Beam (
Beamable,
Columnar,
Database,
DatabaseSettings,
Generic,
Identity,
SqlEq ((/=.)),
Table (..),
TableEntity,
current_,
defaultDbSettings,
insert,
insertValues,
runInsert,
(<-.),
)
import Database.Beam.Backend.SQL.BeamExtensions (
BeamHasInsertOnConflict (
conflictingFields,
insertOnConflict,
onConflictUpdateSetWhere
),
MonadBeamInsertReturning (runInsertReturningList),
)
import Database.Beam.Sqlite (Sqlite, runBeamSqlite)
import Database.Beam.Sqlite.Test (withTestDb)
import Database.SQLite.Simple (execute_)
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.HUnit (testCase, (@?=))

import Database.Beam
import Database.Beam.Backend.SQL.BeamExtensions (
conflictingFields,
insertOnConflict,
onConflictUpdateSetWhere,
runInsertReturningList,
)
import Database.Beam.Migrate (defaultMigratableDbSettings)
import Database.Beam.Migrate.Simple (CheckedDatabaseSettings, autoMigrate)
import Database.Beam.Sqlite (Sqlite, runBeamSqliteDebug)
import Database.Beam.Sqlite.Migrate (migrationBackend)
import Database.SQLite.Simple (open)

tests :: TestTree
tests =
testGroup
"Insertion on conflict returning tests"
[testInsertOnConflictReturning]

data TestDb f
= TestDb
{ usersTable :: f (TableEntity User)
}
deriving stock (Generic)

deriving anyclass instance Database be TestDb

testDb :: DatabaseSettings be TestDb
testDb = defaultDbSettings

checkedDb :: CheckedDatabaseSettings Sqlite TestDb
checkedDb = defaultMigratableDbSettings

data User f
= User
{ userId :: Columnar f Int32
, userName :: Columnar f Text
}
deriving stock (Generic)

deriving stock instance Show (PrimaryKey User Identity)
deriving stock instance Show (User Identity)

instance Table User where
newtype PrimaryKey User f = UserId (Columnar f Int32)
deriving stock (Generic)
primaryKey = UserId . userId

deriving anyclass instance Beamable User
deriving anyclass instance Beamable (PrimaryKey User)

testInsertOnConflictReturning :: TestTree
testInsertOnConflictReturning = testCase "Check that conflicting values are returned by `runInsertReturningList`" $
withTestDb $ \conn -> do
conflicts <-
runBeamSqlite conn $ do
autoMigrate migrationBackend checkedDb

runInsert $
insert
(usersTable testDb)
( insertValues
[ User{userId = 0, userName = "user0"}
, User{userId = 2, userName = "user2"}
, User{userId = 5, userName = "user5"}
]
)

let newUsers =
[ User{userId = 1, userName = "user1"}
, User{userId = 2, userName = "different_user2"}
]

runInsertReturningList $
insertOnConflict
(usersTable testDb)
(insertValues newUsers)
(conflictingFields userId)
( onConflictUpdateSetWhere
( \(User{userName = fld})
(User{userName = excl}) ->
fld <-. excl
)
( \(User{userName = fld})
(User{userName = excl}) ->
current_ fld /=. excl
)
)

-- Expecting that the conflicting user, User id 2, is also returned
userId <$> conflicts @?= [1, 2]
10 changes: 6 additions & 4 deletions beam-sqlite/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import Test.Tasty

import qualified Database.Beam.Sqlite.Test.Migrate as Migrate
import qualified Database.Beam.Sqlite.Test.Insert as Insert
import qualified Database.Beam.Sqlite.Test.InsertOnConflictReturning as InsertOnConflictReturning
import qualified Database.Beam.Sqlite.Test.Select as Select

main :: IO ()
main = defaultMain $ testGroup "beam-sqlite tests"
[ Migrate.tests
, Select.tests
, Insert.tests
]
[ Migrate.tests
, Select.tests
, Insert.tests
, InsertOnConflictReturning.tests
]
Loading