Skip to content

Commit

Permalink
Added transactions to DataSync
Browse files Browse the repository at this point in the history
This adds a new high-level `withTransaction` function and a lower-level `Transaction` object to deal with postgres transactions from DataSync.

The following operations can be used from within a transaction:
- createRecord, createRecords
- updateRecord, updateRecords
- deleteRecord, deleteRecords
- query
  • Loading branch information
mpscholten committed Feb 4, 2022
1 parent 87dbe2c commit dfc09f9
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 41 deletions.
147 changes: 130 additions & 17 deletions IHP/DataSync/Controller.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import qualified IHP.PGListener as PGListener
import IHP.ApplicationContext
import Data.Set (Set)
import qualified Data.Set as Set
import qualified Data.Pool as Pool

instance (
PG.ToField (PrimaryKey (GetTableName CurrentUserRecord))
Expand All @@ -36,7 +37,7 @@ instance (
initialState = DataSyncController

run = do
setState DataSyncReady { subscriptions = HashMap.empty }
setState DataSyncReady { subscriptions = HashMap.empty, transactions = HashMap.empty }

ensureRLSEnabled <- makeCachedEnsureRLSEnabled
installTableChangeTriggers <- ChangeNotifications.makeCachedInstallTableChangeTriggers
Expand All @@ -45,12 +46,12 @@ instance (

let
handleMessage :: DataSyncMessage -> IO ()
handleMessage DataSyncQuery { query, requestId } = do
handleMessage DataSyncQuery { query, requestId, transactionId } = do
ensureRLSEnabled (get #table query)

let (theQuery, theParams) = compileQuery query

result :: [[Field]] <- sqlQueryWithRLS theQuery theParams
result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId theQuery theParams

sendJSON DataSyncResult { result, requestId }

Expand Down Expand Up @@ -131,7 +132,7 @@ instance (
sendJSON DidDeleteDataSubscription { subscriptionId, requestId }
handleMessage CreateRecordMessage { table, record, requestId } = do
handleMessage CreateRecordMessage { table, record, requestId, transactionId } = do
ensureRLSEnabled table
let query = "INSERT INTO ? ? VALUES ? RETURNING *"
Expand All @@ -145,15 +146,15 @@ instance (
let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.In values)
result :: [[Field]] <- sqlQueryWithRLS query params
result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
case result of
[record] -> sendJSON DidCreateRecord { requestId, record }
otherwise -> error "Unexpected result in CreateRecordMessage handler"
pure ()
handleMessage CreateRecordsMessage { table, records, requestId } = do
handleMessage CreateRecordsMessage { table, records, requestId, transactionId } = do
ensureRLSEnabled table
let query = "INSERT INTO ? ? ? RETURNING *"
Expand All @@ -175,13 +176,13 @@ instance (
let params = (PG.Identifier table, PG.In (map PG.Identifier columns), PG.Values [] values)
records :: [[Field]] <- sqlQueryWithRLS query params
records :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId query params
sendJSON DidCreateRecords { requestId, records }
pure ()
handleMessage UpdateRecordMessage { table, id, patch, requestId } = do
handleMessage UpdateRecordMessage { table, id, patch, requestId, transactionId } = do
ensureRLSEnabled table
let columns = patch
Expand All @@ -204,15 +205,15 @@ instance (
<> (join (map (\(key, value) -> [PG.toField key, value]) keyValues))
<> [PG.toField id]
result :: [[Field]] <- sqlQueryWithRLS (PG.Query query) params
result :: [[Field]] <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
case result of
[record] -> sendJSON DidUpdateRecord { requestId, record }
otherwise -> error "Unexpected result in UpdateRecordMessage handler"
pure ()
handleMessage UpdateRecordsMessage { table, ids, patch, requestId } = do
handleMessage UpdateRecordsMessage { table, ids, patch, requestId, transactionId } = do
ensureRLSEnabled table
let columns = patch
Expand All @@ -235,26 +236,63 @@ instance (
<> (join (map (\(key, value) -> [PG.toField key, value]) keyValues))
<> [PG.toField (PG.In ids)]
records <- sqlQueryWithRLS (PG.Query query) params
records <- sqlQueryWithRLSAndTransactionId transactionId (PG.Query query) params
sendJSON DidUpdateRecords { requestId, records }
pure ()
handleMessage DeleteRecordMessage { table, id, requestId } = do
handleMessage DeleteRecordMessage { table, id, requestId, transactionId } = do
ensureRLSEnabled table
sqlExecWithRLS "DELETE FROM ? WHERE id = ?" (PG.Identifier table, id)
sqlExecWithRLSAndTransactionId transactionId "DELETE FROM ? WHERE id = ?" (PG.Identifier table, id)
sendJSON DidDeleteRecord { requestId }
handleMessage DeleteRecordsMessage { table, ids, requestId } = do
handleMessage DeleteRecordsMessage { table, ids, requestId, transactionId } = do
ensureRLSEnabled table
sqlExecWithRLS "DELETE FROM ? WHERE id IN ?" (PG.Identifier table, PG.In ids)
sqlExecWithRLSAndTransactionId transactionId "DELETE FROM ? WHERE id IN ?" (PG.Identifier table, PG.In ids)
sendJSON DidDeleteRecords { requestId }
handleMessage StartTransaction { requestId } = do
ensureBelowTransactionLimit
transactionId <- UUID.nextRandom
(connection, localPool) <- ?modelContext
|> get #connectionPool
|> Pool.takeResource
let transaction = DataSyncTransaction
{ id = transactionId
, connection
, releaseConnection = Pool.putResource localPool connection
}
let globalModelContext = ?modelContext
let ?modelContext = globalModelContext { transactionConnection = Just connection } in sqlExecWithRLS "BEGIN" ()
modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.insert transactionId transaction))
sendJSON DidStartTransaction { requestId, transactionId }
handleMessage RollbackTransaction { requestId, id } = do
sqlExecWithRLSAndTransactionId (Just id) "ROLLBACK" ()
closeTransaction id
sendJSON DidRollbackTransaction { requestId, transactionId = id }
handleMessage CommitTransaction { requestId, id } = do
sqlExecWithRLSAndTransactionId (Just id) "COMMIT" ()
closeTransaction id
sendJSON DidCommitTransaction { requestId, transactionId = id }
forever do
message <- Aeson.eitherDecodeStrict' <$> receiveData @ByteString
Expand Down Expand Up @@ -289,13 +327,15 @@ cleanupAllSubscriptions = do
let pgListener = ?applicationContext |> get #pgListener
case state of
DataSyncReady { subscriptions } -> do
DataSyncReady { subscriptions, transactions } -> do
let channelSubscriptions = subscriptions
|> HashMap.elems
|> map (get #channelSubscription)
forEach channelSubscriptions \channelSubscription -> do
pgListener |> PGListener.unsubscribe channelSubscription
forEach (HashMap.elems transactions) (get #releaseConnection)
pure ()
_ -> pure ()
Expand All @@ -310,8 +350,81 @@ queryFieldNamesToColumnNames sqlQuery = sqlQuery
where
convertOrderByClause OrderByClause { orderByColumn, orderByDirection } = OrderByClause { orderByColumn = cs (fieldNameToColumnName (cs orderByColumn)), orderByDirection }
runInModelContextWithTransaction :: (?state :: IORef DataSyncController, _) => ((?modelContext :: ModelContext) => IO result) -> Maybe UUID -> IO result
runInModelContextWithTransaction function (Just transactionId) = do
let globalModelContext = ?modelContext
DataSyncTransaction { connection } <- findTransactionById transactionId
let
?modelContext = globalModelContext { transactionConnection = Just connection }
in
function
runInModelContextWithTransaction function Nothing = function
findTransactionById :: (?state :: IORef DataSyncController) => UUID -> IO DataSyncTransaction
findTransactionById transactionId = do
transactions <- get #transactions <$> readIORef ?state
case HashMap.lookup transactionId transactions of
Just transaction -> pure transaction
Nothing -> error "No transaction with that id"
closeTransaction transactionId = do
DataSyncTransaction { releaseConnection } <- findTransactionById transactionId
modifyIORef' ?state (\state -> state |> modify #transactions (HashMap.delete transactionId))
releaseConnection
-- | Allow max 10 concurrent transactions per connection to avoid running out of database connections
--
-- Each transaction removes a database connection from the connection pool. If we don't limit the transactions,
-- a single user could take down the application by starting more than 'IHP.FrameworkConfig.DBPoolMaxConnections'
-- concurrent transactions. Then all database connections are removed from the connection pool and further database
-- queries for other users will fail.
--
ensureBelowTransactionLimit :: (?state :: IORef DataSyncController) => IO ()
ensureBelowTransactionLimit = do
transactions <- get #transactions <$> readIORef ?state
let transactionCount = HashMap.size transactions
let maxTransactionsPerConnection = 10
when (transactionCount >= maxTransactionsPerConnection) do
error ("You've reached the transaction limit of " <> tshow maxTransactionsPerConnection <> " transactions")
sqlQueryWithRLSAndTransactionId ::
( ?modelContext :: ModelContext
, PG.ToRow parameters
, ?context :: ControllerContext
, userId ~ Id CurrentUserRecord
, Show (PrimaryKey (GetTableName CurrentUserRecord))
, HasNewSessionUrl CurrentUserRecord
, Typeable CurrentUserRecord
, ?context :: ControllerContext
, HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
, PG.ToField userId
, FromRow result
, ?state :: IORef DataSyncController
) => Maybe UUID -> PG.Query -> parameters -> IO [result]
sqlQueryWithRLSAndTransactionId transactionId theQuery theParams = runInModelContextWithTransaction (sqlQueryWithRLS theQuery theParams) transactionId
sqlExecWithRLSAndTransactionId ::
( ?modelContext :: ModelContext
, PG.ToRow parameters
, ?context :: ControllerContext
, userId ~ Id CurrentUserRecord
, Show (PrimaryKey (GetTableName CurrentUserRecord))
, HasNewSessionUrl CurrentUserRecord
, Typeable CurrentUserRecord
, ?context :: ControllerContext
, HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
, PG.ToField userId
, ?state :: IORef DataSyncController
) => Maybe UUID -> PG.Query -> parameters -> IO Int64
sqlExecWithRLSAndTransactionId transactionId theQuery theParams = runInModelContextWithTransaction (sqlExecWithRLS theQuery theParams) transactionId
$(deriveFromJSON defaultOptions 'DataSyncQuery)
$(deriveToJSON defaultOptions 'DataSyncResult)
instance SetField "subscriptions" DataSyncController (HashMap UUID Subscription) where
setField subscriptions record = record { subscriptions }
setField subscriptions record = record { subscriptions }
instance SetField "transactions" DataSyncController (HashMap UUID DataSyncTransaction) where
setField transactions record = record { transactions }
32 changes: 24 additions & 8 deletions IHP/DataSync/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@ import IHP.QueryBuilder
import IHP.DataSync.DynamicQuery
import Data.HashMap.Strict (HashMap)
import qualified IHP.PGListener as PGListener
import qualified Database.PostgreSQL.Simple as PG

data DataSyncMessage
= DataSyncQuery { query :: !DynamicSQLQuery, requestId :: !Int }
= DataSyncQuery { query :: !DynamicSQLQuery, requestId :: !Int, transactionId :: !(Maybe UUID) }
| CreateDataSubscription { query :: !DynamicSQLQuery, requestId :: !Int }
| DeleteDataSubscription { subscriptionId :: !UUID, requestId :: !Int }
| CreateRecordMessage { table :: !Text, record :: !(HashMap Text Value), requestId :: !Int }
| CreateRecordsMessage { table :: !Text, records :: ![HashMap Text Value], requestId :: !Int }
| UpdateRecordMessage { table :: !Text, id :: !UUID, patch :: !(HashMap Text Value), requestId :: !Int }
| UpdateRecordsMessage { table :: !Text, ids :: ![UUID], patch :: !(HashMap Text Value), requestId :: !Int }
| DeleteRecordMessage { table :: !Text, id :: !UUID, requestId :: !Int }
| DeleteRecordsMessage { table :: !Text, ids :: ![UUID], requestId :: !Int }
| CreateRecordMessage { table :: !Text, record :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) }
| CreateRecordsMessage { table :: !Text, records :: ![HashMap Text Value], requestId :: !Int, transactionId :: !(Maybe UUID) }
| UpdateRecordMessage { table :: !Text, id :: !UUID, patch :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) }
| UpdateRecordsMessage { table :: !Text, ids :: ![UUID], patch :: !(HashMap Text Value), requestId :: !Int, transactionId :: !(Maybe UUID) }
| DeleteRecordMessage { table :: !Text, id :: !UUID, requestId :: !Int, transactionId :: !(Maybe UUID) }
| DeleteRecordsMessage { table :: !Text, ids :: ![UUID], requestId :: !Int, transactionId :: !(Maybe UUID) }
| StartTransaction { requestId :: !Int }
| RollbackTransaction { requestId :: !Int, id :: !UUID }
| CommitTransaction { requestId :: !Int, id :: !UUID }
deriving (Eq, Show)

data DataSyncResponse
Expand All @@ -34,9 +38,21 @@ data DataSyncResponse
| DidUpdateRecords { requestId :: !Int, records :: ![[Field]] } -- ^ Response to 'UpdateRecordsMessage'
| DidDeleteRecord { requestId :: !Int }
| DidDeleteRecords { requestId :: !Int }
| DidStartTransaction { requestId :: !Int, transactionId :: !UUID }
| DidRollbackTransaction { requestId :: !Int, transactionId :: !UUID }
| DidCommitTransaction { requestId :: !Int, transactionId :: !UUID }

data Subscription = Subscription { id :: !UUID, channelSubscription :: !PGListener.Subscription }
data DataSyncTransaction
= DataSyncTransaction
{ id :: !UUID
, connection :: !PG.Connection
, releaseConnection :: IO ()
}

data DataSyncController
= DataSyncController
| DataSyncReady { subscriptions :: !(HashMap UUID Subscription) }
| DataSyncReady
{ subscriptions :: !(HashMap UUID Subscription)
, transactions :: !(HashMap UUID DataSyncTransaction)
}
Loading

0 comments on commit dfc09f9

Please sign in to comment.