Skip to content

Commit

Permalink
Use PQescapeByteaConn instead of hex for compatibility with Postgres …
Browse files Browse the repository at this point in the history
…< 9.0
  • Loading branch information
joeyadams committed May 8, 2012
1 parent a7bee4b commit 33432ce
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 29 deletions.
1 change: 0 additions & 1 deletion postgresql-simple.cabal
Expand Up @@ -37,7 +37,6 @@ Library
Build-depends: Build-depends:
attoparsec >= 0.8.5.3, attoparsec >= 0.8.5.3,
base < 5, base < 5,
base16-bytestring,
blaze-builder, blaze-builder,
blaze-textual, blaze-textual,
bytestring >= 0.9, bytestring >= 0.9,
Expand Down
36 changes: 26 additions & 10 deletions src/Database/PostgreSQL/Simple.hs
Expand Up @@ -133,6 +133,8 @@ import Database.PostgreSQL.Simple.Internal as Base
import qualified Database.PostgreSQL.LibPQ as PQ import qualified Database.PostgreSQL.LibPQ as PQ
import Text.Regex.PCRE.Light (compile, caseless, match) import Text.Regex.PCRE.Light (compile, caseless, match)
import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Vector as V import qualified Data.Vector as V
import Control.Monad.Trans.Reader import Control.Monad.Trans.Reader
import Control.Monad.Trans.State.Strict import Control.Monad.Trans.State.Strict
Expand Down Expand Up @@ -202,16 +204,29 @@ formatMany conn q@(Query template) qs = do
\([^?]*)$" \([^?]*)$"
[caseless] [caseless]


escapeStringConn :: Connection -> ByteString -> IO (Maybe ByteString) escapeStringConn :: Connection -> ByteString -> IO (Either ByteString ByteString)
escapeStringConn conn s = withConnection conn $ \c -> do escapeStringConn conn s =
PQ.escapeStringConn c s withConnection conn $ \c ->
PQ.escapeStringConn c s >>= checkError c

escapeByteaConn :: Connection -> ByteString -> IO (Either ByteString ByteString)
escapeByteaConn conn s =
withConnection conn $ \c ->
PQ.escapeByteaConn c s >>= checkError c

checkError :: PQ.Connection -> Maybe a -> IO (Either ByteString a)
checkError c (Just x) = return $ Right x
checkError c Nothing = Left . maybe "" id <$> PQ.errorMessage c


buildQuery :: Connection -> Query -> ByteString -> [Action] -> IO Builder buildQuery :: Connection -> Query -> ByteString -> [Action] -> IO Builder
buildQuery conn q template xs = zipParams (split template) <$> mapM sub xs buildQuery conn q template xs = zipParams (split template) <$> mapM sub xs
where quote = inQuotes . fromByteString . maybe undefined id where quote = either (\msg -> fmtError (utf8ToString msg) q xs)
sub (Plain b) = pure b (inQuotes . fromByteString)
sub (Escape s) = quote <$> escapeStringConn conn s utf8ToString = T.unpack . TE.decodeUtf8
sub (Many ys) = mconcat <$> mapM sub ys sub (Plain b) = pure b
sub (Escape s) = quote <$> escapeStringConn conn s
sub (EscapeBytea s) = quote <$> escapeByteaConn conn s
sub (Many ys) = mconcat <$> mapM sub ys
split s = fromByteString h : if B.null t then [] else split (B.tail t) split s = fromByteString h : if B.null t then [] else split (B.tail t)
where (h,t) = B.break (=='?') s where (h,t) = B.break (=='?') s
zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps
Expand Down Expand Up @@ -611,9 +626,10 @@ fmtError msg q xs = throw FormatError {
, fmtQuery = q , fmtQuery = q
, fmtParams = map twiddle xs , fmtParams = map twiddle xs
} }
where twiddle (Plain b) = toByteString b where twiddle (Plain b) = toByteString b
twiddle (Escape s) = s twiddle (Escape s) = s
twiddle (Many ys) = B.concat (map twiddle ys) twiddle (EscapeBytea s) = s
twiddle (Many ys) = B.concat (map twiddle ys)


-- $use -- $use
-- --
Expand Down
33 changes: 16 additions & 17 deletions src/Database/PostgreSQL/Simple/ToField.hs
Expand Up @@ -21,13 +21,10 @@ module Database.PostgreSQL.Simple.ToField
, inQuotes , inQuotes
) where ) where


