Skip to content

Commit

Permalink
Snocket.Accept
Browse files Browse the repository at this point in the history
Rescue Alex Vieth's 'Accept' modification.  I couldn't cherry-pick the
commit since it was burried inside a merge commit.

There's no proper way to fix `Ouroboros.Network.Soocket.fromSnocket`,
but this is ok, as it will be removed in a later commit.
  • Loading branch information
coot committed May 12, 2021
1 parent ca52f8a commit 68de40a
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 54 deletions.
103 changes: 78 additions & 25 deletions ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs
Expand Up @@ -8,6 +8,7 @@
module Ouroboros.Network.Snocket
( -- * Snocket Interface
Accept (..)
, Accepted (..)
, AddressFamily (..)
, Snocket (..)
-- ** Socket based Snocktes
Expand All @@ -32,6 +33,7 @@ import Control.Monad (when)
import Control.Monad.Class.MonadTime (DiffTime)
import Control.Tracer (Tracer)
import Data.Bifunctor (Bifunctor (..))
import Data.Bifoldable (Bifoldable (..))
import Data.Hashable
import GHC.Generics (Generic)
import Quiet (Quiet (..))
Expand Down Expand Up @@ -96,13 +98,26 @@ import Ouroboros.Network.IOManager
-- descriptor by `createNamedPipe`, see 'namedPipeSnocket'.
--
newtype Accept m fd addr = Accept
{ runAccept :: m (fd, addr, Accept m fd addr)
{ runAccept :: m (Accepted fd addr, Accept m fd addr)
}

instance Functor m => Bifunctor (Accept m) where
bimap f g ac = Accept $ h <$> runAccept ac
bimap f g (Accept ac) = Accept (h <$> ac)
where
h (fd, addr, next) = (f fd, g addr, bimap f g next)
h (accepted, next) = (bimap f g accepted, bimap f g next)


data Accepted fd addr where
AcceptFailure :: !SomeException -> Accepted fd addr
Accepted :: !fd -> !addr -> Accepted fd addr

instance Bifunctor Accepted where
bimap f g (Accepted fd addr) = Accepted (f fd) (g addr)
bimap _ _ (AcceptFailure err) = AcceptFailure err

instance Bifoldable Accepted where
bifoldMap f g (Accepted fd addr) = f fd <> g addr
bifoldMap _ _ (AcceptFailure _) = mempty


-- | BSD accept loop.
Expand All @@ -112,21 +127,35 @@ berkeleyAccept :: IOManager
-> Accept IO Socket SockAddr
berkeleyAccept ioManager sock = go
where
go = Accept $ do
(sock', addr') <-
go = Accept (acceptOne `catch` handleException)

acceptOne
:: IO ( Accepted Socket SockAddr
, Accept IO Socket SockAddr
)
acceptOne =
bracketOnError
#if !defined(mingw32_HOST_OS)
Socket.accept sock
(Socket.accept sock)
#else
Win32.Async.accept sock
(Win32.Async.accept sock)
#endif
associateWithIOManager ioManager (Right sock')
`catch` \(e :: IOException) -> do
Socket.close sock'
throwIO e
`catch` \(SomeAsyncException _) -> do
Socket.close sock'
throwIO e
return (sock', addr', go)
(Socket.close . fst)
$ \(sock', addr') -> do
associateWithIOManager ioManager (Right sock')
return (Accepted sock' addr', go)

-- Only non-async exceptions will be caught and put into the
-- AcceptFailure variant.
handleException
:: SomeException
-> IO ( Accepted Socket SockAddr
, Accept IO Socket SockAddr
)
handleException err =
case fromException err of
Just (SomeAsyncException _) -> throwIO err
Nothing -> pure (AcceptFailure err, go)

-- | Local address, on Unix is associated with `Socket.AF_UNIX` family, on
--
Expand Down Expand Up @@ -186,6 +215,9 @@ data Snocket m fd addr = Snocket {
, bind :: fd -> addr -> m ()
, listen :: fd -> m ()

-- SomeException is chosen here to avoid having to include it in the Snocket
-- type, and therefore refactoring a bunch of stuff.
-- FIXME probably a good idea to abstract it.
, accept :: fd -> Accept m fd addr

, close :: fd -> m ()
Expand Down Expand Up @@ -346,7 +378,7 @@ localSnocket ioManager path = Snocket {

, accept = \sock@(LocalSocket hpipe) -> Accept $ do
Win32.Async.connectNamedPipe hpipe
return (sock, localAddress, acceptNext)
return (Accepted sock localAddress, acceptNext)

-- Win32.closeHandle is not interrupible
, close = Win32.closeHandle . getLocalHandle
Expand All @@ -358,19 +390,40 @@ localSnocket ioManager path = Snocket {
localAddress = LocalAddress path

acceptNext :: Accept IO LocalSocket LocalAddress
acceptNext = Accept $ do
hpipe <- Win32.createNamedPipe
acceptNext = go
where
go = Accept (acceptOne `catch` handleIOException)

handleIOException
:: IOException
-> IO ( Accepted LocalSocket LocalAddress
, Accept IO LocalSocket LocalAddress
)
handleIOException err =
pure ( AcceptFailure (toException err)
, go
)

acceptOne
:: IO ( Accepted LocalSocket LocalAddress
, Accept IO LocalSocket LocalAddress
)
acceptOne =
bracketOnError
(Win32.createNamedPipe
path
(Win32.pIPE_ACCESS_DUPLEX .|. Win32.fILE_FLAG_OVERLAPPED)
(Win32.pIPE_TYPE_BYTE .|. Win32.pIPE_READMODE_BYTE)
Win32.pIPE_UNLIMITED_INSTANCES
65536 -- outbound pipe size
16384 -- inbound pipe size
0 -- default timeout
Nothing -- default security
associateWithIOManager ioManager (Left hpipe)
Win32.Async.connectNamedPipe hpipe
return (LocalSocket hpipe, localAddress, acceptNext)
65536 -- outbound pipe size
16384 -- inbound pipe size
0 -- default timeout
Nothing) -- default security
Win32.closeHandle
$ \hpipe -> do
associateWithIOManager ioManager (Left hpipe)
Win32.Async.connectNamedPipe hpipe
return (Accepted (LocalSocket hpipe) localAddress, go)

-- local snocket on unix
#else
Expand Down
15 changes: 10 additions & 5 deletions ouroboros-network-framework/src/Ouroboros/Network/Socket.hs
Expand Up @@ -429,11 +429,16 @@ fromSnocket tblVar sn sd = go (Snocket.accept sn sd)
where
go :: Snocket.Accept IO fd addr -> Server.Socket addr fd
go (Snocket.Accept accept) = Server.Socket $ do
(sd', remoteAddr, next) <- accept
-- TOOD: we don't need to that on each accept
localAddr <- Snocket.getLocalAddr sn sd'
atomically $ addConnection tblVar remoteAddr localAddr Nothing
pure (remoteAddr, sd', close remoteAddr localAddr sd', go next)
(result, next) <- accept
case result of
Snocket.Accepted sd' remoteAddr -> do
-- TOOD: we don't need to that on each accept
localAddr <- Snocket.getLocalAddr sn sd'
atomically $ addConnection tblVar remoteAddr localAddr Nothing
pure (remoteAddr, sd', close remoteAddr localAddr sd', go next)
Snocket.AcceptFailure err ->
-- the is no way to construct 'Server.Socket'; This will be removed in a later commit!
throwIO err

close remoteAddr localAddr sd' = do
removeConnection tblVar remoteAddr localAddr
Expand Down
50 changes: 26 additions & 24 deletions ouroboros-network-framework/test/Test/Ouroboros/Network/Socket.hs
Expand Up @@ -13,6 +13,7 @@ module Test.Ouroboros.Network.Socket (tests) where

import Data.Void (Void)
import Data.List (mapAccumL)
import Data.Bifoldable (bitraverse_)
import qualified Data.ByteString.Lazy as BL
import Data.Proxy (Proxy (..))
import Data.Time.Clock (UTCTime, getCurrentTime)
Expand Down Expand Up @@ -333,17 +334,19 @@ prop_socket_recv_error f rerr =
-- accept a connection and start mux on it
bracket
(runAccept $ accept snocket sd)
(\(sd', _, _) -> Socket.close sd')
$ \(sd', _, _) -> do
remoteAddress <- Socket.getPeerName sd'
let timeout = if rerr == RecvSDUTimeout then 0.10
else (-1) -- No timeout
bearer = Mx.socketAsMuxBearer timeout nullTracer sd'
connectionId = ConnectionId {
localAddress = Socket.addrAddress muxAddress,
remoteAddress
}
Mx.muxStart nullTracer (toApplication connectionId (continueForever (Proxy :: Proxy IO)) app) bearer
(bitraverse_ Socket.close pure . fst)
$ \(accepted, _acceptNext) -> case accepted of
AcceptFailure err -> throwIO err
Accepted sd' _ -> do
remoteAddress <- Socket.getPeerName sd'
let timeout = if rerr == RecvSDUTimeout then 0.10
else (-1) -- No timeout
bearer = Mx.socketAsMuxBearer timeout nullTracer sd'
connectionId = ConnectionId {
localAddress = Socket.addrAddress muxAddress,
remoteAddress
}
Mx.muxStart nullTracer (toApplication connectionId (continueForever (Proxy :: Proxy IO)) app) bearer
)
$ \muxAsync -> do

Expand Down Expand Up @@ -407,22 +410,21 @@ prop_socket_send_error rerr =
-- accept a connection and start mux on it
bracket
(runAccept $ accept snocket sd)
(\(sd', _, _) -> Socket.close sd')
(\(sd', _, _) ->
let sduTimeout = if rerr == SendSDUTimeout then 0.10
else (-1) -- No timeout
bearer = Mx.socketAsMuxBearer sduTimeout nullTracer sd'
blob = BL.pack $ replicate 0xffff 0xa5 in
withTimeoutSerial $ \timeout ->
-- send maximum mux sdus until we've filled the window.
replicateM 100 $ do
((), Nothing) <$ write bearer timeout (wrap blob ResponderDir (MiniProtocolNum 0))
)

(bitraverse_ Socket.close pure . fst)
$ \(accepted, _acceptNext) -> case accepted of
AcceptFailure err -> throwIO err
Accepted sd' _ -> do
let sduTimeout = if rerr == SendSDUTimeout then 0.10
else (-1) -- No timeout
bearer = Mx.socketAsMuxBearer sduTimeout nullTracer sd'
blob = BL.pack $ replicate 0xffff 0xa5
withTimeoutSerial $ \timeout ->
-- send maximum mux sdus until we've filled the window.
replicateM 100 $ do
((), Nothing) <$ write bearer timeout (wrap blob ResponderDir (MiniProtocolNum 0))
)
$ \muxAsync -> do


sd' <- openToConnect snocket addr
-- connect to muxAddress
_ <- connect snocket sd' addr
Expand Down

0 comments on commit 68de40a

Please sign in to comment.