diff --git a/Network/IRC/Client.hs b/Network/IRC/Client.hs index d3623c9..e20dc13 100644 --- a/Network/IRC/Client.hs +++ b/Network/IRC/Client.hs @@ -144,6 +144,15 @@ module Network.IRC.Client -- exception will be thrown, killing it. , Timeout(..) + -- * Concurrency + + -- | A client can manage a collection of threads, which get thrown + -- the 'Disconnect' exception whenever the client disconnects for + -- any reason (including a call to 'reconnect'). These can be + -- created from event handlers to manage long-running tasks. + , U.fork + , Disconnect(..) + -- * Lenses , module Network.IRC.Client.Lens @@ -158,6 +167,7 @@ import Control.Monad.IO.Class (MonadIO, liftIO) import Data.ByteString (ByteString) import qualified Data.Conduit.Network.TLS as TLS import Data.Conduit.TMChan (newTBMChanIO) +import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T import Data.Version (showVersion) @@ -171,7 +181,9 @@ import qualified Network.TLS as TLS import Network.IRC.Client.Events import Network.IRC.Client.Internal import Network.IRC.Client.Lens -import Network.IRC.Client.Utils +-- I think exporting 'fork' with 'Disconnect' gives better documentation. +import Network.IRC.Client.Utils hiding (fork) +import qualified Network.IRC.Client.Utils as U import qualified Paths_irc_client as Paths @@ -276,16 +288,9 @@ newIRCState :: MonadIO m -> s -- ^ The initial value for the user state. -> m (IRCState s) -newIRCState cconf iconf ustate = liftIO $ do - ustvar <- newTVarIO ustate - ictvar <- newTVarIO iconf - cstvar <- newTVarIO Disconnected - squeue <- newTVarIO =<< newTBMChanIO 16 - - pure IRCState - { _connectionConfig = cconf - , _userState = ustvar - , _instanceConfig = ictvar - , _connectionState = cstvar - , _sendqueue = squeue - } +newIRCState cconf iconf ustate = liftIO $ IRCState cconf + <$> newTVarIO ustate + <*> newTVarIO iconf + <*> (newTVarIO =<< newTBMChanIO 16) + <*> newTVarIO Disconnected + <*> newTVarIO S.empty diff --git a/Network/IRC/Client/Events.hs b/Network/IRC/Client/Events.hs index feb5573..e68e133 100644 --- a/Network/IRC/Client/Events.hs +++ b/Network/IRC/Client/Events.hs @@ -11,7 +11,8 @@ -- Portability : CPP, OverloadedStrings, RankNTypes -- -- Events and event handlers. When a message is received from the --- server, all matching handlers are executed concurrently. +-- server, all matching handlers are executed sequentially in the +-- order that they appear in the 'handlers' list. module Network.IRC.Client.Events ( -- * Handlers EventHandler(..) diff --git a/Network/IRC/Client/Internal.hs b/Network/IRC/Client/Internal.hs index 9aad208..a6ddfa9 100644 --- a/Network/IRC/Client/Internal.hs +++ b/Network/IRC/Client/Internal.hs @@ -35,6 +35,7 @@ import Data.ByteString (ByteString) import Data.Conduit (Producer, Conduit, Consumer, (=$=), ($=), (=$), await, awaitForever, toProducer, yield) import Data.Conduit.TMChan (closeTBMChan, isClosedTBMChan, isEmptyTBMChan, sourceTBMChan, writeTBMChan, newTBMChan) import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import qualified Data.Set as S import Data.Text (Text) import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, diffUTCTime, getCurrentTime) @@ -148,8 +149,7 @@ forgetful = awaitForever go where go (Left _) = return () go (Right b) = yield b --- | Block on receiving a message and invoke all matching handlers --- concurrently. +-- | Block on receiving a message and invoke all matching handlers. eventSink :: MonadIO m => IORef UTCTime -> IRCState s -> Consumer (Event ByteString) m () eventSink lastReceived ircstate = go where go = await >>= maybe (return ()) (\event -> do @@ -164,7 +164,7 @@ eventSink lastReceived ircstate = go where iconf <- snapshot instanceConfig ircstate forM_ (get handlers iconf) $ \(EventHandler matcher handler) -> maybe (pure ()) - (void . forkIO . flip runIRCAction ircstate . handler (_source event')) + (void . flip runIRCAction ircstate . handler (_source event')) (matcher event') -- If disconnected, do not loop. @@ -263,6 +263,12 @@ disconnect = do closeTBMChan =<< readTVar (_sendqueue s) writeTVar (_connectionState s) Disconnected + -- Kill all managed threads. Don't wait for them to terminate + -- here, as they might be masking exceptions and not pick up + -- the 'Disconnect' for a while; just clear the list. + mapM_ (`throwTo` Disconnect) =<< readTVarIO (_runningThreads s) + atomically $ writeTVar (_runningThreads s) S.empty + -- If already disconnected, or disconnecting, do nothing. _ -> pure () diff --git a/Network/IRC/Client/Internal/Types.hs b/Network/IRC/Client/Internal/Types.hs index 66d7e7e..5b761d6 100644 --- a/Network/IRC/Client/Internal/Types.hs +++ b/Network/IRC/Client/Internal/Types.hs @@ -18,6 +18,7 @@ -- of this library. module Network.IRC.Client.Internal.Types where +import Control.Concurrent (ThreadId) import Control.Concurrent.STM (TVar, atomically, readTVar, writeTVar) import Control.Monad.Catch (Exception, MonadThrow, MonadCatch, MonadMask, SomeException) import Control.Monad.IO.Class (MonadIO, liftIO) @@ -26,6 +27,7 @@ import Control.Monad.State (MonadState(..)) import Data.ByteString (ByteString) import Data.Conduit (Consumer, Producer) import Data.Conduit.TMChan (TBMChan) +import qualified Data.Set as S import Data.Text (Text) import Data.Time.Clock (NominalDiffTime) import Network.IRC.Conduit (Event(..), Message, Source) @@ -56,17 +58,20 @@ instance MonadState s (IRC s) where -- * State -- | The state of an IRC session. -data IRCState s = IRCState { _connectionConfig :: ConnectionConfig s - -- ^Read-only connection configuration - , _userState :: TVar s - -- ^Mutable user state - , _instanceConfig :: TVar (InstanceConfig s) - -- ^Mutable instance configuration in STM - , _sendqueue :: TVar (TBMChan (Message ByteString)) - -- ^ Message send queue. - , _connectionState :: TVar ConnectionState - -- ^State of the connection. - } +data IRCState s = IRCState + { _connectionConfig :: ConnectionConfig s + -- ^ Read-only connection configuration + , _userState :: TVar s + -- ^ Mutable user state + , _instanceConfig :: TVar (InstanceConfig s) + -- ^ Mutable instance configuration in STM + , _sendqueue :: TVar (TBMChan (Message ByteString)) + -- ^ Message send queue. + , _connectionState :: TVar ConnectionState + -- ^ State of the connection. + , _runningThreads :: TVar (S.Set ThreadId) + -- ^ Threads which will be killed when the client disconnects. + } -- | The static state of an IRC server connection. data ConnectionConfig s = ConnectionConfig @@ -113,7 +118,8 @@ data InstanceConfig s = InstanceConfig -- ^ The version is sent in response to the CTCP \"VERSION\" request by -- the default event handlers. , _handlers :: [EventHandler s] - -- ^ The registered event handlers + -- ^ The registered event handlers. The order in this list is the + -- order in which they are executed. , _ignore :: [(Text, Maybe Text)] -- ^ List of nicks (optionally restricted to channels) to ignore -- messages from. 'Nothing' ignores globally. @@ -148,3 +154,10 @@ data Timeout = Timeout deriving (Bounded, Enum, Eq, Ord, Read, Show) instance Exception Timeout + +-- | Exception thrown to all managed threads when the client +-- disconnects. +data Disconnect = Disconnect + deriving (Bounded, Enum, Eq, Ord, Read, Show) + +instance Exception Disconnect diff --git a/Network/IRC/Client/Utils.hs b/Network/IRC/Client/Utils.hs index f3ae47d..58eaf2a 100644 --- a/Network/IRC/Client/Utils.hs +++ b/Network/IRC/Client/Utils.hs @@ -30,6 +30,9 @@ module Network.IRC.Client.Utils , isDisconnected , snapConnState + -- * Concurrency + , fork + -- * Lenses , snapshot , snapshotModify @@ -38,8 +41,10 @@ module Network.IRC.Client.Utils , modify ) where +import Control.Concurrent (ThreadId, myThreadId, forkFinally) import Control.Concurrent.STM (TVar, STM, atomically, modifyTVar) import Control.Monad.IO.Class (liftIO) +import qualified Data.Set as S import Data.Text (Text) import qualified Data.Text as T import Network.IRC.Conduit (Event(..), Message(..), Source(..)) @@ -132,3 +137,19 @@ isDisconnected = (==Disconnected) <$> snapConnState -- | Snapshot the connection state. snapConnState :: IRC s ConnectionState snapConnState = liftIO . atomically . getConnectionState =<< getIRCState + + +------------------------------------------------------------------------------- +-- Concurrency + +-- | Fork a thread which will be thrown a 'Disconnect' exception when +-- the client disconnects. +fork :: IRC s () -> IRC s ThreadId +fork ma = do + s <- getIRCState + liftIO $ do + tid <- forkFinally (runIRCAction ma s) $ \_ -> do + tid <- myThreadId + atomically $ modifyTVar (_runningThreads s) (S.delete tid) + atomically $ modifyTVar (_runningThreads s) (S.insert tid) + pure tid diff --git a/irc-client.cabal b/irc-client.cabal index 3c16a45..b79953a 100644 --- a/irc-client.cabal +++ b/irc-client.cabal @@ -88,6 +88,7 @@ library -- Other library packages from which modules are imported. build-depends: base >=4.7 && <5 , bytestring >=0.10 && <0.11 + , containers >=0.1 && <1 , conduit >=1.2 && <1.3 , connection >=0.2 && <0.3 , contravariant >=0.1 && <1.5