import Blaze.ByteString.Builder (Builder, fromByteString, fromLazyByteString, import Blaze.ByteString.Builder (Builder, fromByteString, toByteString)
toByteString)
import Blaze.ByteString.Builder.Char8 (fromChar) import Blaze.ByteString.Builder.Char8 (fromChar)
import Blaze.Text (integral, double, float) import Blaze.Text (integral, double, float)
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import qualified Data.ByteString.Base16 as B16
import qualified Data.ByteString.Base16.Lazy as L16
import Data.Int (Int8, Int16, Int32, Int64) import Data.Int (Int8, Int16, Int32, Int64)
import Data.List (intersperse) import Data.List (intersperse)
import Data.Monoid (mappend) import Data.Monoid (mappend)
Expand Down Expand Up @@ -58,14 +55,18 @@ data Action =
-- ^ Escape and enclose in quotes before substituting. Use for all -- ^ Escape and enclose in quotes before substituting. Use for all
-- text-like types, and anything else that may contain unsafe -- text-like types, and anything else that may contain unsafe
-- characters when rendered. -- characters when rendered.
| EscapeBytea ByteString
-- ^ Escape binary data for use as a @bytea@ literal. Include surrounding
-- quotes. This is used by the 'Binary' newtype wrapper.
| Many [Action] | Many [Action]
-- ^ Concatenate a series of rendering actions. -- ^ Concatenate a series of rendering actions.
deriving (Typeable) deriving (Typeable)


instance Show Action where instance Show Action where
show (Plain b) = "Plain " ++ show (toByteString b) show (Plain b) = "Plain " ++ show (toByteString b)
show (Escape b) = "Escape " ++ show b show (Escape b) = "Escape " ++ show b
show (Many b) = "Many " ++ show b show (EscapeBytea b) = "EscapeBytea " ++ show b
show (Many b) = "Many " ++ show b


-- | A type that may be used as a single parameter to a SQL query. -- | A type that may be used as a single parameter to a SQL query.
class ToField a where class ToField a where
Expand All @@ -88,16 +89,6 @@ instance (ToField a) => ToField (In [a]) where
(intersperse (Plain (fromChar ',')) . map toField $ xs) ++ (intersperse (Plain (fromChar ',')) . map toField $ xs) ++
[Plain (fromChar ')')] [Plain (fromChar ')')]


instance ToField (Binary SB.ByteString) where
toField (Binary bs) = Plain $ fromByteString "'\\x" `mappend`
fromByteString (B16.encode bs) `mappend`
fromChar '\''

instance ToField (Binary LB.ByteString) where
toField (Binary bs) = Plain $ fromByteString "'\\x" `mappend`
fromLazyByteString (L16.encode bs) `mappend`
fromChar '\''

renderNull :: Action renderNull :: Action
renderNull = Plain (fromByteString "null") renderNull = Plain (fromByteString "null")


Expand Down Expand Up @@ -168,6 +159,14 @@ instance ToField Double where
| otherwise = Plain (double v) | otherwise = Plain (double v)
{-# INLINE toField #-} {-# INLINE toField #-}


instance ToField (Binary SB.ByteString) where
toField (Binary bs) = EscapeBytea bs
{-# INLINE toField #-}

instance ToField (Binary LB.ByteString) where
toField (Binary bs) = (EscapeBytea . SB.concat . LB.toChunks) bs
{-# INLINE toField #-}

instance ToField SB.ByteString where instance ToField SB.ByteString where
toField = Escape toField = Escape
{-# INLINE toField #-} {-# INLINE toField #-}
Expand Down
2 changes: 1 addition & 1 deletion src/Database/PostgreSQL/Simple/Types.hs
Expand Up @@ -107,7 +107,7 @@ newtype Only a = Only {
newtype In a = In a newtype In a = In a
deriving (Eq, Ord, Read, Show, Typeable, Functor) deriving (Eq, Ord, Read, Show, Typeable, Functor)


-- | Wrap a mostly-binary string to be escaped in hexadecimal. -- | Wrap binary data for use as a @bytea@ value.
newtype Binary a = Binary a newtype Binary a = Binary a
deriving (Eq, Ord, Read, Show, Typeable, Functor) deriving (Eq, Ord, Read, Show, Typeable, Functor)


Expand Down

0 comments on commit 33432ce

Please sign in to comment.