Skip to content

Commit

Permalink
Implement differenceBySorted API
Browse files Browse the repository at this point in the history
  • Loading branch information
rnjtranjan committed Feb 10, 2022
1 parent 11b1989 commit c3a11b3
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 6 deletions.
13 changes: 8 additions & 5 deletions src/Streamly/Internal/Data/Stream/IsStream/Top.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top
, intersectBy
, intersectBySorted
, differenceBy
, mergeDifferenceBy
, differenceBySorted
, unionBy
, unionBySorted

Expand Down Expand Up @@ -633,11 +633,14 @@ differenceBy eq s1 s2 =
--
-- Space: O(1)
--
-- /Unimplemented/
{-# INLINE mergeDifferenceBy #-}
mergeDifferenceBy :: -- (IsStream t, Monad m) =>
-- /Pre-release/
{-# INLINE differenceBySorted #-}
differenceBySorted :: (IsStream t, MonadIO m) =>
(a -> a -> Ordering) -> t m a -> t m a -> t m a
mergeDifferenceBy _eq _s1 _s2 = undefined
differenceBySorted eq s1 =
IsStream.fromStreamD
. StreamD.differenceBySorted eq (IsStream.toStreamD s1)
. IsStream.toStreamD

-- | This is essentially an append operation that appends all the extra
-- occurrences of elements from the second stream that are not already present
Expand Down
65 changes: 65 additions & 0 deletions src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ module Streamly.Internal.Data.Stream.StreamD.Nesting
, splitInnerBySuffix
, intersectBySorted
, unionBySorted
, differenceBySorted
)
where

Expand Down Expand Up @@ -3039,3 +3040,67 @@ unionBySorted cmp (Stream stepa ta) (Stream stepb tb) =
)

step _ (_, _, _, _, _, _, _) = return Stop

-------------------------------------------------------------------------------
-- Difference of sorted streams -----------------------------------------------
-------------------------------------------------------------------------------
{-# INLINE_NORMAL differenceBySorted #-}
differenceBySorted :: (Monad m) =>
(a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
differenceBySorted cmp (Stream stepa ta) (Stream stepb tb) =
Stream step (Just ta, Just tb, Nothing, Nothing, Nothing)

where
{-# INLINE_LATE step #-}

-- one of the values is missing, and the corresponding stream is running
step gst (Just sa, sb, Nothing, b, Nothing) = do
r <- stepa gst sa
return $ case r of
Yield a sa' -> Skip (Just sa', sb, Just a, b, Nothing)
Skip sa' -> Skip (Just sa', sb, Nothing, b, Nothing)
Stop -> Skip (Nothing, sb, Nothing, b, Nothing)

step gst (sa, Just sb, a, Nothing, Nothing) = do
r <- stepb gst sb
return $ case r of
Yield b sb' -> Skip (sa, Just sb', a, Just b, Nothing)
Skip sb' -> Skip (sa, Just sb', a, Nothing, Nothing)
Stop -> Skip (sa, Nothing, a, Nothing, Nothing)

-- Matching element
step gst (Just sa, Just sb, Nothing, _, Just _) = do
r1 <- stepa gst sa
r2 <- stepb gst sb
return $ case r1 of
Yield a sa' ->
case r2 of
Yield c sb' ->
Skip (Just sa', Just sb', Just a, Just c, Nothing)
Skip sb' ->
Skip (Just sa', Just sb', Just a, Just a, Nothing)
Stop ->
Yield a (Just sa', Just sb, Nothing, Nothing, Just a)
Skip sa' ->
case r2 of
Yield c sb' ->
Skip (Just sa', Just sb', Just c, Just c, Nothing)
Skip sb' ->
Skip (Just sa', Just sb', Nothing, Nothing, Nothing)
Stop ->
Stop
Stop ->
Stop

-- both the values are available
step _ (sa, sb, Just a, Just b, Nothing) = do
let res = cmp a b
return $ case res of
GT -> Skip (sa, sb, Just a, Nothing, Nothing)
LT -> Yield a (sa, sb, Nothing, Just b, Nothing)
EQ -> Skip (sa, sb, Nothing, Just b, Just b)

-- one of the values is missing, corresponding stream is done
step _ (sa, Nothing, Just a, Nothing, Nothing) =
return $ Yield a (sa, Nothing, Nothing, Nothing , Nothing)
step _ (_, _, _, _, _) = return Stop
23 changes: 22 additions & 1 deletion test/Streamly/Test/Data/Stream/Top.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Main (main)
where

import Data.List (intersect, sort, union)
import Data.List (intersect, sort, union, (\\))
import Test.QuickCheck
( Gen
, Property
Expand Down Expand Up @@ -64,6 +64,26 @@ unionBySorted =
(S.fromList ls1)
let v2 = sort $ union ls0 ls1
assert (v1 == v2)

differenceBySorted :: Property
differenceBySorted =
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
monadicIO $ action (sort ls0) (sort ls1)

where

action ls0 ls1 = do
v1 <-
run
$ S.toList
$ Top.differenceBySorted
compare
(S.fromList ls0)
(S.fromList ls1)
let v2 = ls0 \\ ls1
assert (v1 == sort v2)

-------------------------------------------------------------------------------
moduleName :: String
moduleName = "Data.Stream.Top"
Expand All @@ -74,3 +94,4 @@ main = hspec $ do
-- intersect
prop "intersectBySorted" Main.intersectBySorted
prop "unionBySorted" Main.unionBySorted
prop "differenceBySorted" Main.differenceBySorted

0 comments on commit c3a11b3

Please sign in to comment.