diff --git a/Data/Primitive/Array.hs b/Data/Primitive/Array.hs index f35d2236..4130caad 100644 --- a/Data/Primitive/Array.hs +++ b/Data/Primitive/Array.hs @@ -23,6 +23,7 @@ module Data.Primitive.Array ( cloneArray, cloneMutableArray, sizeofArray, sizeofMutableArray, fromListN, fromList, + mapArray', unsafeTraverseArray ) where @@ -559,6 +560,21 @@ unsafeTraverseArray f = \ !ary -> go 0 mary {-# INLINE unsafeTraverseArray #-} +-- | Strict map over the elements of the array. +mapArray' :: (a -> b) -> Array a -> Array b +mapArray' f a = + createArray (sizeofArray a) (die "mapArray'" "impossible") $ \mb -> + let go i | i == sizeofArray a + = return () + | otherwise + = do x <- indexArrayM a i + -- We use indexArrayM here so that we will perform the + -- indexing eagerly even if f is lazy. + let !y = f x + writeArray mb i y >> go (i+1) + in go 0 +{-# INLINE mapArray' #-} + arrayFromListN :: Int -> [a] -> Array a arrayFromListN n l = createArray n (die "fromListN" "uninitialized element") $ \sma -> diff --git a/Data/Primitive/SmallArray.hs b/Data/Primitive/SmallArray.hs index b0cd77e0..262a70d2 100644 --- a/Data/Primitive/SmallArray.hs +++ b/Data/Primitive/SmallArray.hs @@ -57,6 +57,7 @@ module Data.Primitive.SmallArray , sizeofSmallMutableArray , smallArrayFromList , smallArrayFromListN + , mapSmallArray' , unsafeTraverseSmallArray ) where @@ -436,6 +437,20 @@ unsafeTraverseSmallArray f (SmallArray ar) = SmallArray `liftM` unsafeTraverseAr #endif {-# INLINE unsafeTraverseSmallArray #-} +-- | Strict map over the elements of the array. +mapSmallArray' :: (a -> b) -> SmallArray a -> SmallArray b +#if HAVE_SMALL_ARRAY +mapSmallArray' f sa = createSmallArray (length sa) (die "mapSmallArray'" "impossible") $ \smb -> + fix ? 0 $ \go i -> + when (i < length sa) $ do + x <- indexSmallArrayM sa i + let !y = f x + writeSmallArray smb i y *> go (i+1) +#else +mapSmallArray' f (SmallArray ar) = SmallArray (mapArray' f ar) +#endif +{-# INLINE mapSmallArray' #-} + #ifndef HAVE_SMALL_ARRAY runSmallArray :: (forall s. ST s (SmallMutableArray s a)) diff --git a/test/main.hs b/test/main.hs index 44552d27..197e3c80 100644 --- a/test/main.hs +++ b/test/main.hs @@ -45,15 +45,16 @@ main = do , lawsToTest (QCC.ordLaws (Proxy :: Proxy (Array Int))) , lawsToTest (QCC.monoidLaws (Proxy :: Proxy (Array Int))) , lawsToTest (QCC.showReadLaws (Proxy :: Proxy (Array Int))) -#if MIN_VERSION_base(4,7,0) - , lawsToTest (QCC.isListLaws (Proxy :: Proxy (Array Int))) -#endif #if MIN_VERSION_base(4,9,0) || MIN_VERSION_transformers(0,4,0) , lawsToTest (QCC.functorLaws (Proxy1 :: Proxy1 Array)) , lawsToTest (QCC.applicativeLaws (Proxy1 :: Proxy1 Array)) , lawsToTest (QCC.monadLaws (Proxy1 :: Proxy1 Array)) , lawsToTest (QCC.foldableLaws (Proxy1 :: Proxy1 Array)) , lawsToTest (QCC.traversableLaws (Proxy1 :: Proxy1 Array)) +#endif +#if MIN_VERSION_base(4,7,0) + , lawsToTest (QCC.isListLaws (Proxy :: Proxy (Array Int))) + , TQC.testProperty "mapArray'" (QCCL.mapProp int16 int32 mapArray') #endif ] , testGroup "SmallArray" @@ -61,15 +62,16 @@ main = do , lawsToTest (QCC.ordLaws (Proxy :: Proxy (SmallArray Int))) , lawsToTest (QCC.monoidLaws (Proxy :: Proxy (SmallArray Int))) , lawsToTest (QCC.showReadLaws (Proxy :: Proxy (Array Int))) -#if MIN_VERSION_base(4,7,0) - , lawsToTest (QCC.isListLaws (Proxy :: Proxy (SmallArray Int))) -#endif #if MIN_VERSION_base(4,9,0) || MIN_VERSION_transformers(0,4,0) , lawsToTest (QCC.functorLaws (Proxy1 :: Proxy1 SmallArray)) , lawsToTest (QCC.applicativeLaws (Proxy1 :: Proxy1 SmallArray)) , lawsToTest (QCC.monadLaws (Proxy1 :: Proxy1 SmallArray)) , lawsToTest (QCC.foldableLaws (Proxy1 :: Proxy1 SmallArray)) , lawsToTest (QCC.traversableLaws (Proxy1 :: Proxy1 SmallArray)) +#endif +#if MIN_VERSION_base(4,7,0) + , lawsToTest (QCC.isListLaws (Proxy :: Proxy (SmallArray Int))) + , TQC.testProperty "mapSmallArray'" (QCCL.mapProp int16 int32 mapSmallArray') #endif ] , testGroup "ByteArray"