Permalink
Browse files

Start work on a mock network layer for testing

  • Loading branch information...
1 parent 612df70 commit 2245ca08b0697890a2e2b8682334aed945c54062 @edsko edsko committed Oct 5, 2012
@@ -20,6 +20,10 @@ Source-Repository head
Location: https://github.com/haskell-distributed/distributed-process
SubDir: network-transport-tcp
+Flag use-mock-network
+ Description: Use mock network implementation (for testing)
+ Default: False
+
Library
Build-Depends: base >= 4.3 && < 5,
network-transport >= 0.3 && < 0.4,
@@ -32,6 +36,10 @@ Library
Extensions: CPP
ghc-options: -Wall -fno-warn-unused-do-bind
HS-Source-Dirs: src
+ If flag(use-mock-network)
+ CPP-Options: -DUSE_MOCK_NETWORK
+ Exposed-modules: Network.Transport.TCP.Mock.Socket
+ Network.Transport.TCP.Mock.Socket.ByteString
Test-Suite TestTCP
Type: exitcode-stdio-1.0
@@ -45,3 +53,5 @@ Test-Suite TestTCP
HS-Source-Dirs: tests
Extensions: CPP,
OverloadedStrings
+ If flag(use-mock-network)
+ CPP-Options: -DUSE_MOCK_NETWORK
@@ -55,7 +55,12 @@ import Network.Transport.Internal
, timeoutMaybe
, asyncWhenCancelled
)
+
+#ifdef USE_MOCK_NETWORK
+import qualified Network.Transport.TCP.Mock.Socket as N
+#else
import qualified Network.Socket as N
+#endif
( HostName
, ServiceName
, Socket
@@ -71,7 +76,13 @@ import qualified Network.Socket as N
, sOMAXCONN
, AddrInfo
)
+
+#ifdef USE_MOCK_NETWORK
+import Network.Transport.TCP.Mock.Socket.ByteString (sendMany)
+#else
import Network.Socket.ByteString (sendMany)
+#endif
+
import Control.Concurrent (forkIO, ThreadId, killThread, myThreadId)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent.MVar
@@ -12,7 +12,12 @@ import Prelude hiding (catch)
#endif
import Network.Transport.Internal (decodeInt32, void, tryIO, forkIOWithUnmask)
+
+#ifdef USE_MOCK_NETWORK
+import qualified Network.Transport.TCP.Mock.Socket as N
+#else
import qualified Network.Socket as N
+#endif
( HostName
, ServiceName
, Socket
@@ -30,7 +35,13 @@ import qualified Network.Socket as N
, accept
, sClose
)
+
+#ifdef USE_MOCK_NETWORK
+import qualified Network.Transport.TCP.Mock.Socket.ByteString as NBS (recv)
+#else
import qualified Network.Socket.ByteString as NBS (recv)
+#endif
+
import Control.Concurrent (ThreadId)
import Control.Monad (forever, when)
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
@@ -0,0 +1,249 @@
+{-# LANGUAGE EmptyDataDecls #-}
+module Network.Transport.TCP.Mock.Socket
+ ( -- * Types
+ HostName
+ , ServiceName
+ , Socket
+ , SocketType(..)
+ , SocketOption(..)
+ , AddrInfo(..)
+ , Family
+ , SockAddr
+ , ProtocolNumber
+ , ShutdownCmd(..)
+ -- * Functions
+ , getAddrInfo
+ , socket
+ , bindSocket
+ , listen
+ , setSocketOption
+ , accept
+ , sClose
+ , connect
+ , shutdown
+ -- * Constants
+ , defaultHints
+ , defaultProtocol
+ , sOMAXCONN
+ -- * Internal API
+ , writeSocket
+ , readSocket
+ ) where
+
+import Data.Word (Word8)
+import Data.Map (Map)
+import qualified Data.Map as Map
+import Control.Category ((>>>))
+import Control.Applicative ((<$>))
+import Control.Concurrent.MVar
+import Control.Concurrent.Chan
+import System.IO.Unsafe (unsafePerformIO)
+import Data.Accessor (Accessor, accessor, (^=), (^.))
+import qualified Data.Accessor.Container as DAC (mapMaybe)
+
+--------------------------------------------------------------------------------
+-- Mock state --
+--------------------------------------------------------------------------------
+
+data MockState = MockState {
+ _boundSockets :: !(Map SockAddr Socket)
+ , _nextSocketId :: !Int
+ }
+
+initialMockState :: MockState
+initialMockState = MockState {
+ _boundSockets = Map.empty
+ , _nextSocketId = 0
+ }
+
+mockState :: MVar MockState
+{-# NOINLINE mockState #-}
+mockState = unsafePerformIO $ newMVar initialMockState
+
+get :: Accessor MockState a -> IO a
+get acc = withMVar mockState $ return . (^. acc)
+
+set :: Accessor MockState a -> a -> IO ()
+set acc val = modifyMVar_ mockState $ return . (acc ^= val)
+
+boundSockets :: Accessor MockState (Map SockAddr Socket)
+boundSockets = accessor _boundSockets (\bs st -> st { _boundSockets = bs })
+
+boundSocketAt :: SockAddr -> Accessor MockState (Maybe Socket)
+boundSocketAt addr = boundSockets >>> DAC.mapMaybe addr
+
+nextSocketId :: Accessor MockState Int
+nextSocketId = accessor _nextSocketId (\sid st -> st { _nextSocketId = sid })
+
+--------------------------------------------------------------------------------
+-- The public API (mirroring Network.Socket) --
+--------------------------------------------------------------------------------
+
+type HostName = String
+type ServiceName = String
+type PortNumber = String
+type HostAddress = String
+
+data SocketType = Stream
+data SocketOption = ReuseAddr
+data ShutdownCmd = ShutdownSend
+
+data Family
+data ProtocolNumber
+
+data Socket = Socket {
+ socketState :: MVar SocketState
+ , socketDescription :: String
+ }
+
+data SocketState =
+ Uninit
+ | BoundSocket { socketBacklog :: Chan (Socket, SockAddr, MVar Socket) }
+ | Connected { socketPeer :: Socket,socketBuff :: Chan Word8 }
+ | Closed
+
+data AddrInfo = AddrInfo {
+ addrFamily :: Family
+ , addrAddress :: !SockAddr
+ }
+
+data SockAddr = SockAddrInet PortNumber HostAddress
+ deriving (Eq, Ord, Show)
+
+instance Show AddrInfo where
+ show = show . addrAddress
+
+instance Show Socket where
+ show sock = "<<socket " ++ socketDescription sock ++ ">>"
+
+getAddrInfo :: Maybe AddrInfo -> Maybe HostName -> Maybe ServiceName -> IO [AddrInfo]
+getAddrInfo _ (Just host) (Just port) = return . return $ AddrInfo {
+ addrFamily = error "Family unused"
+ , addrAddress = SockAddrInet port host
+ }
+getAddrInfo _ _ _ = error "getAddrInfo: unsupported arguments"
+
+defaultHints :: AddrInfo
+defaultHints = error "defaultHints not implemented"
+
+socket :: Family -> SocketType -> ProtocolNumber -> IO Socket
+socket _ Stream _ = do
+ state <- newMVar Uninit
+ sid <- get nextSocketId
+ set nextSocketId (sid + 1)
+ return Socket {
+ socketState = state
+ , socketDescription = show sid
+ }
+
+bindSocket :: Socket -> SockAddr -> IO ()
+bindSocket sock addr = do
+ modifyMVar_ (socketState sock) $ \st -> case st of
+ Uninit -> do
+ backlog <- newChan
+ return BoundSocket {
+ socketBacklog = backlog
+ }
+ _ ->
+ error "bind: socket already initialized"
+ set (boundSocketAt addr) (Just sock)
+
+listen :: Socket -> Int -> IO ()
+listen _ _ = return ()
+
+defaultProtocol :: ProtocolNumber
+defaultProtocol = error "defaultProtocol not implemented"
+
+setSocketOption :: Socket -> SocketOption -> Int -> IO ()
+setSocketOption _ ReuseAddr 1 = return ()
+setSocketOption _ _ _ = error "setSocketOption: unsupported arguments"
+
+accept :: Socket -> IO (Socket, SockAddr)
+accept serverSock = do
+ backlog <- withMVar (socketState serverSock) $ \st -> case st of
+ BoundSocket {} ->
+ return (socketBacklog st)
+ _ ->
+ error "accept: socket not bound"
+ (them, theirAddress, reply) <- readChan backlog
+ buff <- newChan
+ ourState <- newMVar Connected {
+ socketPeer = them
+ , socketBuff = buff
+ }
+ let us = Socket {
+ socketState = ourState
+ , socketDescription = ""
+ }
+ putMVar reply us
+ return (us, theirAddress)
+
+sClose :: Socket -> IO ()
+sClose sock = do
+ mPeer <- modifyMVar (socketState sock) $ \st -> case st of
+ Connected {} ->
+ return (Closed, Just $ socketPeer st)
+ _ ->
+ return (Closed, Nothing)
+ case mPeer of
+ Just peer -> modifyMVar_ (socketState peer) $ const (return Closed)
+ Nothing -> return ()
+
+connect :: Socket -> SockAddr -> IO ()
+connect us serverAddr = do
+ mServer <- get (boundSocketAt serverAddr)
+ case mServer of
+ Just server -> do
+ serverBacklog <- withMVar (socketState server) $ \st -> case st of
+ BoundSocket {} ->
+ return (socketBacklog st)
+ _ ->
+ error "connect: server socket not bound"
+ reply <- newEmptyMVar
+ writeChan serverBacklog (us, SockAddrInet "" "", reply)
+ them <- readMVar reply
+ modifyMVar_ (socketState us) $ \st -> case st of
+ Uninit -> do
+ buff <- newChan
+ return Connected {
+ socketPeer = them
+ , socketBuff = buff
+ }
+ _ ->
+ error "connect: already connected"
+ Nothing -> error "connect: unknown address"
+
+sOMAXCONN :: Int
+sOMAXCONN = error "sOMAXCONN not implemented"
+
+shutdown :: Socket -> ShutdownCmd -> IO ()
+shutdown = error "shutdown not implemented"
+
+--------------------------------------------------------------------------------
+-- Functions with no direct public counterpart --
+--------------------------------------------------------------------------------
+
+writeSocket :: Socket -> Word8 -> IO ()
+writeSocket sock w = do
+ peer <- withMVar (socketState sock) $ \st -> case st of
+ Connected {} ->
+ return (socketPeer st)
+ _ ->
+ error "writeSocket: not connected"
+ theirBuff <- withMVar (socketState peer) $ \st -> case st of
+ Connected {} ->
+ return (socketBuff st)
+ _ ->
+ error "writeSocket: peer socket closed"
+ writeChan theirBuff w
+
+readSocket :: Socket -> IO (Maybe Word8)
+readSocket sock = do
+ mBuff <- withMVar (socketState sock) $ \st -> case st of
+ Connected {} ->
+ return (Just $ socketBuff st)
+ _ ->
+ return Nothing
+ case mBuff of
+ Just buff -> Just <$> readChan buff
+ Nothing -> return Nothing
@@ -0,0 +1,27 @@
+module Network.Transport.TCP.Mock.Socket.ByteString
+ ( sendMany
+ , recv
+ ) where
+
+import Data.ByteString (ByteString)
+import qualified Data.ByteString as BSS (pack, foldl)
+import Data.Word (Word8)
+import Control.Applicative ((<$>))
+import Network.Transport.TCP.Mock.Socket
+
+sendMany :: Socket -> [ByteString] -> IO ()
+sendMany sock = mapM_ (bsMapM_ $ writeSocket sock)
+ where
+ bsMapM_ :: (Word8 -> IO ()) -> ByteString -> IO ()
+ bsMapM_ p = BSS.foldl (\io w -> io >> p w) (return ())
+
+recv :: Socket -> Int -> IO ByteString
+recv sock = \n -> BSS.pack <$> go [] n
+ where
+ go :: [Word8] -> Int -> IO [Word8]
+ go acc 0 = return (reverse acc)
+ go acc n = do
+ mw <- readSocket sock
+ case mw of
+ Just w -> go (w : acc) (n - 1)
+ Nothing -> return (reverse acc)
Oops, something went wrong.

0 comments on commit 2245ca0

Please sign in to comment.