Skip to content

Commit

Permalink
Updates to the new multithreaded pub/sub implementation
Browse files Browse the repository at this point in the history
- Minor fixes suggested by reviewers on github
- Added ability to unsubscribe from just the handlers
  you register, instead of unregistering all handlers
  for a given channel or pattern channel name
  • Loading branch information
wuzzeb committed Aug 5, 2016
1 parent d3dcd24 commit da95f9b
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 29 deletions.
8 changes: 7 additions & 1 deletion src/Database/Redis/ProtocolPipelining.hs
Expand Up @@ -15,7 +15,7 @@
--
module Database.Redis.ProtocolPipelining (
Connection,
connect, disconnect, request, send, recv,
connect, disconnect, request, send, recv, flush,
ConnectionLostException(..),
HostName, PortID(..)
) where
Expand Down Expand Up @@ -92,6 +92,12 @@ recv Conn{..} = do
writeIORef connReplies rs
return r

-- | Flush the socket. Normally, the socket is flushed in 'recv' (actually 'conGetReplies'), but
-- for the multithreaded pub/sub code, the sending thread needs to explicitly flush the subscription
-- change requests.
flush :: Connection -> IO ()
flush Conn{..} = hFlush connHandle

-- |Send a request and receive the corresponding reply
request :: Connection -> S.ByteString -> IO Reply
request conn req = send conn req >> recv conn
Expand Down
125 changes: 99 additions & 26 deletions src/Database/Redis/PubSub.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE CPP, OverloadedStrings, RecordWildCards, EmptyDataDecls,
FlexibleInstances, FlexibleContexts #-}
FlexibleInstances, FlexibleContexts, GeneralizedNewtypeDeriving #-}

