Skip to content

Commit

Permalink
Bugfixes wrt closing sockets, bump to 0.9.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jaspervdj committed May 31, 2015
1 parent 5c3e8ed commit 76fefa5
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 81 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG
@@ -1,3 +1,6 @@
- 0.9.5.0
* Bugfixes wrt closing sockets and streams

- 0.9.4.0
* Add `makePendingConnectionFromStream` function
* Bump `attoparsec` dependency
Expand Down
21 changes: 9 additions & 12 deletions src/Network/WebSockets/Client.hs
Expand Up @@ -12,8 +12,7 @@ module Network.WebSockets.Client

--------------------------------------------------------------------------------
import qualified Blaze.ByteString.Builder as Builder
import Control.Concurrent.MVar (newMVar)
import Control.Exception (finally, throwIO)
import Control.Exception (bracket, finally, throwIO)
import Data.IORef (newIORef)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
Expand Down Expand Up @@ -100,18 +99,14 @@ runClientWithStream stream host path opts customHeaders app = do
Response _ _ <- return $ finishResponse protocol request response
parse <- decodeMessages protocol stream
write <- encodeMessages protocol ClientConnection stream
sentRef <- newIORef False

parseLock <- newMVar ()
writeLock <- newMVar ()
stream' <- newIORef (DecoderEncoder parse write)
sentRef <- newIORef False
app Connection
{ connectionOptions = opts
, connectionType = ClientConnection
, connectionProtocol = protocol
, connectionParseLock = parseLock
, connectionWriteLock = writeLock
, connectionStream = stream'
, connectionParse = parse
, connectionWrite = write
, connectionSentClose = sentRef
}
where
Expand All @@ -128,6 +123,8 @@ runClientWithSocket :: S.Socket -- ^ Socket
-> Headers -- ^ Custom headers to send
-> ClientApp a -- ^ Client application
-> IO a
runClientWithSocket sock host path opts customHeaders app = do
stream <- Stream.makeSocketStream sock
runClientWithStream stream host path opts customHeaders app
runClientWithSocket sock host path opts customHeaders app = bracket
(Stream.makeSocketStream sock)
Stream.close
(\stream ->
runClientWithStream stream host path opts customHeaders app)
59 changes: 16 additions & 43 deletions src/Network/WebSockets/Connection.hs
Expand Up @@ -9,7 +9,6 @@ module Network.WebSockets.Connection
, acceptRequestWith
, rejectRequest

, DecoderEncoder (..)
, Connection (..)

