Skip to content

Commit

Permalink
Win32: sync with linux changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Mistuke committed May 3, 2020
1 parent f8cb3e2 commit 99e7f9a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 33 deletions.
30 changes: 18 additions & 12 deletions Network/Socket/Win32/Cmsg.hsc
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Network.Socket.Win32.Cmsg where

Expand Down Expand Up @@ -77,24 +80,27 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs

-- | A class to encode and decode control message.
class Storable a => ControlMessage a where
controlMessageId :: a -> CmsgId
controlMessageId :: CmsgId

encodeCmsg :: ControlMessage a => a -> Cmsg
encodeCmsg :: forall a. ControlMessage a => a -> Cmsg
encodeCmsg x = unsafeDupablePerformIO $ do
bs <- create siz $ \p0 -> do
let p = castPtr p0
poke p x
return $ Cmsg (controlMessageId x) bs
let cmsid = controlMessageId @a
return $ Cmsg cmsid bs
where
siz = sizeOf x

decodeCmsg :: forall a . Storable a => Cmsg -> Maybe a
decodeCmsg (Cmsg _ (PS fptr off len))
| len < siz = Nothing
decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeCmsg (Cmsg cmsid (PS fptr off len))
| cid /= cmsid = Nothing
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
let p = castPtr (p0 `plusPtr` off)
Just <$> peek p
where
cid = controlMessageId @a
siz = sizeOf (undefined :: a)

----------------------------------------------------------------
Expand All @@ -103,31 +109,31 @@ decodeCmsg (Cmsg _ (PS fptr off len))
newtype IPv4TTL = IPv4TTL DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv4TTL where
controlMessageId _ = CmsgIdIPv4TTL
controlMessageId = CmsgIdIPv4TTL

----------------------------------------------------------------

-- | Hop limit of IPv6.
newtype IPv6HopLimit = IPv6HopLimit DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv6HopLimit where
controlMessageId _ = CmsgIdIPv6HopLimit
controlMessageId = CmsgIdIPv6HopLimit

----------------------------------------------------------------

-- | TOS of IPv4.
newtype IPv4TOS = IPv4TOS DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv4TOS where
controlMessageId _ = CmsgIdIPv4TOS
controlMessageId = CmsgIdIPv4TOS

----------------------------------------------------------------

-- | Traffic class of IPv6.
newtype IPv6TClass = IPv6TClass DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv6TClass where
controlMessageId _ = CmsgIdIPv6TClass
controlMessageId = CmsgIdIPv6TClass

----------------------------------------------------------------

Expand All @@ -138,7 +144,7 @@ instance Show IPv4PktInfo where
show (IPv4PktInfo n ha) = "IPv4PktInfo " ++ show n ++ " " ++ show (hostAddressToTuple ha)

instance ControlMessage IPv4PktInfo where
controlMessageId _ = CmsgIdIPv4PktInfo
controlMessageId = CmsgIdIPv4PktInfo

instance Storable IPv4PktInfo where
sizeOf = const #{size IN_PKTINFO}
Expand All @@ -160,7 +166,7 @@ instance Show IPv6PktInfo where
show (IPv6PktInfo n ha6) = "IPv6PktInfo " ++ show n ++ " " ++ show (hostAddress6ToTuple ha6)

instance ControlMessage IPv6PktInfo where
controlMessageId _ = CmsgIdIPv6PktInfo
controlMessageId = CmsgIdIPv6PktInfo

instance Storable IPv6PktInfo where
sizeOf = const #{size IN6_PKTINFO}
Expand Down
14 changes: 8 additions & 6 deletions appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ environment:
DOCTEST: YES
matrix:
- GHCVER: 8.0.2
# - GHCVER: 8.2.2
# - GHCVER: 8.4.4
# - GHCVER: 8.6.5
# - GHCVER: 8.8.3
- GHCVER: 8.2.2
- GHCVER: 8.4.4
- GHCVER: 8.6.5
# GHC 8.8.3 is broken due to a bug in process
# - GHCVER: 8.8.3

platform:
# - x86 # We may want to test x86 as well, but it would double the 23min build time.
Expand Down Expand Up @@ -54,8 +55,9 @@ before_build:
- cabal %CABOPTS% new-update -vverbose+nowrap
- IF EXIST configure.ac bash -c "autoreconf -i"

on_finish:
- ps: $blockRdp = $true; iex ((new-object net.webclient).DownloadString('https://raw.githubusercontent.com/appveyor/ci/master/scripts/enable-rdp.ps1'))
# Uncomment these lines to turn on remote desktop for AppVeyor
# on_finish:
# - ps: $blockRdp = $true; iex ((new-object net.webclient).DownloadString('https://raw.githubusercontent.com/appveyor/ci/master/scripts/enable-rdp.ps1'))

deploy: off

Expand Down
32 changes: 17 additions & 15 deletions tests/Network/Socket/ByteStringSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
module Network.Socket.ByteStringSpec (main, spec) where

import Data.Bits
import Data.Maybe
import Control.Monad
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C
import Network.Socket
import Network.Socket.ByteString
import Network.Test.Common

import System.Environment

import Test.Hspec

main :: IO ()
Expand Down Expand Up @@ -228,25 +232,23 @@ spec = do
udpTest client server

it "receives control messages for IPv4" $ do
-- This test behaves strange on AppVeyor and I don't know why so skip
-- TOS for now.
isAppVeyor <- isJust <$> lookupEnv "APPVEYOR"
let server sock = do
--whenSupported RecvIPv4TTL $ setSocketOption sock RecvIPv4TTL 1
--whenSupported RecvIPv4TOS $ setSocketOption sock RecvIPv4TOS 1
--whenSupported RecvIPv4PktInfo $ setSocketOption sock RecvIPv4PktInfo 1
whenSupported RecvIPv4TTL $ setSocketOption sock RecvIPv4TTL 1
whenSupported RecvIPv4PktInfo $ setSocketOption sock RecvIPv4PktInfo 1
whenSupported RecvIPv4TOS $ setSocketOption sock RecvIPv4TOS 1

(_, _, cmsgs, _) <- recvMsg sock 1024 128 mempty
print RecvIPv4TTL
print RecvIPv4TOS
print RecvIPv4PktInfo
print cmsgs
print =<< getSocketOption sock RecvIPv4TOS
print CmsgIdIPv4TTL
print CmsgIdIPv4TOS
print CmsgIdIPv4PktInfo
whenSupported RecvIPv4TTL $
((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing
whenSupported RecvIPv4TOS $
((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing

whenSupported RecvIPv4PktInfo $
((lookupCmsg CmsgIdIPv4PktInfo cmsgs >>= decodeCmsg) :: Maybe IPv4PktInfo) `shouldNotBe` Nothing
when (not isAppVeyor) $ do
whenSupported RecvIPv4TTL $
((lookupCmsg CmsgIdIPv4TTL cmsgs >>= decodeCmsg) :: Maybe IPv4TTL) `shouldNotBe` Nothing
whenSupported RecvIPv4TOS $
((lookupCmsg CmsgIdIPv4TOS cmsgs >>= decodeCmsg) :: Maybe IPv4TOS) `shouldNotBe` Nothing
client sock addr = sendTo sock seg addr

seg = C.pack "This is a test message"
Expand Down

0 comments on commit 99e7f9a

Please sign in to comment.