Permalink
Browse files

Transactions, execute statement, escaping.

  • Loading branch information...
1 parent 7a26af8 commit 6b1d0fa92338fd53a9c657f78a339cc0e129a583 @chrisdone committed Jun 3, 2011
Showing with 106 additions and 83 deletions.
  1. +69 −31 Database/PostgreSQL/Base.hs
  2. +1 −0 Database/PostgreSQL/Base/Types.hs
  3. +36 −52 Database/PostgreSQL/Simple.hs
View
@@ -34,27 +34,6 @@ import Network
import Prelude
import System.IO hiding (hPutStr)
--- | Escape a string for PostgreSQL.
-escape :: String -> String
-escape ('\'':cs) = '\'' : '\'' : escape cs
-escape (c:cs) = c : escape cs
-escape [] = []
-
--- -- FIXME:
--- insertID :: Connection -> IO Word64
--- insertID _ = return 0
-
--- FIXME:
--- | Turn autocommit on or off.
---
--- By default, PostgreSQL runs with autocommit mode enabled. In this
--- mode, as soon as you modify a table, PostgreSQL stores your
--- modification permanently.
--- autocommit :: Connection -> Bool -> IO ()
--- autocommit conn onOff = withConnection conn $ \ptr ->
--- mysql_autocommit ptr b >>= check "autocommit" conn
--- where b = if onOff then 1 else 0
-
--------------------------------------------------------------------------------
-- Exported values
@@ -100,14 +79,22 @@ connect connectInfo@ConnectInfo{..} = liftIO $ withSocketsDo $ do
withDB :: (MonadCatchIO m,MonadIO m) => ConnectInfo -> (Connection -> m a) -> m a
withDB connectInfo m = E.bracket (liftIO $ connect connectInfo) (liftIO . close) m
+-- | Rollback a transaction.
rollback :: (MonadCatchIO m,MonadIO m) => Connection -> m ()
rollback conn = do
- _ <- query conn (fromString ("ABORT;" :: String))
+ _ <- exec conn (fromString ("ABORT;" :: String))
return ()
+-- | Commit a transaction.
commit :: (MonadCatchIO m,MonadIO m) => Connection -> m ()
commit conn = do
- _ <- query conn (fromString ("COMMIT;" :: String))
+ _ <- exec conn (fromString ("COMMIT;" :: String))
+ return ()
+
+-- | Begin a transaction.
+begin :: (MonadCatchIO m,MonadIO m) => Connection -> m ()
+begin conn = do
+ _ <- exec conn (fromString ("BEGIN;" :: String))
return ()
-- | Close a connection. Can safely be called any number of times.
@@ -125,22 +112,53 @@ query :: MonadIO m
=> Connection -- ^ The connection.
-> ByteString -- ^ The query.
-> m ([Field],[[Maybe ByteString]])
-query conn sql = liftIO $ do
+query conn sql = do
+ result <- execQuery conn sql
+ case result of
+ (_,Just ok) -> return ok
+ _ -> error "query: No results returned."
+
+-- | Run a simple query on a connection.
+execQuery :: MonadIO m
+ => Connection -- ^ The connection.
+ -> ByteString -- ^ The query.
+ -> m (Integer,Maybe ([Field],[[Maybe ByteString]]))
+execQuery conn sql = liftIO $ do
withConnection conn $ \h -> do
types <- readMVar $ connectionObjects conn
Result{..} <- sendQuery types h sql
case resultType of
ErrorResponse -> error "TODO: query.ErrorResponse error"
EmptyQueryResponse -> error "TODO: query.EmptyQueryResponse error"
_ ->
- case resultDesc of
- Just fields -> return (fields,resultRows)
- Nothing -> error "TODO: query.Fields error"
+ let tagCount = fromMaybe 0 resultTagRows
+ in case resultDesc of
+ Just fields -> return $ (tagCount,Just (fields,resultRows))
+ Nothing -> return $ (tagCount,Nothing)
+
+exec :: MonadIO m
+ => Connection
+ -> ByteString
+ -> m Integer
+exec conn sql = do
+ result <- execQuery conn sql
+ case result of
+ (ok,_) -> return ok
-- | PostgreSQL protocol version supported by this library.
protocolVersion :: Int32
protocolVersion = 196608
+-- | Escape a string for PostgreSQL.
+escape :: String -> String
+escape ('\'':cs) = '\'' : '\'' : escape cs
+escape (c:cs) = c : escape cs
+escape [] = []
+
+-- | Escape a string for PostgreSQL.
+escapeBS :: ByteString -> ByteString
+escapeBS = fromString . escape . toString
+
--------------------------------------------------------------------------------
-- Authentication
@@ -185,11 +203,16 @@ objectIds h = do
ErrorResponse -> error "objectIds: ErrorResponse"
_ -> return $ M.fromList $ catMaybes $ flip map resultRows $ \row ->
case map toString $ catMaybes row of
- [typ,objId] -> Just $ (ObjectId (read objId),typ)
- _ -> Nothing
+ [typ,readMay -> Just objId] -> Just $ (ObjectId objId,typ)
+ _ -> Nothing
where q = fromString ("SELECT typname, oid FROM pg_type" :: String)
+readMay :: Read a => String -> Maybe a
+readMay x = case reads x of
+ [(v,"")] -> return v
+ _ -> Nothing
+
--------------------------------------------------------------------------------
-- Queries and commands
@@ -207,7 +230,8 @@ sendQuery types h sql = do
listenPassively -> do
case listenPassively of
EmptyQueryResponse -> setStatus
- CommandComplete -> setStatus
+ CommandComplete -> do setStatus
+ setCommandTag block
ErrorResponse -> do
modify $ \r -> r { resultError = Just block }
setStatus
@@ -218,9 +242,23 @@ sendQuery types h sql = do
continue
- where emptyResponse = Result [] Nothing Nothing [] UnknownMessageType
+ where emptyResponse = Result [] Nothing Nothing [] UnknownMessageType Nothing
listener m = execStateT (fix m) emptyResponse
+-- | CommandComplete returns a ‘tag’ which indicates how many rows were
+-- affected, or returned, as a result of the command.
+-- See http://developer.postgresql.org/pgdocs/postgres/protocol-message-formats.html
+setCommandTag :: MonadState Result m => L.ByteString -> m ()
+setCommandTag block = do
+ modify $ \r -> r { resultTagRows = rows }
+ where rows =
+ case tag block of
+ ["INSERT",_oid,readMay -> Just rows] -> return rows
+ [cmd,readMay -> Just rows] | cmd `elem` cmds -> return rows
+ _ -> Nothing
+ tag = words . concat . map toString . L.toChunks . runGet getString
+ cmds = ["DELETE","UPDATE","SELECT","MOVE","FETCH"]
+
-- | Update the row description of the result.
getRowDesc :: MonadState Result m => Map ObjectId String -> L.ByteString -> m ()
getRowDesc types block =
@@ -45,6 +45,7 @@ data Result =
,resultError :: Maybe L.ByteString
,resultNotices :: [String]
,resultType :: MessageType
+ ,resultTagRows :: Maybe Integer
} deriving Show
-- | An internal message type.
@@ -67,7 +67,7 @@ module Database.PostgreSQL.Simple
-- , forEach
-- , forEach_
-- * Statements that do not return results
- -- , execute
+ , execute
-- , execute_
-- , executeMany
-- , Base.insertID
@@ -84,7 +84,7 @@ module Database.PostgreSQL.Simple
import Blaze.ByteString.Builder (Builder, fromByteString, toByteString)
import Blaze.ByteString.Builder.Char8 (fromChar)
import Control.Applicative ((<$>), pure)
-import Control.Exception (Exception, throw)
+import Control.Exception (Exception, throw, onException)
import Control.Monad (forM)
import Data.ByteString (ByteString)
import Data.List (intersperse)
@@ -158,7 +158,7 @@ formatMany q@(Query template) qs = do
return . toByteString . mconcat $ fromByteString before :
intersperse (fromChar ',') bs ++
[fromByteString after]
- _ -> error "foo" -- FIXME:
+ _ -> error "formatMany: The query did not match the documented format."
where
re = compile "^([^?]+\\bvalues\\s*)\
\(\\(\\s*[?](?:\\s*,\\s*[?])*\\s*\\))\
@@ -168,7 +168,7 @@ formatMany q@(Query template) qs = do
buildQuery :: Query -> ByteString -> [Action] -> IO Builder
buildQuery q template xs = zipParams (split template) <$> mapM sub xs
where sub (Plain b) = pure b
- sub (Escape s) = pure $ (inQuotes . fromByteString . escape) s
+ sub (Escape s) = pure $ (inQuotes . fromByteString . Base.escapeBS) s
sub (Many ys) = mconcat <$> mapM sub ys
split s = fromByteString h : if B.null t then [] else split (B.tail t)
where (h,t) = B.break (=='?') s
@@ -177,48 +177,32 @@ buildQuery q template xs = zipParams (split template) <$> mapM sub xs
zipParams _ _ = fmtError (show (B.count '?' template) ++
" '?' characters, but " ++
show (length xs) ++ " parameters") q xs
--- FIXME:
-escape :: ByteString -> ByteString
-escape = id
-
--- FIXME:
--- -- | Execute an @INSERT@, @UPDATE@, or other SQL query that is not
--- -- expected to return results.
--- --
--- -- Returns the number of rows affected.
--- --
--- -- Throws 'FormatError' if the query could not be formatted correctly.
--- execute :: (QueryParams q) => Connection -> Query -> q -> IO Int64
--- execute conn template qs = do
--- Base.query conn =<< formatQuery template qs
--- finishExecute conn template
-
--- -- | A version of 'execute' that does not perform query substitution.
--- execute_ :: Connection -> Query -> IO Int64
--- execute_ conn q@(Query stmt) = do
--- Base.query conn stmt
--- finishExecute conn q
-
--- -- | Execute a multi-row @INSERT@, @UPDATE@, or other SQL query that is not
--- -- expected to return results.
--- --
--- -- Returns the number of rows affected.
--- --
--- -- Throws 'FormatError' if the query could not be formatted correctly.
--- executeMany :: (QueryParams q) => Connection -> Query -> [q] -> IO Int64
--- executeMany _ _ [] = return 0
--- executeMany conn q qs = do
--- Base.query conn =<< formatMany q qs
--- finishExecute conn q
-
--- finishExecute :: Connection -> Query -> IO Int64
--- finishExecute conn q = do
--- return 0
- -- ncols <- Base.fieldCount (Left conn)
- -- if ncols /= 0
- -- then throwIO $ QueryError ("execute resulted in " ++ show ncols ++
- -- "-column result") q
- -- else Base.affectedRows conn
+
+-- | Execute an @INSERT@, @UPDATE@, or other SQL query that is not
+-- expected to return results.
+--
+-- Returns the number of rows affected.
+--
+-- Throws 'FormatError' if the query could not be formatted correctly.
+execute :: (QueryParams q) => Connection -> Query -> q -> IO Integer
+execute conn template qs = do
+ Base.exec conn =<< formatQuery template qs
+
+-- | A version of 'execute' that does not perform query substitution.
+execute_ :: Connection -> Query -> IO Integer
+execute_ conn q@(Query stmt) = do
+ Base.exec conn stmt
+
+-- | Execute a multi-row @INSERT@, @UPDATE@, or other SQL query that is not
+-- expected to return results.
+--
+-- Returns the number of rows affected.
+--
+-- Throws 'FormatError' if the query could not be formatted correctly.
+executeMany :: (QueryParams q) => Connection -> Query -> [q] -> IO Integer
+executeMany _ _ [] = return 0
+executeMany conn q qs = do
+ Base.exec conn =<< formatMany q qs
-- | Perform a @SELECT@ or other SQL query that is expected to return
-- results. All results are retrieved and converted before this
@@ -260,12 +244,12 @@ query_ conn (Query q) = do
-- If the action throws /any/ kind of exception (not just a
-- MySQL-related exception), the transaction will be rolled back using
-- 'Base.rollback', then the exception will be rethrown.
--- withTransaction :: Connection -> IO a -> IO a
--- withTransaction conn act = do
--- _ <- execute_ conn "start transaction"
--- r <- act `onException` Base.rollback conn
--- Base.commit conn
--- return r
+withTransaction :: Connection -> IO a -> IO a
+withTransaction conn act = do
+ Base.begin conn
+ r <- act `onException` Base.rollback conn
+ Base.commit conn
+ return r
fmtError :: String -> Query -> [Action] -> a
fmtError msg q xs = throw FormatError {

0 comments on commit 6b1d0fa

Please sign in to comment.