Skip to content

Commit

Permalink
Merge pull request k-bx#65 from factisresearch/wirePutWithSize
Browse files Browse the repository at this point in the history
Fix time complexity of serialization
  • Loading branch information
tvh committed Jun 23, 2018
2 parents 01f9cc8 + d7ef882 commit 5321d12
Show file tree
Hide file tree
Showing 33 changed files with 802 additions and 685 deletions.
7 changes: 6 additions & 1 deletion Text/ProtocolBuffers/Extensions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ module Text.ProtocolBuffers.Extensions
, Key(..),ExtKey(..),MessageAPI(..)
, PackedSeq(..), EP(..)
-- * Internal types, functions, and classes
, wireSizeExtField,wirePutExtField,loadExtension,notExtension
, wireSizeExtField,wirePutExtField,wirePutExtFieldWithSize,loadExtension,notExtension
, wireGetKeyToUnPacked, wireGetKeyToPacked
, GPB,ExtField(..),ExtendMessage(..),ExtFieldValue(..)
) where
Expand Down Expand Up @@ -694,6 +694,11 @@ wirePutExtField (ExtField m) = mapM_ aPut (M.assocs m) where
aPut (fi,(ExtRepeated ft (GPDynSeq s))) = wirePutRep (toWireTag fi ft) ft s
aPut (fi,(ExtPacked ft (GPDynSeq s))) = wirePutPacked (toPackedWireTag fi) ft s

-- FIXME: implement this directly
-- | This is used by the generated code
wirePutExtFieldWithSize :: ExtField -> PutM WireSize
wirePutExtFieldWithSize m = wirePutExtField m >> return (wireSizeExtField m)

notExtension :: (ReflectDescriptor a, ExtendMessage a,Typeable a) => FieldId -> WireType -> a -> Get a
notExtension fieldId _wireType msg = throwError ("Field id "++show fieldId++" is not a valid extension field id for "++show (typeOf (undefined `asTypeOf` msg)))

Expand Down
12 changes: 8 additions & 4 deletions Text/ProtocolBuffers/Header.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import Text.Parsec(choice, sepEndBy, spaces, try)

