Skip to content

Commit

Permalink
Rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
schernichkin committed Apr 21, 2012
1 parent 8e0f089 commit 0799267
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 52 deletions.
82 changes: 41 additions & 41 deletions Data/Conduit/Cereal.hs
@@ -1,42 +1,42 @@
{-# LANGUAGE FlexibleContexts #-}

-- | Turn a 'Get' into a 'Sink' and a 'Put' into a 'Source'

module Data.Conduit.Cereal (GetError, sinkGet, conduitGet, sourcePut) where

import Control.Monad.Error
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Conduit as C
import Data.Conduit.Cereal.Internal
import Data.Conduit.List (sourceList)
import Data.Serialize hiding (get, put)

data GetError = GetError String
deriving (Show, Eq)

instance Error GetError where
noMsg = GetError ""
strMsg = GetError

-- | Run a 'Get' repeatedly on the input stream, producing an output stream of whatever the 'Get' outputs.
conduitGet :: MonadError GetError m => Get output -> C.Conduit BS.ByteString m output
conduitGet = mkConduitGet deserializarionError where
deserializarionError msg _ = pipeError $ strMsg msg

-- | Convert a 'Get' into a 'Sink'. The 'Get' will be streamed bytes until it returns 'Done' or 'Fail'.
--
-- If 'Get' succeed it will return the data read and unconsumed part of the input stream.
-- If the 'Get' fails due to deserialization error or early termination of the input stream it raise an error.
sinkGet :: MonadError GetError m => Get r -> C.Sink BS.ByteString m r
sinkGet = mkSinkGet deserializarionError earlyTermination where
deserializarionError msg _ = pipeError $ strMsg msg
earlyTermination f _ = let Fail msg = f BS.empty in pipeError $ strMsg msg

pipeError :: MonadError e m => e -> C.Pipe i o m r
pipeError e = C.PipeM trow (lift trow) where
trow = throwError e

-- | Convert a 'Put' into a 'Source'. Runs in constant memory.
sourcePut :: Monad m => Put -> C.Source m BS.ByteString
{-# LANGUAGE FlexibleContexts #-}

-- | Turn a 'Get' into a 'Sink' and a 'Put' into a 'Source'

module Data.Conduit.Cereal (GetError, sinkGet, conduitGet, sourcePut) where

import Control.Monad.Error
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Conduit as C
import Data.Conduit.Cereal.Internal
import Data.Conduit.List (sourceList)
import Data.Serialize hiding (get, put)

data GetError = GetError String
deriving (Show, Eq)

instance Error GetError where
noMsg = GetError ""
strMsg = GetError

-- | Run a 'Get' repeatedly on the input stream, producing an output stream of whatever the 'Get' outputs.
conduitGet :: MonadError GetError m => Get output -> C.Conduit BS.ByteString m output
conduitGet = mkConduitGet id errorHandler
where errorHandler msg _ = pipeError $ strMsg msg

-- | Convert a 'Get' into a 'Sink'. The 'Get' will be streamed bytes until it returns 'Done' or 'Fail'.
--
-- If 'Get' succeed it will return the data read and unconsumed part of the input stream.
-- If the 'Get' fails due to deserialization error or early termination of the input stream it raise an error.
sinkGet :: MonadError GetError m => Get r -> C.Sink BS.ByteString m r
sinkGet = mkSinkGet id errorHandler terminationHandler
where errorHandler msg _ = pipeError $ strMsg msg
terminationHandler f _ = let Fail msg = f BS.empty in pipeError $ strMsg msg

pipeError :: MonadError e m => e -> C.Pipe i o m r
pipeError e = C.PipeM throw (lift throw)
where throw = throwError e

-- | Convert a 'Put' into a 'Source'. Runs in constant memory.
sourcePut :: Monad m => Put -> C.Source m BS.ByteString
sourcePut put = sourceList $ LBS.toChunks $ runPutLazy put
63 changes: 63 additions & 0 deletions Data/Conduit/Cereal/Internal.hs
@@ -0,0 +1,63 @@
module Data.Conduit.Cereal.Internal
( ErrorHandler
, ResultMapper
, TerminationHandler

, mkConduitGet
, mkSinkGet
) where

import qualified Data.ByteString as BS
import qualified Data.Conduit as C
import Data.Serialize hiding (get, put)
import Data.Void

type ErrorHandler i o m r = String -> Maybe BS.ByteString -> C.Pipe i o m r

type ResultMapper a b = a -> b

type TerminationHandler i o m r = (BS.ByteString -> Result r) -> Maybe BS.ByteString -> C.Pipe i o m r

mkConduitGet :: Monad m
=> ResultMapper a o
-> ErrorHandler BS.ByteString o m ()
-> Get a
-> C.Conduit BS.ByteString m o
mkConduitGet resultMapper errorHandler get = consume True (runGetPartial get) [] BS.empty
where push f b s | BS.null s = C.NeedInput (push f b) (close b)
| otherwise = consume False f b s
consume initial f b s = case f s of
Fail msg -> errorHandler msg (chunkedStreamToMaybe consumed)
Partial p -> C.NeedInput (push p consumed) (close consumed)
Done a s' -> case initial of
True -> infiniteSequence (resultMapper a)
False -> C.HaveOutput (push (runGetPartial get) [] s') (return ()) (resultMapper a)
where consumed = s : b
infiniteSequence r = C.HaveOutput (infiniteSequence r) (return ()) r

close b = C.Done (chunkedStreamToMaybe b) ()

mkSinkGet :: Monad m
=> ResultMapper a r
-> ErrorHandler BS.ByteString Void m r
-> TerminationHandler BS.ByteString Void m r
-> Get a
-> C.Sink BS.ByteString m r
mkSinkGet resultMapper errorHandler terminationHandler get = consume (runGetPartial get) [] BS.empty
where push f b s
| BS.null s = C.NeedInput (push f b) (close f b)
| otherwise = consume f b s
consume f b s = case f s of
Fail msg -> errorHandler msg (chunkedStreamToMaybe consumed)
Partial p -> C.NeedInput (push p consumed) (close p consumed)
Done r s' -> C.Done (streamToMaybe s') (resultMapper r)
where consumed = s : b
close f = terminationHandler (fmap resultMapper . f) . chunkedStreamToMaybe

chunkedStreamToMaybe :: [BS.ByteString] -> Maybe BS.ByteString
chunkedStreamToMaybe = streamToMaybe . BS.concat . reverse

streamToMaybe :: BS.ByteString -> Maybe BS.ByteString
streamToMaybe s = if BS.null s
then Nothing
else Just s
31 changes: 21 additions & 10 deletions Test/CerealConduit.hs
@@ -1,10 +1,13 @@
{-# LANGUAGE FlexibleContexts, RankNTypes #-}

module Test.CerealConduit where

import Control.Monad.Identity
import Control.Monad.Error
import Test.HUnit
import qualified Data.Conduit as C
import Data.Conduit.Cereal
import Data.Conduit.Cereal.Internal
import Data.Conduit.List as CL
import Data.Serialize
import qualified Data.ByteString as BS
Expand All @@ -23,48 +26,56 @@ twoItemGet = do
sinktest1 :: Test
sinktest1 = TestCase (assertEqual "Handles starting with empty bytestring"
(Right 1)
(runIdentity $ (sourceList [BS.pack [], BS.pack [1]]) C.$$ (sinkGet getWord8)))
((sourceList [BS.pack [], BS.pack [1]]) C.$$ (sinkGet getWord8)))

sinktest2 :: Test
sinktest2 = TestCase (assertEqual "Handles empty bytestring in middle"
(Right [1, 3])
(runIdentity $ (sourceList [BS.pack [1], BS.pack [], BS.pack [3]]) C.$$ (sinkGet (do
((sourceList [BS.pack [1], BS.pack [], BS.pack [3]]) C.$$ (sinkGet (do
x <- getWord8
y <- getWord8
return [x, y]))))

sinktest3 :: Test
sinktest3 = TestCase (assertBool "Handles no data"
(case (runIdentity $ (sourceList []) C.$$ (sinkGet getWord8)) of
(case (sourceList []) C.$$ (sinkGet getWord8) of
Right _ -> False
Left _ -> True))

sinktest4 :: Test
sinktest4 = TestCase (assertEqual "Consumes no data"
(Right ())
(runIdentity $ (sourceList [BS.pack [1]]) C.$$ (sinkGet $ return ())))
((sourceList [BS.pack [1]]) C.$$ (sinkGet $ return ())))

sinktest5 :: Test
sinktest5 = TestCase (assertEqual "Empty list"
(Right ())
(runIdentity $ (sourceList []) C.$$ (sinkGet $ return ())))
((sourceList []) C.$$ (sinkGet $ return ())))

sinktest6 :: Test
sinktest6 = TestCase (assertEqual "Leftover input works"
(Right 1, BS.pack [2, 3, 4, 5])
(runIdentity $ (sourceList [BS.pack [1, 2, 3], BS.pack [4, 5]]) C.$$ (do
(Right (1, BS.pack [2, 3, 4, 5]))
((sourceList [BS.pack [1, 2, 3], BS.pack [4, 5]]) C.$$ (do
output <- sinkGet getWord8
output' <- CL.consume
return (output, BS.concat output'))))

-- Current sink implementation will terminate the pipe in case of error.
-- One may need non-terminating version like one defined below to get access to Leftovers

sinkGetMaybe :: Monad m => Get output -> C.Sink BS.ByteString m (Maybe output)
sinkGetMaybe = mkSinkGet Just errorHandler terminationHandler
where errorHandler msg s = C.Done s Nothing
terminationHandler f s = C.Done s Nothing

sinktest7 :: Test
sinktest7 = TestCase (assertBool "Leftover input with failure works"
(case runIdentity $ do
(sourceList [BS.pack [1, 2]]) C.$$ (do
output <- sinkGet (getWord8 >> fail "" :: Get Word8)
output <- sinkGetMaybe (getWord8 >> fail "" :: Get Word8)
output' <- CL.consume
return (output, BS.concat output')) of
(Left _, bs) -> bs == BS.pack [1, 2]
(Nothing, bs) -> bs == BS.pack [1, 2]
otherwise -> False))

conduittest1 :: Test
Expand Down Expand Up @@ -182,4 +193,4 @@ main = do
counts <- runTestTT hunittests
if errors counts == 0 && failures counts == 0
then exitSuccess
else exitFailure
else exitFailure
3 changes: 2 additions & 1 deletion cereal-conduit.cabal
Expand Up @@ -18,6 +18,7 @@ library
, cereal >= 0.3.1.0
, mtl
, bytestring
, void
exposed-modules: Data.Conduit.Cereal
ghc-options: -Wall

Expand All @@ -34,4 +35,4 @@ Test-Suite test-cereal-conduit

source-repository head
type: git
location: git://github.com/litherum/cereal-conduit.git
location: git://github.com/litherum/cereal-conduit.git

0 comments on commit 0799267

Please sign in to comment.