diff --git a/core/Network/TLS/Context.hs b/core/Network/TLS/Context.hs index 6d8835a95..12efa0b53 100644 --- a/core/Network/TLS/Context.hs +++ b/core/Network/TLS/Context.hs @@ -163,6 +163,7 @@ contextNew backend params = liftIO $ do , ctxShared = shared , ctxSupported = supported , ctxState = stvar + , ctxFragmentSize = Just 16384 , ctxTxState = tx , ctxRxState = rx , ctxHandshake = hs diff --git a/core/Network/TLS/Context/Internal.hs b/core/Network/TLS/Context/Internal.hs index 097590344..d441e1e60 100644 --- a/core/Network/TLS/Context/Internal.hs +++ b/core/Network/TLS/Context/Internal.hs @@ -115,6 +115,7 @@ data Context = Context , ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello. -- the flag will be set to false regardless of its initial value -- after the first packet received. + , ctxFragmentSize :: Maybe Int -- ^ maximum size of plaintext fragments , ctxTxState :: MVar RecordState -- ^ current tx state , ctxRxState :: MVar RecordState -- ^ current rx state , ctxHandshake :: MVar (Maybe HandshakeState) -- ^ optional handshake state diff --git a/core/Network/TLS/Core.hs b/core/Network/TLS/Core.hs index df01c49ac..613595b0f 100644 --- a/core/Network/TLS/Core.hs +++ b/core/Network/TLS/Core.hs @@ -101,7 +101,8 @@ sendData ctx dataToSend = liftIO $ do -- All chunks are protected with the same write lock because we don't -- want to interleave writes from other threads in the middle of our -- possibly large write. - mapM_ (mapChunks_ 16384 sendP) (L.toChunks dataToSend) + let len = ctxFragmentSize ctx + mapM_ (mapChunks_ len sendP) (L.toChunks dataToSend) -- | Get data out of Data packet, and automatically renegotiate if a Handshake -- ClientHello is received. An empty result means EOF. diff --git a/core/Network/TLS/Credentials.hs b/core/Network/TLS/Credentials.hs index 7ac46abae..1a0d15102 100644 --- a/core/Network/TLS/Credentials.hs +++ b/core/Network/TLS/Credentials.hs @@ -34,10 +34,8 @@ type Credential = (CertificateChain, PrivKey) newtype Credentials = Credentials [Credential] -#if MIN_VERSION_base(4,9,0) instance Semigroup Credentials where Credentials l1 <> Credentials l2 = Credentials (l1 ++ l2) -#endif instance Monoid Credentials where mempty = Credentials [] @@ -82,7 +80,7 @@ credentialLoadX509ChainFromMemory :: ByteString -> [ByteString] -> ByteString -> Either String Credential -credentialLoadX509ChainFromMemory certData chainData privateData = do +credentialLoadX509ChainFromMemory certData chainData privateData = let x509 = readSignedObjectFromMemory certData chains = map readSignedObjectFromMemory chainData keys = readKeyFileFromMemory privateData diff --git a/core/Network/TLS/Handshake/Client.hs b/core/Network/TLS/Handshake/Client.hs index 0dd65039c..fd4d1ef90 100644 --- a/core/Network/TLS/Handshake/Client.hs +++ b/core/Network/TLS/Handshake/Client.hs @@ -102,7 +102,7 @@ handshakeClient' cparams ctx groups mparams = do | otherwise -> throwCore $ Error_Protocol ("server-selected group is not supported", True, IllegalParameter) Just _ -> error "handshakeClient': invalid KeyShare value" Nothing -> throwCore $ Error_Protocol ("key exchange not implemented in HRR, expected key_share extension", True, HandshakeFailure) - else do + else handshakeClient13 cparams ctx groupToSend else do when rtt0 $ @@ -311,7 +311,8 @@ handshakeClient' cparams ctx groups mparams = do let ClientTrafficSecret clientEarlySecret = pairClient earlyKey runPacketFlight ctx $ sendChangeCipherSpec13 ctx setTxState ctx usedHash usedCipher clientEarlySecret - mapChunks_ 16384 (sendPacket13 ctx . AppData13) earlyData + let len = ctxFragmentSize ctx + mapChunks_ len (sendPacket13 ctx . AppData13) earlyData usingHState ctx $ setTLS13RTT0Status RTT0Sent recvServerHello clientSession sentExts = runRecvState ctx recvState diff --git a/core/Network/TLS/IO.hs b/core/Network/TLS/IO.hs index 265af9c84..07885aa07 100644 --- a/core/Network/TLS/IO.hs +++ b/core/Network/TLS/IO.hs @@ -85,13 +85,19 @@ sendBytes ctx dataToSend = liftIO $ do ---------------------------------------------------------------- +exceeds :: Integral ty => Context -> Int -> ty -> Bool +exceeds ctx overhead actual = + case ctxFragmentSize ctx of + Nothing -> False + Just sz -> fromIntegral actual > sz + overhead + getRecord :: Context -> Int -> Header -> ByteString -> IO (Either TLSError (Record Plaintext)) getRecord ctx appDataOverhead header@(Header pt _ _) content = do withLog ctx $ \logging -> loggingIORecv logging header content runRxState ctx $ do r <- decodeRecordM header content let Record _ _ fragment = r - when (B.length (fragmentGetBytes fragment) > 16384 + overhead) $ + when (exceeds ctx overhead $ B.length (fragmentGetBytes fragment)) $ throwError contentSizeExceeded return r where overhead = if pt == ProtocolType_AppData then appDataOverhead else 0 @@ -153,8 +159,8 @@ recvRecord compatSSLv2 appDataOverhead ctx where recvLengthE = either (return . Left) recvLength recvLength header@(Header _ _ readlen) - | readlen > 16384 + 2048 = return $ Left maximumSizeExceeded - | otherwise = + | exceeds ctx 2048 readlen = return $ Left maximumSizeExceeded + | otherwise = readExactBytes ctx (fromIntegral readlen) >>= either (return . Left) (getRecord ctx appDataOverhead header) #ifdef SSLV2_COMPATIBLE @@ -221,8 +227,8 @@ recvRecord13 :: Context recvRecord13 ctx = readExactBytes ctx 5 >>= either (return . Left) (recvLengthE . decodeHeader) where recvLengthE = either (return . Left) recvLength recvLength header@(Header _ _ readlen) - | readlen > 16384 + 256 = return $ Left maximumSizeExceeded - | otherwise = + | exceeds ctx 256 readlen = return $ Left maximumSizeExceeded + | otherwise = readExactBytes ctx (fromIntegral readlen) >>= either (return . Left) (getRecord ctx 0 header) diff --git a/core/Network/TLS/Imports.hs b/core/Network/TLS/Imports.hs index 8a18ecf72..98b5649e9 100644 --- a/core/Network/TLS/Imports.hs +++ b/core/Network/TLS/Imports.hs @@ -1,6 +1,5 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE NoImplicitPrelude #-} -{-# OPTIONS_GHC -fno-warn-dodgy-exports #-} -- Char8 -- | -- Module : Network.TLS.Imports @@ -13,7 +12,6 @@ module Network.TLS.Imports ( -- generic exports ByteString - , module Data.ByteString.Char8 -- instance , module Control.Applicative , module Control.Monad #if !MIN_VERSION_base(4,13,0) @@ -22,22 +20,15 @@ module Network.TLS.Imports , module Data.Bits , module Data.List , module Data.Maybe -#if MIN_VERSION_base(4,9,0) , module Data.Semigroup -#else - , module Data.Monoid -#endif , module Data.Ord , module Data.Word -#if !MIN_VERSION_base(4,8,0) - , sortOn -#endif -- project definition , showBytesHex ) where import Data.ByteString (ByteString) -import Data.ByteString.Char8 () +import Data.ByteString.Char8 () -- instance import Control.Applicative import Control.Monad @@ -47,25 +38,13 @@ import Control.Monad.Fail (MonadFail) import Data.Bits import Data.List import Data.Maybe hiding (fromJust) -#if MIN_VERSION_base(4,9,0) import Data.Semigroup -#else -import Data.Monoid -#endif import Data.Ord import Data.Word import Data.ByteArray.Encoding as B import qualified Prelude as P -#if !MIN_VERSION_base(4,8,0) -import Prelude ((.)) - -sortOn :: Ord b => (a -> b) -> [a] -> [a] -sortOn f = - map P.snd . sortBy (comparing P.fst) . map (\x -> let y = f x in y `P.seq` (y, x)) -#endif - showBytesHex :: ByteString -> P.String showBytesHex bs = P.show (B.convertToBase B.Base16 bs :: ByteString) diff --git a/core/Network/TLS/Sending.hs b/core/Network/TLS/Sending.hs index ddc1a63df..b42310127 100644 --- a/core/Network/TLS/Sending.hs +++ b/core/Network/TLS/Sending.hs @@ -39,7 +39,8 @@ encodePacket ctx pkt = do (ver, _) <- decideRecordVersion ctx let pt = packetType pkt mkRecord bs = Record pt ver (fragmentPlaintext bs) - records <- map mkRecord <$> packetToFragments ctx 16384 pkt + len = ctxFragmentSize ctx + records <- map mkRecord <$> packetToFragments ctx len pkt bs <- fmap B.concat <$> forEitherM records (encodeRecord ctx) when (pkt == ChangeCipherSpec) $ switchTxEncryption ctx return bs @@ -47,7 +48,7 @@ encodePacket ctx pkt = do -- Decompose handshake packets into fragments of the specified length. AppData -- packets are not fragmented here but by callers of sendPacket, so that the -- empty-packet countermeasure may be applied to each fragment independently. -packetToFragments :: Context -> Int -> Packet -> IO [ByteString] +packetToFragments :: Context -> Maybe Int -> Packet -> IO [ByteString] packetToFragments ctx len (Handshake hss) = getChunks len . B.concat <$> mapM (updateHandshake ctx ClientRole) hss packetToFragments _ _ (Alert a) = return [encodeAlerts a] diff --git a/core/Network/TLS/Sending13.hs b/core/Network/TLS/Sending13.hs index c19c27755..286cf6876 100644 --- a/core/Network/TLS/Sending13.hs +++ b/core/Network/TLS/Sending13.hs @@ -32,7 +32,8 @@ encodePacket13 :: Context -> Packet13 -> IO (Either TLSError ByteString) encodePacket13 ctx pkt = do let pt = contentType pkt mkRecord bs = Record pt TLS12 (fragmentPlaintext bs) - records <- map mkRecord <$> packetToFragments ctx 16384 pkt + len = ctxFragmentSize ctx + records <- map mkRecord <$> packetToFragments ctx len pkt fmap B.concat <$> forEitherM records (encodeRecord ctx) prepareRecord :: Context -> RecordM a -> IO (Either TLSError a) @@ -41,7 +42,7 @@ prepareRecord = runTxState encodeRecord :: Context -> Record Plaintext -> IO (Either TLSError ByteString) encodeRecord ctx = prepareRecord ctx . encodeRecordM -packetToFragments :: Context -> Int -> Packet13 -> IO [ByteString] +packetToFragments :: Context -> Maybe Int -> Packet13 -> IO [ByteString] packetToFragments ctx len (Handshake13 hss) = getChunks len . B.concat <$> mapM (updateHandshake13 ctx) hss packetToFragments _ _ (Alert13 a) = return [encodeAlerts a] diff --git a/core/Network/TLS/Util.hs b/core/Network/TLS/Util.hs index 1e6e0a383..18664ee80 100644 --- a/core/Network/TLS/Util.hs +++ b/core/Network/TLS/Util.hs @@ -86,15 +86,17 @@ forEitherM (x:xs) f = f x >>= doTail doTail (Left e) = return (Left e) mapChunks_ :: Monad m - => Int -> (B.ByteString -> m a) -> B.ByteString -> m () + => Maybe Int -> (B.ByteString -> m a) -> B.ByteString -> m () mapChunks_ len f = mapM_ f . getChunks len -getChunks :: Int -> B.ByteString -> [B.ByteString] -getChunks len bs - | B.length bs > len = - let (chunk, remain) = B.splitAt len bs - in chunk : getChunks len remain - | otherwise = [bs] +getChunks :: Maybe Int -> B.ByteString -> [B.ByteString] +getChunks Nothing = (: []) +getChunks (Just len) = go + where + go bs | B.length bs > len = + let (chunk, remain) = B.splitAt len bs + in chunk : go remain + | otherwise = [bs] -- | An opaque newtype wrapper to prevent from poking inside content that has -- been saved.