Skip to content

Commit

Permalink
RawBearer API
Browse files Browse the repository at this point in the history
Lower-level send/receive API for Snockets, bypassing the normal Mux
protocol. We need this for KES secure forgetting, as we cannot store
secrets in intermediate data structures for serialization purposes; we
must copy data directly between secure memory and file descriptors.
  • Loading branch information
tdammers committed Mar 22, 2023
1 parent 7f6afd7 commit 82864b3
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Added
-----

- RawBearer API
- ToRawBearer typeclass
- ToRawBearer instances for `Socket`, `LocalSocket`, and `Simulation.Network.Snocket.FD`
2 changes: 2 additions & 0 deletions ouroboros-network-framework/ouroboros-network-framework.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ library
Ouroboros.Network.IOManager
Ouroboros.Network.Mux
Ouroboros.Network.MuxMode
Ouroboros.Network.RawBearer

Ouroboros.Network.Protocol.Handshake
Ouroboros.Network.Protocol.Handshake.Type
Expand Down Expand Up @@ -155,6 +156,7 @@ test-suite test
Test.Ouroboros.Network.Socket
Test.Ouroboros.Network.Subscription
Test.Ouroboros.Network.RateLimiting
Test.Ouroboros.Network.RawBearer
Test.Simulation.Network.Snocket

build-depends: base >=4.14 && <4.17
Expand Down
50 changes: 50 additions & 0 deletions ouroboros-network-framework/src/Ouroboros/Network/RawBearer.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE MultiParamTypeClasses #-}

module Ouroboros.Network.RawBearer
where

import Network.Socket (Socket)
import qualified Network.Socket as Socket
import Foreign.Ptr (Ptr)
import Data.Word (Word8)

#if defined(mingw32_HOST_OS)
import Data.Bits
import Foreign.Ptr (IntPtr (..), ptrToIntPtr)
import qualified System.Win32 as Win32
import qualified System.Win32.Async as Win32.Async
import qualified System.Win32.NamedPipes as Win32
#endif

-- | Generalized API for sending and receiving raw bytes over a file
-- descriptor, socket, or similar object.
data RawBearer m =
RawBearer
{ send :: Ptr Word8 -> Int -> m Int
, recv :: Ptr Word8 -> Int -> m Int
}

class ToRawBearer m fd where
toRawBearer :: fd -> m (RawBearer m)

instance ToRawBearer IO Socket where
toRawBearer s =
return RawBearer
{ send = Socket.sendBuf s
, recv = Socket.recvBuf s
}

#if defined(mingw32_HOST_OS)

-- | We cannot declare an @instance ToRawBearer Win32.HANDLE@, because
-- 'Win32.Handle' is just a type alias for @Ptr ()@. So instead, we provide
-- this function, which can be used to implement 'ToRawBearer' elsewhere (e.g.
-- over a newtype).
win32HandleToRawBearer :: Win32.HANDLE -> RawBearer IO
win32HandleToRawBearer s =
RawBearer
{ send = \buf size -> fromIntegral <$> Win32.win32_WriteFile s (castPtr buf) (fromIntegral size)
, recv = \buf size -> fromIntegral <$> Win32.win32_ReadFile s (castPtr buf) (fromIntegral size)
}
#endif
8 changes: 8 additions & 0 deletions ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import qualified Network.Socket as Socket

import Network.Mux.Bearer

import Ouroboros.Network.RawBearer
import Ouroboros.Network.IOManager


Expand Down Expand Up @@ -349,10 +350,17 @@ data LocalSocket = LocalSocket { getLocalHandle :: !LocalHandle
}
deriving (Eq, Generic)
deriving Show via Quiet LocalSocket

instance ToRawBearer IO LocalSocket where
toRawBearer = return . win32HandleToRawBearer . getLocalHandle

#else
newtype LocalSocket = LocalSocket { getLocalHandle :: LocalHandle }
deriving (Eq, Generic)
deriving Show via Quiet LocalSocket

instance ToRawBearer IO LocalSocket where
toRawBearer = toRawBearer . getLocalHandle
#endif

makeLocalBearer :: MakeBearer IO LocalSocket
Expand Down
117 changes: 104 additions & 13 deletions ouroboros-network-framework/src/Simulation/Network/Snocket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ import Control.Monad (when)
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import Control.Monad.Class.MonadST
import Control.Monad.Class.MonadSay
import Control.Monad.ST.Unsafe (unsafeIOToST)
import Control.Tracer (Tracer, contramap, contramapM, traceWith)

import GHC.IO.Exception
Expand All @@ -62,8 +65,12 @@ import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Typeable (Typeable)
import Foreign.C.Error
import Foreign.Ptr (castPtr)
import Foreign.Marshal (copyBytes)
import Numeric.Natural (Natural)
import Text.Printf (printf)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS

import Data.Monoid.Synchronisation (FirstToFinish (..))
import Data.Wedge
Expand All @@ -74,6 +81,7 @@ import Network.Mux.Bearer.AttenuatedChannel
import Ouroboros.Network.ConnectionId
import Ouroboros.Network.ConnectionManager.Types (AddressType (..))
import Ouroboros.Network.Snocket
import Ouroboros.Network.RawBearer

import Ouroboros.Network.Testing.Data.Script (Script (..),
stepScriptSTM)
Expand Down Expand Up @@ -526,8 +534,92 @@ instance Show addr => Show (FD_ m addr) where

-- | File descriptor type.
--
newtype FD m peerAddr = FD { fdVar :: (StrictTVar m (FD_ m peerAddr)) }