, ConnectionOptions (..)
Expand All @@ -33,9 +32,8 @@ module Network.WebSockets.Connection
--------------------------------------------------------------------------------
import qualified Blaze.ByteString.Builder as Builder
import Control.Concurrent (forkIO, threadDelay)
import Control.Concurrent.MVar (MVar, newMVar, withMVar)
import Control.Exception (AsyncException, fromException,
handle, onException, throwIO)
handle, throwIO)
import Control.Monad (unless)
import qualified Data.ByteString as B
import Data.IORef (IORef, newIORef, readIORef,
Expand Down Expand Up @@ -103,17 +101,13 @@ acceptRequestWith pc ar = case find (flip compatible request) protocols of
parse <- decodeMessages protocol (pendingStream pc)
write <- encodeMessages protocol ServerConnection (pendingStream pc)

parseLock <- newMVar ()
writeLock <- newMVar ()
stream <- newIORef (DecoderEncoder parse write)
sentRef <- newIORef False
let connection = Connection
{ connectionOptions = pendingOptions pc
, connectionType = ServerConnection
, connectionProtocol = protocol
, connectionParseLock = parseLock
, connectionWriteLock = writeLock
, connectionStream = stream
, connectionParse = parse
, connectionWrite = write
, connectionSentClose = sentRef
}

Expand All @@ -130,21 +124,13 @@ rejectRequest :: PendingConnection -> B.ByteString -> IO ()
rejectRequest pc message = sendResponse pc $ response400 [] message


--------------------------------------------------------------------------------
-- | Type representing an available or unavailable stream.
data DecoderEncoder a
= DecoderEncoder !(IO (Maybe a)) !(a -> IO ())
| Closed


--------------------------------------------------------------------------------
data Connection = Connection
{ connectionOptions :: !ConnectionOptions
, connectionType :: !ConnectionType
, connectionProtocol :: !Protocol
, connectionParseLock :: !(MVar ())
, connectionWriteLock :: !(MVar ())
, connectionStream :: !(IORef (DecoderEncoder Message))
, connectionParse :: !(IO (Maybe Message))
, connectionWrite :: !(Message -> IO ())
, connectionSentClose :: !(IORef Bool)
-- ^ According to the RFC, both the client and the server MUST send
-- a close control message to each other. Either party can initiate
Expand Down Expand Up @@ -172,18 +158,11 @@ defaultConnectionOptions = ConnectionOptions

--------------------------------------------------------------------------------
receive :: Connection -> IO Message
receive conn = withMVar (connectionParseLock conn) $ \() ->
tryParse `onException` (writeIORef (connectionStream conn) Closed)
where
tryParse = do
stream <- readIORef (connectionStream conn)
case stream of
Closed -> throwIO ConnectionClosed
DecoderEncoder parse _ -> do
mbMsg <- parse
case mbMsg of
Nothing -> throwIO ConnectionClosed
Just msg -> return msg
receive conn = do
mbMsg <- connectionParse conn
case mbMsg of
Nothing -> throwIO ConnectionClosed
Just msg -> return msg


--------------------------------------------------------------------------------
Expand Down Expand Up @@ -227,18 +206,12 @@ receiveData conn = do

--------------------------------------------------------------------------------
send :: Connection -> Message -> IO ()
send conn msg = withMVar (connectionWriteLock conn) $ \() ->
trySend `onException` writeIORef (connectionStream conn) Closed
where
trySend = do
stream <- readIORef (connectionStream conn)
case msg of
(ControlMessage (Close _ _)) ->
writeIORef (connectionSentClose conn) True
_ -> return ()
case stream of
Closed -> throwIO ConnectionClosed
DecoderEncoder _ write -> write msg
send conn msg = do
case msg of
(ControlMessage (Close _ _)) ->
writeIORef (connectionSentClose conn) True
_ -> return ()
connectionWrite conn msg


--------------------------------------------------------------------------------
Expand Down
14 changes: 9 additions & 5 deletions src/Network/WebSockets/Server.hs
Expand Up @@ -84,19 +84,23 @@ runApp :: Socket
-> ConnectionOptions
-> ServerApp
-> IO ()
runApp socket opts app = do
pending <- makePendingConnection socket opts
app pending
runApp socket opts app =
bracket
(makePendingConnection socket opts)
(Stream.close . pendingStream)
app


--------------------------------------------------------------------------------
-- | Turns a socket, connected to some client, into a 'PendingConnection'.
-- | Turns a socket, connected to some client, into a 'PendingConnection'. The
-- 'PendingConnection' should be closed using 'Stream.close' later.
makePendingConnection
:: Socket -> ConnectionOptions -> IO PendingConnection
makePendingConnection socket opts = do
stream <- Stream.makeSocketStream socket
stream <- Stream.makeSocketStream socket
makePendingConnectionFromStream stream opts


-- | More general version of 'makePendingConnection' for 'Stream.Stream'
-- instead of a 'Socket'.
makePendingConnectionFromStream
Expand Down
63 changes: 52 additions & 11 deletions src/Network/WebSockets/Stream.hs
Expand Up @@ -11,14 +11,15 @@ module Network.WebSockets.Stream
, close
) where

import Control.Applicative ((<$>))
import qualified Control.Concurrent.Chan as Chan
import Control.Exception (throwIO)
import Control.Monad (forM_)
import Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar,
putMVar, takeMVar, withMVar)
import Control.Exception (onException, throwIO)
import Control.Monad (forM_, when)
import qualified Data.Attoparsec.ByteString as Atto
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.IORef (IORef, newIORef, readIORef,
import Data.IORef (IORef, atomicModifyIORef,
newIORef, readIORef,
writeIORef)
import qualified Network.Socket as S
import qualified Network.Socket.ByteString as SB (recv)
Expand Down Expand Up @@ -49,11 +50,51 @@ data Stream = Stream


--------------------------------------------------------------------------------
-- | Create a stream from a "receive" and "send" action. The following
-- properties apply:
--
-- - Regardless of the provided "receive" and "send" functions, reading and
-- writing from the stream will be thread-safe, i.e. this function will create
-- a receive and write lock to be used internally.
--
-- - Reading from or writing or to a closed 'Stream' will always throw an
-- exception, even if the underlying "receive" and "send" functions do not
-- (we do the bookkeeping).
--
-- - Streams should always be closed.
makeStream
:: IO (Maybe B.ByteString) -- ^ Reading
-> (Maybe BL.ByteString -> IO ()) -- ^ Writing
-> IO Stream -- ^ Resulting stream
makeStream i o = Stream i o <$> newIORef (Open B.empty)
makeStream receive send = do
ref <- newIORef (Open B.empty)
receiveLock <- newMVar ()
sendLock <- newMVar ()
return $ Stream (receive' ref receiveLock) (send' ref sendLock) ref
where
closeRef :: IORef StreamState -> IO ()
closeRef ref = atomicModifyIORef ref $ \state -> case state of
Open buf -> (Closed buf, ())
Closed buf -> (Closed buf, ())

assertNotClosed :: IORef StreamState -> IO a -> IO a
assertNotClosed ref io = do
state <- readIORef ref
case state of
Closed _ -> throwIO ConnectionClosed
Open _ -> io

receive' :: IORef StreamState -> MVar () -> IO (Maybe B.ByteString)
receive' ref lock = withMVar lock $ \() -> assertNotClosed ref $ do
mbBs <- onException receive (closeRef ref)
case mbBs of
Nothing -> closeRef ref >> return Nothing
Just bs -> return (Just bs)

send' :: IORef StreamState -> MVar () -> (Maybe BL.ByteString -> IO ())
send' ref lock mbBs = withMVar lock $ \() -> assertNotClosed ref $ do
when (mbBs == Nothing) (closeRef ref)
onException (send mbBs) (closeRef ref)


--------------------------------------------------------------------------------
Expand All @@ -65,7 +106,7 @@ makeSocketStream socket = makeStream receive send
return $ if B.null bs then Nothing else Just bs

send Nothing = return ()
send (Just bs) =
send (Just bs) = do
#if !defined(mingw32_HOST_OS)
SBL.sendAll socket bs
#else
Expand All @@ -76,10 +117,10 @@ makeSocketStream socket = makeStream receive send
--------------------------------------------------------------------------------
makeEchoStream :: IO Stream
makeEchoStream = do
chan <- Chan.newChan
makeStream (Chan.readChan chan) $ \mbBs -> case mbBs of
Nothing -> Chan.writeChan chan Nothing
Just bs -> forM_ (BL.toChunks bs) $ \c -> Chan.writeChan chan (Just c)
mvar <- newEmptyMVar
makeStream (takeMVar mvar) $ \mbBs -> case mbBs of
Nothing -> putMVar mvar Nothing
Just bs -> forM_ (BL.toChunks bs) $ \c -> putMVar mvar (Just c)


--------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion tests/haskell/Network/WebSockets/Handshake/Tests.hs
Expand Up @@ -22,7 +22,6 @@ import Network.WebSockets
import Network.WebSockets.Connection
import Network.WebSockets.Http
import qualified Network.WebSockets.Stream as Stream
import Network.WebSockets.Tests.Util


--------------------------------------------------------------------------------
Expand All @@ -43,6 +42,7 @@ testHandshake rq app = do
_ <- app (PendingConnection defaultConnectionOptions rq nullify echo)
return ()
mbRh <- Stream.parse echo decodeResponseHead
Stream.close echo
case mbRh of
Nothing -> fail "testHandshake: No response"
Just rh -> return rh
Expand Down
19 changes: 11 additions & 8 deletions tests/haskell/Network/WebSockets/Tests.hs
Expand Up @@ -9,7 +9,8 @@ module Network.WebSockets.Tests
--------------------------------------------------------------------------------
import qualified Blaze.ByteString.Builder as Builder
import Control.Applicative ((<$>))
import Control.Monad (replicateM, forM_)
import Control.Concurrent (forkIO)
import Control.Monad (forM_, replicateM)
import qualified Data.ByteString.Lazy as BL
import Data.List (intersperse)
import Data.Maybe (catMaybes)
Expand Down Expand Up @@ -47,8 +48,9 @@ testSimpleEncodeDecode protocol = QC.monadicIO $
echo <- Stream.makeEchoStream
parse <- decodeMessages protocol echo
write <- encodeMessages protocol ClientConnection echo
forM_ msgs write
_ <- forkIO $ forM_ msgs write
msgs' <- catMaybes <$> replicateM (length msgs) parse
Stream.close echo
msgs @=? msgs'


Expand All @@ -61,12 +63,13 @@ testFragmentedHybi13 = QC.monadicIO $
-- is' <- Streams.filter isDataMessage =<< Hybi13.decodeMessages is

-- Simple hacky encoding of all frames
mapM_ (Stream.write echo)
[ Builder.toLazyByteString (Hybi13.encodeFrame Nothing f)
| FragmentedMessage _ frames <- fragmented
, f <- frames
]
Stream.close echo
_ <- forkIO $ do
mapM_ (Stream.write echo)
[ Builder.toLazyByteString (Hybi13.encodeFrame Nothing f)
| FragmentedMessage _ frames <- fragmented
, f <- frames
]
Stream.close echo

-- Check if we got all data
msgs <- filter isDataMessage <$> parseAll parse
Expand Down
2 changes: 1 addition & 1 deletion websockets.cabal
@@ -1,5 +1,5 @@
Name: websockets
Version: 0.9.4.0
Version: 0.9.5.0

Synopsis:
A sensible and clean way to write WebSocket-capable servers in Haskell.
Expand Down

0 comments on commit 76fefa5

Please sign in to comment.