Skip to content

Commit

Permalink
implement decodeMessageDelimitedH (decoding from a file handle); fix #61
Browse files Browse the repository at this point in the history
 (#324)
  • Loading branch information
ulysses4ever authored and judah committed Jun 14, 2019
1 parent 4e8369d commit 5eb4ea1
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 9 deletions.
11 changes: 11 additions & 0 deletions proto-lens-tests/package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ tests:
- Proto.Any
- Proto.Any_Fields

decode_delimited_test:
main: decode_delimited_test.hs
source-dirs: tests
dependencies:
- proto-lens-protobuf-types
- proto-lens-tests
- temporary
other-modules:
- Proto.DecodeDelimited
- Proto.DecodeDelimited_Fields

service_test:
main: service_test.hs
source-dirs: tests
Expand Down
8 changes: 8 additions & 0 deletions proto-lens-tests/tests/decode_delimited.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
syntax = "proto3";
package test.decode_delimited;

message Foo {
int32 a = 1;
string b = 2;
}

43 changes: 43 additions & 0 deletions proto-lens-tests/tests/decode_delimited_test.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{-# LANGUAGE OverloadedStrings #-}

module Main (main) where

import qualified Data.ByteString as B
import qualified Data.Text as T
import Data.ProtoLens
import Data.ProtoLens.Encoding.Bytes (runBuilder)
import Lens.Family2 ((&), (.~))
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?))

import Data.ProtoLens.TestUtil
import Proto.DecodeDelimited
import Proto.DecodeDelimited_Fields

import System.IO (openBinaryFile, hClose, IOMode(ReadMode))
import System.IO.Temp (withSystemTempFile)

filename_template :: String
filename_template = "test_decode_delimited"

main :: IO ()
main = testMain
[ testCase "buildDelimited/decodeDelimited-short" (testWithMessage foo1)
, testCase "buildDelimited/decodeDelimited-long" (testWithMessage foo2)
]
where
foo1 = defMessage & a .~ 42 & b .~ "hello" :: Foo
foo2 = defMessage
& a .~ 43
& b .~ (T.pack . take 300 . repeat $ 'x') :: Foo

testWithMessage :: (Eq msg, Show msg, Message msg) => msg -> IO ()
testWithMessage msg =
let bs = runBuilder . buildMessageDelimited $ msg
in
withSystemTempFile filename_template $ \fname h -> do
B.hPut h bs
hClose h
h' <- openBinaryFile fname ReadMode
any1 <- decodeMessageDelimitedH h'
Right msg @=? any1
12 changes: 12 additions & 0 deletions proto-lens/src/Data/ProtoLens/Encoding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@ module Data.ProtoLens.Encoding (
-- ** Delimited messages
buildMessageDelimited,
parseMessageDelimited,
decodeMessageDelimitedH,
) where

import System.IO (Handle)

import Data.ProtoLens.Message (Message(..))
import Data.ProtoLens.Encoding.Bytes (Parser, Builder)
import qualified Data.ProtoLens.Encoding.Bytes as Bytes

import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (runExceptT, ExceptT(..))
import qualified Data.ByteString as B
import Data.Semigroup ((<>))

Expand Down Expand Up @@ -47,3 +52,10 @@ parseMessageDelimited = do
len <- Bytes.getVarInt
bytes <- Bytes.getBytes $ fromIntegral len
either fail return $ decodeMessage bytes

-- | Same as @decodeMessage@ but for delimited messages read through a Handle
decodeMessageDelimitedH :: Message msg => Handle -> IO (Either String msg)
decodeMessageDelimitedH h = runExceptT $
Bytes.getVarIntH h >>=
liftIO . B.hGet h . fromIntegral >>=
ExceptT . return . decodeMessage
54 changes: 45 additions & 9 deletions proto-lens/src/Data/ProtoLens/Encoding/Bytes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}

-- | Utility functions for parsing and encoding individual types.
module Data.ProtoLens.Encoding.Bytes(
Expand All @@ -23,6 +24,7 @@ module Data.ProtoLens.Encoding.Bytes(
putBytes,
-- * Integral types
getVarInt,
getVarIntH,
putVarInt,
getFixed32,
getFixed64,
Expand All @@ -45,6 +47,8 @@ module Data.ProtoLens.Encoding.Bytes(
foldMapBuilder,
) where

import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (throwE, ExceptT)
import Data.Bits
import Data.ByteString (ByteString)
import Data.ByteString.Lazy.Builder as Builder
Expand All @@ -53,13 +57,16 @@ import qualified Data.ByteString.Lazy as L
import Data.Int (Int32, Int64)
import Data.Monoid ((<>))
import qualified Data.Vector.Generic as V
import Data.Word (Word32, Word64)
import Data.Word (Word8, Word32, Word64)
import Foreign.Marshal (malloc, free)
import Foreign.Storable (peek)
import System.IO (Handle, hGetBuf)
#if MIN_VERSION_base(4,11,0)
import qualified GHC.Float as Float
#else
import Foreign.Ptr (castPtr)
import Foreign.Marshal.Alloc (alloca)
import Foreign.Storable (Storable, peek, poke)
import Foreign.Storable (Storable, poke)
import System.IO.Unsafe (unsafePerformIO)
#endif

Expand All @@ -76,14 +83,43 @@ putBytes = Builder.byteString
-- VarInts are inherently unsigned; there are different ways of encoding
-- negative numbers for int32/64 and sint32/64.
getVarInt :: Parser Word64
getVarInt = loop 1 0
getVarInt = loopStart 0 1
where
loop !s !n = do
b <- getWord8
let n' = n + s * fromIntegral (b .&. 127)
if (b .&. 128) == 0
then return $! n'
else loop (128*s) n'
loopStart !n !s = getWord8 >>= getVarIntLoopFinish loopStart n s

-- Same as getVarInt but reads from a Handle
getVarIntH :: Handle -> ExceptT String IO Word64
getVarIntH h = do
buf <- liftIO malloc
let loopStart !n !s =
(liftIO $ hGetBuf h buf 1) >>=
\case
1 -> (liftIO $ peek buf) >>=
getVarIntLoopFinish loopStart n s
_ -> throwE "Unexpected end of file"
res <- loopStart 0 1
liftIO $ free buf
return res

getVarIntLoopFinish
:: (Monad m)
=> (Word64 -> Word64 -> m Word64) -- "loop start" callback
-> Word64
-> Word64
-> Word8
-> m Word64
getVarIntLoopFinish ls !n !s !b = do
let n' = decodeVarIntStep n s b
if testMsb b
then ls n' (128*s)
else return $! n'

-- n -- result of previous step; s -- 128^{step index}; b -- step byte
decodeVarIntStep :: Word64 -> Word64 -> Word8 -> Word64
decodeVarIntStep n s b = n + s * fromIntegral (b .&. 127)

testMsb :: Word8 -> Bool
testMsb b = (b .&. 128) /= 0

putVarInt :: Word64 -> Builder
putVarInt n
Expand Down

0 comments on commit 5eb4ea1

Please sign in to comment.