Permalink
Browse files

Implement decompression of lazy bytestrings.

  • Loading branch information...
1 parent 773f0a2 commit 0bd9071109ee48b3831f655c88e9b87c2352a707 @bos committed Mar 25, 2011
@@ -21,10 +21,9 @@ module Codec.Compression.Snappy
, decompress
) where
-import Codec.Compression.Snappy.Internal (maxCompressedLength)
-import Control.Monad (when)
+import Codec.Compression.Snappy.Internal (check, maxCompressedLength)
import Data.ByteString.Internal (ByteString(..), mallocByteString)
-import Data.Word (Word8)
+import Data.Word (Word8, Word32)
import Foreign.C.Types (CInt, CSize)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Marshal.Alloc (alloca)
@@ -36,7 +35,7 @@ import qualified Data.ByteString as B
-- | Compress data into the Snappy format.
compress :: ByteString -> ByteString
-compress bs@(PS sfp off len) = unsafePerformIO $ do
+compress (PS sfp off len) = unsafePerformIO $ do
let dlen0 = maxCompressedLength len
dfp <- mallocByteString dlen0
withForeignPtr sfp $ \sptr ->
@@ -54,21 +53,22 @@ decompress (PS sfp off slen) = unsafePerformIO $
withForeignPtr sfp $ \sptr0 -> do
let sptr = sptr0 `plusPtr` off
len = fromIntegral slen
- let check ok = when (ok == 0) $
- fail "Codec.Compression.Snappy.decompress: corrupt input"
alloca $ \dlenPtr -> do
- check =<< c_GetUncompressedLength sptr len dlenPtr
+ check "decompress" $ c_GetUncompressedLength sptr len dlenPtr
dlen <- fromIntegral `fmap` peek dlenPtr
- dfp <- mallocByteString dlen
- withForeignPtr dfp $ \dptr -> do
- check =<< c_RawUncompress sptr len dptr
- return (PS dfp 0 dlen)
+ if dlen == 0
+ then return B.empty
+ else do
+ dfp <- mallocByteString dlen
+ withForeignPtr dfp $ \dptr -> do
+ check "decompress" $ c_RawUncompress sptr len dptr
+ return (PS dfp 0 dlen)
foreign import ccall unsafe "hs_snappy.h _hsnappy_RawCompress"
c_RawCompress :: Ptr a -> CSize -> Ptr Word8 -> Ptr CSize -> IO ()
foreign import ccall unsafe "hs_snappy.h _hsnappy_GetUncompressedLength"
- c_GetUncompressedLength :: Ptr a -> CSize -> Ptr CSize -> IO CInt
+ c_GetUncompressedLength :: Ptr a -> CSize -> Ptr Word32 -> IO CInt
foreign import ccall unsafe "hs_snappy.h _hsnappy_RawUncompress"
c_RawUncompress :: Ptr a -> CSize -> Ptr Word8 -> IO CInt
@@ -17,14 +17,23 @@
module Codec.Compression.Snappy.Internal
(
- maxCompressedLength
+ check
+ , maxCompressedLength
) where
+import Control.Monad (when)
import Foreign.C.Types (CSize)
maxCompressedLength :: Int -> Int
maxCompressedLength = fromIntegral . c_MaxCompressedLength . fromIntegral
{-# INLINE maxCompressedLength #-}
+check :: (Integral a) => String -> IO a -> IO ()
+check func act = do
+ ok <- act
+ when (ok == 0) . fail $ "Codec.Compression.Snappy." ++ func ++
+ ": corrupt input "
+{-# INLINE check #-}
+
foreign import ccall unsafe "hs_snappy.h _hsnappy_MaxCompressedLength"
c_MaxCompressedLength :: CSize -> CSize
@@ -29,56 +29,89 @@ module Codec.Compression.Snappy.Lazy
#include "hs_snappy.h"
-import Codec.Compression.Snappy.Internal (maxCompressedLength)
+import Codec.Compression.Snappy.Internal (check, maxCompressedLength)
+import Control.Exception (bracket)
import Data.ByteString.Internal hiding (ByteString)
import Data.ByteString.Lazy.Internal (ByteString(..))
-import Data.Word (Word8)
+import Data.Word (Word8, Word32)
import Foreign.C.Types (CInt, CSize)
import Foreign.ForeignPtr (touchForeignPtr, withForeignPtr)
+import Foreign.Marshal.Alloc (alloca)
import Foreign.Marshal.Array (withArray)
import Foreign.Marshal.Utils (with)
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (Storable(..))
import System.IO.Unsafe (unsafePerformIO)
-import qualified Codec.Compression.Snappy as S
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
newtype BS = BS B.ByteString
+data BSSource
+
instance Storable BS where
sizeOf _ = (#size struct BS)
alignment _ = alignment (undefined :: Ptr CInt)
poke ptr (BS (PS fp off len)) = withForeignPtr fp $ \p -> do
(#poke struct BS, ptr) ptr (p `plusPtr` off)
+ (#poke struct BS, off) ptr (0::CSize)
(#poke struct BS, len) ptr len
{-# INLINE poke #-}
-- | Compress data into the Snappy format.
compress :: ByteString -> ByteString
-compress bs = unsafePerformIO $ do
- let len = fromIntegral (L.length bs)
+compress bs = unsafePerformIO . withChunks bs $ \chunkPtr numChunks len -> do
let dlen0 = maxCompressedLength len
dfp <- mallocByteString dlen0
withForeignPtr dfp $ \dptr -> do
- let chunks = L.toChunks bs
- withArray (map BS chunks) $ \chunkPtr ->
- with (fromIntegral dlen0) $ \dlenPtr -> do
- c_CompressChunks chunkPtr (fromIntegral (length chunks))
- (fromIntegral len) dptr dlenPtr
- foldr (\(PS fp _ _) _ -> touchForeignPtr fp) (return ()) chunks
- dlen <- fromIntegral `fmap` peek dlenPtr
- if dlen == 0
- then return Empty
- else return (Chunk (PS dfp 0 dlen) Empty)
+ with (fromIntegral dlen0) $ \dlenPtr -> do
+ c_CompressChunks chunkPtr (fromIntegral numChunks)
+ (fromIntegral len) dptr dlenPtr
+ dlen <- fromIntegral `fmap` peek dlenPtr
+ if dlen == 0
+ then return Empty
+ else return (Chunk (PS dfp 0 dlen) Empty)
-- | Decompress data in the Snappy format.
--
-- If the input is not compressed or is corrupt, an exception will be
-- thrown.
decompress :: ByteString -> ByteString
-decompress = L.fromChunks . (:[]) . S.decompress . B.concat . L.toChunks
+decompress bs = unsafePerformIO . withChunks bs $ \chunkPtr numChunks len ->
+ bracket (c_NewSource chunkPtr (fromIntegral numChunks) (fromIntegral len))
+ c_DeleteSource $ \srcPtr -> do
+ alloca $ \dlenPtr -> do
+ check "Lazy.decompress" $ c_GetUncompressedLengthChunks srcPtr dlenPtr
+ dlen <- fromIntegral `fmap` peek dlenPtr
+ if dlen == 0
+ then return L.empty
+ else do
+ dfp <- mallocByteString dlen
+ withForeignPtr dfp $ \dptr -> do
+ check "Lazy.decompress" $ c_UncompressChunks srcPtr dptr
+ return (Chunk (PS dfp 0 dlen) Empty)
+
+withChunks :: ByteString -> (Ptr BS -> Int -> Int -> IO a) -> IO a
+withChunks bs act = do
+ let len = fromIntegral (L.length bs)
+ let chunks = L.toChunks bs
+ r <- withArray (map BS chunks) $ \chunkPtr ->
+ act chunkPtr (length chunks) len
+ foldr (\(PS fp _ _) _ -> touchForeignPtr fp) (return ()) chunks
+ return r
foreign import ccall unsafe "hs_snappy.h _hsnappy_CompressChunks"
c_CompressChunks :: Ptr BS -> CSize -> CSize -> Ptr Word8 -> Ptr CSize
-> IO ()
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_NewSource"
+ c_NewSource :: Ptr BS -> CSize -> CSize -> IO (Ptr BSSource)
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_DeleteSource"
+ c_DeleteSource :: Ptr BSSource -> IO ()
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_UncompressChunks"
+ c_UncompressChunks :: Ptr BSSource -> Ptr Word8 -> IO Int
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_GetUncompressedLengthChunks"
+ c_GetUncompressedLengthChunks :: Ptr BSSource -> Ptr Word32 -> IO Int
View
@@ -28,38 +28,49 @@ int _hsnappy_RawUncompress(const char *compressed, size_t compressed_length,
return RawUncompress(compressed, compressed_length, uncompressed);
}
-class BSSource : public Source
+class BSSource : public Source
{
public:
- BSSource(BS *chunks, size_t nchunks, size_t left)
- : chunks_(chunks), nchunks_(nchunks), cur_(chunks), left_(left) { }
-
+ BSSource(BS *chunks, size_t nchunks, size_t size)
+ : chunks_(chunks), nchunks_(nchunks), size_(size), cur_(chunks),
+ left_(size) { }
+
size_t Available() const { return left_; }
-
+
const char *Peek(size_t *len) {
- *len = cur_->len;
- return cur_->ptr;
+ if (left_ > 0) {
+ *len = cur_->len - cur_->off;
+ return cur_->ptr + cur_->off;
+ } else {
+ *len = 0;
+ return NULL;
+ }
}
void Skip(size_t n) {
- left_ -= n;
- while (n >= cur_->len) {
- n -= cur_->len;
- cur_++;
- }
if (n > 0) {
- cur_->len -= n;
- cur_->ptr += n;
+ left_ -= n;
+ cur_->off += n;
+ if (cur_->off == cur_->len)
+ cur_++;
}
}
+ void Rewind() {
+ left_ = size_;
+ cur_ = chunks_;
+ for (size_t i = 0; i < nchunks_ && chunks_[i].off > 0; i++)
+ chunks_[i].off = 0;
+ }
+
private:
BS *chunks_;
- const int nchunks_;
+ const size_t nchunks_;
+ const size_t size_;
BS *cur_;
size_t left_;
};
-
+
void _hsnappy_CompressChunks(BS *chunks, size_t nchunks, size_t length,
char *compressed, size_t *compressed_length)
{
@@ -70,3 +81,25 @@ void _hsnappy_CompressChunks(BS *chunks, size_t nchunks, size_t length,
*compressed_length = writer.CurrentDestination() - compressed;
}
+
+BSSource *_hsnappy_NewSource(BS *chunks, size_t nchunks, size_t length)
+{
+ return new BSSource(chunks, nchunks, length);
+}
+
+void _hsnappy_DeleteSource(BSSource *src)
+{
+ delete src;
+}
+
+int _hsnappy_UncompressChunks(BSSource *reader, char *uncompressed)
+{
+ return RawUncompress(reader, uncompressed);
+}
+
+int _hsnappy_GetUncompressedLengthChunks(BSSource *reader, uint32_t *result)
+{
+ int n = GetUncompressedLength(reader, result);
+ reader->Rewind();
+ return n;
+}
View
@@ -2,17 +2,19 @@
#define _hs_snappy_h
#include <stddef.h>
+#include <stdint.h>
#ifdef __cplusplus
-extern "C"
+extern "C"
{
#endif
struct BS {
const char *ptr;
+ size_t off;
size_t len;
};
-
+
size_t _hsnappy_MaxCompressedLength(size_t);
void _hsnappy_RawCompress(const char *input, size_t input_length,
@@ -25,10 +27,24 @@ int _hsnappy_GetUncompressedLength(const char *compressed,
int _hsnappy_RawUncompress(const char *compressed, size_t compressed_length,
char *uncompressed);
+struct BS;
+
void _hsnappy_CompressChunks(struct BS *chunks, size_t count,
size_t length, char *compressed,
size_t *compressed_length);
+struct BSSource;
+
+struct BSSource *_hsnappy_NewSource(struct BS *chunks, size_t nchunks,
+ size_t length);
+
+void _hsnappy_DeleteSource(struct BSSource *reader);
+
+int _hsnappy_UncompressChunks(struct BSSource *reader, char *uncompressed);
+
+int _hsnappy_GetUncompressedLengthChunks(struct BSSource *reader,
+ uint32_t *result);
+
#ifdef __cplusplus
}
#endif
View
@@ -24,13 +24,15 @@ extra-source-files:
library
c-sources: cbits/hs_snappy.cpp
- cc-options: -g -O0
include-dirs: include
extra-libraries: snappy stdc++
- build-depends: base < 5, bytestring
+ cc-options: -Wall
+ ghc-options: -Wall
+
+ build-depends: base < 5, bytestring
if impl(ghc >= 6.10)
- build-depends: base >= 4
+ build-depends: base >= 4
exposed-modules:
Codec.Compression.Snappy
View
@@ -1,3 +1,6 @@
+{-# LANGUAGE FlexibleInstances #-}
+
+import Control.Applicative
import qualified Codec.Compression.Snappy as B
import qualified Codec.Compression.Snappy.Lazy as L
import Test.Framework (defaultMain, testGroup)
@@ -6,15 +9,44 @@ import Test.QuickCheck (Arbitrary(..))
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
-s_roundtrip s = B.decompress (B.compress bs) == bs
- where bs = B.pack s
+instance Arbitrary B.ByteString where
+ arbitrary = B.pack <$> arbitrary
+
+instance Arbitrary L.ByteString where
+ arbitrary = rechunk <$> arbitrary <*> arbitrary
+
+s_roundtrip bs = B.decompress (B.compress bs) == bs
+
+newtype Compressed a = Compressed { compressed :: a }
+ deriving (Eq, Ord)
+
+instance Show a => Show (Compressed a)
+ where show (Compressed a) = "Compressed " ++ show a
+
+instance Arbitrary (Compressed B.ByteString) where
+ arbitrary = (Compressed . B.compress) <$> arbitrary
+
+compress_eq n bs = L.fromChunks [B.compress bs] == L.compress (rechunk n bs)
+decompress_eq n bs0 =
+ L.fromChunks [B.decompress bs] == L.decompress (rechunk n bs)
+ where bs = B.compress bs0
+
+rechunk :: Int -> B.ByteString -> L.ByteString
+rechunk n = L.fromChunks . go
+ where go bs | B.null bs = []
+ | otherwise = case B.splitAt ((n `mod` 63) + 1) bs of
+ (x,y) -> x : go y
+
+t_rechunk n bs = L.fromChunks [bs] == rechunk n bs
-l_roundtrip s = L.decompress (L.compress bs) == bs
- where bs = L.pack s
+l_roundtrip bs = L.decompress (L.compress bs) == bs
main = defaultMain tests
tests = [
testProperty "s_roundtrip" s_roundtrip
+ , testProperty "t_rechunk" t_rechunk
+ , testProperty "compress_eq" compress_eq
+ , testProperty "decompress_eq" decompress_eq
, testProperty "l_roundtrip" l_roundtrip
]

0 comments on commit 0bd9071

Please sign in to comment.