Skip to content

Commit

Permalink
Merge PR #445.
Browse files Browse the repository at this point in the history
  • Loading branch information
kazu-yamamoto committed May 6, 2020
2 parents 54b872f + 99e7f9a commit 1d61d74
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ matrix:
addons: {apt: {packages: [ghc-ppa-tools,cabal-install-2.4,ghc-8.6.5], sources: [hvr-ghc]}}
- compiler: "ghc-8.8.3"
# env: TEST=--disable-tests BENCH=--disable-benchmarks
addons: {apt: {packages: [ghc-ppa-tools,cabal-install-3.0,ghc-8.8.1], sources: [hvr-ghc]}}
addons: {apt: {packages: [ghc-ppa-tools,cabal-install-3.0,ghc-8.8.3], sources: [hvr-ghc]}}
- compiler: "ghc-head"
# env: TEST=--disable-tests BENCH=--disable-benchmarks
addons: {apt: {packages: [ghc-ppa-tools,cabal-install-head,ghc-head], sources: [hvr-ghc]}}
Expand Down
19 changes: 7 additions & 12 deletions Network/Socket/Buffer.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -257,28 +257,23 @@ recvBufMsg s bufsizs clen flags = do
allocaBytes clen $ \ctrlPtr ->
#if !defined(mingw32_HOST_OS)
withIOVec bufsizs $ \(iovsPtr, iovsLen) -> do
#else
withWSABuf bufsizs $ \(wsaBPtr, wsaBLen) -> do
#endif
let msgHdr = MsgHdr {
msgName = addrPtr
, msgNameLen = fromIntegral addrSize
#if !defined(mingw32_HOST_OS)
, msgIov = iovsPtr
, msgIovLen = fromIntegral iovsLen
, msgCtrl = castPtr ctrlPtr
, msgCtrlLen = fromIntegral clen
, msgFlags = 0
#else
withWSABuf bufsizs $ \(wsaBPtr, wsaBLen) -> do
let msgHdr = MsgHdr {
msgName = addrPtr
, msgNameLen = fromIntegral addrSize
, msgBuffer = wsaBPtr
, msgBufferLen = fromIntegral wsaBLen
#endif
#if !defined(mingw32_HOST_OS)
, msgCtrl = castPtr ctrlPtr
#else
, msgCtrl = if clen == 0 then nullPtr else castPtr ctrlPtr
#endif
, msgCtrlLen = fromIntegral clen
#if !defined(mingw32_HOST_OS)
, msgFlags = 0
#else
, msgFlags = fromIntegral $ fromMsgFlag flags
#endif
}
Expand Down
2 changes: 2 additions & 0 deletions Network/Socket/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ throwSocketErrorIfMinus1_ name act = do
_ <- throwSocketErrorIfMinus1Retry name act
return ()

throwSocketErrorIfMinus1ButRetry :: (Eq a, Num a) =>
(CInt -> Bool) -> String -> IO a -> IO a
throwSocketErrorIfMinus1ButRetry exempt name act = do
r <- act
if (r == -1)
Expand Down
33 changes: 19 additions & 14 deletions Network/Socket/Posix/Cmsg.hsc
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Network.Socket.Posix.Cmsg where

Expand Down Expand Up @@ -87,24 +89,27 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs
-- Each control message type has a numeric 'CmsgId' and a 'Storable'
-- data representation.
class Storable a => ControlMessage a where
controlMessageId :: a -> CmsgId
controlMessageId :: CmsgId

encodeCmsg :: ControlMessage a => a -> Cmsg
encodeCmsg :: forall a . ControlMessage a => a -> Cmsg
encodeCmsg x = unsafeDupablePerformIO $ do
bs <- create siz $ \p0 -> do
let p = castPtr p0
poke p x
return $ Cmsg (controlMessageId x) bs
let cmsid = controlMessageId @a
return $ Cmsg cmsid bs
where
siz = sizeOf x

decodeCmsg :: forall a . Storable a => Cmsg -> Maybe a
decodeCmsg (Cmsg _ (PS fptr off len))
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeCmsg (Cmsg cmsid (PS fptr off len))
| cid /= cmsid = Nothing
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
let p = castPtr (p0 `plusPtr` off)
Just <$> peek p
where
cid = controlMessageId @a
siz = sizeOf (undefined :: a)

----------------------------------------------------------------
Expand All @@ -117,31 +122,31 @@ newtype IPv4TTL = IPv4TTL CInt deriving (Eq, Show, Storable)
#endif

instance ControlMessage IPv4TTL where
controlMessageId _ = CmsgIdIPv4TTL
controlMessageId = CmsgIdIPv4TTL

----------------------------------------------------------------

-- | Hop limit of IPv6.
newtype IPv6HopLimit = IPv6HopLimit CInt deriving (Eq, Show, Storable)

instance ControlMessage IPv6HopLimit where
controlMessageId _ = CmsgIdIPv6HopLimit
controlMessageId = CmsgIdIPv6HopLimit

----------------------------------------------------------------

-- | TOS of IPv4.
newtype IPv4TOS = IPv4TOS CChar deriving (Eq, Show, Storable)

instance ControlMessage IPv4TOS where
controlMessageId _ = CmsgIdIPv4TOS
controlMessageId = CmsgIdIPv4TOS

----------------------------------------------------------------

-- | Traffic class of IPv6.
newtype IPv6TClass = IPv6TClass CInt deriving (Eq, Show, Storable)

instance ControlMessage IPv6TClass where
controlMessageId _ = CmsgIdIPv6TClass
controlMessageId = CmsgIdIPv6TClass

----------------------------------------------------------------

Expand All @@ -152,7 +157,7 @@ instance Show IPv4PktInfo where
show (IPv4PktInfo n sa ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple sa) ++ " " ++ show (hostAddressToTuple ha)

instance ControlMessage IPv4PktInfo where
controlMessageId _ = CmsgIdIPv4PktInfo
controlMessageId = CmsgIdIPv4PktInfo

instance Storable IPv4PktInfo where
sizeOf _ = (#size struct in_pktinfo)
Expand All @@ -176,7 +181,7 @@ instance Show IPv6PktInfo where
show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6)

instance ControlMessage IPv6PktInfo where
controlMessageId _ = CmsgIdIPv6PktInfo
controlMessageId = CmsgIdIPv6PktInfo

instance Storable IPv6PktInfo where
sizeOf _ = (#size struct in6_pktinfo)
Expand All @@ -192,4 +197,4 @@ instance Storable IPv6PktInfo where
----------------------------------------------------------------

instance ControlMessage Fd where
controlMessageId _ = CmsgIdFd
controlMessageId = CmsgIdFd
30 changes: 18 additions & 12 deletions Network/Socket/Win32/Cmsg.hsc
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Network.Socket.Win32.Cmsg where

Expand Down Expand Up @@ -77,24 +80,27 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs

-- | A class to encode and decode control message.
class Storable a => ControlMessage a where
controlMessageId :: a -> CmsgId
controlMessageId :: CmsgId

encodeCmsg :: ControlMessage a => a -> Cmsg
encodeCmsg :: forall a. ControlMessage a => a -> Cmsg
encodeCmsg x = unsafeDupablePerformIO $ do
bs <- create siz $ \p0 -> do
let p = castPtr p0
poke p x
return $ Cmsg (controlMessageId x) bs
let cmsid = controlMessageId @a
return $ Cmsg cmsid bs
where
siz = sizeOf x

decodeCmsg :: forall a . Storable a => Cmsg -> Maybe a
decodeCmsg (Cmsg _ (PS fptr off len))
| len < siz = Nothing
decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeCmsg (Cmsg cmsid (PS fptr off len))
| cid /= cmsid = Nothing
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
let p = castPtr (p0 `plusPtr` off)
Just <$> peek p
where
cid = controlMessageId @a
siz = sizeOf (undefined :: a)

----------------------------------------------------------------
Expand All @@ -103,31 +109,31 @@ decodeCmsg (Cmsg _ (PS fptr off len))
newtype IPv4TTL = IPv4TTL DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv4TTL where
controlMessageId _ = CmsgIdIPv4TTL
controlMessageId = CmsgIdIPv4TTL

----------------------------------------------------------------

-- | Hop limit of IPv6.
newtype IPv6HopLimit = IPv6HopLimit DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv6HopLimit where
controlMessageId _ = CmsgIdIPv6HopLimit
controlMessageId = CmsgIdIPv6HopLimit

----------------------------------------------------------------

-- | TOS of IPv4.
newtype IPv4TOS = IPv4TOS DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv4TOS where
controlMessageId _ = CmsgIdIPv4TOS
controlMessageId = CmsgIdIPv4TOS

----------------------------------------------------------------

-- | Traffic class of IPv6.
newtype IPv6TClass = IPv6TClass DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv6TClass where
controlMessageId _ = CmsgIdIPv6TClass
controlMessageId = CmsgIdIPv6TClass

----------------------------------------------------------------

Expand All @@ -138,7 +144,7 @@ instance Show IPv4PktInfo where
show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha)

instance ControlMessage IPv4PktInfo where
controlMessageId _ = CmsgIdIPv4PktInfo
controlMessageId = CmsgIdIPv4PktInfo

instance Storable IPv4PktInfo where
sizeOf = const #{size IN_PKTINFO}
Expand All @@ -160,7 +166,7 @@ instance Show IPv6PktInfo where
show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6)

