diff --git a/beam-sqlite/ChangeLog.md b/beam-sqlite/ChangeLog.md index 8e27e8e4..aa2cc239 100644 --- a/beam-sqlite/ChangeLog.md +++ b/beam-sqlite/ChangeLog.md @@ -1,8 +1,17 @@ # Unreleased +## Added features + +* `runInsertReturningList` now uses SQLite's relatively new `RETURNING` clause. + +## Bux fixes + +* Fixed an issue where values inserted with conflicts did not return then when using `runInsertReturningList` (#774) + ## Updated dependencies -* Updated the upper bound on `time` to include `time-1.14` +* Updated the lower bound of `direct-sqlite` to `2.3.27`. +* Updated the upper bound on `time` to include `time-1.14`. # 0.5.4.1 diff --git a/beam-sqlite/Database/Beam/Sqlite/Connection.hs b/beam-sqlite/Database/Beam/Sqlite/Connection.hs index 2307fc5a..be1f0ab5 100644 --- a/beam-sqlite/Database/Beam/Sqlite/Connection.hs +++ b/beam-sqlite/Database/Beam/Sqlite/Connection.hs @@ -11,7 +11,6 @@ module Database.Beam.Sqlite.Connection , runBeamSqlite, runBeamSqliteDebug - -- * Emulated @INSERT RETURNING@ support , insertReturning, runInsertReturningList ) where @@ -39,22 +38,21 @@ import Database.Beam.Schema.Tables ( Beamable , DatabaseEntityDescriptor(..) , TableEntity , TableField(..) - , allBeamValues , changeBeamRep ) import Database.Beam.Sqlite.Syntax import Database.SQLite.Simple ( Connection, ToRow(..), FromRow(..) - , Query(..), SQLData(..), field - , execute, execute_ + , SQLData(..), field + , execute , withStatement, bind, nextRow - , query_, open, close ) + , open, close ) import Database.SQLite.Simple.FromField ( FromField(..), ResultError(..) , returnError, fieldData) import Database.SQLite.Simple.Internal (RowParser(RP), unRP) import Database.SQLite.Simple.Ok (Ok(..)) import Database.SQLite.Simple.Types (Null) -import Control.Exception (SomeException(..), bracket_, onException, mask) +import Control.Exception (SomeException(..)) import Control.Monad (forM_) import Control.Monad.Base (MonadBase) import Control.Monad.Fail (MonadFail(..)) @@ -71,20 +69,19 @@ import Data.ByteString.Builder (toLazyByteString) import qualified Data.ByteString.Char8 as BS import qualified Data.ByteString.Lazy.Char8 as BL import qualified Data.DList as D -import Data.Hashable (hash) import Data.Int +import Data.IORef (newIORef, atomicModifyIORef') import Data.Maybe (mapMaybe) import Data.Proxy (Proxy(..)) import Data.Scientific (Scientific) import Data.String (fromString) import qualified Data.Text as T -import qualified Data.Text.Encoding as T (decodeUtf8) import qualified Data.Text.Lazy as TL -import qualified Data.Text.Lazy.Encoding as TL (decodeUtf8) import Data.Time ( LocalTime, UTCTime, Day - , ZonedTime, utc, utcToLocalTime, getCurrentTime ) + , ZonedTime, utc, utcToLocalTime ) import Data.Typeable (cast) import Data.Word +import GHC.IORef (atomicModifyIORef'_) import GHC.TypeLits import Network.URI @@ -335,11 +332,64 @@ instance MonadBeam Sqlite SqliteM where Nothing -> pure Nothing Just (BeamSqliteRow row) -> pure row runReaderT (runSqliteM (action nextRow')) (logger, conn) - runReturningMany SqliteCommandInsert {} _ = - fail . mconcat $ - [ "runReturningMany{Sqlite}: sqlite does not support returning " - , "rows from an insert, use Database.Beam.Sqlite.insertReturning " - , "for emulation" ] + runReturningMany (SqliteCommandInsert (SqliteInsertSyntax tbl fields vs onConflict)) action + | SqliteInsertExpressions es <- vs, any (any (== SqliteExpressionDefault)) es = + -- SQLite's handling of default values differs from other DBMses because + -- it lacks support for DEFAULT. In order to insert a default value in a column, + -- the column's name should be omitted from the INSERT statement. + -- + -- This is problematic if you insert multiple rows, some of which have defaults; + -- you must use multiple INSERT statements. This is what we do below. + -- + -- However, to respect the 'runReturningMany' interface, be must accumulate the + -- results of all those inserts into an 'IORef [a]', and then feed the results + -- incrementally to 'action'. + SqliteM $ do + (logger, conn) <- ask + resultsRef <- liftIO (newIORef []) + forM_ es $ \row -> do + -- RETURNING is only supported by SQLite 3.35+, which requires direct-sqlite 2.3.27+ + let returningClause = emit " RETURNING " <> commas (map quotedIdentifier fields) + (insertFields, insertRow) = unzip $ filter ((/= SqliteExpressionDefault) . snd) $ zip fields row + SqliteSyntax cmd vals = formatSqliteInsertOnConflict tbl insertFields (SqliteInsertExpressions [ insertRow ]) onConflict <> returningClause + cmdString = BL.unpack (toLazyByteString (withPlaceholders cmd)) + + liftIO $ do + logger (cmdString ++ ";\n-- With values: " ++ show (D.toList vals)) + withStatement conn (fromString cmdString) $ \stmt -> + do bind stmt (BeamSqliteParams (D.toList vals)) + unfoldM (nextRow stmt) >>= \new -> atomicModifyIORef'_ resultsRef (new ++) + + -- We must reverse the list in the IORef because it has been constructed in reverse + -- order. We construct the list in reverse because it's faster to prepend to + -- a linked list + _ <- liftIO (atomicModifyIORef'_ resultsRef reverse) + let nextRow' = liftIO $ do + atomicModifyIORef' resultsRef $ \results -> case results of + (BeamSqliteRow h:rest) -> (rest, Just h) + [] -> ([], Nothing) + runSqliteM (action nextRow') + | otherwise = + SqliteM $ do + (logger, conn) <- ask + let returningClause = emit " RETURNING " <> commas (map quotedIdentifier fields) + SqliteSyntax cmd vals = formatSqliteInsertOnConflict tbl fields vs onConflict <> returningClause + cmdString = BL.unpack (toLazyByteString (withPlaceholders cmd)) + liftIO $ do + logger (cmdString ++ ";\n-- With values: " ++ show (D.toList vals)) + withStatement conn (fromString cmdString) $ \stmt -> + do bind stmt (BeamSqliteParams (D.toList vals)) + let nextRow' = liftIO (nextRow stmt) >>= \x -> + case x of + Nothing -> pure Nothing + Just (BeamSqliteRow row) -> pure row + runReaderT (runSqliteM (action nextRow')) (logger, conn) + + +unfoldM :: Monad m => m (Maybe a) -> m [a] +unfoldM f = go [] + where + go acc = f >>= maybe (pure acc) (\x -> go (x : acc)) instance Beam.MonadBeamInsertReturning Sqlite SqliteM where runInsertReturningList = runInsertReturningList @@ -361,7 +411,7 @@ runSqliteInsert logger conn (SqliteInsertSyntax tbl fields vs onConflict) logger (cmdString ++ ";\n-- With values: " ++ show (D.toList vals)) execute conn (fromString cmdString) (D.toList vals) --- * emulated INSERT returning support +-- * INSERT returning support -- | Build a 'SqliteInsertReturning' representing inserting the given values -- into the given table. Use 'runInsertReturningList' @@ -377,54 +427,7 @@ runInsertReturningList :: (Beamable table, FromBackendRow Sqlite (table Identity => SqlInsert Sqlite table -> SqliteM [ table Identity ] runInsertReturningList SqlInsertNoRows = pure [] -runInsertReturningList (SqlInsert tblSettings insertStmt_@(SqliteInsertSyntax nm _ _ _)) = - do (logger, conn) <- SqliteM ask - SqliteM . liftIO $ do - - -- We create a pseudo-random savepoint identification that can be referenced - -- throughout this operation. -- This used to be based on the process ID - -- (e.g. `System.Posix.Process.getProcessID` for UNIX), - -- but using timestamps is more portable; see #738 - -- - -- Note that `hash` can return negative numbers, hence the use of `abs`. - savepointId <- fromString . show . abs . hash <$> getCurrentTime - - let tableNameTxt = T.decodeUtf8 (BL.toStrict (sqliteRenderSyntaxScript (fromSqliteTableName nm))) - - startSavepoint = - execute_ conn (Query ("SAVEPOINT insert_savepoint_" <> savepointId)) - rollbackToSavepoint = - execute_ conn (Query ("ROLLBACK TRANSACTION TO SAVEPOINT insert_savepoint_" <> savepointId)) - releaseSavepoint = - execute_ conn (Query ("RELEASE SAVEPOINT insert_savepoint_" <> savepointId)) - - createInsertedValuesTable = - execute_ conn (Query ("CREATE TEMPORARY TABLE inserted_values_" <> savepointId <> " AS SELECT * FROM " <> tableNameTxt <> " LIMIT 0")) - dropInsertedValuesTable = - execute_ conn (Query ("DROP TABLE inserted_values_" <> savepointId)) - - createInsertTrigger = - execute_ conn (Query ("CREATE TEMPORARY TRIGGER insert_trigger_" <> savepointId <> " AFTER INSERT ON " <> tableNameTxt <> " BEGIN " <> - "INSERT INTO inserted_values_" <> savepointId <> " SELECT * FROM " <> tableNameTxt <> " WHERE ROWID=last_insert_rowid(); END" )) - dropInsertTrigger = - execute_ conn (Query ("DROP TRIGGER insert_trigger_" <> savepointId)) - - - mask $ \restore -> do - startSavepoint - flip onException rollbackToSavepoint . restore $ do - x <- bracket_ createInsertedValuesTable dropInsertedValuesTable $ - bracket_ createInsertTrigger dropInsertTrigger $ do - runSqliteInsert logger conn insertStmt_ - - let columns = TL.toStrict $ TL.decodeUtf8 $ - sqliteRenderSyntaxScript $ commas $ - allBeamValues (\(Columnar' projField) -> quotedIdentifier (_fieldName projField)) $ - tblSettings - - fmap (\(BeamSqliteRow r) -> r) <$> query_ conn (Query ("SELECT " <> columns <> " FROM inserted_values_" <> savepointId)) - releaseSavepoint - return x +runInsertReturningList (SqlInsert _ insertCommand) = runReturningList $ SqliteCommandInsert insertCommand instance Beam.BeamHasInsertOnConflict Sqlite where newtype SqlConflictTarget Sqlite table = SqliteConflictTarget diff --git a/beam-sqlite/beam-sqlite.cabal b/beam-sqlite/beam-sqlite.cabal index 448b7195..526bd20b 100644 --- a/beam-sqlite/beam-sqlite.cabal +++ b/beam-sqlite/beam-sqlite.cabal @@ -42,7 +42,9 @@ library aeson >=0.11 && <2.3, attoparsec >=0.13 && <0.15, transformers-base >=0.4 && <0.5, - direct-sqlite >=2.3.24 + -- Minimum version of direct-sqlite that includes a version + -- of SQLite that supports the RETURNING construct + direct-sqlite >=2.3.27 default-language: Haskell2010 default-extensions: ScopedTypeVariables, OverloadedStrings, MultiParamTypeClasses, RankNTypes, FlexibleInstances, DeriveDataTypeable, DeriveGeneric, StandaloneDeriving, TypeFamilies, GADTs, OverloadedStrings, @@ -72,6 +74,7 @@ test-suite beam-sqlite-tests other-modules: Database.Beam.Sqlite.Test Database.Beam.Sqlite.Test.Insert + Database.Beam.Sqlite.Test.InsertOnConflictReturning Database.Beam.Sqlite.Test.Migrate Database.Beam.Sqlite.Test.Select default-language: Haskell2010 diff --git a/beam-sqlite/test/Database/Beam/Sqlite/Test/Insert.hs b/beam-sqlite/test/Database/Beam/Sqlite/Test/Insert.hs index ab8883b9..2f41508a 100644 --- a/beam-sqlite/test/Database/Beam/Sqlite/Test/Insert.hs +++ b/beam-sqlite/test/Database/Beam/Sqlite/Test/Insert.hs @@ -61,7 +61,7 @@ testInsertReturningColumnOrder = testCase "runInsertReturningList with mismatchi , TestTable 1 "sally" "apple" ((val_ 56 + val_ 109) `div_` 5) currentTimestamp_ (val_ oneUtcTime) , TestTable 4 "blah" "blah" (-1) currentTimestamp_ (val_ now) ] - let dateJoined = ttDateJoined (head inserted) + let dateJoined : _ = ttDateJoined <$> inserted expected = [ TestTable 0 "jim" "smith" 19 dateJoined zeroUtcTime , TestTable 1 "sally" "apple" 33 dateJoined oneUtcTime diff --git a/beam-sqlite/test/Database/Beam/Sqlite/Test/InsertOnConflictReturning.hs b/beam-sqlite/test/Database/Beam/Sqlite/Test/InsertOnConflictReturning.hs new file mode 100644 index 00000000..082dde66 --- /dev/null +++ b/beam-sqlite/test/Database/Beam/Sqlite/Test/InsertOnConflictReturning.hs @@ -0,0 +1,132 @@ +{-# LANGUAGE DerivingStrategies #-} + +{- + This module is based on the bug reproducer from issue #773 + https://github.com/haskell-beam/beam/issues/773 +-} +module Database.Beam.Sqlite.Test.InsertOnConflictReturning (tests) where + +import Data.Int (Int32) +import Data.Text (Text) +import Database.Beam ( + Beamable, + Columnar, + Database, + DatabaseSettings, + Generic, + Identity, + SqlEq ((/=.)), + Table (..), + TableEntity, + current_, + defaultDbSettings, + insert, + insertValues, + runInsert, + (<-.), + ) +import Database.Beam.Backend.SQL.BeamExtensions ( + BeamHasInsertOnConflict ( + conflictingFields, + insertOnConflict, + onConflictUpdateSetWhere + ), + MonadBeamInsertReturning (runInsertReturningList), + ) +import Database.Beam.Sqlite (Sqlite, runBeamSqlite) +import Database.Beam.Sqlite.Test (withTestDb) +import Database.SQLite.Simple (execute_) +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.HUnit (testCase, (@?=)) + +import Database.Beam +import Database.Beam.Backend.SQL.BeamExtensions ( + conflictingFields, + insertOnConflict, + onConflictUpdateSetWhere, + runInsertReturningList, + ) +import Database.Beam.Migrate (defaultMigratableDbSettings) +import Database.Beam.Migrate.Simple (CheckedDatabaseSettings, autoMigrate) +import Database.Beam.Sqlite (Sqlite, runBeamSqliteDebug) +import Database.Beam.Sqlite.Migrate (migrationBackend) +import Database.SQLite.Simple (open) + +tests :: TestTree +tests = + testGroup + "Insertion on conflict returning tests" + [testInsertOnConflictReturning] + +data TestDb f + = TestDb + { usersTable :: f (TableEntity User) + } + deriving stock (Generic) + +deriving anyclass instance Database be TestDb + +testDb :: DatabaseSettings be TestDb +testDb = defaultDbSettings + +checkedDb :: CheckedDatabaseSettings Sqlite TestDb +checkedDb = defaultMigratableDbSettings + +data User f + = User + { userId :: Columnar f Int32 + , userName :: Columnar f Text + } + deriving stock (Generic) + +deriving stock instance Show (PrimaryKey User Identity) +deriving stock instance Show (User Identity) + +instance Table User where + newtype PrimaryKey User f = UserId (Columnar f Int32) + deriving stock (Generic) + primaryKey = UserId . userId + +deriving anyclass instance Beamable User +deriving anyclass instance Beamable (PrimaryKey User) + +testInsertOnConflictReturning :: TestTree +testInsertOnConflictReturning = testCase "Check that conflicting values are returned by `runInsertReturningList`" $ + withTestDb $ \conn -> do + conflicts <- + runBeamSqlite conn $ do + autoMigrate migrationBackend checkedDb + + runInsert $ + insert + (usersTable testDb) + ( insertValues + [ User{userId = 0, userName = "user0"} + , User{userId = 2, userName = "user2"} + , User{userId = 5, userName = "user5"} + ] + ) + + let newUsers = + [ User{userId = 1, userName = "user1"} + , User{userId = 2, userName = "different_user2"} + ] + + runInsertReturningList $ + insertOnConflict + (usersTable testDb) + (insertValues newUsers) + (conflictingFields userId) + ( onConflictUpdateSetWhere + ( \(User{userName = fld}) + (User{userName = excl}) -> + fld <-. excl + ) + ( \(User{userName = fld}) + (User{userName = excl}) -> + current_ fld /=. excl + ) + ) + + -- Expecting that the conflicting user, User id 2, is also returned + userId <$> conflicts @?= [1, 2] diff --git a/beam-sqlite/test/Main.hs b/beam-sqlite/test/Main.hs index 13876e54..72f2d200 100644 --- a/beam-sqlite/test/Main.hs +++ b/beam-sqlite/test/Main.hs @@ -4,11 +4,13 @@ import Test.Tasty import qualified Database.Beam.Sqlite.Test.Migrate as Migrate import qualified Database.Beam.Sqlite.Test.Insert as Insert +import qualified Database.Beam.Sqlite.Test.InsertOnConflictReturning as InsertOnConflictReturning import qualified Database.Beam.Sqlite.Test.Select as Select main :: IO () main = defaultMain $ testGroup "beam-sqlite tests" - [ Migrate.tests - , Select.tests - , Insert.tests - ] + [ Migrate.tests + , Select.tests + , Insert.tests + , InsertOnConflictReturning.tests + ]