Skip to content
Browse files

First draft of COPY IN/OUT support

  • Loading branch information...
1 parent 93c556f commit 491247f687fcbfd30c8b0f50bf7e4490a8140fba @lpsmith committed Jun 27, 2013
Showing with 173 additions and 5 deletions.
  1. +1 −0 postgresql-simple.cabal
  2. +155 −0 src/Database/PostgreSQL/Simple/Copy.hs
  3. +17 −5 src/Database/PostgreSQL/Simple/Internal.hs
View
1 postgresql-simple.cabal
@@ -23,6 +23,7 @@ Library
Database.PostgreSQL.Simple
Database.PostgreSQL.Simple.Arrays
Database.PostgreSQL.Simple.BuiltinTypes
+ Database.PostgreSQL.Simple.Copy
Database.PostgreSQL.Simple.FromField
Database.PostgreSQL.Simple.FromRow
Database.PostgreSQL.Simple.LargeObjects
View
155 src/Database/PostgreSQL/Simple/Copy.hs
@@ -0,0 +1,155 @@
+{-# LANGUAGE CPP #-}
+
+------------------------------------------------------------------------------
+-- |
+-- Module: Database.PostgreSQL.Simple.Copy
+-- Copyright: (c) 2013 Leon P Smith
+-- License: BSD3
+-- Maintainer: Leon P Smith <leon@melding-monads.com>
+-- Stability: experimental
+--
+-- mid-level support for COPY IN and COPY OUT.
+--
+------------------------------------------------------------------------------
+
+module Database.PostgreSQL.Simple.Copy
+ ( copy
+ , copy_
+ , getCopyData
+ , putCopyData
+ , putCopyEnd
+ , putCopyError
+ ) where
+
+import Control.Applicative
+import Control.Concurrent ( threadWaitRead, threadWaitWrite )
+import Control.Exception ( throwIO )
+import qualified Data.Attoparsec.ByteString.Char8 as P
+import Data.Int(Int64)
+import qualified Data.ByteString.Char8 as B
+import qualified Database.PostgreSQL.LibPQ as PQ
+import Database.PostgreSQL.Simple hiding
+ ( fold, fold_, forEach, forEach_ )
+import Database.PostgreSQL.Simple.Types
+import Database.PostgreSQL.Simple.Internal
+
+copy :: ( ToRow params ) => Connection -> Query -> params -> IO ()
+copy conn template qs = do
+ q <- formatQuery conn template qs
+ doCopy "Database.PostgreSQL.Simple.Copy.copy" conn template q
+
+copy_ :: Connection -> Query -> IO ()
+copy_ conn (Query q) = do
+ doCopy "Database.PostgreSQL.Simple.Copy.copy_" conn (Query q) q
+
+doCopy :: B.ByteString -> Connection -> Query -> B.ByteString -> IO ()
+doCopy funcName conn template q = do
+ result <- exec conn q
+ status <- PQ.resultStatus result
+ let err = throwIO $ QueryError
+ (B.unpack funcName ++ " " ++ show status)
+ template
+ case status of
+ PQ.EmptyQuery -> err
+ PQ.CommandOk -> err
+ PQ.TuplesOk -> err
+ PQ.CopyOut -> return ()
+ PQ.CopyIn -> return ()
+ PQ.BadResponse -> throwResultError funcName result status
+ PQ.NonfatalError -> throwResultError funcName result status
+ PQ.FatalError -> throwResultError funcName result status
+{-# INLINE doCopy #-}
+
+data CopyOutResult
+ = CopyOutRow !B.ByteString -- ^ Data representing exactly one row
+ -- of the result.
+ | CopyOutDone {-# UNPACK #-} !Int64 -- ^ No more rows, and a count of the
+ -- number of rows returned.
+
+getCopyData :: Connection -> IO CopyOutResult
+getCopyData conn = withConnection conn loop
+ where
+ funcName = "Database.PostgreSQL.Simple.Copy.getCopyData"
+ errCmdStatus = B.unpack funcName ++ ": failed to fetch command status"
+ errCmdStatusFmt = B.unpack funcName ++ ": failed to parse command status"
+ loop pqconn = do
+#if defined(mingw32_HOST_OS)
+ row <- PQ.getCopyData pqconn False
+#else
+ row <- PQ.getCopyData pqconn True
+#endif
+ case row of
+ PQ.CopyOutRow rowdata -> return (CopyOutRow rowdata)
+ PQ.CopyOutDone -> do
+ result <- maybe (fail errCmdStatus) return =<< PQ.getResult pqconn
+ cmdStat <- maybe (fail errCmdStatus) return =<< PQ.cmdStatus result
+ let rowCount = P.string "COPY " *> P.decimal
+ case P.parseOnly rowCount cmdStat of
+ Left _ -> fail errCmdStatusFmt
+ Right n -> return (CopyOutDone n)
+#if defined(mingw32_HOST_OS)
+ PQ.CopyOutWouldBlock -> do
+ fail (B.unpack funcName ++ ": the impossible happened")
+#else
+ PQ.CopyOutWouldBlock -> do
+ mfd <- PQ.socket pqconn
+ case mfd of
+ Nothing -> throwIO (fdError funcName)
+ Just fd -> do
+ threadWaitRead fd
+ _ <- PQ.consumeInput pqconn
+ loop pqconn
+#endif
+ PQ.CopyOutError -> do
+ mmsg <- PQ.errorMessage pqconn
+ throwIO SqlError {
+ sqlState = "",
+ sqlExecStatus = FatalError,
+ sqlErrorMsg = maybe "" id mmsg,
+ sqlErrorDetail = "",
+ sqlErrorHint = funcName
+ }
+
+putCopyData :: Connection -> B.ByteString -> IO ()
+putCopyData conn dat =
+ doCopyIn "Database.PostgreSQL.Simple.Copy.putCopyData"
+ (\c -> PQ.putCopyData c dat)
+ conn
+
+putCopyEnd :: Connection -> IO ()
+putCopyEnd conn = do
+ doCopyIn "Database.PostgreSQL.Simple.Copy.putCopyEnd"
+ (\c -> PQ.putCopyEnd c Nothing)
+ conn
+
+putCopyError :: Connection -> B.ByteString -> IO ()
+putCopyError conn err = do
+ doCopyIn "Database.PostgreSQL.Simple.Copy.putCopyError"
+ (\c -> PQ.putCopyEnd c (Just err))
+ conn
+
+doCopyIn :: B.ByteString -> (PQ.Connection -> IO PQ.CopyInResult)
+ -> Connection -> IO ()
+doCopyIn funcName action conn = withConnection conn loop
+ where
+ loop pqconn = do
+ stat <- action pqconn
+ case stat of
+ PQ.CopyInOk -> return ()
+ PQ.CopyInError -> do
+ mmsg <- PQ.errorMessage pqconn
+ throwIO SqlError {
+ sqlState = "",
+ sqlExecStatus = FatalError,
+ sqlErrorMsg = maybe "" id mmsg,
+ sqlErrorDetail = "",
+ sqlErrorHint = funcName
+ }
+ PQ.CopyInWouldBlock -> do
+ mfd <- PQ.socket pqconn
+ case mfd of
+ Nothing -> throwIO (fdError funcName)
+ Just fd -> do
+ threadWaitWrite fd
+ loop pqconn
+{-# INLINE doCopyIn #-}
View
22 src/Database/PostgreSQL/Simple/Internal.hs
@@ -2,14 +2,14 @@
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RecordWildCards #-}
+
------------------------------------------------------------------------------
-- |
-- Module: Database.PostgreSQL.Simple.Internal
-- Copyright: (c) 2011-2012 Leon P Smith
-- License: BSD3
-- Maintainer: Leon P Smith <leon@melding-monads.com>
-- Stability: experimental
--- Portability: portable
--
-- Internal bits. This interface is less stable and can change at any time.
-- In particular this means that while the rest of the postgresql-simple
@@ -44,6 +44,7 @@ import Database.PostgreSQL.Simple.Types (Query(..))
import Database.PostgreSQL.Simple.TypeInfo.Types(TypeInfo)
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Reader
+import GHC.IO.Exception
-- | A Field represents metadata about a particular field
--
@@ -54,8 +55,8 @@ import Control.Monad.Trans.Reader
data Field = Field {
result :: !PQ.Result
, column :: {-# UNPACK #-} !PQ.Column
- , typeOid :: {-# UNPACK #-} !PQ.Oid
- -- ^ This returns the type oid associated with the column. Analogous
+ , typeOid :: {-# UNPACK #-} !PQ.Oid
+ -- ^ This returns the type oid associated with the column. Analogous
-- to libpq's @PQftype@.
}
@@ -315,11 +316,11 @@ instance Alternative Conversion where
case oka of
Ok _ -> return oka
Errors _ -> (oka <|>) <$> runConversion mb conn
-
+
instance Monad Conversion where
return a = Conversion $ \_conn -> return (return a)
m >>= f = Conversion $ \conn -> do
- oka <- runConversion m conn
+ oka <- runConversion m conn
case oka of
Ok a -> runConversion (f a) conn
Errors err -> return (Errors err)
@@ -339,3 +340,14 @@ newTempName Connection{..} = do
!n <- atomicModifyIORef connectionTempNameCounter
(\n -> let !n' = n+1 in (n', n'))
return $! Query $ B8.pack $ "temp" ++ show n
+
+-- FIXME? What error should getNotification and getCopyData throw?
+fdError :: ByteString -> IOError
+fdError funcName = IOError {
+ ioe_handle = Nothing,
+ ioe_type = ResourceVanished,
+ ioe_location = B8.unpack funcName,
+ ioe_description = "failed to fetch file descriptor",
+ ioe_errno = Nothing,
+ ioe_filename = Nothing
+ }

0 comments on commit 491247f

Please sign in to comment.
Something went wrong with that request. Please try again.