Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add COPY FROM support #3

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
193 changes: 193 additions & 0 deletions Database/PostgreSQL/LibPQ.hsc
Expand Up @@ -143,6 +143,16 @@ module Database.PostgreSQL.LibPQ
, escapeByteaConn , escapeByteaConn
, unescapeBytea , unescapeBytea


-- * Using COPY FROM
-- $copyfrom
, CopyResult(..)
, putCopyData
, putCopyEnd

-- ** Formatting datums
, formatCopyRow
, putCopyRow

-- * Asynchronous Command Processing -- * Asynchronous Command Processing
-- $asynccommand -- $asynccommand
, sendQuery , sendQuery
Expand Down Expand Up @@ -203,6 +213,9 @@ import Prelude hiding ( print )
import Foreign import Foreign
import Foreign.C.Types import Foreign.C.Types
import Foreign.C.String import Foreign.C.String
#if __GLASGOW_HASKELL__ >= 702
import qualified Foreign.ForeignPtr.Unsafe as Unsafe
#endif
import qualified Foreign.Concurrent as FC import qualified Foreign.Concurrent as FC
import System.Posix.Types ( Fd(..) ) import System.Posix.Types ( Fd(..) )
import Data.List ( foldl' ) import Data.List ( foldl' )
Expand Down Expand Up @@ -290,7 +303,11 @@ newNullConnection = Conn `fmap` newForeignPtr_ nullPtr