import Text.ProtocolBuffers.Basic -- all
import Text.ProtocolBuffers.Extensions
( wireSizeExtField,wirePutExtField,loadExtension,notExtension
( wireSizeExtField,wirePutExtField,wirePutExtFieldWithSize
, loadExtension,notExtension
, wireGetKeyToUnPacked, wireGetKeyToPacked
, GPB,Key(..),ExtField,ExtendMessage(..),MessageAPI(..),ExtKey(wireGetKey),PackedSeq )
import Text.ProtocolBuffers.Identifiers(FIName(..),MName(..),FName(..))
Expand All @@ -38,15 +39,18 @@ import Text.ProtocolBuffers.Reflections
, GetMessageInfo(GetMessageInfo),DescriptorInfo(extRanges),makePNF )
import Text.ProtocolBuffers.TextMessage -- all
import Text.ProtocolBuffers.Unknown
( UnknownField,UnknownMessage(..),wireSizeUnknownField,wirePutUnknownField,catch'Unknown )
( UnknownField,UnknownMessage(..),wireSizeUnknownField,wirePutUnknownField,wirePutUnknownFieldWithSize,catch'Unknown )
import Text.ProtocolBuffers.WireMessage
( Wire(..)
, prependMessageSize,putSize,splitWireTag
, runPutM
, wireSizeReq,wireSizeOpt,wireSizeRep
, wirePutReq,wirePutOpt,wirePutRep
, wirePutPacked,wireSizePacked
, wirePutReqWithSize,wirePutOptWithSize,wirePutRepWithSize
, sequencePutWithSize
, wirePutPacked,wirePutPackedWithSize,wireSizePacked
, getMessageWith,getBareMessageWith,wireGetEnum,wireGetPackedEnum
, wireSizeErr,wirePutErr,wireGetErr
, wireSizeErr,wirePutErr,wireGetErr,size'WireSize
, unknown,unknownField
, fieldIdOf)

Expand Down
7 changes: 6 additions & 1 deletion Text/ProtocolBuffers/Unknown.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
-- notice. Importer beware.
module Text.ProtocolBuffers.Unknown
( UnknownField(..),UnknownMessage(..),UnknownFieldValue(..)
, wireSizeUnknownField,wirePutUnknownField,catch'Unknown
, wireSizeUnknownField,wirePutUnknownField, wirePutUnknownFieldWithSize,catch'Unknown
) where

import qualified Data.ByteString.Lazy as L
Expand Down Expand Up @@ -53,6 +53,11 @@ wirePutUnknownField :: UnknownField -> Put
wirePutUnknownField (UnknownField m) = F.mapM_ aPut m where
aPut (UFV tag bs) = putVarUInt (getWireTag tag) >> putLazyByteString bs

-- | This is used by the generated code
wirePutUnknownFieldWithSize :: UnknownField -> PutM WireSize
wirePutUnknownFieldWithSize m =
wirePutUnknownField m >> return (wireSizeUnknownField m)

{-# INLINE catch'Unknown #-}
-- | This is used by the generated code
catch'Unknown :: (Typeable a, UnknownMessage a) => (WireTag -> a -> Get a) -> (WireTag -> a -> Get a)
Expand Down
84 changes: 64 additions & 20 deletions Text/ProtocolBuffers/WireMessage.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,24 @@ module Text.ProtocolBuffers.WireMessage
-- ** Encoding to write or read a single message field (good for delimited messages or incremental use)
, messageAsFieldSize,messageAsFieldPutM,messageAsFieldGetM
-- ** The Put monad from the binary package, and a custom binary Get monad ("Text.ProtocolBuffers.Get")
, Put,Get,runPut,runGet,runGetOnLazy,getFromBS
, Put,PutM,Get,runPut,runPutM,runGet,runGetOnLazy,getFromBS
-- * The Wire monad itself. Users should beware that passing an incompatible 'FieldType' is a runtime error or fail
, Wire(..)
-- * The internal exports, for use by generated code and the "Text.ProtcolBuffer.Extensions" module
, size'WireTag,toWireType,toWireTag,toPackedWireTag,mkWireTag
, size'WireTag,size'WireSize,toWireType,toWireTag,toPackedWireTag,mkWireTag
, prependMessageSize,putSize,putVarUInt,getVarInt,putLazyByteString,splitWireTag,fieldIdOf
, wireSizeReq,wireSizeOpt,wireSizeRep,wireSizePacked
, wirePutReq,wirePutOpt,wirePutRep,wirePutPacked
, wirePutReqWithSize,wirePutOptWithSize,wirePutRepWithSize,wirePutPackedWithSize
, sequencePutWithSize
, wireSizeErr,wirePutErr,wireGetErr
, getMessageWith,getBareMessageWith,wireGetEnum,wireGetPackedEnum
, unknownField,unknown,wireGetFromWire
, castWord64ToDouble,castWord32ToFloat,castDoubleToWord64,castFloatToWord32
, zzEncode64,zzEncode32,zzDecode64,zzDecode32
) where

import Control.Monad(when)
import Control.Monad(when,foldM)
import Control.Monad.Error.Class(throwError)
import Control.Monad.ST
import Data.Array.ST(newArray,readArray)
Expand All @@ -45,7 +47,7 @@ import Data.Bits (Bits(..))
--import qualified Data.ByteString as S(last)
--import qualified Data.ByteString.Unsafe as S(unsafeIndex)
import qualified Data.ByteString.Lazy as BS (length)
import qualified Data.Foldable as F(foldl',forM_)
import qualified Data.Foldable as F(foldl', Foldable)
--import Data.List (genericLength)
import Data.Maybe(fromMaybe)
import Data.Sequence ((|>))
Expand All @@ -57,7 +59,7 @@ import Data.Typeable (Typeable,typeOf)
--import GHC.Exts (Double(D#),Float(F#),unsafeCoerce#)
--import GHC.Word (Word64(W64#)) -- ,Word32(W32#))
-- binary package
import Data.Binary.Put (Put,runPut,putWord8,putWord32le,putWord64le,putLazyByteString)
import Data.Binary.Put (Put,PutM,runPutM,runPut,putWord8,putWord32le,putWord64le,putLazyByteString)

import Text.ProtocolBuffers.Basic
import Text.ProtocolBuffers.Get as Get (Result(..),Get,runGet,runGetAll,bytesRead,isReallyEmpty,decode7unrolled
Expand Down Expand Up @@ -179,7 +181,7 @@ getFromBS parser bs = case runGetOnLazy parser bs of
Left msg -> error msg
Right (r,_) -> r

-- This is like 'runGet', without the ability to pass in more input
-- | This is like 'runGet', without the ability to pass in more input
-- beyond the initial ByteString. Thus the 'ByteString' argument is
-- taken to be the entire input. To be able to incrementally feed in
-- more input you should use 'runGet' and respond to 'Partial'
Expand All @@ -195,33 +197,69 @@ runGetOnLazy parser bs = resolve (runGetAll parser bs)
prependMessageSize :: WireSize -> WireSize
prependMessageSize n = n + size'WireSize n

{-# INLINE sequencePutWithSize #-}
-- | Used in generated code.
sequencePutWithSize :: F.Foldable f => f (PutM WireSize) -> PutM WireSize
sequencePutWithSize =
let combine size act =
do size2 <- act
return $! size + size2
in foldM combine 0

{-# INLINE wirePutReqWithSize #-}
-- | Used in generated code.
wirePutReqWithSize :: Wire v => WireTag -> FieldType -> v -> PutM WireSize
wirePutReqWithSize wireTag fieldType v =
let startTag = getWireTag wireTag
tagSize = size'WireTag wireTag
putTag tag = putVarUInt tag >> return tagSize
putAct = wirePutWithSize fieldType v
endTag = succ startTag
in case fieldType of
10 -> sequencePutWithSize [putTag startTag, putAct, putTag endTag]
_ -> sequencePutWithSize [putTag startTag, putAct]

{-# INLINE wirePutOptWithSize #-}
-- | Used in generated code.
wirePutOptWithSize :: Wire v => WireTag -> FieldType -> Maybe v -> PutM WireSize
wirePutOptWithSize _wireTag _fieldType Nothing = return 0
wirePutOptWithSize wireTag fieldType (Just v) = wirePutReqWithSize wireTag fieldType v

{-# INLINE wirePutRepWithSize #-}
-- | Used in generated code.
wirePutRepWithSize :: Wire v => WireTag -> FieldType -> Seq v -> PutM WireSize
wirePutRepWithSize wireTag fieldType vs =
sequencePutWithSize $ fmap (wirePutReqWithSize wireTag fieldType) vs

{-# INLINE wirePutPackedWithSize #-}
-- | Used in generated code.
wirePutPackedWithSize :: Wire v => WireTag -> FieldType -> Seq v -> PutM WireSize
wirePutPackedWithSize wireTag fieldType vs =
let actInner = wirePutRepWithSize wireTag fieldType vs
(size, _) = runPutM actInner -- This should be lazy enough not to allocate the ByteString
tagSize = size'WireTag wireTag
putTag tag = putVarUInt (getWireTag tag) >> return tagSize
in sequencePutWithSize [putTag wireTag, putSize size>>return (prependMessageSize size), actInner]

{-# INLINE wirePutReq #-}
-- | Used in generated code.
wirePutReq :: Wire v => WireTag -> FieldType -> v -> Put
wirePutReq wireTag 10 v = let startTag = getWireTag wireTag
endTag = succ startTag
in putVarUInt startTag >> wirePut 10 v >> putVarUInt endTag
wirePutReq wireTag fieldType v = putVarUInt (getWireTag wireTag) >> wirePut fieldType v
wirePutReq wireTag fieldType v = wirePutReqWithSize wireTag fieldType v >> return ()

{-# INLINE wirePutOpt #-}
-- | Used in generated code.
wirePutOpt :: Wire v => WireTag -> FieldType -> Maybe v -> Put
wirePutOpt _wireTag _fieldType Nothing = return ()
wirePutOpt wireTag fieldType (Just v) = wirePutReq wireTag fieldType v
wirePutOpt wireTag fieldType v = wirePutOptWithSize wireTag fieldType v >> return ()

{-# INLINE wirePutRep #-}
-- | Used in generated code.
wirePutRep :: Wire v => WireTag -> FieldType -> Seq v -> Put
wirePutRep wireTag fieldType vs = F.forM_ vs (\v -> wirePutReq wireTag fieldType v)
wirePutRep wireTag fieldType vs = wirePutRepWithSize wireTag fieldType vs >> return ()

{-# INLINE wirePutPacked #-}
-- | Used in generated code.
wirePutPacked :: Wire v => WireTag -> FieldType -> Seq v -> Put
wirePutPacked wireTag fieldType vs = do
putVarUInt (getWireTag wireTag)
let size = F.foldl' (\n v -> n + wireSize fieldType v) 0 vs
putSize size
F.forM_ vs (\v -> wirePut fieldType v)
wirePutPacked wireTag fieldType vs = wirePutPackedWithSize wireTag fieldType vs >> return ()

{-# INLINE wireSizeReq #-}
-- | Used in generated code.
Expand Down Expand Up @@ -431,7 +469,7 @@ castDoubleToWord64 x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip re
wireSizeErr :: Typeable a => FieldType -> a -> WireSize
wireSizeErr ft x = error $ concat [ "Impossible? wireSize field type mismatch error: Field type number ", show ft
, " does not match internal type ", show (typeOf x) ]
wirePutErr :: Typeable a => FieldType -> a -> Put
wirePutErr :: Typeable a => FieldType -> a -> PutM b
wirePutErr ft x = fail $ concat [ "Impossible? wirePut field type mismatch error: Field type number ", show ft
, " does not match internal type ", show (typeOf x) ]
wireGetErr :: Typeable a => FieldType -> Get a
Expand All @@ -449,8 +487,14 @@ wireGetErr ft = answer where
-- "Text.ProtocolBuffers.WireMessage" and exported to use user by
-- "Text.ProtocolBuffers". These are less likely to change.
class Wire b where
{-# MINIMAL wireGet, wireSize, (wirePut | wirePutWithSize) #-}
wireSize :: FieldType -> b -> WireSize
{-# INLINE wirePut #-}
wirePut :: FieldType -> b -> Put
wirePut ft x = wirePutWithSize ft x >> return ()
{-# INLINE wirePutWithSize #-}
wirePutWithSize :: FieldType -> b -> PutM WireSize
wirePutWithSize ft x = wirePut ft x >> return (wireSize ft x)
wireGet :: FieldType -> Get b
{-# INLINE wireGetPacked #-}
wireGetPacked :: FieldType -> Get (Seq b)
Expand Down Expand Up @@ -889,4 +933,4 @@ getVarInt = do
-- OPTIMIZE try inlinining getMessageWith and getBareMessageWith: bench-005, slower


-- OPTIMIZE try NO-inlining getMessageWith and getBareMessageWith
-- OPTIMIZE try NO-inlining getMessageWith and getBareMessageWith
13 changes: 7 additions & 6 deletions descriptor/src-auto-generated/Text/DescriptorProtos.hs

Large diffs are not rendered by default.

Loading

0 comments on commit 5321d12

Please sign in to comment.