diff --git a/Data/Text/Encoding.hs b/Data/Text/Encoding.hs index 8758551d..25fe17de 100644 --- a/Data/Text/Encoding.hs +++ b/Data/Text/Encoding.hs @@ -73,7 +73,7 @@ import Data.ByteString.Internal as B hiding (c2w) import Data.Text.Encoding.Error (OnDecodeError, UnicodeException, strictDecode) import Data.Text.Internal (Text(..), safe, text) import Data.Text.Internal.Private (runText) -import Data.Text.Internal.Unsafe.Char (ord, unsafeWrite) +import Data.Text.Internal.Unsafe.Char (ord, unsafeWrite, unsafeWriteSingle) import Data.Text.Internal.Unsafe.Shift (shiftR) import Data.Text.Show () import Data.Text.Unsafe (unsafeDupablePerformIO) @@ -131,6 +131,11 @@ decodeLatin1 (PS fp off len) = text a 0 len return dest -- | Decode a 'ByteString' containing UTF-8 encoded text. +-- +-- Note that, for the provided 'OnDecodeError' function, it is invalid +-- to provide a 'Char' value as a result which is greater than +-- @'\xffff'@. This is usually not a concern. For more information, +-- see . decodeUtf8With :: OnDecodeError -> ByteString -> Text decodeUtf8With onErr (PS fp off len) = runText $ \done -> do let go dest = withForeignPtr fp $ \ptr -> @@ -146,12 +151,15 @@ decodeUtf8With onErr (PS fp off len) = runText $ \done -> do x <- peek curPtr' case onErr desc (Just x) of Nothing -> loop $ curPtr' `plusPtr` 1 - Just c -> do - destOff <- peek destOffPtr - w <- unsafeSTToIO $ - unsafeWrite dest (fromIntegral destOff) (safe c) - poke destOffPtr (destOff + fromIntegral w) - loop $ curPtr' `plusPtr` 1 + Just c + | c >= '\x10000' -> + error "decodeUtf8With: cannot supply characters beyond \\xffff" + | otherwise -> do + destOff <- peek destOffPtr + unsafeSTToIO $ + unsafeWriteSingle dest (fromIntegral destOff) (safe c) + poke destOffPtr (destOff + 1) + loop $ curPtr' `plusPtr` 1 loop (ptr `plusPtr` off) (unsafeIOToST . go) =<< A.new len where diff --git a/Data/Text/Internal/Unsafe/Char.hs b/Data/Text/Internal/Unsafe/Char.hs index d208e3f0..ce3975a6 100644 --- a/Data/Text/Internal/Unsafe/Char.hs +++ b/Data/Text/Internal/Unsafe/Char.hs @@ -23,6 +23,7 @@ module Data.Text.Internal.Unsafe.Char , unsafeChr8 , unsafeChr32 , unsafeWrite + , unsafeWriteSingle -- , unsafeWriteRev ) where @@ -52,6 +53,19 @@ unsafeChr32 :: Word32 -> Char unsafeChr32 (W32# w#) = C# (chr# (word2Int# w#)) {-# INLINE unsafeChr32 #-} +-- | Write a character into the array, assuming it can be represented by a single Word16, i.e. that it is < 0x10000. +-- the number of 'Word16's written. +unsafeWriteSingle :: A.MArray s -> Int -> Char -> ST s () +unsafeWriteSingle marr i c = do +#if defined(ASSERTS) + assert (i >= 0) + . assert (i < A.length marr) + . assert (n < 0x10000) + $ return () +#endif + A.unsafeWrite marr i (fromIntegral n) + where n = ord c + -- | Write a character into the array at the given offset. Returns -- the number of 'Word16's written. unsafeWrite :: A.MArray s -> Int -> Char -> ST s Int diff --git a/tests/Tests/Properties.hs b/tests/Tests/Properties.hs index ec9a1fdf..acb4ef24 100644 --- a/tests/Tests/Properties.hs +++ b/tests/Tests/Properties.hs @@ -122,22 +122,38 @@ data Badness = Solo | Leading | Trailing instance Arbitrary Badness where arbitrary = elements [Solo, Leading, Trailing] -t_utf8_err :: Badness -> DecodeErr -> Property -t_utf8_err bad de = do +t_utf8_err :: Badness -> Maybe DecodeErr -> Property +t_utf8_err bad mde = do let gen = case bad of Solo -> genInvalidUTF8 Leading -> B.append <$> genInvalidUTF8 <*> genUTF8 Trailing -> B.append <$> genUTF8 <*> genInvalidUTF8 genUTF8 = E.encodeUtf8 <$> genUnicode - forAll gen $ \bs -> MkProperty $ do - onErr <- genDecodeErr de - unProperty . monadicIO $ do - l <- run $ let len = T.length (E.decodeUtf8With onErr bs) - in (len `seq` return (Right len)) `Exception.catch` - (\(e::UnicodeException) -> return (Left e)) - assert $ case l of - Left err -> length (show err) >= 0 - Right _ -> de /= Strict + forAll gen $ \bs -> MkProperty $ + case mde of + -- generate an invalid character + Nothing -> do + c <- choose ('\x10000', maxBound) + let onErr _ _ = Just c + unProperty . monadicIO $ do + l <- run $ let len = T.length (E.decodeUtf8With onErr bs) + in (len `seq` return (Right len)) `Exception.catch` + (\(e::Exception.SomeException) -> return (Left e)) + assert $ case l of + Left err -> + "cannot supply characters" `T.isInfixOf` T.pack (show err) + Right _ -> False + + -- generate a valid onErr + Just de -> do + onErr <- genDecodeErr de + unProperty . monadicIO $ do + l <- run $ let len = T.length (E.decodeUtf8With onErr bs) + in (len `seq` return (Right len)) `Exception.catch` + (\(e::UnicodeException) -> return (Left e)) + assert $ case l of + Left err -> length (show err) >= 0 + Right _ -> de /= Strict t_utf8_err' :: B.ByteString -> Property t_utf8_err' bs = monadicIO . assert $ case E.decodeUtf8' bs of @@ -203,9 +219,10 @@ t_decode_with_error4' = case E.streamDecodeUtf8With (\_ _ -> Just 'x') (B.pack [0xC2, 97, 97, 97]) of E.Some x _ _ -> x === "xaaa" -t_infix_concat bs1 text bs2 rep = +t_infix_concat bs1 text bs2 = + forAll (genDecodeErr Replace) $ \onErr -> text `T.isInfixOf` - E.decodeUtf8With (\_ _ -> rep) (B.concat [bs1, E.encodeUtf8 text, bs2]) + E.decodeUtf8With onErr (B.concat [bs1, E.encodeUtf8 text, bs2]) s_Eq s = (s==) `eq` ((S.streamList s==) . S.streamList) where _types = s :: String diff --git a/tests/Tests/QuickCheckUtils.hs b/tests/Tests/QuickCheckUtils.hs index 851b6588..24da94a0 100644 --- a/tests/Tests/QuickCheckUtils.hs +++ b/tests/Tests/QuickCheckUtils.hs @@ -210,7 +210,10 @@ genDecodeErr :: DecodeErr -> Gen T.OnDecodeError genDecodeErr Lenient = return T.lenientDecode genDecodeErr Ignore = return T.ignore genDecodeErr Strict = return T.strictDecode -genDecodeErr Replace = arbitrary +genDecodeErr Replace = (\c _ _ -> c) <$> frequency + [ (1, return Nothing) + , (50, Just <$> choose ('\x1', '\xffff')) + ] instance Arbitrary DecodeErr where arbitrary = elements [Lenient, Ignore, Strict, Replace] diff --git a/text.cabal b/text.cabal index 27fa6463..45564783 100644 --- a/text.cabal +++ b/text.cabal @@ -246,7 +246,7 @@ test-suite tests build-depends: HUnit >= 1.2, - QuickCheck >= 2.7 && < 2.10, + QuickCheck >= 2.7 && < 2.11, array, base, binary,