newtype FD m peerAddr = FD { fdVar :: StrictTVar m (FD_ m peerAddr) }

instance ( MonadST m
, MonadThrow m
, MonadSay m
, MonadLabelledSTM m
, Show addr
) => ToRawBearer m (FD m (TestAddress addr)) where
toRawBearer = makeRawFDBearer

-- | Make a 'RawBearer' from an 'FD'. Since this is only used for testing, we
-- can bypass the requirement of moving raw bytes directly between file
-- descriptors and provided memory buffers, and we can instead covertly use
-- plain old 'ByteString' under the hood. This allows us to use the
-- 'AttenuatedChannel' inside the `FD_`, even though its send and receive
-- methods do not have the right format.
makeRawFDBearer :: forall addr m.
( MonadST m
, MonadLabelledSTM m
, MonadThrow m
, MonadSay m
, Show addr
)
=> FD m (TestAddress addr)
-> m (RawBearer m)
makeRawFDBearer (FD {fdVar}) = do
(bufVar :: StrictTMVar m LBS.ByteString) <- newTMVarIO LBS.empty
return RawBearer
{ send = \src srcSize -> do
labelTVarIO fdVar "sender"
say $ "Sending " ++ show srcSize ++ " bytes"
fd_ <- readTVarIO fdVar
case fd_ of
FDConnected _ conn -> do
bs <- withLiftST $ \liftST ->
liftST . unsafeIOToST $ BS.packCStringLen (castPtr src, srcSize)
let bsl = LBS.fromStrict bs
acWrite (connChannelLocal conn) bsl
say $ "Sent " ++ show srcSize ++ " bytes"
return srcSize
_ ->
throwIO (invalidError fd_)
, recv = \dst size -> do
labelTVarIO fdVar "receiver"
let size64 = fromIntegral size
say $ "Receiving " ++ show size ++ " bytes"
fd_ <- readTVarIO fdVar
case fd_ of
FDConnected _ conn -> do
say $ "Checking buffer"
bytesFromBuffer <- atomically $ takeTMVar bufVar
say $ "Buffer: " ++ show bytesFromBuffer
(lhs, rhs) <- if not (LBS.null bytesFromBuffer) then do
say $ "Reading up to " ++ show size ++ " bytes from buffer"
return (LBS.take size64 bytesFromBuffer, LBS.drop size64 bytesFromBuffer)
else do
bytesRead <- acRead (connChannelLocal conn)
say $ "Received " ++ show (LBS.length $ LBS.take size64 bytesRead) ++ " or more bytes"
return (LBS.take size64 bytesRead, LBS.drop size64 bytesRead)
say $ "Updating buffer; use: " ++ show lhs ++ " keep: " ++ show (LBS.take 10 rhs) ++
if (LBS.length . LBS.take 11 $ rhs) == 11 then "..." else ""
atomically $ putTMVar bufVar rhs
say $ "Receive: buffer updated"
if LBS.null lhs then do
say $ "Receive: End of stream"
return 0
else do
say $ "Receive: copying."
let bs = LBS.toStrict lhs
withLiftST $ \liftST ->
liftST . unsafeIOToST $ BS.useAsCStringLen bs $ \(src, srcSize) -> do
copyBytes dst (castPtr src) srcSize
return srcSize
_ ->
throwIO (invalidError fd_)
}
where
invalidError :: FD_ m (TestAddress addr) -> IOError
invalidError fd_ = IOError
{ ioe_handle = Nothing
, ioe_type = InvalidArgument
, ioe_location = "Ouroboros.Network.Snocket.Sim.toRawBearer"
, ioe_description = printf "Invalid argument (%s)" (show fd_)
, ioe_errno = Nothing
, ioe_filename = Nothing
}

makeFDBearer :: forall addr m.
( MonadMonotonicTime m
Expand All @@ -551,17 +643,16 @@ makeFDBearer = MakeBearer $ \sduTimeout muxTracer FD { fdVar } -> do
(connChannelLocal conn)
FDClosed {} ->
throwIO (invalidError fd_)
where
-- io errors
invalidError :: FD_ m (TestAddress addr) -> IOError
invalidError fd_ = IOError
{ ioe_handle = Nothing
, ioe_type = InvalidArgument
, ioe_location = "Ouroboros.Network.Snocket.Sim.toBearer"
, ioe_description = printf "Invalid argument (%s)" (show fd_)
, ioe_errno = Nothing
, ioe_filename = Nothing
}
where
invalidError :: FD_ m (TestAddress addr) -> IOError
invalidError fd_ = IOError
{ ioe_handle = Nothing
, ioe_type = InvalidArgument
, ioe_location = "Ouroboros.Network.Snocket.Sim.toBearer"
, ioe_description = printf "Invalid argument (%s)" (show fd_)
, ioe_errno = Nothing
, ioe_filename = Nothing
}

--
-- Simulated snockets
Expand Down
2 changes: 2 additions & 0 deletions ouroboros-network-framework/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import qualified Test.Ouroboros.Network.Server2 as Server2
import qualified Test.Ouroboros.Network.Socket as Socket
import qualified Test.Ouroboros.Network.Subscription as Subscription
import qualified Test.Simulation.Network.Snocket as Snocket
import qualified Test.Ouroboros.Network.RawBearer as RawBearer

main :: IO ()
main = defaultMain tests
Expand All @@ -23,6 +24,7 @@ tests =
, Subscription.tests
, RateLimiting.tests
, Snocket.tests
, RawBearer.tests
]


Loading

0 comments on commit 82864b3

Please sign in to comment.