Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arrays #33

Closed
wants to merge 11 commits into from
2 changes: 2 additions & 0 deletions postgresql-simple.cabal
Expand Up @@ -21,6 +21,7 @@ Library
hs-source-dirs: src
Exposed-modules:
Database.PostgreSQL.Simple
Database.PostgreSQL.Simple.Arrays
Database.PostgreSQL.Simple.BuiltinTypes
Database.PostgreSQL.Simple.FromField
Database.PostgreSQL.Simple.FromRow
Expand Down Expand Up @@ -86,6 +87,7 @@ test-suite test
, OverloadedStrings
, Rank2Types
, RecordWildCards
, PatternGuards

build-depends: base
, base16-bytestring
Expand Down
56 changes: 35 additions & 21 deletions src/Database/PostgreSQL/Simple.hs
Expand Up @@ -4,6 +4,7 @@
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE QuasiQuotes #-}

------------------------------------------------------------------------------
-- |
Expand Down Expand Up @@ -140,6 +141,7 @@ import Database.PostgreSQL.Simple.ToRow (ToRow(..))
import Database.PostgreSQL.Simple.Types
( Binary(..), In(..), Only(..), Query(..), (:.)(..) )
import Database.PostgreSQL.Simple.Internal as Base
import Database.PostgreSQL.Simple.SqlQQ (sql)
import qualified Database.PostgreSQL.LibPQ as PQ
import qualified Data.ByteString.Char8 as B
import qualified Data.Text as T
Expand Down Expand Up @@ -549,19 +551,19 @@ finishQuery conn q result = do
PQ.TuplesOk -> do
ncols <- PQ.nfields result
let unCol (PQ.Col x) = fromIntegral x :: Int
typenames <- V.generateM (unCol ncols)
typeinfos <- V.generateM (unCol ncols)
(\(PQ.Col . fromIntegral -> col) -> do
getTypename conn =<< PQ.ftype result col)
getTypeInfo conn =<< PQ.ftype result col)
nrows <- PQ.ntuples result
ncols <- PQ.nfields result
forM' 0 (nrows-1) $ \row -> do
let rw = Row row typenames result
let rw = Row row typeinfos result
case runStateT (runReaderT (unRP fromRow) rw) 0 of
Ok (val,col) | col == ncols -> return val
| otherwise -> do
vals <- forM' 0 (ncols-1) $ \c -> do
v <- PQ.getvalue result row c
return ( typenames V.! unCol c
return ( typeinfos V.! unCol c
, fmap ellipsis v )
throw (ConversionFailed
(show (unCol ncols) ++ " values: " ++ show vals)
Expand Down Expand Up @@ -963,24 +965,36 @@ fmtError msg q xs = throw FormatError {
-- wrong results. In such cases, write a @newtype@ wrapper and a
-- custom 'Result' instance to handle your encoding.

getTypename :: Connection -> PQ.Oid -> IO ByteString
getTypename conn@Connection{..} oid =
getTypeInfo :: Connection -> PQ.Oid -> IO TypeInfo
getTypeInfo conn@Connection{..} oid =
case oid2typname oid of
Just name -> return name
Just name -> return $! TypeInfo { typ = NamedOid oid name
, typelem = Nothing
}
Nothing -> modifyMVar connectionObjects $ \oidmap -> do
case IntMap.lookup (oid2int oid) oidmap of
Just name -> return (oidmap, name)
Just typeinfo -> return (oidmap, typeinfo)
Nothing -> do
names <- query conn "SELECT typname FROM pg_type WHERE oid=?"
(Only oid)
name <- case names of
[] -> return $ throw SqlError {
sqlNativeError = -1,
sqlErrorMsg = "invalid type oid",
sqlState = ""
}
[Only x] -> return x
_ -> fail "typename query returned more than one result"
-- oid is a primary key, so the query should
-- never return more than one result
return (IntMap.insert (oid2int oid) name oidmap, name)
names <- query conn
[sql| SELECT p.oid, p.typname, c.oid, c.typname
FROM pg_type AS p LEFT OUTER JOIN pg_type AS c
ON c.oid = p.typelem
WHERE p.oid = ?
|] (Only oid)
typinf <- case names of
[] -> return $ throw SqlError {
sqlNativeError = -1,
sqlErrorMsg = "invalid type oid",
sqlState = ""
}
[(pOid, pTypName, mbCOid, mbCTypName)] ->
return $! TypeInfo { typ = NamedOid pOid pTypName
, typelem = do
cOid <- mbCOid
cTypName <- mbCTypName
return $ NamedOid cOid cTypName
}
_ -> fail "typename query returned more than one result"
-- oid is a primary key, so the query should
-- never return more than one result
return (IntMap.insert (oid2int oid) typinf oidmap, typinf)
94 changes: 94 additions & 0 deletions src/Database/PostgreSQL/Simple/Arrays.hs
@@ -0,0 +1,94 @@
{-# LANGUAGE PatternGuards #-}

------------------------------------------------------------------------------
-- |
-- Module: Database.PostgreSQL.Simple.Arrays
-- Copyright: (c) 2012 Leon P Smith
-- License: BSD3
-- Maintainer: Leon P Smith <leon@melding-monads.com>
-- Stability: experimental
-- Portability: portable
--
-- A Postgres array parser and pretty-printer.
------------------------------------------------------------------------------

module Database.PostgreSQL.Simple.Arrays where

import Control.Applicative (Applicative(..), Alternative(..), (<$>))
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Monoid
import Data.Attoparsec.Char8


-- | Parse one of three primitive field formats: array, quoted and plain.
arrayFormat :: Char -> Parser ArrayFormat
arrayFormat delim = Array <$> array delim
<|> Plain <$> plain delim
<|> Quoted <$> quoted

data ArrayFormat = Array [ArrayFormat]
| Plain ByteString
| Quoted ByteString
deriving (Eq, Show, Ord)

array :: Char -> Parser [ArrayFormat]
array delim = char '{' *> option [] (arrays <|> strings) <* char '}'
where
strings = sepBy1 (Quoted <$> quoted <|> Plain <$> plain delim) (char delim)
arrays = sepBy1 (Array <$> array delim) (char ',')
-- NB: Arrays seem to always be delimited by commas.

-- | Recognizes a quoted string.
quoted :: Parser ByteString
quoted = char '"' *> option "" contents <* char '"'
where
esc = char '\\' *> (char '\\' <|> char '"')
unQ = takeWhile1 (notInClass "\"\\")
contents = mconcat <$> many (unQ <|> B.singleton <$> esc)

-- | Recognizes a plain string literal, not containing quotes or brackets and
-- not containing the delimiter character.
plain :: Char -> Parser ByteString
plain delim = takeWhile1 (notInClass (delim:"\"{}"))

-- Mutually recursive 'fmt' and 'delimit' separate out value formatting
-- from the subtleties of delimiting.

-- | Format an array format item, using the delimiter character if the item is
-- itself an array.
fmt :: Char -> ArrayFormat -> ByteString
fmt = fmt' False

-- | Format a list of array format items, inserting the appropriate delimiter
-- between them. When the items are arrays, they will be delimited with
-- commas; otherwise, they are delimited with the passed-in-delimiter.
delimit :: Char -> [ArrayFormat] -> ByteString
delimit _ [] = ""
delimit c [x] = fmt' True c x
delimit c (x:y:z) = fmt' True c x `B.snoc` c' `mappend` delimit c (y:z)
where
c' | Array _ <- x = ','
| otherwise = c

-- | Format an array format item, using the delimiter character if the item is
-- itself an array, optionally applying quoting rules. Creates copies for
-- safety when used in 'FromField' instances.
fmt' :: Bool -> Char -> ArrayFormat -> ByteString
fmt' quoting c x =
case x of
Array items -> '{' `B.cons` delimit c items `B.snoc` '}'
Plain bytes -> B.copy bytes
Quoted q | quoting -> '"' `B.cons` esc q `B.snoc` '"'
| otherwise -> B.copy q
-- NB: The 'snoc' and 'cons' functions always copy.

-- | Escape a string according to Postgres double-quoted string format.
esc :: ByteString -> ByteString
esc = B.concatMap f
where
f '"' = "\\\""
f '\\' = "\\\\"
f c = B.singleton c
-- TODO: Implement easy performance improvements with unfoldr.

21 changes: 20 additions & 1 deletion src/Database/PostgreSQL/Simple/FromField.hs
Expand Up @@ -53,15 +53,19 @@ import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Int (Int16, Int32, Int64)
import Data.List (foldl')
import Data.Maybe (fromMaybe)
import Data.Ratio (Ratio)
import Data.Time ( UTCTime, ZonedTime, LocalTime, Day, TimeOfDay )
import Data.Typeable (Typeable, typeOf)
import Data.Vector (Vector)
import qualified Data.Vector as V
import Data.Word (Word64)
import Database.PostgreSQL.Simple.Internal
import Database.PostgreSQL.Simple.BuiltinTypes
import Database.PostgreSQL.Simple.Ok
import Database.PostgreSQL.Simple.Types (Binary(..), Null(..))
import Database.PostgreSQL.Simple.Time
import Database.PostgreSQL.Simple.Arrays
import qualified Database.PostgreSQL.LibPQ as PQ
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as SB
Expand Down Expand Up @@ -241,6 +245,21 @@ instance (FromField a, FromField b) => FromField (Either a b) where
fromField f dat = (Right <$> fromField f dat)
<|> (Left <$> fromField f dat)

instance (FromField a, Typeable a) => FromField (Vector a) where
fromField f dat = either (returnError ConversionFailed f)
(V.fromList <$>)
(parseOnly (fromArray ',' f) (maybe "" id dat))

fromArray :: (FromField a) => Char -> Field -> Parser (Ok [a])
fromArray delim f = sequence . (parseIt <$>) <$> array delim
where
fElem = f{ typeinfo = TypeInfo tElem Nothing }
tInfo = typeinfo f
tElem = fromMaybe (typ tInfo) (typelem tInfo)
parseIt item = (fromField f' . Just . fmt delim) item
where f' | Array _ <- item = f
| otherwise = fElem

newtype Compat = Compat Word64

mkCompats :: [BuiltinType] -> Compat
Expand Down Expand Up @@ -270,7 +289,7 @@ doFromField :: forall a . (Typeable a)
=> Field -> Compat -> (ByteString -> Ok a)
-> Maybe ByteString -> Ok a
doFromField f types cvt (Just bs)
| Just typ <- oid2builtin (typeOid f)
| Just typ <- oid2builtin (typoid $ typ $ typeinfo f)
, mkCompat typ `compat` types = cvt bs
| otherwise = returnError Incompatible f "types incompatible"
doFromField f _ _ _ = returnError UnexpectedNull f ""
Expand Down
6 changes: 3 additions & 3 deletions src/Database/PostgreSQL/Simple/FromRow.hs
Expand Up @@ -68,13 +68,13 @@ class FromRow a where
fieldWith :: FieldParser a -> RowParser a
fieldWith fieldP = RP $ do
let unCol (PQ.Col x) = fromIntegral x :: Int
Row{..} <- ask
r@Row{..} <- ask
column <- lift get
lift (put (column + 1))
let ncols = nfields rowresult
if (column >= ncols)
then do
let vals = map (\c -> ( typenames ! (unCol c)
let vals = map (\c -> ( typenames r ! (unCol c)
, fmap ellipsis (getvalue rowresult row c) ))
[0..ncols-1]
convertError = ConversionFailed
Expand All @@ -85,7 +85,7 @@ fieldWith fieldP = RP $ do
\convert and number in target type"
lift (lift (Errors [SomeException convertError]))
else do
let typename = typenames ! unCol column
let typeinfo = typeinfos ! unCol column
result = rowresult
field = Field{..}
lift (lift (fieldP field (getvalue result row column)))
Expand Down
23 changes: 18 additions & 5 deletions src/Database/PostgreSQL/Simple/Internal.hs
Expand Up @@ -52,9 +52,20 @@ import System.IO.Unsafe (unsafePerformIO)
data Field = Field {
result :: !PQ.Result
, column :: {-# UNPACK #-} !PQ.Column
, typename :: !ByteString
, typeinfo :: !TypeInfo
}

data NamedOid = NamedOid { typoid :: !PQ.Oid
, typname :: !ByteString
} deriving Show

data TypeInfo = TypeInfo { typ :: !NamedOid
, typelem :: !(Maybe NamedOid)
} deriving Show

typename :: Field -> ByteString
typename = typname . typ . typeinfo

name :: Field -> Maybe ByteString
name Field{..} = unsafePerformIO (PQ.fname result column)

Expand All @@ -71,12 +82,11 @@ format :: Field -> PQ.Format
format Field{..} = unsafePerformIO (PQ.fformat result column)

typeOid :: Field -> PQ.Oid
typeOid Field{..} = unsafePerformIO (PQ.ftype result column)

typeOid = typoid . typ . typeinfo

data Connection = Connection {
connectionHandle :: {-# UNPACK #-} !(MVar PQ.Connection)
, connectionObjects :: {-# UNPACK #-} !(MVar (IntMap.IntMap ByteString))
, connectionObjects :: {-# UNPACK #-} !(MVar (IntMap.IntMap TypeInfo))
}

data SqlType
Expand Down Expand Up @@ -301,10 +311,13 @@ newNullConnection = do

data Row = Row {
row :: {-# UNPACK #-} !PQ.Row
, typenames :: !(V.Vector ByteString)
, typeinfos :: !(V.Vector TypeInfo)
, rowresult :: !PQ.Result
}

typenames :: Row -> V.Vector ByteString
typenames = V.map (typname . typ) . typeinfos

newtype RowParser a = RP { unRP :: ReaderT Row (StateT PQ.Column Ok) a }
deriving ( Functor, Applicative, Alternative, Monad )

Expand Down
10 changes: 10 additions & 0 deletions src/Database/PostgreSQL/Simple/ToField.hs
Expand Up @@ -39,6 +39,8 @@ import qualified Data.ByteString.Lazy as LB
import qualified Data.Text as ST
import qualified Data.Text.Encoding as ST
import qualified Data.Text.Lazy as LT
import Data.Vector (Vector)
import qualified Data.Vector as V
import qualified Database.PostgreSQL.LibPQ as PQ
import Database.PostgreSQL.Simple.Time

Expand Down Expand Up @@ -221,6 +223,14 @@ instance ToField Date where
toField = Plain . inQuotes . dateToBuilder
{-# INLINE toField #-}

instance (ToField a) => ToField (Vector a) where
toField xs = Many $
Plain (fromByteString "ARRAY[") :
(intersperse (Plain (fromChar ',')) . map toField $ V.toList xs) ++
[Plain (fromChar ']')]
-- Because the ARRAY[...] input syntax is being used, it is possible
-- that the use of type-specific separator characters is unnecessary.

-- | Surround a string with single-quote characters: \"@'@\"
--
-- This function /does not/ perform any other escaping.
Expand Down