diff --git a/core/Network/TLS.hs b/core/Network/TLS.hs index ca0905b56..2dae6b0da 100644 --- a/core/Network/TLS.hs +++ b/core/Network/TLS.hs @@ -56,6 +56,7 @@ module Network.TLS , Handshake , Logging(..) , contextHookSetHandshakeRecv + , contextHookSetHandshake13Recv , contextHookSetCertificateRecv , contextHookSetLogging , contextModifyHooks diff --git a/core/Network/TLS/Context.hs b/core/Network/TLS/Context.hs index 32e142a71..4e72fa0d4 100644 --- a/core/Network/TLS/Context.hs +++ b/core/Network/TLS/Context.hs @@ -49,6 +49,7 @@ module Network.TLS.Context -- * Context hooks , contextHookSetHandshakeRecv + , contextHookSetHandshake13Recv , contextHookSetCertificateRecv , contextHookSetLogging @@ -201,6 +202,10 @@ contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO () contextHookSetHandshakeRecv context f = contextModifyHooks context (\hooks -> hooks { hookRecvHandshake = f }) +contextHookSetHandshake13Recv :: Context -> (Handshake13 -> IO Handshake13) -> IO () +contextHookSetHandshake13Recv context f = + contextModifyHooks context (\hooks -> hooks { hookRecvHandshake13 = f }) + contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO () contextHookSetCertificateRecv context f = contextModifyHooks context (\hooks -> hooks { hookRecvCertificates = f }) diff --git a/core/Network/TLS/Handshake/Client.hs b/core/Network/TLS/Handshake/Client.hs index da41bb37d..55e61b4ad 100644 --- a/core/Network/TLS/Handshake/Client.hs +++ b/core/Network/TLS/Handshake/Client.hs @@ -99,7 +99,7 @@ handshakeClient' cparams ctx groups mcrand = do Just _ -> error "handshakeClient': invalid KeyShare value" Nothing -> throwCore $ Error_Protocol ("key exchange not implemented in HRR, expected key_share extension", True, HandshakeFailure) else do - handshakeClient13 cparams ctx + handshakeClient13 cparams ctx groupToSend else do sessionResuming <- usingState_ ctx isSessionResuming if sessionResuming @@ -112,6 +112,7 @@ handshakeClient' cparams ctx groups mcrand = do compressions = supportedCompressions $ ctxSupported ctx highestVer = maximum $ supportedVersions $ ctxSupported ctx tls13 = highestVer >= TLS13 + groupToSend = listToMaybe groups getExtensions pskInfo rtt0 = sequence [ sniExtension , secureReneg @@ -166,9 +167,9 @@ handshakeClient' cparams ctx groups mcrand = do -- FIXME keyshareExtension - | tls13 = case groups of - [] -> return Nothing - grp:_ -> do + | tls13 = case groupToSend of + Nothing -> return Nothing + Just grp -> do (cpri, ent) <- makeClientKeyShare ctx grp usingHState ctx $ setGroupPrivate cpri return $ Just $ toExtensionRaw $ KeyShareClientHello [ent] @@ -792,14 +793,14 @@ requiredCertKeyUsage cipher = , KeyUsage_keyAgreement ] -handshakeClient13 :: ClientParams -> Context -> IO () -handshakeClient13 _cparams ctx = do +handshakeClient13 :: ClientParams -> Context -> Maybe Group -> IO () +handshakeClient13 cparams ctx groupSent = do usedCipher <- usingHState ctx getPendingCipher let usedHash = cipherHash usedCipher - handshakeClient13' _cparams ctx usedCipher usedHash + handshakeClient13' cparams ctx groupSent usedCipher usedHash -handshakeClient13' :: ClientParams -> Context -> Cipher -> Hash -> IO () -handshakeClient13' cparams ctx usedCipher usedHash = do +handshakeClient13' :: ClientParams -> Context -> Maybe Group -> Cipher -> Hash -> IO () +handshakeClient13' cparams ctx groupSent usedCipher usedHash = do (resuming, handshakeSecret, clientHandshakeTrafficSecret, serverHandshakeTrafficSecret) <- switchToHandshakeSecret rtt0accepted <- runRecvHandshake13 $ do accepted <- recvHandshake13preUpdate ctx expectEncryptedExtensions @@ -849,7 +850,8 @@ handshakeClient13' cparams ctx usedCipher usedHash = do Just _ -> error "calcSharedKey: invalid KeyShare value" Nothing -> throwCore $ Error_Protocol ("key exchange not implemented, expected key_share extension", True, HandshakeFailure) let grp = keyShareEntryGroup serverKeyShare - checkSupportedGroup ctx grp + unless (groupSent == Just grp) $ + throwCore $ Error_Protocol ("received incompatible group for (EC)DHE", True, IllegalParameter) usingHState ctx $ setNegotiatedGroup grp usingHState ctx getGroupPrivate >>= fromServerKeyShare serverKeyShare diff --git a/core/Network/TLS/Handshake/Server.hs b/core/Network/TLS/Handshake/Server.hs index 6f4f57ad2..64448a70e 100644 --- a/core/Network/TLS/Handshake/Server.hs +++ b/core/Network/TLS/Handshake/Server.hs @@ -886,18 +886,25 @@ doHandshake13 sparams ctx allCreds chosenVersion usedCipher exts usedHash client loadPacket13 ctx $ Handshake13 [vrfy] sendExtensions rtt0OK = do - extensions' <- liftIO $ applicationProtocol ctx exts sparams + protoExt <- liftIO $ applicationProtocol ctx exts sparams msni <- liftIO $ usingState_ ctx getClientSNI - let extensions'' = case msni of + let sniExtension = case msni of -- RFC6066: In this event, the server SHALL include -- an extension of type "server_name" in the -- (extended) server hello. The "extension_data" -- field of this extension SHALL be empty. - Just _ -> ExtensionRaw extensionID_ServerName "" : extensions' - Nothing -> extensions' - let extensions - | rtt0OK = ExtensionRaw extensionID_EarlyData (extensionEncode (EarlyDataIndication Nothing)) : extensions'' - | otherwise = extensions'' + Just _ -> Just $ ExtensionRaw extensionID_ServerName "" + Nothing -> Nothing + mgroup <- usingHState ctx getNegotiatedGroup + let serverGroups = supportedGroups (ctxSupported ctx) + groupExtension + | null serverGroups = Nothing + | maybe True (== head serverGroups) mgroup = Nothing + | otherwise = Just $ ExtensionRaw extensionID_NegotiatedGroups $ extensionEncode (NegotiatedGroups serverGroups) + let earlyDataExtension + | rtt0OK = Just $ ExtensionRaw extensionID_EarlyData $ extensionEncode (EarlyDataIndication Nothing) + | otherwise = Nothing + let extensions = catMaybes [earlyDataExtension, groupExtension, sniExtension] ++ protoExt loadPacket13 ctx $ Handshake13 [EncryptedExtensions13 extensions] sendNewSessionTicket masterSecret sfSentTime = when sendNST $ do diff --git a/core/Network/TLS/Hooks.hs b/core/Network/TLS/Hooks.hs index 942523a1c..8ca05e3aa 100644 --- a/core/Network/TLS/Hooks.hs +++ b/core/Network/TLS/Hooks.hs @@ -12,7 +12,8 @@ module Network.TLS.Hooks ) where import qualified Data.ByteString as B -import Network.TLS.Struct (Header, Handshake(..)) +import Network.TLS.Struct (Header, Handshake) +import Network.TLS.Struct13 (Handshake13) import Network.TLS.X509 (CertificateChain) import Data.Default.Class @@ -41,6 +42,8 @@ instance Default Logging where data Hooks = Hooks { -- | called at each handshake message received hookRecvHandshake :: Handshake -> IO Handshake + -- | called at each handshake message received for TLS 1.3 + , hookRecvHandshake13 :: Handshake13 -> IO Handshake13 -- | called at each certificate chain message received , hookRecvCertificates :: CertificateChain -> IO () -- | hooks on IO and packets, receiving and sending. @@ -50,6 +53,7 @@ data Hooks = Hooks defaultHooks :: Hooks defaultHooks = Hooks { hookRecvHandshake = return + , hookRecvHandshake13 = return , hookRecvCertificates = return . const () , hookLogging = def } diff --git a/core/Network/TLS/IO.hs b/core/Network/TLS/IO.hs index d56824536..8859db3a1 100644 --- a/core/Network/TLS/IO.hs +++ b/core/Network/TLS/IO.hs @@ -203,12 +203,17 @@ recvPacket13 ctx = liftIO $ do _ -> return $ Left err Left err -> return $ Left err Right record -> do - pkt <- processPacket13 ctx record - if isEmptyHandshake13 pkt then + pktRecv <- processPacket13 ctx record + if isEmptyHandshake13 pktRecv then -- When a handshake record is fragmented we continue receiving -- in order to feed stHandshakeRecordCont13 recvPacket13 ctx else do + pkt <- case pktRecv of + Right (Handshake13 hss) -> + ctxWithHooks ctx $ \hooks -> + Right . Handshake13 <$> mapM (hookRecvHandshake13 hooks) hss + _ -> return pktRecv case pkt of Right p -> withLog ctx $ \logging -> loggingPacketRecv logging $ show p _ -> return () diff --git a/core/Network/TLS/Internal.hs b/core/Network/TLS/Internal.hs index 2287d12ce..16e166e04 100644 --- a/core/Network/TLS/Internal.hs +++ b/core/Network/TLS/Internal.hs @@ -8,7 +8,9 @@ -- module Network.TLS.Internal ( module Network.TLS.Struct + , module Network.TLS.Struct13 , module Network.TLS.Packet + , module Network.TLS.Packet13 , module Network.TLS.Receiving , module Network.TLS.Sending , module Network.TLS.Wire @@ -17,7 +19,9 @@ module Network.TLS.Internal ) where import Network.TLS.Struct +import Network.TLS.Struct13 import Network.TLS.Packet +import Network.TLS.Packet13 import Network.TLS.Receiving import Network.TLS.Sending import Network.TLS.Wire diff --git a/core/Tests/Certificate.hs b/core/Tests/Certificate.hs index bfc4a101e..620d34cf9 100644 --- a/core/Tests/Certificate.hs +++ b/core/Tests/Certificate.hs @@ -5,6 +5,7 @@ module Certificate ( arbitraryX509 , arbitraryX509WithKey , arbitraryX509WithKeyAndUsage + , arbitraryDN , arbitraryKeyUsage , simpleCertificate , simpleX509 diff --git a/core/Tests/Marshalling.hs b/core/Tests/Marshalling.hs index 07ac1b450..6803f1ee9 100644 --- a/core/Tests/Marshalling.hs +++ b/core/Tests/Marshalling.hs @@ -1,5 +1,10 @@ {-# OPTIONS_GHC -fno-warn-orphans #-} -module Marshalling where +module Marshalling + ( someWords8 + , prop_header_marshalling_id + , prop_handshake_marshalling_id + , prop_handshake13_marshalling_id + ) where import Control.Monad import Control.Applicative @@ -9,14 +14,14 @@ import Network.TLS import qualified Data.ByteString as B import Data.Word -import Data.X509 +import Data.X509 (CertificateChain(..)) import Certificate genByteString :: Int -> Gen B.ByteString genByteString i = B.pack <$> vector i instance Arbitrary Version where - arbitrary = elements [ SSL2, SSL3, TLS10, TLS11, TLS12 ] + arbitrary = elements [ SSL2, SSL3, TLS10, TLS11, TLS12, TLS13 ] instance Arbitrary ProtocolType where arbitrary = elements @@ -41,6 +46,34 @@ instance Arbitrary Session where 2 -> Session . Just <$> genByteString 32 _ -> return $ Session Nothing +instance Arbitrary HashAlgorithm where + arbitrary = elements + [ Network.TLS.HashNone + , Network.TLS.HashMD5 + , Network.TLS.HashSHA1 + , Network.TLS.HashSHA224 + , Network.TLS.HashSHA256 + , Network.TLS.HashSHA384 + , Network.TLS.HashSHA512 + , Network.TLS.HashIntrinsic + ] + +instance Arbitrary SignatureAlgorithm where + arbitrary = elements + [ SignatureAnonymous + , SignatureRSA + , SignatureDSS + , SignatureECDSA + , SignatureRSApssRSAeSHA256 + , SignatureRSApssRSAeSHA384 + , SignatureRSApssRSAeSHA512 + , SignatureEd25519 + , SignatureEd448 + , SignatureRSApsspssSHA256 + , SignatureRSApsspssSHA384 + , SignatureRSApsspssSHA512 + ] + instance Arbitrary DigitallySigned where arbitrary = DigitallySigned Nothing <$> genByteString 32 @@ -53,6 +86,16 @@ arbitraryCompressionIDs = choose (0,200) >>= vector someWords8 :: Int -> Gen [Word8] someWords8 = vector +instance Arbitrary ExtensionRaw where + arbitrary = + let arbitraryContent = choose (0,40) >>= genByteString + in ExtensionRaw <$> arbitrary <*> arbitraryContent + +arbitraryHelloExtensions :: Version -> Gen [ExtensionRaw] +arbitraryHelloExtensions ver + | ver >= SSL3 = arbitrary + | otherwise = return [] -- no hello extension with SSLv2 + instance Arbitrary CertificateType where arbitrary = elements [ CertificateType_RSA_Sign, CertificateType_DSS_Sign @@ -62,31 +105,54 @@ instance Arbitrary CertificateType where instance Arbitrary Handshake where arbitrary = oneof - [ ClientHello + [ arbitrary >>= \ver -> ClientHello ver <$> arbitrary <*> arbitrary - <*> arbitrary <*> arbitraryCiphersIDs <*> arbitraryCompressionIDs - <*> return [] + <*> arbitraryHelloExtensions ver <*> return Nothing - , ServerHello + , arbitrary >>= \ver -> ServerHello ver <$> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary - <*> arbitrary - <*> return [] + <*> arbitraryHelloExtensions ver , Certificates . CertificateChain <$> resize 2 (listOf arbitraryX509) , pure HelloRequest , pure ServerHelloDone , ClientKeyXchg . CKX_RSA <$> genByteString 48 --, liftM ServerKeyXchg - , liftM3 CertRequest arbitrary (return Nothing) (return []) + , liftM3 CertRequest arbitrary (return Nothing) (listOf arbitraryDN) , CertVerify <$> arbitrary , Finished <$> genByteString 12 ] +arbitraryCertReqContext :: Gen B.ByteString +arbitraryCertReqContext = oneof [ return B.empty, genByteString 32 ] + +instance Arbitrary Handshake13 where + arbitrary = oneof + [ NewSessionTicket13 + <$> arbitrary + <*> arbitrary + <*> genByteString 32 -- nonce + <*> genByteString 32 -- session ID + <*> arbitrary + , pure EndOfEarlyData13 + , EncryptedExtensions13 <$> arbitrary + , CertRequest13 + <$> arbitraryCertReqContext + <*> arbitrary + , resize 2 (listOf arbitraryX509) >>= \certs -> Certificate13 + <$> arbitraryCertReqContext + <*> return (CertificateChain certs) + <*> replicateM (length certs) arbitrary + , CertVerify13 <$> arbitrary <*> genByteString 32 + , Finished13 <$> genByteString 12 + , KeyUpdate13 <$> elements [ UpdateNotRequested, UpdateRequested ] + ] + {- quickcheck property -} prop_header_marshalling_id :: Header -> Bool @@ -94,9 +160,17 @@ prop_header_marshalling_id x = decodeHeader (encodeHeader x) == Right x prop_handshake_marshalling_id :: Handshake -> Bool prop_handshake_marshalling_id x = decodeHs (encodeHandshake x) == Right x - where decodeHs b = case decodeHandshakeRecord b of - GotPartial _ -> error "got partial" - GotError e -> error ("got error: " ++ show e) - GotSuccessRemaining _ _ -> error "got remaining byte left" - GotSuccess (ty, content) -> decodeHandshake cp ty content + where decodeHs b = verifyResult (decodeHandshake cp) $ decodeHandshakeRecord b cp = CurrentParams { cParamsVersion = TLS10, cParamsKeyXchgType = Just CipherKeyExchange_RSA } + +prop_handshake13_marshalling_id :: Handshake13 -> Bool +prop_handshake13_marshalling_id x = decodeHs (encodeHandshake13 x) == Right x + where decodeHs b = verifyResult decodeHandshake13 $ decodeHandshakeRecord13 b + +verifyResult :: (t -> b -> r) -> GetResult (t, b) -> r +verifyResult fn result = + case result of + GotPartial _ -> error "got partial" + GotError e -> error ("got error: " ++ show e) + GotSuccessRemaining _ _ -> error "got remaining byte left" + GotSuccess (ty, content) -> fn ty content diff --git a/core/Tests/Tests.hs b/core/Tests/Tests.hs index 62dc6760e..ae5511053 100644 --- a/core/Tests/Tests.hs +++ b/core/Tests/Tests.hs @@ -20,6 +20,7 @@ import qualified Data.ByteString.Char8 as C8 import qualified Data.ByteString.Lazy as L import Network.TLS import Network.TLS.Extra +import Network.TLS.Internal import Control.Applicative import Control.Concurrent import Control.Concurrent.Async @@ -116,6 +117,30 @@ runTLSPipeSimple13 params mode mEarlyData = runTLSPipe params tlsServer tlsClien Just mode `assertEq` (minfo >>= infoTLS13HandshakeMode) byeBye ctx +runTLSPipeCapture13 :: (ClientParams, ServerParams) -> PropertyM IO ([Handshake13], [Handshake13]) +runTLSPipeCapture13 params = do + sRef <- run $ newIORef [] + cRef <- run $ newIORef [] + runTLSPipe params (tlsServer sRef) (tlsClient cRef) + sReceived <- run $ readIORef sRef + cReceived <- run $ readIORef cRef + return (reverse sReceived, reverse cReceived) + where tlsServer ref ctx queue = do + installHook ctx ref + handshake ctx + d <- recvData ctx + writeChan queue [d] + bye ctx + tlsClient ref queue ctx = do + installHook ctx ref + handshake ctx + d <- readChan queue + sendData ctx (L.fromChunks [d]) + byeBye ctx + installHook ctx ref = + let recv hss = modifyIORef ref (hss :) >> return hss + in contextHookSetHandshake13Recv ctx recv + runTLSPipeSimpleKeyUpdate :: (ClientParams, ServerParams) -> PropertyM IO () runTLSPipeSimpleKeyUpdate params = runTLSPipeN 3 params tlsServer tlsClient where tlsServer ctx queue = do @@ -415,6 +440,20 @@ prop_handshake13_rtt0_length = do | otherwise = (RTT0, Just earlyData) runTLSPipeSimple13 params2 mode mEarlyData +prop_handshake13_ee_groups :: PropertyM IO () +prop_handshake13_ee_groups = do + (cli, srv) <- pick arbitraryPairParams13 + let cliSupported = (clientSupported cli) { supportedGroups = [P256,X25519] } + svrSupported = (serverSupported srv) { supportedGroups = [X25519,P256] } + params = (cli { clientSupported = cliSupported } + ,srv { serverSupported = svrSupported } + ) + (_, serverMessages) <- runTLSPipeCapture13 params + let isNegotiatedGroups (ExtensionRaw eid _) = eid == 0xa + eeMessagesHaveExt = [ any isNegotiatedGroups exts | + EncryptedExtensions13 exts <- serverMessages ] + [True] `assertEq` eeMessagesHaveExt -- one EE message with extension + prop_handshake_ciphersuites :: PropertyM IO () prop_handshake_ciphersuites = do tls13 <- pick arbitrary @@ -833,6 +872,7 @@ main = defaultMain $ testGroup "tls" tests_marshalling = testGroup "Marshalling" [ testProperty "Header" prop_header_marshalling_id , testProperty "Handshake" prop_handshake_marshalling_id + , testProperty "Handshake13" prop_handshake13_marshalling_id ] tests_ciphers = testGroup "Ciphers" [ testProperty "Bulk" propertyBulkFunctional ] @@ -864,6 +904,7 @@ main = defaultMain $ testGroup "tls" , testProperty "TLS 1.3 RTT0" (monadicIO prop_handshake13_rtt0) , testProperty "TLS 1.3 RTT0 -> PSK" (monadicIO prop_handshake13_rtt0_fallback) , testProperty "TLS 1.3 RTT0 length" (monadicIO prop_handshake13_rtt0_length) + , testProperty "TLS 1.3 EE groups" (monadicIO prop_handshake13_ee_groups) , testProperty "TLS 1.3 Post-handshake auth" (monadicIO prop_post_handshake_auth) ]