Skip to content

Commit

Permalink
WIP on Numeric operations
Browse files Browse the repository at this point in the history
  • Loading branch information
lehins committed Jul 24, 2019
1 parent 28f35f2 commit d82597f
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 30 deletions.
55 changes: 55 additions & 0 deletions massiv-bench/bench/Arithmetic.hs
@@ -0,0 +1,55 @@
{-# LANGUAGE BangPatterns #-}
module Main where

import Criterion.Main
import Data.Massiv.Array as A
--import Data.Massiv.Core.Operations
import Data.Massiv.Array.SIMD
import Data.Massiv.Bench as A


main :: IO ()
main = do
let !sz = Sz2 16 16
!arr1' = arrRLightIx2 S Seq sz
!arr2' = computeAs S $ transpose arr1'
!arrV1' = arrRLightIx2 V Seq sz
!arrV2' = computeAs V $ transpose arr1'
defaultMain
[ env (return (arr1', arr2')) $ \ ~(arr1, arr2) ->
bgroup
"Addition"
[ bgroup
"S (Seq)"
[ bench "zipWith (+)" $ whnf (computeAs S . A.zipWith (+) arr1) arr2
, bench "(.+.)" $ whnfIO (computeAs S <$> (toManifest arr1 .+. toManifest arr2))
]
]
, env (return (setComp Par arr1', setComp Par arr2')) $ \ ~(arr1, arr2) ->
bgroup
"Addition"
[ bgroup
"S (Par)"
[ bench "zipWith (+)" $ whnf (computeAs S . A.zipWith (+) arr1) arr2
, bench "(.+.)" $ whnfIO (computeAs S <$> (toManifest arr1 .+. toManifest arr2))
]
]
, env (return (arrV1', arrV2')) $ \ ~(arr1, arr2) ->
bgroup
"Addition"
[ bgroup
"V (Seq)"
[ bench "zipWith (+)" $ whnf (computeAs V . A.zipWith (+) arr1) arr2
, bench "(.+.)" $ whnfIO (arr1 .+. arr2)
]
]
, env (return (setComp Par arrV1', setComp Par arrV2')) $ \ ~(arr1, arr2) ->
bgroup
"Addition"
[ bgroup
"V (Par)"
[ bench "zipWith (+)" $ whnf (computeAs V . A.zipWith (+) arr1) arr2
, bench "(.+.)" $ whnfIO (arr1 .+. arr2)
]
]
]
12 changes: 7 additions & 5 deletions massiv/src/Data/Massiv/Array/Delayed/Pull.hs
Expand Up @@ -70,7 +70,7 @@ instance Index ix => Construct D ix e where
instance Index ix => Source D ix e where
unsafeIndex = INDEX_CHECK("(Source D ix e).unsafeIndex", size, dIndex)
{-# INLINE unsafeIndex #-}
unsafeLinearSlice arr ix sz = unsafeExtract ix sz (unsafeResize (SafeSz (totalElem sz)) arr)
unsafeLinearSlice ix sz arr = unsafeExtract ix sz (unsafeResize sz arr)
{-# INLINE unsafeLinearSlice #-}


Expand Down Expand Up @@ -214,14 +214,16 @@ instance Num e => Numeric D e where
{-# INLINE absPointwise #-}
additionPointwise = liftDArray2 (+)
{-# INLINE additionPointwise #-}
subtractionPointwise = liftDArray2 (-)
{-# INLINE subtractionPointwise #-}
multiplicationPointwise = liftDArray2 (*)
{-# INLINE multiplicationPointwise #-}
powerPointwise arr pow = liftDArray (^ pow) arr
{-# INLINE powerPointwise #-}
powerSumArray p arr = sumArray p . powerPointwise arr
powerSumArray arr = sumArray . powerPointwise arr
{-# INLINE powerSumArray #-}
dotProduct p a1 a2 = sumArray p (multiplicationPointwise a1 a2)
{-# INLINE dotProduct #-}
unsafeDotProduct a1 a2 = sumArray (multiplicationPointwise a1 a2)
{-# INLINE unsafeDotProduct #-}

instance (Floating e, RealFrac e) => NumericFloat D e where
recipPointwise = liftDArray recip
Expand All @@ -234,7 +236,7 @@ instance (Floating e, RealFrac e) => NumericFloat D e where
{-# INLINE ceilingPointwise #-}
divisionPointwise = liftDArray2 (/)
{-# INLINE divisionPointwise #-}
divideScalar arr e = liftDArray (/e) arr
divideScalar arr e = liftDArray (/ e) arr
{-# INLINE divideScalar #-}


Expand Down
9 changes: 4 additions & 5 deletions massiv/src/Data/Massiv/Array/Manifest/Primitive.hs
Expand Up @@ -267,13 +267,12 @@ fromByteArrayM comp sz ba =
arr = PArray comp sz ba
{-# INLINE fromByteArrayM #-}

-- | See `fromByteArrayM`.
-- | /O(1)/ - Construct a flat Array from `ByteArray`
--
-- @since 0.2.1
fromByteArray :: (Index ix, Prim e) => Comp -> Sz ix -> ByteArray -> Maybe (Array P ix e)
fromByteArray = fromByteArrayM
-- @since 0.4.0
fromByteArray :: forall e . Prim e => Comp -> ByteArray -> Array P Ix1 e
fromByteArray comp ba = PArray comp (SafeSz (elemsBA (Proxy :: Proxy e) ba)) ba
{-# INLINE fromByteArray #-}
{-# DEPRECATED fromByteArray "In favor of more general `fromByteArrayM`" #-}


-- | /O(1)/ - Extract the internal `MutableByteArray`.
Expand Down
2 changes: 1 addition & 1 deletion massiv/src/Data/Massiv/Array/Numeric.hs
Expand Up @@ -110,6 +110,7 @@ liftArray2M f a1 a2
{-# INLINE liftArray2M #-}



-- | Add two arrays together pointwise. Throws `SizeMismatchException` if arrays sizes do
-- not match.
--
Expand Down Expand Up @@ -490,4 +491,3 @@ atan2A
=> Array r ix e -> Array r ix e -> Array D ix e
atan2A = liftArray2Matching atan2
{-# INLINE atan2A #-}

2 changes: 1 addition & 1 deletion massiv/src/Data/Massiv/Array/Ops/Fold.hs
Expand Up @@ -332,7 +332,7 @@ sum' ::
forall r ix e. (Source r ix e, Numeric r e)
=> Array r ix e
-> IO e
sum' = splitReduce (\_ -> pure . sumArray (Proxy :: Proxy r)) (\x y -> pure (x + y)) 0
sum' = splitReduce (\_ -> pure . sumArray) (\x y -> pure (x + y)) 0
{-# INLINE sum' #-}

-- | /O(n)/ - Compute sum of all elements.
Expand Down
6 changes: 3 additions & 3 deletions massiv/src/Data/Massiv/Array/Ops/Fold/Internal.hs
Expand Up @@ -335,7 +335,7 @@ ifoldlIO f !initAcc g !tAcc !arr = do
-- @since 0.3.7
splitReduce ::
(MonadUnliftIO m, Source r ix e)
=> (Scheduler m a -> Array (R r) Ix1 e -> m a)
=> (Scheduler m a -> Array r Ix1 e -> m a)
-> (b -> a -> m b) -- ^ Folding action that is applied to the results of a parallel fold
-> b -- ^ Accumulator for chunks folding
-> Array r ix e
Expand All @@ -347,10 +347,10 @@ splitReduce f g !tAcc !arr = do
withScheduler (getComp arr) $ \scheduler ->
splitLinearly (numWorkers scheduler) totalLength $ \chunkLength slackStart -> do
loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
scheduleWork scheduler $ f scheduler $ unsafeLinearSlice arr start (SafeSz chunkLength)
scheduleWork scheduler $ f scheduler $ unsafeLinearSlice start (SafeSz chunkLength) arr
when (slackStart < totalLength) $
scheduleWork scheduler $
f scheduler $ unsafeLinearSlice arr slackStart (SafeSz (totalLength - slackStart))
f scheduler $ unsafeLinearSlice slackStart (SafeSz (totalLength - slackStart)) arr
F.foldlM g tAcc results
{-# INLINE splitReduce #-}

Expand Down
1 change: 1 addition & 0 deletions massiv/src/Data/Massiv/Array/Unsafe.hs
Expand Up @@ -30,6 +30,7 @@ module Data.Massiv.Array.Unsafe
, unsafeSlice
, unsafeOuterSlice
, unsafeInnerSlice
, unsafeLinearSlice
-- * Mutable interface
, unsafeThaw
, unsafeFreeze
Expand Down
7 changes: 1 addition & 6 deletions massiv/src/Data/Massiv/Core/Common.hs
Expand Up @@ -208,12 +208,7 @@ class Load r ix e => Source r ix e where
-- | Source arrays also give us ability to look at their linear slices
--
-- @since 0.4.0
unsafeLinearSlice :: Array r ix e -> Ix1 -> Sz1 -> Array (R r) Ix1 e
unsafeLinearSlice = undefined
-- default unsafeLinearSlice :: (Extract r Ix1 e, Resize r ix) =>
-- Array r ix e -> Ix1 -> Sz1 -> Array (R r) Ix1 e
-- unsafeLinearSlice arr ix sz =
-- unsafeExtract ix sz (unsafeResize (SafeSz (totalElem sz)) arr)
unsafeLinearSlice :: Ix1 -> Sz1 -> Array r ix e -> Array r Ix1 e

-- | Any array that can be computed and loaded into memory
class (Typeable r, Index ix) => Load r ix e where
Expand Down
18 changes: 9 additions & 9 deletions massiv/src/Data/Massiv/Core/Operations.hs
Expand Up @@ -19,22 +19,22 @@ import Data.Massiv.Core.Common
import Data.Massiv.Array.Ops.Fold.Internal


class Num e => Numeric r e where
class (R r ~ r, Num e) => Numeric r e where

sumArray :: proxy r -> Array (R r) Ix1 e -> e
default sumArray :: Source (R r) Ix1 e => proxy r -> Array (R r) Ix1 e -> e
sumArray _ = foldlS (+) 0
sumArray :: Array r Ix1 e -> e
default sumArray :: Source r Ix1 e => Array r Ix1 e -> e
sumArray = foldlS (+) 0
{-# INLINE sumArray #-}

productArray :: proxy r -> Array (R r) Ix1 e -> e
default productArray :: Source (R r) Ix1 e => proxy r -> Array (R r) Ix1 e -> e
productArray _ = foldlS (*) 1
productArray :: Array r Ix1 e -> e
default productArray :: Source r Ix1 e => Array r Ix1 e -> e
productArray = foldlS (*) 1
{-# INLINE productArray #-}

-- | Raise each element in the array to some non-negative power and sum the results
powerSumArray :: proxy r -> Array (R r) Ix1 e -> Int -> e
powerSumArray :: Array r Ix1 e -> Int -> e

dotProduct :: proxy r -> Array (R r) Ix1 e -> Array (R r) Ix1 e -> e
unsafeDotProduct :: Array r Ix1 e -> Array r Ix1 e -> e

plusScalar :: Index ix => Array r ix e -> e -> Array r ix e

Expand Down

0 comments on commit d82597f

Please sign in to comment.