From 76fefa582df559d5ce0e552c34ba45d5438d195c Mon Sep 17 00:00:00 2001 From: Jasper Van der Jeugt Date: Sun, 31 May 2015 16:51:45 +0200 Subject: [PATCH] Bugfixes wrt closing sockets, bump to 0.9.5.0 --- CHANGELOG | 3 + src/Network/WebSockets/Client.hs | 21 +++---- src/Network/WebSockets/Connection.hs | 59 +++++------------ src/Network/WebSockets/Server.hs | 14 +++-- src/Network/WebSockets/Stream.hs | 63 +++++++++++++++---- .../Network/WebSockets/Handshake/Tests.hs | 2 +- tests/haskell/Network/WebSockets/Tests.hs | 19 +++--- websockets.cabal | 2 +- 8 files changed, 102 insertions(+), 81 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index f1d0dab..2141a3f 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,6 @@ +- 0.9.5.0 + * Bugfixes wrt closing sockets and streams + - 0.9.4.0 * Add `makePendingConnectionFromStream` function * Bump `attoparsec` dependency diff --git a/src/Network/WebSockets/Client.hs b/src/Network/WebSockets/Client.hs index fba7616..65dbd54 100644 --- a/src/Network/WebSockets/Client.hs +++ b/src/Network/WebSockets/Client.hs @@ -12,8 +12,7 @@ module Network.WebSockets.Client -------------------------------------------------------------------------------- import qualified Blaze.ByteString.Builder as Builder -import Control.Concurrent.MVar (newMVar) -import Control.Exception (finally, throwIO) +import Control.Exception (bracket, finally, throwIO) import Data.IORef (newIORef) import qualified Data.Text as T import qualified Data.Text.Encoding as T @@ -100,18 +99,14 @@ runClientWithStream stream host path opts customHeaders app = do Response _ _ <- return $ finishResponse protocol request response parse <- decodeMessages protocol stream write <- encodeMessages protocol ClientConnection stream + sentRef <- newIORef False - parseLock <- newMVar () - writeLock <- newMVar () - stream' <- newIORef (DecoderEncoder parse write) - sentRef <- newIORef False app Connection { connectionOptions = opts , connectionType = ClientConnection , connectionProtocol = protocol - , connectionParseLock = parseLock - , connectionWriteLock = writeLock - , connectionStream = stream' + , connectionParse = parse + , connectionWrite = write , connectionSentClose = sentRef } where @@ -128,6 +123,8 @@ runClientWithSocket :: S.Socket -- ^ Socket -> Headers -- ^ Custom headers to send -> ClientApp a -- ^ Client application -> IO a -runClientWithSocket sock host path opts customHeaders app = do - stream <- Stream.makeSocketStream sock - runClientWithStream stream host path opts customHeaders app +runClientWithSocket sock host path opts customHeaders app = bracket + (Stream.makeSocketStream sock) + Stream.close + (\stream -> + runClientWithStream stream host path opts customHeaders app) diff --git a/src/Network/WebSockets/Connection.hs b/src/Network/WebSockets/Connection.hs index 2e01614..81ef445 100644 --- a/src/Network/WebSockets/Connection.hs +++ b/src/Network/WebSockets/Connection.hs @@ -9,7 +9,6 @@ module Network.WebSockets.Connection , acceptRequestWith , rejectRequest - , DecoderEncoder (..) , Connection (..) , ConnectionOptions (..) @@ -33,9 +32,8 @@ module Network.WebSockets.Connection -------------------------------------------------------------------------------- import qualified Blaze.ByteString.Builder as Builder import Control.Concurrent (forkIO, threadDelay) -import Control.Concurrent.MVar (MVar, newMVar, withMVar) import Control.Exception (AsyncException, fromException, - handle, onException, throwIO) + handle, throwIO) import Control.Monad (unless) import qualified Data.ByteString as B import Data.IORef (IORef, newIORef, readIORef, @@ -103,17 +101,13 @@ acceptRequestWith pc ar = case find (flip compatible request) protocols of parse <- decodeMessages protocol (pendingStream pc) write <- encodeMessages protocol ServerConnection (pendingStream pc) - parseLock <- newMVar () - writeLock <- newMVar () - stream <- newIORef (DecoderEncoder parse write) sentRef <- newIORef False let connection = Connection { connectionOptions = pendingOptions pc , connectionType = ServerConnection , connectionProtocol = protocol - , connectionParseLock = parseLock - , connectionWriteLock = writeLock - , connectionStream = stream + , connectionParse = parse + , connectionWrite = write , connectionSentClose = sentRef } @@ -130,21 +124,13 @@ rejectRequest :: PendingConnection -> B.ByteString -> IO () rejectRequest pc message = sendResponse pc $ response400 [] message --------------------------------------------------------------------------------- --- | Type representing an available or unavailable stream. -data DecoderEncoder a - = DecoderEncoder !(IO (Maybe a)) !(a -> IO ()) - | Closed - - -------------------------------------------------------------------------------- data Connection = Connection { connectionOptions :: !ConnectionOptions , connectionType :: !ConnectionType , connectionProtocol :: !Protocol - , connectionParseLock :: !(MVar ()) - , connectionWriteLock :: !(MVar ()) - , connectionStream :: !(IORef (DecoderEncoder Message)) + , connectionParse :: !(IO (Maybe Message)) + , connectionWrite :: !(Message -> IO ()) , connectionSentClose :: !(IORef Bool) -- ^ According to the RFC, both the client and the server MUST send -- a close control message to each other. Either party can initiate @@ -172,18 +158,11 @@ defaultConnectionOptions = ConnectionOptions -------------------------------------------------------------------------------- receive :: Connection -> IO Message -receive conn = withMVar (connectionParseLock conn) $ \() -> - tryParse `onException` (writeIORef (connectionStream conn) Closed) - where - tryParse = do - stream <- readIORef (connectionStream conn) - case stream of - Closed -> throwIO ConnectionClosed - DecoderEncoder parse _ -> do - mbMsg <- parse - case mbMsg of - Nothing -> throwIO ConnectionClosed - Just msg -> return msg +receive conn = do + mbMsg <- connectionParse conn + case mbMsg of + Nothing -> throwIO ConnectionClosed + Just msg -> return msg -------------------------------------------------------------------------------- @@ -227,18 +206,12 @@ receiveData conn = do -------------------------------------------------------------------------------- send :: Connection -> Message -> IO () -send conn msg = withMVar (connectionWriteLock conn) $ \() -> - trySend `onException` writeIORef (connectionStream conn) Closed - where - trySend = do - stream <- readIORef (connectionStream conn) - case msg of - (ControlMessage (Close _ _)) -> - writeIORef (connectionSentClose conn) True - _ -> return () - case stream of - Closed -> throwIO ConnectionClosed - DecoderEncoder _ write -> write msg +send conn msg = do + case msg of + (ControlMessage (Close _ _)) -> + writeIORef (connectionSentClose conn) True + _ -> return () + connectionWrite conn msg -------------------------------------------------------------------------------- diff --git a/src/Network/WebSockets/Server.hs b/src/Network/WebSockets/Server.hs index 68d7ad3..6ad8ca0 100644 --- a/src/Network/WebSockets/Server.hs +++ b/src/Network/WebSockets/Server.hs @@ -84,19 +84,23 @@ runApp :: Socket -> ConnectionOptions -> ServerApp -> IO () -runApp socket opts app = do - pending <- makePendingConnection socket opts - app pending +runApp socket opts app = + bracket + (makePendingConnection socket opts) + (Stream.close . pendingStream) + app -------------------------------------------------------------------------------- --- | Turns a socket, connected to some client, into a 'PendingConnection'. +-- | Turns a socket, connected to some client, into a 'PendingConnection'. The +-- 'PendingConnection' should be closed using 'Stream.close' later. makePendingConnection :: Socket -> ConnectionOptions -> IO PendingConnection makePendingConnection socket opts = do - stream <- Stream.makeSocketStream socket + stream <- Stream.makeSocketStream socket makePendingConnectionFromStream stream opts + -- | More general version of 'makePendingConnection' for 'Stream.Stream' -- instead of a 'Socket'. makePendingConnectionFromStream diff --git a/src/Network/WebSockets/Stream.hs b/src/Network/WebSockets/Stream.hs index e190984..6ed5371 100644 --- a/src/Network/WebSockets/Stream.hs +++ b/src/Network/WebSockets/Stream.hs @@ -11,14 +11,15 @@ module Network.WebSockets.Stream , close ) where -import Control.Applicative ((<$>)) -import qualified Control.Concurrent.Chan as Chan -import Control.Exception (throwIO) -import Control.Monad (forM_) +import Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar, + putMVar, takeMVar, withMVar) +import Control.Exception (onException, throwIO) +import Control.Monad (forM_, when) import qualified Data.Attoparsec.ByteString as Atto import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL -import Data.IORef (IORef, newIORef, readIORef, +import Data.IORef (IORef, atomicModifyIORef, + newIORef, readIORef, writeIORef) import qualified Network.Socket as S import qualified Network.Socket.ByteString as SB (recv) @@ -49,11 +50,51 @@ data Stream = Stream -------------------------------------------------------------------------------- +-- | Create a stream from a "receive" and "send" action. The following +-- properties apply: +-- +-- - Regardless of the provided "receive" and "send" functions, reading and +-- writing from the stream will be thread-safe, i.e. this function will create +-- a receive and write lock to be used internally. +-- +-- - Reading from or writing or to a closed 'Stream' will always throw an +-- exception, even if the underlying "receive" and "send" functions do not +-- (we do the bookkeeping). +-- +-- - Streams should always be closed. makeStream :: IO (Maybe B.ByteString) -- ^ Reading -> (Maybe BL.ByteString -> IO ()) -- ^ Writing -> IO Stream -- ^ Resulting stream -makeStream i o = Stream i o <$> newIORef (Open B.empty) +makeStream receive send = do + ref <- newIORef (Open B.empty) + receiveLock <- newMVar () + sendLock <- newMVar () + return $ Stream (receive' ref receiveLock) (send' ref sendLock) ref + where + closeRef :: IORef StreamState -> IO () + closeRef ref = atomicModifyIORef ref $ \state -> case state of + Open buf -> (Closed buf, ()) + Closed buf -> (Closed buf, ()) + + assertNotClosed :: IORef StreamState -> IO a -> IO a + assertNotClosed ref io = do + state <- readIORef ref + case state of + Closed _ -> throwIO ConnectionClosed + Open _ -> io + + receive' :: IORef StreamState -> MVar () -> IO (Maybe B.ByteString) + receive' ref lock = withMVar lock $ \() -> assertNotClosed ref $ do + mbBs <- onException receive (closeRef ref) + case mbBs of + Nothing -> closeRef ref >> return Nothing + Just bs -> return (Just bs) + + send' :: IORef StreamState -> MVar () -> (Maybe BL.ByteString -> IO ()) + send' ref lock mbBs = withMVar lock $ \() -> assertNotClosed ref $ do + when (mbBs == Nothing) (closeRef ref) + onException (send mbBs) (closeRef ref) -------------------------------------------------------------------------------- @@ -65,7 +106,7 @@ makeSocketStream socket = makeStream receive send return $ if B.null bs then Nothing else Just bs send Nothing = return () - send (Just bs) = + send (Just bs) = do #if !defined(mingw32_HOST_OS) SBL.sendAll socket bs #else @@ -76,10 +117,10 @@ makeSocketStream socket = makeStream receive send -------------------------------------------------------------------------------- makeEchoStream :: IO Stream makeEchoStream = do - chan <- Chan.newChan - makeStream (Chan.readChan chan) $ \mbBs -> case mbBs of - Nothing -> Chan.writeChan chan Nothing - Just bs -> forM_ (BL.toChunks bs) $ \c -> Chan.writeChan chan (Just c) + mvar <- newEmptyMVar + makeStream (takeMVar mvar) $ \mbBs -> case mbBs of + Nothing -> putMVar mvar Nothing + Just bs -> forM_ (BL.toChunks bs) $ \c -> putMVar mvar (Just c) -------------------------------------------------------------------------------- diff --git a/tests/haskell/Network/WebSockets/Handshake/Tests.hs b/tests/haskell/Network/WebSockets/Handshake/Tests.hs index 061563b..d0b1a20 100644 --- a/tests/haskell/Network/WebSockets/Handshake/Tests.hs +++ b/tests/haskell/Network/WebSockets/Handshake/Tests.hs @@ -22,7 +22,6 @@ import Network.WebSockets import Network.WebSockets.Connection import Network.WebSockets.Http import qualified Network.WebSockets.Stream as Stream -import Network.WebSockets.Tests.Util -------------------------------------------------------------------------------- @@ -43,6 +42,7 @@ testHandshake rq app = do _ <- app (PendingConnection defaultConnectionOptions rq nullify echo) return () mbRh <- Stream.parse echo decodeResponseHead + Stream.close echo case mbRh of Nothing -> fail "testHandshake: No response" Just rh -> return rh diff --git a/tests/haskell/Network/WebSockets/Tests.hs b/tests/haskell/Network/WebSockets/Tests.hs index 14f404c..3674b6f 100644 --- a/tests/haskell/Network/WebSockets/Tests.hs +++ b/tests/haskell/Network/WebSockets/Tests.hs @@ -9,7 +9,8 @@ module Network.WebSockets.Tests -------------------------------------------------------------------------------- import qualified Blaze.ByteString.Builder as Builder import Control.Applicative ((<$>)) -import Control.Monad (replicateM, forM_) +import Control.Concurrent (forkIO) +import Control.Monad (forM_, replicateM) import qualified Data.ByteString.Lazy as BL import Data.List (intersperse) import Data.Maybe (catMaybes) @@ -47,8 +48,9 @@ testSimpleEncodeDecode protocol = QC.monadicIO $ echo <- Stream.makeEchoStream parse <- decodeMessages protocol echo write <- encodeMessages protocol ClientConnection echo - forM_ msgs write + _ <- forkIO $ forM_ msgs write msgs' <- catMaybes <$> replicateM (length msgs) parse + Stream.close echo msgs @=? msgs' @@ -61,12 +63,13 @@ testFragmentedHybi13 = QC.monadicIO $ -- is' <- Streams.filter isDataMessage =<< Hybi13.decodeMessages is -- Simple hacky encoding of all frames - mapM_ (Stream.write echo) - [ Builder.toLazyByteString (Hybi13.encodeFrame Nothing f) - | FragmentedMessage _ frames <- fragmented - , f <- frames - ] - Stream.close echo + _ <- forkIO $ do + mapM_ (Stream.write echo) + [ Builder.toLazyByteString (Hybi13.encodeFrame Nothing f) + | FragmentedMessage _ frames <- fragmented + , f <- frames + ] + Stream.close echo -- Check if we got all data msgs <- filter isDataMessage <$> parseAll parse diff --git a/websockets.cabal b/websockets.cabal index f727c9e..92d522a 100644 --- a/websockets.cabal +++ b/websockets.cabal @@ -1,5 +1,5 @@ Name: websockets -Version: 0.9.4.0 +Version: 0.9.5.0 Synopsis: A sensible and clean way to write WebSocket-capable servers in Haskell.