module Database.Redis.PubSub (
publish,
Expand All @@ -16,7 +16,8 @@ module Database.Redis.PubSub (
pubSubForever,
RedisChannel, RedisPChannel, MessageCallback, PMessageCallback,
PubSubController, newPubSubController, currentChannels, currentPChannels,
addChannels, addChannelsAndWait, removeChannels, removeChannelsAndWait
addChannels, addChannelsAndWait, removeChannels, removeChannelsAndWait,
UnregisterCallbacksAction
) where

#if __GLASGOW_HASKELL__ < 710
Expand Down Expand Up @@ -101,9 +102,13 @@ sendCmd cmd = do
lift $ Core.send (redisCmd cmd : changes cmd)
modifyPending (updatePending cmd)

cmdCount :: Cmd a b -> Int
cmdCount DoNothing = 0
cmdCount (Cmd c) = length c

totalPendingChanges :: PubSub -> Int
totalPendingChanges (PubSub{..}) =
updatePending subs $ updatePending unsubs $ updatePending psubs $ updatePending punsubs 0
cmdCount subs + cmdCount unsubs + cmdCount psubs + cmdCount punsubs

rawSendCmd :: (Command (Cmd a b)) => PP.Connection -> Cmd a b -> IO ()
rawSendCmd _ DoNothing = return ()
Expand Down Expand Up @@ -249,7 +254,7 @@ type RedisPChannel = ByteString
--
-- Messages are processed synchronously in the receiving thread, so if the callback
-- takes a long time it will block other callbacks and other messages from being
-- received. If you need to move long-running work to a different thread, I suggest
-- received. If you need to move long-running work to a different thread, we suggest
-- you use 'TBQueue' with a reasonable bound, so that if messages are arriving faster
-- than you can process them, you do eventually block.
--
Expand All @@ -267,17 +272,32 @@ type MessageCallback = ByteString -> IO ()
-- are rethrown from 'pubSubForever'.
type PMessageCallback = RedisChannel -> ByteString -> IO ()

-- | An action that when executed will unregister the callbacks. It is returned from 'addChannels'
-- or 'addChannelsAndWait' and typically you would use it in 'bracket' to guarantee that you
-- unsubscribe from channels. For example, if you are using websockets to distribute messages to
-- clients, you could use something such as:
--
-- > websocketConn <- Network.WebSockets.acceptRequest pending
-- > let mycallback msg = Network.WebSockets.sendTextData websocketConn msg
-- > bracket (addChannelsAndWait ctrl [("hello", mycallback)] []) id $ const $ do
-- > {- loop here calling Network.WebSockets.receiveData -}
type UnregisterCallbacksAction = IO ()

newtype UnregisterHandle = UnregisterHandle Integer
deriving (Eq, Show, Num)

-- | A controller that stores a set of channels, pattern channels, and callbacks.
-- It allows you to manage Pub/Sub subscriptions and pattern subscriptions and alter them at
-- any time throughout the life of your program.
-- You should typically create the controller at the start of your program and then store it
-- through the life of your program, using 'addChannels' and 'removeChannels' to update the
-- current subscriptions.
data PubSubController = PubSubController
{ callbacks :: TVar (HM.HashMap RedisChannel [MessageCallback])
, pcallbacks :: TVar (HM.HashMap RedisPChannel [PMessageCallback])
{ callbacks :: TVar (HM.HashMap RedisChannel [(UnregisterHandle, MessageCallback)])
, pcallbacks :: TVar (HM.HashMap RedisPChannel [(UnregisterHandle, PMessageCallback)])
, sendChanges :: TBQueue PubSub
, pendingCnt :: TVar Int
, lastUsedCallbackId :: TVar UnregisterHandle
}

-- | Create a new 'PubSubController'. Note that this does not subscribe to any channels, it just
Expand All @@ -286,11 +306,12 @@ newPubSubController :: MonadIO m => [(RedisChannel, MessageCallback)] -- ^ the i
-> [(RedisPChannel, PMessageCallback)] -- ^ the initial pattern subscriptions
-> m PubSubController
newPubSubController x y = liftIO $ do
cbs <- newTVarIO (fmap (\z -> [z]) $ HM.fromList x)
pcbs <- newTVarIO (fmap (\z -> [z]) $ HM.fromList y)
cbs <- newTVarIO (HM.map (\z -> [(0,z)]) $ HM.fromList x)
pcbs <- newTVarIO (HM.map (\z -> [(0,z)]) $ HM.fromList y)
c <- newTBQueueIO 10
pending <- newTVarIO 0
return $ PubSubController cbs pcbs c pending
lastId <- newTVarIO 0
return $ PubSubController cbs pcbs c pending lastId

-- | Get the list of current channels in the 'PubSubController'. WARNING! This might not
-- exactly reflect the subscribed channels in the Redis server, because there is a delay
Expand All @@ -313,21 +334,29 @@ currentPChannels ctrl = HM.keys <$> (liftIO $ atomically $ readTVar $ pcallbacks
--
-- You can subscribe to the same channel or pattern channel multiple times; the 'PubSubController' keeps
-- a list of callbacks and executes each callback in response to a message.
--
-- The return value is an action 'UnregisterCallbacksAction' which will unregister the callbacks,
-- which should typically used with 'bracket'.
addChannels :: MonadIO m => PubSubController
-> [(RedisChannel, MessageCallback)] -- ^ the channels to subscribe to
-> [(RedisPChannel, PMessageCallback)] -- ^ the channels to pattern subscribe to
-> m ()
addChannels _ [] [] = return ()
addChannels ctrl newChans newPChans = liftIO $ atomically $ do
cm <- readTVar $ callbacks ctrl
pm <- readTVar $ pcallbacks ctrl
let newChans' = [ n | (n,_) <- newChans, not $ HM.member n cm]
newPChans' = [ n | (n, _) <- newPChans, not $ HM.member n pm]
ps = subscribe newChans' `mappend` psubscribe newPChans'
writeTBQueue (sendChanges ctrl) ps
writeTVar (callbacks ctrl) (HM.unionWith (++) cm (fmap (\z -> [z]) $ HM.fromList newChans))
writeTVar (pcallbacks ctrl) (HM.unionWith (++) pm (fmap (\z -> [z]) $ HM.fromList newPChans))
modifyTVar (pendingCnt ctrl) (+ totalPendingChanges ps)
-> m UnregisterCallbacksAction
addChannels _ [] [] = return $ return ()
addChannels ctrl newChans newPChans = liftIO $ do
ident <- atomically $ do
modifyTVar (lastUsedCallbackId ctrl) (+1)
ident <- readTVar $ lastUsedCallbackId ctrl
cm <- readTVar $ callbacks ctrl
pm <- readTVar $ pcallbacks ctrl
let newChans' = [ n | (n,_) <- newChans, not $ HM.member n cm]
newPChans' = [ n | (n, _) <- newPChans, not $ HM.member n pm]
ps = subscribe newChans' `mappend` psubscribe newPChans'
writeTBQueue (sendChanges ctrl) ps
writeTVar (callbacks ctrl) (HM.unionWith (++) cm (fmap (\z -> [(ident,z)]) $ HM.fromList newChans))
writeTVar (pcallbacks ctrl) (HM.unionWith (++) pm (fmap (\z -> [(ident,z)]) $ HM.fromList newPChans))
modifyTVar (pendingCnt ctrl) (+ totalPendingChanges ps)
return ident
return $ unsubChannels ctrl (map fst newChans) (map fst newPChans) ident

-- | Call 'addChannels' and then wait for Redis to acknowledge that the channels are actually subscribed.
--
Expand All @@ -346,13 +375,14 @@ addChannels ctrl newChans newPChans = liftIO $ atomically $ do
addChannelsAndWait :: MonadIO m => PubSubController
-> [(RedisChannel, MessageCallback)] -- ^ the channels to subscribe to
-> [(RedisPChannel, PMessageCallback)] -- ^ the channels to psubscribe to
-> m ()
addChannelsAndWait _ [] [] = return ()
-> m UnregisterCallbacksAction
addChannelsAndWait _ [] [] = return $ return ()
addChannelsAndWait ctrl newChans newPChans = do
addChannels ctrl newChans newPChans
unreg <- addChannels ctrl newChans newPChans
liftIO $ atomically $ do
r <- readTVar (pendingCnt ctrl)
when (r > 0) retry
return unreg

-- | Remove channels from the 'PubSubController', and if there is an active 'pubSubForever', send the
-- unsubscribe commands to Redis. Note that as soon as this function returns, no more callbacks will be
Expand Down Expand Up @@ -380,6 +410,46 @@ removeChannels ctrl remChans remPChans = liftIO $ atomically $ do
writeTVar (pcallbacks ctrl) (foldl' (flip HM.delete) pm remPChans')
modifyTVar (pendingCnt ctrl) (+ totalPendingChanges ps)

-- | Internal function to unsubscribe only from those channels matching the given handle.
unsubChannels :: PubSubController -> [RedisChannel] -> [RedisPChannel] -> UnregisterHandle -> IO ()
unsubChannels ctrl chans pchans h = liftIO $ atomically $ do
cm <- readTVar $ callbacks ctrl
pm <- readTVar $ pcallbacks ctrl

-- only worry about channels that exist
let remChans = filter (\n -> HM.member n cm) chans
remPChans = filter (\n -> HM.member n pm) pchans

-- helper functions to filter out handlers that match
let filterHandle :: Maybe [(UnregisterHandle,a)] -> Maybe [(UnregisterHandle,a)]
filterHandle Nothing = Nothing
filterHandle (Just lst) = case filter (\x -> fst x /= h) lst of
[] -> Nothing
xs -> Just xs
let removeHandles :: HM.HashMap ByteString [(UnregisterHandle,a)]
-> ByteString
-> HM.HashMap ByteString [(UnregisterHandle,a)]
removeHandles m k = case filterHandle (HM.lookup k m) of -- recent versions of unordered-containers have alter
Nothing -> HM.delete k m
Just v -> HM.insert k v m

-- maps after taking out channels matching the handle
let cm' = foldl' removeHandles cm remChans
pm' = foldl' removeHandles pm remPChans

-- the channels to unsubscribe are those that no longer exist in cm' and pm'
let remChans' = filter (\n -> not $ HM.member n cm') remChans
remPChans' = filter (\n -> not $ HM.member n pm') remPChans
ps = (if null remChans' then mempty else unsubscribe remChans')
`mappend` (if null remPChans' then mempty else punsubscribe remPChans')

-- do the unsubscribe
writeTBQueue (sendChanges ctrl) ps
writeTVar (callbacks ctrl) cm'
writeTVar (pcallbacks ctrl) pm'
modifyTVar (pendingCnt ctrl) (+ totalPendingChanges ps)
return ()

-- | Call 'removeChannels' and then wait for all pending subscription change requests to be acknowledged
-- by Redis. This uses the same waiting logic as 'addChannelsAndWait'. Since 'removeChannels' immediately
-- notifies the 'PubSubController' to start discarding messages, you likely don't need this function and
Expand All @@ -406,12 +476,12 @@ listenThread ctrl rawConn = forever $ do
cm <- atomically $ readTVar (callbacks ctrl)
case HM.lookup channel cm of
Nothing -> return ()
Just c -> mapM_ ($ msgCt) c
Just c -> mapM_ (\(_,x) -> x msgCt) c
Msg (PMessage pattern channel msgCt) -> do
pm <- atomically $ readTVar (pcallbacks ctrl)
case HM.lookup pattern pm of
Nothing -> return ()
Just c -> mapM_ (\x -> x channel msgCt) c
Just c -> mapM_ (\(_,x) -> x channel msgCt) c
Subscribed -> atomically $
modifyTVar (pendingCnt ctrl) (\x -> x - 1)
Unsubscribed _ -> atomically $
Expand All @@ -427,6 +497,9 @@ sendThread ctrl rawConn = forever $ do
rawSendCmd rawConn unsubs
rawSendCmd rawConn psubs
rawSendCmd rawConn punsubs
-- normally, the socket is flushed during 'recv', but
-- 'recv' could currently be blocking on a message.
PP.flush rawConn

-- | Open a connection to the Redis server, register to all channels in the 'PubSubController',
-- and process messages and subscription change requests forever. The only way this will ever
Expand Down
25 changes: 23 additions & 2 deletions test/ManualPubSub.hs
Expand Up @@ -45,6 +45,11 @@ msgHandler msg = hPutStrLn stderr $ "Saw msg: " ++ unpack (decodeUtf8 msg)
pmsgHandler :: RedisChannel -> ByteString -> IO ()
pmsgHandler channel msg = hPutStrLn stderr $ "Saw pmsg: " ++ unpack (decodeUtf8 channel) ++ unpack (decodeUtf8 msg)

showChannels :: Connection -> IO ()
showChannels c = do
resp :: Either Reply [ByteString] <- runRedis c $ sendRequest ["PUBSUB", "CHANNELS"]
liftIO $ hPutStrLn stderr $ "Current redis channels: " ++ show resp

main :: IO ()
main = do
ctrl <- newPubSubController [("foo", msgHandler)] []
Expand All @@ -54,10 +59,10 @@ main = do
withAsync (handlerThread conn ctrl) $ \_handlerT -> do

void $ hPutStrLn stderr "Press enter to subscribe to bar" >> getLine
addChannels ctrl [("bar", msgHandler)] []
void $ addChannels ctrl [("bar", msgHandler)] []

void $ hPutStrLn stderr "Press enter to subscribe to baz:*" >> getLine
addChannels ctrl [] [("baz:*", pmsgHandler)]
void $ addChannels ctrl [] [("baz:*", pmsgHandler)]

void $ hPutStrLn stderr "Press enter to unsub from foo" >> getLine
removeChannels ctrl ["foo"] []
Expand All @@ -68,4 +73,20 @@ main = do
void $ hPutStrLn stderr "Press enter to unsub from baz:*" >> getLine
removeChannels ctrl [] ["baz:*"]

void $ hPutStrLn stderr "Press enter to sub to foo and baz:*" >> getLine
unsub1 <- addChannelsAndWait ctrl [("foo", msgHandler)] [("baz:*", pmsgHandler)]
showChannels conn

void $ hPutStrLn stderr "Press enter to sub to foo again and baz:1" >> getLine
unsub2 <- addChannelsAndWait ctrl [("foo", msgHandler), ("baz:1", msgHandler)] []
showChannels conn

void $ hPutStrLn stderr "Press enter to unsub to foo and baz:1" >> getLine
unsub2

void $ hPutStrLn stderr "Press enter to unsub to foo and baz:*" >> getLine
showChannels conn
unsub1

void $ hPutStrLn stderr "Press enter to exit" >> getLine
showChannels conn

0 comments on commit da95f9b

Please sign in to comment.