-- | Test if a connection is the Null Connection. -- | Test if a connection is the Null Connection.
isNullConnection :: Connection -> Bool isNullConnection :: Connection -> Bool
#if __GLASGOW_HASKELL__ >= 702
isNullConnection (Conn x) = Unsafe.unsafeForeignPtrToPtr x == nullPtr
#else
isNullConnection (Conn x) = unsafeForeignPtrToPtr x == nullPtr isNullConnection (Conn x) = unsafeForeignPtrToPtr x == nullPtr
#endif
{-# INLINE isNullConnection #-} {-# INLINE isNullConnection #-}


-- | If 'connectStart' succeeds, the next stage is to poll libpq so -- | If 'connectStart' succeeds, the next stage is to poll libpq so
Expand Down Expand Up @@ -1499,6 +1516,161 @@ unescapeBytea bs =
return $ Just $ B.fromForeignPtr tofp 0 $ fromIntegral l return $ Just $ B.fromForeignPtr tofp 0 $ fromIntegral l




-- $copyfrom
--
-- This provides support for PostgreSQL's @COPY FROM@ facility. When inserting
-- rows in bulk, @COPY FROM@ is faster than individual @INSERT@ statements for
-- each row.
--
-- For more information, see:
--
-- * <http://www.postgresql.org/docs/current/static/sql-copy.html>
--
-- * <http://www.postgresql.org/docs/current/static/libpq-copy.html>
--
-- The following example illustrates the procedure for using @COPY FROM@ with
-- libpq:
--
-- >-- Put the connection in the COPY_IN state
-- >-- by executing a COPY ... FROM query.
-- >Just result <- exec conn "COPY foo (a, b) FROM stdin"
-- >CopyIn <- resultStatus result
-- >
-- >-- Send a couple lines of COPY data
-- >CopyOk <- putCopyRow conn [Just ("1", Text), Just ("one", Text)]
-- >CopyOk <- putCopyRow conn [Just ("2", Text), Nothing]
-- >
-- >-- Send end-of-data indication
-- >CopyOk <- putCopyEnd conn Nothing
-- >
-- >-- Get the final result status of the copy command
-- >Just result <- getResult conn
-- >CommandOk <- resultStatus result

data CopyResult = CopyOk -- ^ The data was sent.
| CopyError -- ^ An error occurred (use 'errorMessage'
-- to retrieve details).
| CopyWouldBlock -- ^ The data was not sent because the
-- attempt would block (this case is only
-- possible if the connection is in
-- nonblocking mode) Wait for
-- write-ready (e.g. by using
-- 'Control.Concurrent.threadWaitWrite'
-- on the 'socket') and try again.


toCopyResult :: CInt -> CopyResult
toCopyResult n | n < 0 = CopyError
| n == 0 = CopyWouldBlock
| otherwise = CopyOk


-- | Send raw @COPY@ data to the server during the @COPY_IN@ state.
putCopyData :: Connection -> B.ByteString -> IO CopyResult
putCopyData conn bs =
B.unsafeUseAsCStringLen bs $ putCopyCString conn


putCopyCString :: Connection -> CStringLen -> IO CopyResult
putCopyCString conn (str, len) =
fmap toCopyResult $
withConn conn $ \ptr -> c_PQputCopyData ptr str (fromIntegral len)


-- | Send end-of-data indication to the server during the @COPY_IN@ state.
--
-- * @putCopyEnd conn Nothing@ ends the @COPY_IN@ operation successfully.
--
-- * @putCopyEnd conn (Just errormsg)@ forces the @COPY@ to fail, with
-- @errormsg@ used as the error message.
--
-- After 'putCopyEnd' returns 'CopyOk', call 'getResult' to obtain the final
-- result status of the @COPY@ command. Then return to normal operation.
putCopyEnd :: Connection -> Maybe B.ByteString -> IO CopyResult
putCopyEnd conn Nothing =
fmap toCopyResult $
withConn conn $ \ptr -> c_PQputCopyEnd ptr nullPtr
putCopyEnd conn (Just errormsg) =
fmap toCopyResult $
B.useAsCString errormsg $ \errormsg_cstr ->
withConn conn $ \ptr -> c_PQputCopyEnd ptr errormsg_cstr


-- | A combination of 'putCopyData' and 'formatCopyRow'. This should be
-- slightly more efficient than:
--
-- >putCopyData conn $ formatCopyRow params
--
-- as it does not allocate an intermediate 'B.ByteString'.
putCopyRow :: Connection -> [Maybe (B.ByteString, Format)] -> IO CopyResult

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the format travel with the value? Shouldn't it be the same for all rows after the COPY FROM statement is executed?

putCopyRow conn params = withFormatCopyRow params $ putCopyCString conn


-- | Escape a row of data for use with a COPY FROM statement.
-- Include a trailing newline at the end.
--
-- This assumes text format (rather than BINARY or CSV) with the default
-- delimiter (tab) and default null string (\\N). A suitable query looks like:
--

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you relying on this assumption?

It seems like this needs to either be (1) a "rosetta stone" of COPY formats (that is, support exactly the formats that postgres supports) and get the necessary information from the caller, e.g. text/binary, delimiter, CSV mode, etc.; or (2) it needs to control the initial query so it can control the format.

To step back: what is the point in generating "COPY TO"-formatted data only to send the output to "COPY FROM"? How much faster is it than, say, a prepared INSERT statement executed again and again? Is there precedent for this kind of thing in other client libraries?

-- >COPY tablename (id, col1, col2) FROM stdin;
formatCopyRow :: [Maybe (B.ByteString, Format)] -> IO B.ByteString
formatCopyRow params = withFormatCopyRow params B.packCStringLen


withFormatCopyRow :: [Maybe (B.ByteString, Format)]
-> (CStringLen -> IO a)
-> IO a
withFormatCopyRow params inner =
let bufsize =
if null params
then 1
else sum $ map paramSize params
in allocaBytes bufsize $ \buf -> do
end <- emitParams buf params
let len = end `minusPtr` buf
if len <= bufsize
then inner (castPtr buf, len)
else error $ "formatCopyRow: Buffer overrun (buffer is "
++ show bufsize ++ " bytes, but "
++ show len ++ " bytes were written into it)"


-- | Compute the maximum number of bytes the escaped datum may take up,
-- including the trailing tab or newline character.
paramSize :: Maybe (B.ByteString, Format) -> Int
paramSize Nothing = 3 -- Length of "\\N\t"
paramSize (Just (s, Text)) = B.length s * 2 + 1
paramSize (Just (s, Binary)) = B.length s * 5 + 1


emitParam :: Ptr CUChar -> Maybe (B.ByteString, Format) -> IO (Ptr CUChar)
emitParam out Nothing = do
pokeElemOff out 0 92 -- '\\'
pokeElemOff out 1 78 -- 'N'
return (out `plusPtr` 2)
emitParam out (Just (s, Text)) =
B.unsafeUseAsCStringLen s $ \(ptr, len) ->
c_escape_copy_text (castPtr ptr) (fromIntegral len) out
emitParam out (Just (s, Binary)) =
B.unsafeUseAsCStringLen s $ \(ptr, len) ->
c_escape_copy_bytea (castPtr ptr) (fromIntegral len) out


emitParams :: Ptr CUChar -> [Maybe (B.ByteString, Format)] -> IO (Ptr CUChar)
emitParams out [] = do
poke out 10 -- newline
return (out `plusPtr` 1)
emitParams out (x:xs) = do
out' <- emitParam out x
if null xs
then do
poke out' 10 -- newline
return (out' `plusPtr` 1)
else do
poke out' 9 -- tab
emitParams (out' `plusPtr` 1) xs


-- $asynccommand -- $asynccommand
-- The 'exec' function is adequate for submitting commands in normal, -- The 'exec' function is adequate for submitting commands in normal,
-- synchronous applications. It has a couple of deficiencies, however, -- synchronous applications. It has a couple of deficiencies, however,
Expand Down Expand Up @@ -2287,6 +2459,12 @@ type PGVerbosity = CInt
foreign import ccall unsafe "libpq-fe.h PQsetErrorVerbosity" foreign import ccall unsafe "libpq-fe.h PQsetErrorVerbosity"
c_PQsetErrorVerbosity :: Ptr PGconn -> PGVerbosity -> IO PGVerbosity c_PQsetErrorVerbosity :: Ptr PGconn -> PGVerbosity -> IO PGVerbosity


foreign import ccall "libpq-fe.h PQputCopyData"
c_PQputCopyData :: Ptr PGconn -> Ptr CChar -> CInt -> IO CInt

foreign import ccall "libpq-fe.h PQputCopyEnd"
c_PQputCopyEnd :: Ptr PGconn -> CString -> IO CInt

foreign import ccall "libpq-fe.h PQsendQuery" foreign import ccall "libpq-fe.h PQsendQuery"
c_PQsendQuery :: Ptr PGconn -> CString ->IO CInt c_PQsendQuery :: Ptr PGconn -> CString ->IO CInt


Expand Down Expand Up @@ -2496,3 +2674,18 @@ foreign import ccall "libpq-fs.h lo_close"


foreign import ccall "libpq-fs.h lo_unlink" foreign import ccall "libpq-fs.h lo_unlink"
c_lo_unlink :: Ptr PGconn -> Oid -> IO CInt c_lo_unlink :: Ptr PGconn -> Oid -> IO CInt

------------------------------------------------------------------------
-- cbits imports

foreign import ccall unsafe
c_escape_copy_text :: Ptr CUChar -- ^ const unsigned char *in
-> CInt -- ^ int in_size
-> Ptr CUChar -- ^ unsigned char *out
-> IO (Ptr CUChar) -- ^ Returns pointer to end of written data

foreign import ccall unsafe
c_escape_copy_bytea :: Ptr CUChar -- ^ const unsigned char *in
-> CInt -- ^ int in_size
-> Ptr CUChar -- ^ unsigned char *out
-> IO (Ptr CUChar) -- ^ Returns pointer to end of written data
94 changes: 94 additions & 0 deletions cbits/escape-copy.c
@@ -0,0 +1,94 @@
/*
* Escape a datum for COPY FROM. The buffer pointed to by @out should be
* at least 2*in_size bytes long.
*
* Return a pointer to the end of the bytes emitted.
*/
unsigned char *c_escape_copy_text(const unsigned char *in, int in_size, unsigned char *out)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be written in C? Is this code copied from somewhere?

{
while (in_size-- > 0) {
unsigned char c = *in++;

switch (c) {
case '\t':
*out++ = '\\';
*out++ = 't';
break;
case '\n':
*out++ = '\\';
*out++ = 'n';
break;
case '\r':
*out++ = '\\';
*out++ = 'r';
break;
case '\\':
*out++ = '\\';
*out++ = '\\';
break;

default:
*out++ = c;
}
}

return out;
}

/*
* Like c_escape_copy_text, but escape the datum so it will be suitable for
* PostgreSQL's BYTEA input function. Note that this does not use the hex
* format introduced by PostgreSQL 9.0, as it is readable only by
* PostgreSQL 9.0 and up.
*
* This performs two escape operations:
*
* * Convert raw binary data to the format accepted by PostgreSQL's BYTEA
* input function.
*
* * Escape the result for use in COPY FROM data.
*
* The buffer pointed to by @out should be at least 5*in_size bytes long.
*/
unsigned char *c_escape_copy_bytea(const unsigned char *in, int in_size, unsigned char *out)
{
while (in_size-- > 0) {
unsigned char c = *in++;

if (c == '\\') {
/* Escape backslash twice, once for BYTEA, and again for COPY FROM. */
*out++ = '\\';
*out++ = '\\';
*out++ = '\\';
*out++ = '\\';
} else if (c >= 32 && c <= 126) {
/*
* Printable characters (except backslash) are subject to neither
* BYTEA escaping nor COPY FROM escaping.
*/
*out++ = c;
} else {
/*
* Escape using octal format. This consists of two backslashes
* (single backslash, escaped for COPY FROM) followed by three
* digits [0-7].
*
* We can't use letter escapes \t, \n, \r because:
*
* * The BYTEA input function doesn't understand letter escapes.
*
* * We could use only one backslash so BYTEA sees the literal
* octet values of 9, 10, and 13. However, we're escaping other
* non-printable characters for BYTEA; why give 9, 10, and 13
* special treatment?
*/
*out++ = '\\';
*out++ = '\\';
*out++ = '0' + ((c >> 6) & 0x7);
*out++ = '0' + ((c >> 3) & 0x7);
*out++ = '0' + (c & 0x7);
}
}

return out;
}
11 changes: 10 additions & 1 deletion postgresql-libpq.cabal
Expand Up @@ -18,11 +18,20 @@ Copyright: (c) 2010 Grant Monroe
(c) 2011 Leon P Smith (c) 2011 Leon P Smith
Category: Database Category: Database
Build-type: Custom Build-type: Custom
-- Extra-source-files:
Cabal-version: >=1.8 Cabal-version: >=1.8

Extra-source-files:
testing/copy-from.hs
testing/copy-from-example.hs
testing/run-copy-from
testing/run-copy-from.expected

Library Library
Exposed-modules: Database.PostgreSQL.LibPQ Exposed-modules: Database.PostgreSQL.LibPQ


C-Sources:
cbits/escape-copy.c

Build-depends: base >= 4 && < 5 Build-depends: base >= 4 && < 5
, bytestring , bytestring


Expand Down
39 changes: 39 additions & 0 deletions testing/copy-from-example.hs
@@ -0,0 +1,39 @@
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
import Control.Monad (forM_)
import Database.PostgreSQL.LibPQ

main :: IO ()
main = do
conn <- connectdb ""

-- Create a temporary table for testing
Just result <- exec conn "CREATE TEMPORARY TABLE foo (a INT, b TEXT)"
CommandOk <- resultStatus result

-- Put the connection in the COPY_IN state
-- by executing a COPY ... FROM query.
Just result <- exec conn "COPY foo (a, b) FROM stdin"
CopyIn <- resultStatus result

-- Send a couple lines of COPY data
CopyOk <- putCopyRow conn [Just ("1", Text), Just ("one", Text)]
CopyOk <- putCopyRow conn [Just ("2", Text), Nothing]

-- Send end-of-data indication
CopyOk <- putCopyEnd conn Nothing

-- Get the final result status of the copy command.
Just result <- getResult conn
CommandOk <- resultStatus result

-- Retrieve the rows and print them
Just result <- exec conn "SELECT * FROM foo"
TuplesOk <- resultStatus result
n <- ntuples result
forM_ [0..n-1] $ \i -> do
c0 <- getvalue result i 0
c1 <- getvalue result i 1
putStrLn $ show c0 ++ "\t" ++ show c1

finish conn