Permalink
Browse files

Merge pull request #16 from joeyadams/escapeByteaConn

bytea compatibility fix, tests, and error handling fixes
  • Loading branch information...
2 parents a7bee4b + aa84100 commit f9ee35767b7efa75a675822060c3d059c63c9c9f @lpsmith committed May 9, 2012
View
19 postgresql-simple.cabal
@@ -12,7 +12,7 @@ Copyright: (c) 2011 MailRank, Inc.
Category: Database
Build-type: Simple
-Cabal-version: >=1.6
+Cabal-version: >= 1.9.2
Library
hs-source-dirs: src
@@ -37,7 +37,6 @@ Library
Build-depends:
attoparsec >= 0.8.5.3,
base < 5,
- base16-bytestring,
blaze-builder,
blaze-textual,
bytestring >= 0.9,
@@ -64,3 +63,19 @@ source-repository this
type: git
location: http://github.com/lpsmith/postgresql-simple
tag: v0.1.1
+
+test-suite test
+ type: exitcode-stdio-1.0
+
+ hs-source-dirs: test
+ main-is: Main.hs
+ other-modules:
+ Bytea
+
+ build-depends: base
+ , base16-bytestring
+ , bytestring
+ , cryptohash
+ , HUnit
+ , postgresql-simple
+ , text
View
99 src/Database/PostgreSQL/Simple.hs
@@ -113,7 +113,6 @@ import Control.Exception
( Exception, onException, throw, throwIO, finally )
import Control.Monad (foldM)
import Data.ByteString (ByteString)
-import Data.Char(ord)
import Data.Int (Int64)
import qualified Data.IntMap as IntMap
import Data.List (intersperse)
@@ -133,6 +132,8 @@ import Database.PostgreSQL.Simple.Internal as Base
import qualified Database.PostgreSQL.LibPQ as PQ
import Text.Regex.PCRE.Light (compile, caseless, match)
import qualified Data.ByteString.Char8 as B
+import qualified Data.Text as T
+import qualified Data.Text.Encoding as TE
import qualified Data.Vector as V
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State.Strict
@@ -148,15 +149,6 @@ data FormatError = FormatError {
instance Exception FormatError
--- | Exception thrown if 'query' is used to perform an @INSERT@-like
--- operation, or 'execute' is used to perform a @SELECT@-like operation.
-data QueryError = QueryError {
- qeMessage :: String
- , qeQuery :: Query
- } deriving (Eq, Show, Typeable)
-
-instance Exception QueryError
-
-- | Format a query string.
--
-- This function is exposed to help with debugging and logging. Do not
@@ -202,16 +194,29 @@ formatMany conn q@(Query template) qs = do
\([^?]*)$"
[caseless]
-escapeStringConn :: Connection -> ByteString -> IO (Maybe ByteString)
-escapeStringConn conn s = withConnection conn $ \c -> do
- PQ.escapeStringConn c s
+escapeStringConn :: Connection -> ByteString -> IO (Either ByteString ByteString)
+escapeStringConn conn s =
+ withConnection conn $ \c ->
+ PQ.escapeStringConn c s >>= checkError c
+
+escapeByteaConn :: Connection -> ByteString -> IO (Either ByteString ByteString)
+escapeByteaConn conn s =
+ withConnection conn $ \c ->
+ PQ.escapeByteaConn c s >>= checkError c
+
+checkError :: PQ.Connection -> Maybe a -> IO (Either ByteString a)
+checkError c (Just x) = return $ Right x
+checkError c Nothing = Left . maybe "" id <$> PQ.errorMessage c
buildQuery :: Connection -> Query -> ByteString -> [Action] -> IO Builder
buildQuery conn q template xs = zipParams (split template) <$> mapM sub xs
- where quote = inQuotes . fromByteString . maybe undefined id
- sub (Plain b) = pure b
- sub (Escape s) = quote <$> escapeStringConn conn s
- sub (Many ys) = mconcat <$> mapM sub ys
+ where quote = either (\msg -> fmtError (utf8ToString msg) q xs)
+ (inQuotes . fromByteString)
+ utf8ToString = T.unpack . TE.decodeUtf8
+ sub (Plain b) = pure b
+ sub (Escape s) = quote <$> escapeStringConn conn s
+ sub (EscapeBytea s) = quote <$> escapeByteaConn conn s
+ sub (Many ys) = mconcat <$> mapM sub ys
split s = fromByteString h : if B.null t then [] else split (B.tail t)
where (h,t) = B.break (=='?') s
zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps
@@ -231,12 +236,6 @@ execute conn template qs = do
result <- exec conn =<< formatQuery conn template qs
finishExecute conn template result
--- | A version of 'execute' that does not perform query substitution.
-execute_ :: Connection -> Query -> IO Int64
-execute_ conn q@(Query stmt) = do
- result <- exec conn stmt
- finishExecute conn q result
-
-- | Execute a multi-row @INSERT@, @UPDATE@, or other SQL query that is not
-- expected to return results.
--
@@ -249,44 +248,6 @@ executeMany conn q qs = do
result <- exec conn =<< formatMany conn q qs
finishExecute conn q result
-finishExecute :: Connection -> Query -> PQ.Result -> IO Int64
-finishExecute _conn q result = do
- status <- PQ.resultStatus result
- case status of
- PQ.CommandOk -> do
- ncols <- PQ.nfields result
- if ncols /= 0
- then throwIO $ QueryError ("execute resulted in " ++ show ncols ++
- "-column result") q
- else do
- nstr <- PQ.cmdTuples result
- return $ case nstr of
- Nothing -> 0 -- is this appropriate?
- Just str -> toInteger str
- PQ.TuplesOk -> do
- ncols <- PQ.nfields result
- throwIO $ QueryError ("execute resulted in " ++ show ncols ++
- "-column result") q
- PQ.CopyIn -> fail "FIXME: postgresql-simple does not currently support COPY IN"
- PQ.CopyOut -> fail "FIXME: postgresql-simple does not currently support COPY OUT"
- _ -> do
- errormsg <- maybe "" id <$> PQ.resultErrorMessage result
- statusmsg <- PQ.resStatus status
- state <- maybe "" id <$> PQ.resultErrorField result PQ.DiagSqlstate
- throwIO $ SqlError { sqlState = state
- , sqlNativeError = fromEnum status
- , sqlErrorMsg = B.concat [ "execute: ", statusmsg
- , ": ", errormsg ]}
- where
- toInteger str = B.foldl' delta 0 str
- where
- delta acc c =
- if '0' <= c && c <= '9'
- then 10 * acc + fromIntegral (ord c - ord '0')
- else error ("finishExecute: not an int: " ++ B.unpack str)
-
-
-
-- | Perform a @SELECT@ or other SQL query that is expected to return
-- results. All results are retrieved and converted before this
-- function returns.
@@ -467,6 +428,8 @@ finishQuery :: (FromRow r) => Connection -> Query -> PQ.Result -> IO [r]
finishQuery conn q result = do
status <- PQ.resultStatus result
case status of
+ PQ.EmptyQuery ->
+ throwIO $ QueryError "query: Empty query" q
PQ.CommandOk -> do
throwIO $ QueryError "query resulted in a command response" q
PQ.TuplesOk -> do
@@ -494,6 +457,13 @@ finishQuery conn q result = do
Errors [] -> throwIO $ ConversionFailed "" "" "unknown error"
Errors [x] -> throwIO x
Errors xs -> throwIO $ ManyErrors xs
+ PQ.CopyOut ->
+ throwIO $ QueryError "query: COPY TO is not supported" q
+ PQ.CopyIn ->
+ throwIO $ QueryError "query: COPY FROM is not supported" q
+ PQ.BadResponse -> throwResultError "query" result status
+ PQ.NonfatalError -> throwResultError "query" result status
+ PQ.FatalError -> throwResultError "query" result status
ellipsis :: ByteString -> ByteString
ellipsis bs
@@ -611,9 +581,10 @@ fmtError msg q xs = throw FormatError {
, fmtQuery = q
, fmtParams = map twiddle xs
}
- where twiddle (Plain b) = toByteString b
- twiddle (Escape s) = s
- twiddle (Many ys) = B.concat (map twiddle ys)
+ where twiddle (Plain b) = toByteString b
+ twiddle (Escape s) = s
+ twiddle (EscapeBytea s) = s
+ twiddle (Many ys) = B.concat (map twiddle ys)
-- $use
--
View
66 src/Database/PostgreSQL/Simple/Internal.hs
@@ -25,6 +25,10 @@ import Control.Applicative
import Control.Exception
import Control.Concurrent.MVar
import Data.ByteString(ByteString)
+import qualified Data.ByteString as B
+import qualified Data.ByteString.Char8 as B8
+import Data.Char (ord)
+import Data.Int (Int64)
import qualified Data.IntMap as IntMap
import Data.String
import Data.Typeable
@@ -33,6 +37,7 @@ import Database.PostgreSQL.LibPQ(Oid(..))
import qualified Database.PostgreSQL.LibPQ as PQ
import Database.PostgreSQL.Simple.BuiltinTypes (BuiltinType)
import Database.PostgreSQL.Simple.Ok
+import Database.PostgreSQL.Simple.Types (Query(..))
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Reader
import qualified Data.Vector as V
@@ -86,6 +91,15 @@ data SqlError = SqlError {
instance Exception SqlError
+-- | Exception thrown if 'query' is used to perform an @INSERT@-like
+-- operation, or 'execute' is used to perform a @SELECT@-like operation.
+data QueryError = QueryError {
+ qeMessage :: String
+ , qeQuery :: Query
+ } deriving (Eq, Show, Typeable)
+
+instance Exception QueryError
+
data ConnectInfo = ConnectInfo {
connectHost :: String
, connectPort :: Word16
@@ -138,6 +152,8 @@ connectPostgreSQL connstr = do
PQ.ConnectionOk -> do
connectionHandle <- newMVar conn
connectionObjects <- newMVar (IntMap.empty)
+ let wconn = Connection{..}
+ _ <- execute_ wconn "SET standard_conforming_strings TO on"
return Connection{..}
_ -> do
msg <- maybe "connectPostgreSQL error" id <$> PQ.errorMessage conn
@@ -198,6 +214,56 @@ exec conn sql =
Just res -> do
return res
+-- | A version of 'execute' that does not perform query substitution.
+execute_ :: Connection -> Query -> IO Int64
+execute_ conn q@(Query stmt) = do
+ result <- exec conn stmt
+ finishExecute conn q result
+
+finishExecute :: Connection -> Query -> PQ.Result -> IO Int64
+finishExecute _conn q result = do
+ status <- PQ.resultStatus result
+ case status of
+ PQ.EmptyQuery -> throwIO $ QueryError "execute: Empty query" q
+ PQ.CommandOk -> do
+ ncols <- PQ.nfields result
+ if ncols /= 0
+ then throwIO $ QueryError ("execute resulted in " ++ show ncols ++
+ "-column result") q
+ else do
+ nstr <- PQ.cmdTuples result
+ return $ case nstr of
+ Nothing -> 0 -- is this appropriate?
+ Just str -> toInteger str
+ PQ.TuplesOk -> do
+ ncols <- PQ.nfields result
+ throwIO $ QueryError ("execute resulted in " ++ show ncols ++
+ "-column result") q
+ PQ.CopyOut ->
+ throwIO $ QueryError "execute: COPY TO is not supported" q
+ PQ.CopyIn ->
+ throwIO $ QueryError "execute: COPY FROM is not supported" q
+ PQ.BadResponse -> throwResultError "execute" result status
+ PQ.NonfatalError -> throwResultError "execute" result status
+ PQ.FatalError -> throwResultError "execute" result status
+ where
+ toInteger str = B8.foldl' delta 0 str
+ where
+ delta acc c =
+ if '0' <= c && c <= '9'
+ then 10 * acc + fromIntegral (ord c - ord '0')
+ else error ("finishExecute: not an int: " ++ B8.unpack str)
+
+throwResultError :: ByteString -> PQ.Result -> PQ.ExecStatus -> IO a
+throwResultError context result status = do
+ errormsg <- maybe "" id <$> PQ.resultErrorMessage result
+ statusmsg <- PQ.resStatus status
+ state <- maybe "" id <$> PQ.resultErrorField result PQ.DiagSqlstate
+ throwIO $ SqlError { sqlState = state
+ , sqlNativeError = fromEnum status
+ , sqlErrorMsg = B.concat [ context, ": ", statusmsg
+ , ": ", errormsg ]}
+
disconnectedError :: SqlError
disconnectedError = SqlError {
sqlNativeError = -1,
View
33 src/Database/PostgreSQL/Simple/ToField.hs
@@ -21,13 +21,10 @@ module Database.PostgreSQL.Simple.ToField
, inQuotes
) where
-import Blaze.ByteString.Builder (Builder, fromByteString, fromLazyByteString,
- toByteString)
+import Blaze.ByteString.Builder (Builder, fromByteString, toByteString)
import Blaze.ByteString.Builder.Char8 (fromChar)
import Blaze.Text (integral, double, float)
import Data.ByteString (ByteString)
-import qualified Data.ByteString.Base16 as B16
-import qualified Data.ByteString.Base16.Lazy as L16
import Data.Int (Int8, Int16, Int32, Int64)
import Data.List (intersperse)
import Data.Monoid (mappend)
@@ -58,14 +55,18 @@ data Action =
-- ^ Escape and enclose in quotes before substituting. Use for all
-- text-like types, and anything else that may contain unsafe
-- characters when rendered.
+ | EscapeBytea ByteString
+ -- ^ Escape binary data for use as a @bytea@ literal. Include surrounding
+ -- quotes. This is used by the 'Binary' newtype wrapper.
| Many [Action]
-- ^ Concatenate a series of rendering actions.
deriving (Typeable)
instance Show Action where
- show (Plain b) = "Plain " ++ show (toByteString b)
- show (Escape b) = "Escape " ++ show b
- show (Many b) = "Many " ++ show b
+ show (Plain b) = "Plain " ++ show (toByteString b)
+ show (Escape b) = "Escape " ++ show b
+ show (EscapeBytea b) = "EscapeBytea " ++ show b
+ show (Many b) = "Many " ++ show b
-- | A type that may be used as a single parameter to a SQL query.
class ToField a where
@@ -88,16 +89,6 @@ instance (ToField a) => ToField (In [a]) where
(intersperse (Plain (fromChar ',')) . map toField $ xs) ++
[Plain (fromChar ')')]
-instance ToField (Binary SB.ByteString) where
- toField (Binary bs) = Plain $ fromByteString "'\\x" `mappend`
- fromByteString (B16.encode bs) `mappend`
- fromChar '\''
-
-instance ToField (Binary LB.ByteString) where
- toField (Binary bs) = Plain $ fromByteString "'\\x" `mappend`
- fromLazyByteString (L16.encode bs) `mappend`
- fromChar '\''
-
renderNull :: Action
renderNull = Plain (fromByteString "null")
@@ -168,6 +159,14 @@ instance ToField Double where
| otherwise = Plain (double v)
{-# INLINE toField #-}
+instance ToField (Binary SB.ByteString) where
+ toField (Binary bs) = EscapeBytea bs
+ {-# INLINE toField #-}
+
+instance ToField (Binary LB.ByteString) where
+ toField (Binary bs) = (EscapeBytea . SB.concat . LB.toChunks) bs
+ {-# INLINE toField #-}
+
instance ToField SB.ByteString where
toField = Escape
{-# INLINE toField #-}
View
2 src/Database/PostgreSQL/Simple/Types.hs
@@ -107,7 +107,7 @@ newtype Only a = Only {
newtype In a = In a
deriving (Eq, Ord, Read, Show, Typeable, Functor)
--- | Wrap a mostly-binary string to be escaped in hexadecimal.
+-- | Wrap binary data for use as a @bytea@ value.
newtype Binary a = Binary a
deriving (Eq, Ord, Read, Show, Typeable, Functor)
View
36 test/Bytea.hs
@@ -0,0 +1,36 @@
+{-# LANGUAGE OverloadedStrings #-}
+module Bytea where
+
+import Data.ByteString (ByteString)
+import Data.Text (Text)
+import Database.PostgreSQL.Simple
+import Test.HUnit
+
+import qualified Crypto.Hash.MD5 as MD5
+import qualified Data.ByteString as B
+import qualified Data.ByteString.Base16 as Base16
+import qualified Data.Text.Encoding as TE
+
+testBytea :: Connection -> Test
+testBytea conn = TestList
+ [ testStr "empty" []
+ , testStr "\"hello\"" $ map (fromIntegral . fromEnum) ("hello" :: String)
+ , testStr "ascending" [0..255]
+ , testStr "descending" [255,254..0]
+ , testStr "ascending, doubled up" $ doubleUp [0..255]
+ , testStr "descending, doubled up" $ doubleUp [255,254..0]
+ ]
+ where
+ testStr label bytes = TestLabel label $ TestCase $ do
+ let bs = B.pack bytes
+
+ [Only h] <- query conn "SELECT md5(?::bytea)" [Binary bs]
+ assertBool "Haskell -> SQL conversion altered the string" $ md5 bs == h
+
+ [Only (Binary r)] <- query conn "SELECT ?::bytea" [Binary bs]
+ assertBool "SQL -> Haskell conversion altered the string" $ bs == r
+
+ doubleUp = concatMap (\x -> [x, x])
+
+md5 :: ByteString -> Text
+md5 = TE.decodeUtf8 . Base16.encode . MD5.hash
View
23 test/Main.hs
@@ -0,0 +1,23 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE NamedFieldPuns #-}
+
+import Control.Exception (bracket)
+import Control.Monad (when)
+import Database.PostgreSQL.Simple
+import System.Exit (exitFailure)
+import System.IO
+import Test.HUnit
+
+import Bytea
+
+tests :: Connection -> [Test]
+tests conn =
+ [ TestLabel "Bytea" $ testBytea conn
+ ]
+
+main :: IO ()
+main = do
+ mapM_ (`hSetBuffering` LineBuffering) [stdout, stderr]
+ bracket (connectPostgreSQL "") close $ \conn -> do
+ Counts{cases, tried, errors, failures} <- runTestTT $ TestList $ tests conn
+ when (cases /= tried || errors /= 0 || failures /= 0) $ exitFailure
View
15 test/WRITING-TESTS
@@ -0,0 +1,15 @@
+Main.hs is a small wrapper around HUnit that opens a database connection and
+starts the HUnit text-based test controller.
+
+To add a new module to the test suite, do the following:
+
+ * Create a new module that exports a function of type Connection -> Test.
+ 'Test' is basically a tree of 'Assertion's, where 'Assertion' is just a type
+ alias for IO (). See Bytea.hs for reference.
+
+ * Add an entry to 'tests' in Main.hs (along with the corresponding
+ module import) so the test driver will know about the new module.
+
+ * Add the module to postgresql-simple.cabal, under test-suite > other-modules.
+ Otherwise, the module will be left out of the tarball generated by
+ `cabal sdist`, and tests will fail to build when installing from Hackage.

0 comments on commit f9ee357

Please sign in to comment.