Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checking EOF in bye (#261). #262

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/Network/TLS/Backend.hs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ safeRecv s buf = do
instance HasBackend Network.Socket where
initializeBackend _ = return ()
getBackend sock = Backend (return ()) (Network.close sock) (Network.sendAll sock) recvAll
where recvAll n = B.concat `fmap` loop n
where recvAll n = B.concat <$> loop n
where loop 0 = return []
loop left = do
r <- safeRecv sock left
Expand All @@ -86,7 +86,7 @@ instance HasBackend Hans.Socket where
initializeBackend _ = return ()
getBackend sock = Backend (return ()) (Hans.close sock) sendAll recvAll
where sendAll x = do
amt <- fromIntegral `fmap` Hans.sendBytes sock (L.fromStrict x)
amt <- fromIntegral <$> Hans.sendBytes sock (L.fromStrict x)
if (amt == 0) || (amt == B.length x)
then return ()
else sendAll (B.drop amt x)
Expand Down
2 changes: 1 addition & 1 deletion core/Network/TLS/Cipher.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ instance Show BulkState where
show (BulkStateStream _) = "BulkStateStream"
show (BulkStateBlock _) = "BulkStateBlock"
show (BulkStateAEAD _) = "BulkStateAEAD"
show (BulkStateUninitialized) = "BulkStateUninitialized"
show BulkStateUninitialized = "BulkStateUninitialized"

newtype BulkStream = BulkStream (B.ByteString -> (B.ByteString, BulkStream))

Expand Down
2 changes: 1 addition & 1 deletion core/Network/TLS/Compression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ instance Eq Compression where
-- the function keeps the list of compression in order, to be able to find quickly the prefered
-- compression.
compressionIntersectID :: [Compression] -> [Word8] -> [Compression]
compressionIntersectID l ids = filter (\c -> elem (compressionID c) ids) l
compressionIntersectID l ids = filter (\c -> compressionID c `elem` ids) l

-- | This is the default compression which is a NOOP.
data NullCompression = NullCompression
Expand Down
4 changes: 2 additions & 2 deletions core/Network/TLS/Context.hs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ contextNew backend params = liftIO $ do

seed <- case debugSeed debug of
Nothing -> do seed <- seedNew
debugPrintSeed debug $ seed
debugPrintSeed debug seed
return seed
Just determ -> return determ
let rng = newStateRNG seed
Expand All @@ -145,7 +145,7 @@ contextNew backend params = liftIO $ do
lockRead <- newMVar ()
lockState <- newMVar ()

return $ Context
return Context
{ ctxConnection = getBackend backend
, ctxShared = shared
, ctxSupported = supported
Expand Down
2 changes: 1 addition & 1 deletion core/Network/TLS/Context/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ usingHState :: Context -> HandshakeM a -> IO a
usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst ->
case mst of
Nothing -> throwCore $ Error_Misc "missing handshake"
Just st -> return $ swap (Just `fmap` runHandshake st f)
Just st -> return $ swap (Just <$> runHandshake st f)

getHState :: Context -> IO (Maybe HandshakeState)
getHState ctx = liftIO $ readMVar (ctxHandshake ctx)
Expand Down
6 changes: 4 additions & 2 deletions core/Network/TLS/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ import Control.Monad.State.Strict
--
-- this doesn't actually close the handle
bye :: MonadIO m => Context -> m ()
bye ctx = sendPacket ctx $ Alert [(AlertLevel_Warning, CloseNotify)]
bye ctx = do
eof <- liftIO $ ctxEOF ctx
unless eof $ sendPacket ctx $ Alert [(AlertLevel_Warning, CloseNotify)]

-- | If the ALPN extensions have been used, this will
-- return get the protocol agreed upon.
Expand Down Expand Up @@ -91,7 +93,7 @@ recvData ctx = liftIO $ do
onError err =
terminate err AlertLevel_Fatal InternalError (show err)

process (Handshake [ch@(ClientHello {})]) =
process (Handshake [ch@ClientHello{}]) =
handshakeWith ctx ch >> recvData ctx
process (Handshake [hr@HelloRequest]) =
handshakeWith ctx hr >> recvData ctx
Expand Down
2 changes: 1 addition & 1 deletion core/Network/TLS/Credentials.hs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ credentialLoadX509ChainFromMemory certData chainData privateData = do
(k:_) -> Right (CertificateChain . concat $ x509 : chains, k)

credentialsListSigningAlgorithms :: Credentials -> [DigitalSignatureAlg]
credentialsListSigningAlgorithms (Credentials l) = catMaybes $ map credentialCanSign l
credentialsListSigningAlgorithms (Credentials l) = mapMaybe credentialCanSign l

credentialsFindForSigning :: DigitalSignatureAlg -> Credentials -> Maybe Credential
credentialsFindForSigning sigAlg (Credentials l) = find forSigning l
Expand Down
12 changes: 6 additions & 6 deletions core/Network/TLS/Crypto.hs
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ generalizeRSAError (Left e) = Left (RSAError e)
generalizeRSAError (Right x) = Right x

kxEncrypt :: MonadRandom r => PublicKey -> ByteString -> r (Either KxError ByteString)
kxEncrypt (PubKeyRSA pk) b = generalizeRSAError `fmap` RSA.encrypt pk b
kxEncrypt (PubKeyRSA pk) b = generalizeRSAError <$> RSA.encrypt pk b
kxEncrypt _ _ = return (Left KxUnsupported)

kxDecrypt :: MonadRandom r => PrivateKey -> ByteString -> r (Either KxError ByteString)
kxDecrypt (PrivKeyRSA pk) b = generalizeRSAError `fmap` RSA.decryptSafer pk b
kxDecrypt (PrivKeyRSA pk) b = generalizeRSAError <$> RSA.decryptSafer pk b
kxDecrypt _ _ = return (Left KxUnsupported)

data RSAEncoding = RSApkcs1 | RSApss deriving (Show,Eq)
Expand Down Expand Up @@ -208,10 +208,10 @@ kxVerify (PubKeyDSA pk) DSSParams msg signBS =
Right asn1 ->
case asn1 of
Start Sequence:IntVal r:IntVal s:End Sequence:_ ->
Just $ DSA.Signature { DSA.sign_r = r, DSA.sign_s = s }
Just DSA.Signature { DSA.sign_r = r, DSA.sign_s = s }
_ ->
Nothing
kxVerify (PubKeyEC key) (ECDSAParams alg) msg sigBS = maybe False id $ do
kxVerify (PubKeyEC key) (ECDSAParams alg) msg sigBS = fromMaybe False $ do
-- get the curve name and the public key data
(curveName, pubBS) <- case key of
PubKeyEC_Named curveName' pub -> Just (curveName',pub)
Expand Down Expand Up @@ -274,9 +274,9 @@ kxSign :: MonadRandom r
-> ByteString
-> r (Either KxError ByteString)
kxSign (PrivKeyRSA pk) (RSAParams hashAlg RSApkcs1) msg =
generalizeRSAError `fmap` rsaSignHash hashAlg pk msg
generalizeRSAError <$> rsaSignHash hashAlg pk msg
kxSign (PrivKeyRSA pk) (RSAParams hashAlg RSApss) msg =
generalizeRSAError `fmap` rsapssSignHash hashAlg pk msg
generalizeRSAError <$> rsapssSignHash hashAlg pk msg
kxSign (PrivKeyDSA pk) DSSParams msg = do
sign <- DSA.sign pk H.SHA1 msg
return (Right $ encodeASN1' DER $ dsaSequence sign)
Expand Down
8 changes: 4 additions & 4 deletions core/Network/TLS/Extension.hs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ instance Extension ServerName where
extensionEncode (ServerName l) = runPut $ putOpaque16 (runPut $ mapM_ encodeNameType l)
where encodeNameType (ServerNameHostName hn) = putWord8 0 >> putOpaque16 (BC.pack hn) -- FIXME: should be puny code conversion
encodeNameType (ServerNameOther (nt,opaque)) = putWord8 nt >> putBytes opaque
extensionDecode _ = runGetMaybe (getWord16 >>= \len -> getList (fromIntegral len) getServerName >>= return . ServerName)
extensionDecode _ = runGetMaybe (getWord16 >>= \len -> ServerName <$> getList (fromIntegral len) getServerName)
where getServerName = do
ty <- getWord8
sname <- getOpaque16
Expand Down Expand Up @@ -229,7 +229,7 @@ newtype NegotiatedGroups = NegotiatedGroups [Group] deriving (Show,Eq)
instance Extension NegotiatedGroups where
extensionID _ = extensionID_NegotiatedGroups
extensionEncode (NegotiatedGroups groups) = runPut $ putWords16 $ map fromEnumSafe16 groups
extensionDecode _ = runGetMaybe (NegotiatedGroups . catMaybes . map toEnumSafe16 <$> getWords16)
extensionDecode _ = runGetMaybe (NegotiatedGroups . mapMaybe toEnumSafe16 <$> getWords16)

newtype EcPointFormatsSupported = EcPointFormatsSupported [EcPointFormat] deriving (Show,Eq)

Expand All @@ -253,14 +253,14 @@ instance EnumSafe8 EcPointFormat where
instance Extension EcPointFormatsSupported where
extensionID _ = extensionID_EcPointFormats
extensionEncode (EcPointFormatsSupported formats) = runPut $ putWords8 $ map fromEnumSafe8 formats
extensionDecode _ = runGetMaybe (EcPointFormatsSupported . catMaybes . map toEnumSafe8 <$> getWords8)
extensionDecode _ = runGetMaybe (EcPointFormatsSupported . mapMaybe toEnumSafe8 <$> getWords8)

data SessionTicket = SessionTicket
deriving (Show,Eq)

instance Extension SessionTicket where
extensionID _ = extensionID_SessionTicket
extensionEncode (SessionTicket {}) = runPut $ return ()
extensionEncode SessionTicket{} = runPut $ return ()
extensionDecode _ = runGetMaybe (return SessionTicket)

newtype HeartBeat = HeartBeat HeartBeatMode deriving (Show,Eq)
Expand Down
5 changes: 3 additions & 2 deletions core/Network/TLS/Extra/Cipher.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import qualified Data.ByteString as B

import Network.TLS.Types (Version(..))
import Network.TLS.Cipher
import Network.TLS.Imports
import Data.Tuple (swap)

import Crypto.Cipher.AES
Expand Down Expand Up @@ -123,7 +124,7 @@ noFail :: CryptoFailable a -> a
noFail = throwCryptoError

makeIV_ :: BlockCipher a => B.ByteString -> IV a
makeIV_ = maybe (error "makeIV_") id . makeIV
makeIV_ = fromMaybe (error "makeIV_") . makeIV

tripledes_ede :: BulkDirection -> BulkKey -> BulkBlock
tripledes_ede BulkEncrypt key =
Expand All @@ -134,7 +135,7 @@ tripledes_ede BulkDecrypt key =
in (\iv input -> let output = cbcDecrypt ctx (tripledes_iv iv) input in (output, takelast 8 input))

tripledes_iv :: BulkIV -> IV DES_EDE3
tripledes_iv iv = maybe (error "tripledes cipher iv internal error") id $ makeIV iv
tripledes_iv iv = fromMaybe (error "tripledes cipher iv internal error") $ makeIV iv

rc4 :: BulkDirection -> BulkKey -> BulkStream
rc4 _ bulkKey = BulkStream (combineRC4 $ RC4.initialize bulkKey)
Expand Down
7 changes: 4 additions & 3 deletions core/Network/TLS/Handshake.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.IO
import Network.TLS.Util (catchException)
import Network.TLS.Imports

import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Client
Expand All @@ -30,17 +31,17 @@ import Control.Exception (IOException, catch, fromException)
-- This is to be called at the beginning of a connection, and during renegotiation
handshake :: MonadIO m => Context -> m ()
handshake ctx =
liftIO $ handleException ctx $ withRWLock ctx (ctxDoHandshake ctx $ ctx)
liftIO $ handleException ctx $ withRWLock ctx (ctxDoHandshake ctx ctx)

-- Handshake when requested by the remote end
-- This is called automatically by 'recvData'
handshakeWith :: MonadIO m => Context -> Handshake -> m ()
handshakeWith ctx hs =
liftIO $ handleException ctx $ withRWLock ctx ((ctxDoHandshakeWith ctx) ctx hs)
liftIO $ handleException ctx $ withRWLock ctx $ ctxDoHandshakeWith ctx ctx hs

handleException :: Context -> IO () -> IO ()
handleException ctx f = catchException f $ \exception -> do
let tlserror = maybe (Error_Misc $ show exception) id $ fromException exception
let tlserror = fromMaybe (Error_Misc $ show exception) $ fromException exception
setEstablished ctx False
sendPacket ctx (errorToAlert tlserror) `catch` ignoreIOErr
handshakeFailed tlserror
Expand Down
69 changes: 33 additions & 36 deletions core/Network/TLS/Handshake/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ handshakeClient cparams ctx = do
signatureAlgExtension = return $ Just $ toExtensionRaw $ SignatureAlgorithms $ supportedHashSignatures $ clientSupported cparams

sendClientHello = do
crand <- getStateRNG ctx 32 >>= return . ClientRandom
let clientSession = Session . maybe Nothing (Just . fst) $ clientWantSessionResume cparams
crand <- ClientRandom <$> getStateRNG ctx 32
let clientSession = Session . (fst <$>) $ clientWantSessionResume cparams
highestVer = maximum $ supportedVersions $ ctxSupported ctx
extensions <- catMaybes <$> getExtensions
startHandshake ctx highestVer crand
Expand Down Expand Up @@ -243,33 +243,30 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
-- Only send a certificate verify message when we
-- have sent a non-empty list of certificates.
--
certSent <- usingHState ctx $ getClientCertSent
case certSent of
True -> do
sigAlg <- getLocalSignatureAlg

mhashSig <- case usedVersion of
TLS12 -> do
Just (_, Just hashSigs, _) <- usingHState ctx $ getClientCertRequest
-- The values in the "signature_algorithms" extension
-- are in descending order of preference.
-- However here the algorithms are selected according
-- to client preference in 'supportedHashSignatures'.
let suppHashSigs = supportedHashSignatures $ ctxSupported ctx
matchHashSigs = filter (sigAlg `signatureCompatible`) suppHashSigs
hashSigs' = filter (\ a -> a `elem` hashSigs) matchHashSigs

when (null hashSigs') $
throwCore $ Error_Protocol ("no " ++ show sigAlg ++ " hash algorithm in common with the server", True, HandshakeFailure)
return $ Just $ head hashSigs'
_ -> return Nothing

-- Fetch all handshake messages up to now.
msgs <- usingHState ctx $ B.concat <$> getHandshakeMessages
sigDig <- createCertificateVerify ctx usedVersion sigAlg mhashSig msgs
sendPacket ctx $ Handshake [CertVerify sigDig]

_ -> return ()
certSent <- usingHState ctx getClientCertSent
when certSent $ do
sigAlg <- getLocalSignatureAlg

mhashSig <- case usedVersion of
TLS12 -> do
Just (_, Just hashSigs, _) <- usingHState ctx getClientCertRequest
-- The values in the "signature_algorithms" extension
-- are in descending order of preference.
-- However here the algorithms are selected according
-- to client preference in 'supportedHashSignatures'.
let suppHashSigs = supportedHashSignatures $ ctxSupported ctx
matchHashSigs = filter (sigAlg `signatureCompatible`) suppHashSigs
hashSigs' = filter (`elem` hashSigs) matchHashSigs

when (null hashSigs') $
throwCore $ Error_Protocol ("no " ++ show sigAlg ++ " hash algorithm in common with the server", True, HandshakeFailure)
return $ Just $ head hashSigs'
_ -> return Nothing

-- Fetch all handshake messages up to now.
msgs <- usingHState ctx $ B.concat <$> getHandshakeMessages
sigDig <- createCertificateVerify ctx usedVersion sigAlg mhashSig msgs
sendPacket ctx $ Handshake [CertVerify sigDig]

getLocalSignatureAlg = do
pk <- usingHState ctx getLocalPrivateKey
Expand Down Expand Up @@ -301,7 +298,7 @@ throwMiscErrorOnException msg e =
onServerHello :: Context -> ClientParams -> [ExtensionID] -> Handshake -> IO (RecvState IO)
onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cipher compression exts) = do
when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
case find ((==) rver) (supportedVersions $ ctxSupported ctx) of
case find (== rver) (supportedVersions $ ctxSupported ctx) of
Nothing -> throwCore $ Error_Protocol ("server version " ++ show rver ++ " is not supported", True, ProtocolVersion)
Just _ -> return ()
-- find the compression and cipher methods that the server want to use.
Expand All @@ -314,7 +311,7 @@ onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cip

-- intersect sent extensions in client and the received extensions from server.
-- if server returns extensions that we didn't request, fail.
when (not $ null $ filter (not . flip elem sentExts . (\(ExtensionRaw i _) -> i)) exts) $
unless (null $ filter (not . flip elem sentExts . (\(ExtensionRaw i _) -> i)) exts) $
throwCore $ Error_Protocol ("spurious extensions received", True, UnsupportedExtension)

let resumingSession =
Expand All @@ -327,11 +324,11 @@ onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cip
setVersion rver
usingHState ctx $ setServerHelloParameters rver serverRan cipherAlg compressAlg

case extensionDecode False `fmap` (extensionLookup extensionID_ApplicationLayerProtocolNegotiation exts) of
case extensionDecode False <$> extensionLookup extensionID_ApplicationLayerProtocolNegotiation exts of
Just (Just (ApplicationLayerProtocolNegotiation [proto])) -> usingState_ ctx $ do
mprotos <- getClientALPNSuggest
case mprotos of
Just protos -> when (elem proto protos) $ do
Just protos -> when (proto `elem` protos) $ do
setExtensionALPN True
setNegotiatedProtocol proto
_ -> return ()
Expand All @@ -347,7 +344,7 @@ onServerHello _ _ _ p = unexpected (show p) (Just "server hello")
processCertificate :: ClientParams -> Context -> Handshake -> IO (RecvState IO)
processCertificate cparams ctx (Certificates certs) = do
-- run certificate recv hook
ctxWithHooks ctx (\hooks -> hookRecvCertificates hooks $ certs)
ctxWithHooks ctx (\hooks -> hookRecvCertificates hooks certs)
-- then run certificate validation
usage <- catchException (wrapCertificateChecks <$> checkCert) rejectOnException
case usage of
Expand Down Expand Up @@ -394,13 +391,13 @@ processServerKeyExchange ctx (ServerKeyXchg origSkx) = do
doDHESignature dhparams signature signatureType = do
-- FIXME verify if FF group is one of supported groups
verified <- digitallySignDHParamsVerify ctx dhparams signatureType signature
when (not verified) $ throwCore $ Error_Protocol ("bad " ++ show signatureType ++ " signature for dhparams " ++ show dhparams, True, HandshakeFailure)
unless verified $ throwCore $ Error_Protocol ("bad " ++ show signatureType ++ " signature for dhparams " ++ show dhparams, True, HandshakeFailure)
usingHState ctx $ setServerDHParams dhparams

doECDHESignature ecdhparams signature signatureType = do
-- FIXME verify if EC group is one of supported groups
verified <- digitallySignECDHParamsVerify ctx ecdhparams signatureType signature
when (not verified) $ throwCore $ Error_Protocol ("bad " ++ show signatureType ++ " signature for ecdhparams", True, HandshakeFailure)
unless verified $ throwCore $ Error_Protocol ("bad " ++ show signatureType ++ " signature for ecdhparams", True, HandshakeFailure)
usingHState ctx $ setServerECDHParams ecdhparams

processServerKeyExchange ctx p = processCertificateRequest ctx p
Expand Down
Loading