Skip to content

Commit

Permalink
Handle ?/value mismatch properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
bos committed Apr 29, 2011
1 parent 520db84 commit 06eed69
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions Database/MySQL/Simple.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
{-# LANGUAGE DeriveDataTypeable #-}

module Database.MySQL.Simple
(
execute
FormatError(fmtMessage, fmtQuery, fmtParams)
, Only(..)
, execute
, query
, query_
, formatQuery
Expand All @@ -9,31 +13,42 @@ module Database.MySQL.Simple
import Blaze.ByteString.Builder (fromByteString, toByteString)
import Control.Applicative ((<$>), pure)
import Control.DeepSeq (NFData(..))
import Control.Exception (Exception, throw)
import Control.Monad.Fix (fix)
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Monoid (mappend, mempty)
import Data.Monoid (mappend)
import Data.Typeable (Typeable)
import Database.MySQL.Base (Connection)
import Database.MySQL.Simple.Param (Action(..), inQuotes)
import Database.MySQL.Simple.QueryParams (QueryParams(..))
import Database.MySQL.Simple.QueryResults (QueryResults(..))
import Database.MySQL.Simple.Types (Query(..))
import Database.MySQL.Simple.Types (Only(..), Query(..))
import qualified Data.ByteString.Char8 as B
import qualified Database.MySQL.Base as Base

data FormatError = FormatError {
fmtMessage :: String
, fmtQuery :: Query
, fmtParams :: [ByteString]
} deriving (Eq, Show, Typeable)

instance Exception FormatError

formatQuery :: QueryParams q => Connection -> Query -> q -> IO ByteString
formatQuery conn (Query template) qs
| '?' `B.notElem` template = return template
| otherwise =
toByteString . zipParams (split template) <$> mapM sub (renderParams qs)
where sub (Plain b) = pure b
formatQuery conn q@(Query template) qs
| null xs && '?' `B.notElem` template = return template
| otherwise = toByteString . zipParams (split template) <$> mapM sub xs
where xs = renderParams qs
sub (Plain b) = pure b
sub (Escape s) = (inQuotes . fromByteString) <$> Base.escape conn s
split q = fromByteString h : if B.null t then [] else split (B.tail t)
where (h,t) = B.break (=='?') q
split s = fromByteString h : if B.null t then [] else split (B.tail t)
where (h,t) = B.break (=='?') s
zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps
zipParams [] [] = mempty
zipParams [] _ = fmtError "more parameters than '?' characters"
zipParams _ [] = fmtError "more '?' characters than parameters"
zipParams [t] [] = t
zipParams _ _ = fmtError (show (B.count '?' template) ++
" '?' characters, but " ++
show (length xs) ++ " parameters") q xs

execute :: (QueryParams q) => Connection -> Query -> q -> IO Int64
execute conn template qs = do
Expand Down Expand Up @@ -68,5 +83,11 @@ finishQuery conn = do
_ -> let c = convertResults fs row
in rnf c `seq` loop (c:acc)

fmtError :: String -> a
fmtError msg = error $ "Database.MySQL.formatQuery: " ++ msg
fmtError :: String -> Query -> [Action] -> a
fmtError msg q xs = throw FormatError {
fmtMessage = msg
, fmtQuery = q
, fmtParams = map twiddle xs
}
where twiddle (Plain b) = toByteString b
twiddle (Escape s) = s

0 comments on commit 06eed69

Please sign in to comment.