diff --git a/postgresql-simple.cabal b/postgresql-simple.cabal index 88e42f08..9a7c6069 100644 --- a/postgresql-simple.cabal +++ b/postgresql-simple.cabal @@ -37,7 +37,6 @@ Library Build-depends: attoparsec >= 0.8.5.3, base < 5, - base16-bytestring, blaze-builder, blaze-textual, bytestring >= 0.9, diff --git a/src/Database/PostgreSQL/Simple.hs b/src/Database/PostgreSQL/Simple.hs index 15fbe1fe..83a083d5 100644 --- a/src/Database/PostgreSQL/Simple.hs +++ b/src/Database/PostgreSQL/Simple.hs @@ -133,6 +133,8 @@ import Database.PostgreSQL.Simple.Internal as Base import qualified Database.PostgreSQL.LibPQ as PQ import Text.Regex.PCRE.Light (compile, caseless, match) import qualified Data.ByteString.Char8 as B +import qualified Data.Text as T +import qualified Data.Text.Encoding as TE import qualified Data.Vector as V import Control.Monad.Trans.Reader import Control.Monad.Trans.State.Strict @@ -202,16 +204,29 @@ formatMany conn q@(Query template) qs = do \([^?]*)$" [caseless] -escapeStringConn :: Connection -> ByteString -> IO (Maybe ByteString) -escapeStringConn conn s = withConnection conn $ \c -> do - PQ.escapeStringConn c s +escapeStringConn :: Connection -> ByteString -> IO (Either ByteString ByteString) +escapeStringConn conn s = + withConnection conn $ \c -> + PQ.escapeStringConn c s >>= checkError c + +escapeByteaConn :: Connection -> ByteString -> IO (Either ByteString ByteString) +escapeByteaConn conn s = + withConnection conn $ \c -> + PQ.escapeByteaConn c s >>= checkError c + +checkError :: PQ.Connection -> Maybe a -> IO (Either ByteString a) +checkError c (Just x) = return $ Right x +checkError c Nothing = Left . maybe "" id <$> PQ.errorMessage c buildQuery :: Connection -> Query -> ByteString -> [Action] -> IO Builder buildQuery conn q template xs = zipParams (split template) <$> mapM sub xs - where quote = inQuotes . fromByteString . maybe undefined id - sub (Plain b) = pure b - sub (Escape s) = quote <$> escapeStringConn conn s - sub (Many ys) = mconcat <$> mapM sub ys + where quote = either (\msg -> fmtError (utf8ToString msg) q xs) + (inQuotes . fromByteString) + utf8ToString = T.unpack . TE.decodeUtf8 + sub (Plain b) = pure b + sub (Escape s) = quote <$> escapeStringConn conn s + sub (EscapeBytea s) = quote <$> escapeByteaConn conn s + sub (Many ys) = mconcat <$> mapM sub ys split s = fromByteString h : if B.null t then [] else split (B.tail t) where (h,t) = B.break (=='?') s zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps @@ -611,9 +626,10 @@ fmtError msg q xs = throw FormatError { , fmtQuery = q , fmtParams = map twiddle xs } - where twiddle (Plain b) = toByteString b - twiddle (Escape s) = s - twiddle (Many ys) = B.concat (map twiddle ys) + where twiddle (Plain b) = toByteString b + twiddle (Escape s) = s + twiddle (EscapeBytea s) = s + twiddle (Many ys) = B.concat (map twiddle ys) -- $use -- diff --git a/src/Database/PostgreSQL/Simple/ToField.hs b/src/Database/PostgreSQL/Simple/ToField.hs index 07c8ed37..159b676c 100644 --- a/src/Database/PostgreSQL/Simple/ToField.hs +++ b/src/Database/PostgreSQL/Simple/ToField.hs @@ -21,13 +21,10 @@ module Database.PostgreSQL.Simple.ToField , inQuotes ) where -import Blaze.ByteString.Builder (Builder, fromByteString, fromLazyByteString, - toByteString) +import Blaze.ByteString.Builder (Builder, fromByteString, toByteString) import Blaze.ByteString.Builder.Char8 (fromChar) import Blaze.Text (integral, double, float) import Data.ByteString (ByteString) -import qualified Data.ByteString.Base16 as B16 -import qualified Data.ByteString.Base16.Lazy as L16 import Data.Int (Int8, Int16, Int32, Int64) import Data.List (intersperse) import Data.Monoid (mappend) @@ -58,14 +55,18 @@ data Action = -- ^ Escape and enclose in quotes before substituting. Use for all -- text-like types, and anything else that may contain unsafe -- characters when rendered. + | EscapeBytea ByteString + -- ^ Escape binary data for use as a @bytea@ literal. Include surrounding + -- quotes. This is used by the 'Binary' newtype wrapper. | Many [Action] -- ^ Concatenate a series of rendering actions. deriving (Typeable) instance Show Action where - show (Plain b) = "Plain " ++ show (toByteString b) - show (Escape b) = "Escape " ++ show b - show (Many b) = "Many " ++ show b + show (Plain b) = "Plain " ++ show (toByteString b) + show (Escape b) = "Escape " ++ show b + show (EscapeBytea b) = "EscapeBytea " ++ show b + show (Many b) = "Many " ++ show b -- | A type that may be used as a single parameter to a SQL query. class ToField a where @@ -88,16 +89,6 @@ instance (ToField a) => ToField (In [a]) where (intersperse (Plain (fromChar ',')) . map toField $ xs) ++ [Plain (fromChar ')')] -instance ToField (Binary SB.ByteString) where - toField (Binary bs) = Plain $ fromByteString "'\\x" `mappend` - fromByteString (B16.encode bs) `mappend` - fromChar '\'' - -instance ToField (Binary LB.ByteString) where - toField (Binary bs) = Plain $ fromByteString "'\\x" `mappend` - fromLazyByteString (L16.encode bs) `mappend` - fromChar '\'' - renderNull :: Action renderNull = Plain (fromByteString "null") @@ -168,6 +159,14 @@ instance ToField Double where | otherwise = Plain (double v) {-# INLINE toField #-} +instance ToField (Binary SB.ByteString) where + toField (Binary bs) = EscapeBytea bs + {-# INLINE toField #-} + +instance ToField (Binary LB.ByteString) where + toField (Binary bs) = (EscapeBytea . SB.concat . LB.toChunks) bs + {-# INLINE toField #-} + instance ToField SB.ByteString where toField = Escape {-# INLINE toField #-} diff --git a/src/Database/PostgreSQL/Simple/Types.hs b/src/Database/PostgreSQL/Simple/Types.hs index 36425c9c..fe015e34 100644 --- a/src/Database/PostgreSQL/Simple/Types.hs +++ b/src/Database/PostgreSQL/Simple/Types.hs @@ -107,7 +107,7 @@ newtype Only a = Only { newtype In a = In a deriving (Eq, Ord, Read, Show, Typeable, Functor) --- | Wrap a mostly-binary string to be escaped in hexadecimal. +-- | Wrap binary data for use as a @bytea@ value. newtype Binary a = Binary a deriving (Eq, Ord, Read, Show, Typeable, Functor)