instance ControlMessage IPv6PktInfo where
controlMessageId _ = CmsgIdIPv6PktInfo
controlMessageId = CmsgIdIPv6PktInfo

instance Storable IPv6PktInfo where
sizeOf = const #{size IN6_PKTINFO}
Expand Down
7 changes: 6 additions & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ environment:
- GHCVER: 8.2.2
- GHCVER: 8.4.4
- GHCVER: 8.6.5
- GHCVER: 8.8.3
# GHC 8.8.3 is broken due to a bug in process
# - GHCVER: 8.8.3

platform:
# - x86 # We may want to test x86 as well, but it would double the 23min build time.
Expand Down Expand Up @@ -54,6 +55,10 @@ before_build:
- cabal %CABOPTS% new-update -vverbose+nowrap
- IF EXIST configure.ac bash -c "autoreconf -i"

# Uncomment these lines to turn on remote desktop for AppVeyor
# on_finish:
# - ps: $blockRdp = $true; iex ((new-object net.webclient).DownloadString('https://raw.githubusercontent.com/appveyor/ci/master/scripts/enable-rdp.ps1'))

deploy: off

build_script:
Expand Down
19 changes: 14 additions & 5 deletions tests/Network/Socket/ByteStringSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
module Network.Socket.ByteStringSpec (main, spec) where

import Data.Bits
import Data.Maybe
import Control.Monad
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C
import Network.Socket
import Network.Socket.ByteString
import Network.Test.Common

import System.Environment

import Test.Hspec

main :: IO ()
Expand Down Expand Up @@ -228,18 +232,23 @@ spec = do
udpTest client server

it "receives control messages for IPv4" $ do
-- This test behaves strange on AppVeyor and I don't know why so skip
-- TOS for now.
isAppVeyor <- isJust <$> lookupEnv "APPVEYOR"
let server sock = do
whenSupported RecvIPv4TTL $ setSocketOption sock RecvIPv4TTL 1
whenSupported RecvIPv4TOS $ setSocketOption sock RecvIPv4TOS 1
whenSupported RecvIPv4PktInfo $ setSocketOption sock RecvIPv4PktInfo 1
whenSupported RecvIPv4TOS $ setSocketOption sock RecvIPv4TOS 1

(_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty

whenSupported RecvIPv4TTL $
((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing
whenSupported RecvIPv4TOS $
((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing
whenSupported RecvIPv4PktInfo $
((lookupCmsg CmsgIdIPv4PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing
when (not isAppVeyor) $ do
whenSupported RecvIPv4TTL $
((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing
whenSupported RecvIPv4TOS $
((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing
client sock addr = sendTo sock seg addr

seg = C.pack "This is a test message"
Expand Down

0 comments on commit 1d61d74

Please sign in to comment.