Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checking EOF in bye (#261). #262

Closed
wants to merge 18 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/Network/TLS/Backend.hs
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ safeRecv s buf = do
instance HasBackend Network.Socket where
initializeBackend _ = return ()
getBackend sock = Backend (return ()) (Network.close sock) (Network.sendAll sock) recvAll
where recvAll n = B.concat `fmap` loop n
where recvAll n = B.concat <$> loop n
where loop 0 = return []
loop left = do
r <- safeRecv sock left
@@ -86,7 +86,7 @@ instance HasBackend Hans.Socket where
initializeBackend _ = return ()
getBackend sock = Backend (return ()) (Hans.close sock) sendAll recvAll
where sendAll x = do
amt <- fromIntegral `fmap` Hans.sendBytes sock (L.fromStrict x)
amt <- fromIntegral <$> Hans.sendBytes sock (L.fromStrict x)
if (amt == 0) || (amt == B.length x)
then return ()
else sendAll (B.drop amt x)
2 changes: 1 addition & 1 deletion core/Network/TLS/Cipher.hs
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ instance Show BulkState where
show (BulkStateStream _) = "BulkStateStream"
show (BulkStateBlock _) = "BulkStateBlock"
show (BulkStateAEAD _) = "BulkStateAEAD"
show (BulkStateUninitialized) = "BulkStateUninitialized"
show BulkStateUninitialized = "BulkStateUninitialized"

newtype BulkStream = BulkStream (B.ByteString -> (B.ByteString, BulkStream))

2 changes: 1 addition & 1 deletion core/Network/TLS/Compression.hs
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ instance Eq Compression where
-- the function keeps the list of compression in order, to be able to find quickly the prefered
-- compression.
compressionIntersectID :: [Compression] -> [Word8] -> [Compression]
compressionIntersectID l ids = filter (\c -> elem (compressionID c) ids) l
compressionIntersectID l ids = filter (\c -> compressionID c `elem` ids) l

-- | This is the default compression which is a NOOP.
data NullCompression = NullCompression
4 changes: 2 additions & 2 deletions core/Network/TLS/Context.hs
Original file line number Diff line number Diff line change
@@ -121,7 +121,7 @@ contextNew backend params = liftIO $ do

seed <- case debugSeed debug of
Nothing -> do seed <- seedNew
debugPrintSeed debug $ seed
debugPrintSeed debug seed
return seed
Just determ -> return determ
let rng = newStateRNG seed
@@ -145,7 +145,7 @@ contextNew backend params = liftIO $ do
lockRead <- newMVar ()
lockState <- newMVar ()

return $ Context
return Context
{ ctxConnection = getBackend backend
, ctxShared = shared
, ctxSupported = supported
2 changes: 1 addition & 1 deletion core/Network/TLS/Context/Internal.hs
Original file line number Diff line number Diff line change
@@ -195,7 +195,7 @@ usingHState :: Context -> HandshakeM a -> IO a
usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst ->
case mst of
Nothing -> throwCore $ Error_Misc "missing handshake"
Just st -> return $ swap (Just `fmap` runHandshake st f)
Just st -> return $ swap (Just <$> runHandshake st f)

getHState :: Context -> IO (Maybe HandshakeState)
getHState ctx = liftIO $ readMVar (ctxHandshake ctx)
6 changes: 4 additions & 2 deletions core/Network/TLS/Core.hs
Original file line number Diff line number Diff line change
@@ -51,7 +51,9 @@ import Control.Monad.State.Strict
--
-- this doesn't actually close the handle
bye :: MonadIO m => Context -> m ()
bye ctx = sendPacket ctx $ Alert [(AlertLevel_Warning, CloseNotify)]
bye ctx = do
eof <- liftIO $ ctxEOF ctx
unless eof $ sendPacket ctx $ Alert [(AlertLevel_Warning, CloseNotify)]

-- | If the ALPN extensions have been used, this will
-- return get the protocol agreed upon.
@@ -91,7 +93,7 @@ recvData ctx = liftIO $ do
onError err =
terminate err AlertLevel_Fatal InternalError (show err)

process (Handshake [ch@(ClientHello {})]) =
process (Handshake [ch@ClientHello{}]) =
handshakeWith ctx ch >> recvData ctx
process (Handshake [hr@HelloRequest]) =
handshakeWith ctx hr >> recvData ctx
2 changes: 1 addition & 1 deletion core/Network/TLS/Credentials.hs
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ credentialLoadX509ChainFromMemory certData chainData privateData = do
(k:_) -> Right (CertificateChain . concat $ x509 : chains, k)

credentialsListSigningAlgorithms :: Credentials -> [DigitalSignatureAlg]
credentialsListSigningAlgorithms (Credentials l) = catMaybes $ map credentialCanSign l
credentialsListSigningAlgorithms (Credentials l) = mapMaybe credentialCanSign l

credentialsFindForSigning :: DigitalSignatureAlg -> Credentials -> Maybe Credential
credentialsFindForSigning sigAlg (Credentials l) = find forSigning l
12 changes: 6 additions & 6 deletions core/Network/TLS/Crypto.hs
Original file line number Diff line number Diff line change
@@ -170,11 +170,11 @@ generalizeRSAError (Left e) = Left (RSAError e)
generalizeRSAError (Right x) = Right x

kxEncrypt :: MonadRandom r => PublicKey -> ByteString -> r (Either KxError ByteString)
kxEncrypt (PubKeyRSA pk) b = generalizeRSAError `fmap` RSA.encrypt pk b
kxEncrypt (PubKeyRSA pk) b = generalizeRSAError <$> RSA.encrypt pk b
kxEncrypt _ _ = return (Left KxUnsupported)

kxDecrypt :: MonadRandom r => PrivateKey -> ByteString -> r (Either KxError ByteString)
kxDecrypt (PrivKeyRSA pk) b = generalizeRSAError `fmap` RSA.decryptSafer pk b
kxDecrypt (PrivKeyRSA pk) b = generalizeRSAError <$> RSA.decryptSafer pk b
kxDecrypt _ _ = return (Left KxUnsupported)

data RSAEncoding = RSApkcs1 | RSApss deriving (Show,Eq)
@@ -208,10 +208,10 @@ kxVerify (PubKeyDSA pk) DSSParams msg signBS =
Right asn1 ->
case asn1 of
Start Sequence:IntVal r:IntVal s:End Sequence:_ ->
Just $ DSA.Signature { DSA.sign_r = r, DSA.sign_s = s }
Just DSA.Signature { DSA.sign_r = r, DSA.sign_s = s }
_ ->
Nothing
kxVerify (PubKeyEC key) (ECDSAParams alg) msg sigBS = maybe False id $ do
kxVerify (PubKeyEC key) (ECDSAParams alg) msg sigBS = fromMaybe False $ do
-- get the curve name and the public key data
(curveName, pubBS) <- case key of
PubKeyEC_Named curveName' pub -> Just (curveName',pub)
@@ -274,9 +274,9 @@ kxSign :: MonadRandom r
-> ByteString
-> r (Either KxError ByteString)
kxSign (PrivKeyRSA pk) (RSAParams hashAlg RSApkcs1) msg =
generalizeRSAError `fmap` rsaSignHash hashAlg pk msg
generalizeRSAError <$> rsaSignHash hashAlg pk msg
kxSign (PrivKeyRSA pk) (RSAParams hashAlg RSApss) msg =
generalizeRSAError `fmap` rsapssSignHash hashAlg pk msg
generalizeRSAError <$> rsapssSignHash hashAlg pk msg
kxSign (PrivKeyDSA pk) DSSParams msg = do
sign <- DSA.sign pk H.SHA1 msg
return (Right $ encodeASN1' DER $ dsaSequence sign)
8 changes: 4 additions & 4 deletions core/Network/TLS/Extension.hs
Original file line number Diff line number Diff line change
@@ -163,7 +163,7 @@ instance Extension ServerName where
extensionEncode (ServerName l) = runPut $ putOpaque16 (runPut $ mapM_ encodeNameType l)
where encodeNameType (ServerNameHostName hn) = putWord8 0 >> putOpaque16 (BC.pack hn) -- FIXME: should be puny code conversion
encodeNameType (ServerNameOther (nt,opaque)) = putWord8 nt >> putBytes opaque
extensionDecode _ = runGetMaybe (getWord16 >>= \len -> getList (fromIntegral len) getServerName >>= return . ServerName)
extensionDecode _ = runGetMaybe (getWord16 >>= \len -> ServerName <$> getList (fromIntegral len) getServerName)
where getServerName = do
ty <- getWord8
sname <- getOpaque16
@@ -229,7 +229,7 @@ newtype NegotiatedGroups = NegotiatedGroups [Group] deriving (Show,Eq)
instance Extension NegotiatedGroups where
extensionID _ = extensionID_NegotiatedGroups
extensionEncode (NegotiatedGroups groups) = runPut $ putWords16 $ map fromEnumSafe16 groups
extensionDecode _ = runGetMaybe (NegotiatedGroups . catMaybes . map toEnumSafe16 <$> getWords16)
extensionDecode _ = runGetMaybe (NegotiatedGroups . mapMaybe toEnumSafe16 <$> getWords16)

newtype EcPointFormatsSupported = EcPointFormatsSupported [EcPointFormat] deriving (Show,Eq)

@@ -253,14 +253,14 @@ instance EnumSafe8 EcPointFormat where
instance Extension EcPointFormatsSupported where
extensionID _ = extensionID_EcPointFormats
extensionEncode (EcPointFormatsSupported formats) = runPut $ putWords8 $ map fromEnumSafe8 formats
extensionDecode _ = runGetMaybe (EcPointFormatsSupported . catMaybes . map toEnumSafe8 <$> getWords8)
extensionDecode _ = runGetMaybe (EcPointFormatsSupported . mapMaybe toEnumSafe8 <$> getWords8)

data SessionTicket = SessionTicket
deriving (Show,Eq)

instance Extension SessionTicket where
extensionID _ = extensionID_SessionTicket
extensionEncode (SessionTicket {}) = runPut $ return ()
extensionEncode SessionTicket{} = runPut $ return ()
extensionDecode _ = runGetMaybe (return SessionTicket)

newtype HeartBeat = HeartBeat HeartBeatMode deriving (Show,Eq)
5 changes: 3 additions & 2 deletions core/Network/TLS/Extra/Cipher.hs
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@ import qualified Data.ByteString as B

import Network.TLS.Types (Version(..))
import Network.TLS.Cipher
import Network.TLS.Imports
import Data.Tuple (swap)

import Crypto.Cipher.AES
@@ -123,7 +124,7 @@ noFail :: CryptoFailable a -> a
noFail = throwCryptoError

makeIV_ :: BlockCipher a => B.ByteString -> IV a
makeIV_ = maybe (error "makeIV_") id . makeIV
makeIV_ = fromMaybe (error "makeIV_") . makeIV

tripledes_ede :: BulkDirection -> BulkKey -> BulkBlock
tripledes_ede BulkEncrypt key =
@@ -134,7 +135,7 @@ tripledes_ede BulkDecrypt key =
in (\iv input -> let output = cbcDecrypt ctx (tripledes_iv iv) input in (output, takelast 8 input))

tripledes_iv :: BulkIV -> IV DES_EDE3
tripledes_iv iv = maybe (error "tripledes cipher iv internal error") id $ makeIV iv
tripledes_iv iv = fromMaybe (error "tripledes cipher iv internal error") $ makeIV iv

rc4 :: BulkDirection -> BulkKey -> BulkStream
rc4 _ bulkKey = BulkStream (combineRC4 $ RC4.initialize bulkKey)
7 changes: 4 additions & 3 deletions core/Network/TLS/Handshake.hs
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.IO
import Network.TLS.Util (catchException)
import Network.TLS.Imports

import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Client
@@ -30,17 +31,17 @@ import Control.Exception (IOException, catch, fromException)
-- This is to be called at the beginning of a connection, and during renegotiation
handshake :: MonadIO m => Context -> m ()
handshake ctx =
liftIO $ handleException ctx $ withRWLock ctx (ctxDoHandshake ctx $ ctx)
liftIO $ handleException ctx $ withRWLock ctx (ctxDoHandshake ctx ctx)

-- Handshake when requested by the remote end
-- This is called automatically by 'recvData'
handshakeWith :: MonadIO m => Context -> Handshake -> m ()
handshakeWith ctx hs =
liftIO $ handleException ctx $ withRWLock ctx ((ctxDoHandshakeWith ctx) ctx hs)
liftIO $ handleException ctx $ withRWLock ctx $ ctxDoHandshakeWith ctx ctx hs

handleException :: Context -> IO () -> IO ()
handleException ctx f = catchException f $ \exception -> do
let tlserror = maybe (Error_Misc $ show exception) id $ fromException exception
let tlserror = fromMaybe (Error_Misc $ show exception) $ fromException exception
setEstablished ctx False
sendPacket ctx (errorToAlert tlserror) `catch` ignoreIOErr
handshakeFailed tlserror
69 changes: 33 additions & 36 deletions core/Network/TLS/Handshake/Client.hs
Original file line number Diff line number Diff line change
@@ -99,8 +99,8 @@ handshakeClient cparams ctx = do
signatureAlgExtension = return $ Just $ toExtensionRaw $ SignatureAlgorithms $ supportedHashSignatures $ clientSupported cparams

sendClientHello = do
crand <- getStateRNG ctx 32 >>= return . ClientRandom
let clientSession = Session . maybe Nothing (Just . fst) $ clientWantSessionResume cparams
crand <- ClientRandom <$> getStateRNG ctx 32
let clientSession = Session . (fst <$>) $ clientWantSessionResume cparams
highestVer = maximum $ supportedVersions $ ctxSupported ctx
extensions <- catMaybes <$> getExtensions
startHandshake ctx highestVer crand
@@ -243,33 +243,30 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
-- Only send a certificate verify message when we
-- have sent a non-empty list of certificates.
--
certSent <- usingHState ctx $ getClientCertSent
case certSent of
True -> do
sigAlg <- getLocalSignatureAlg

mhashSig <- case usedVersion of
TLS12 -> do
Just (_, Just hashSigs, _) <- usingHState ctx $ getClientCertRequest
-- The values in the "signature_algorithms" extension
-- are in descending order of preference.
-- However here the algorithms are selected according
-- to client preference in 'supportedHashSignatures'.
let suppHashSigs = supportedHashSignatures $ ctxSupported ctx
matchHashSigs = filter (sigAlg `signatureCompatible`) suppHashSigs
hashSigs' = filter (\ a -> a `elem` hashSigs) matchHashSigs

when (null hashSigs') $
throwCore $ Error_Protocol ("no " ++ show sigAlg ++ " hash algorithm in common with the server", True, HandshakeFailure)
return $ Just $ head hashSigs'
_ -> return Nothing

-- Fetch all handshake messages up to now.
msgs <- usingHState ctx $ B.concat <$> getHandshakeMessages
sigDig <- createCertificateVerify ctx usedVersion sigAlg mhashSig msgs
sendPacket ctx $ Handshake [CertVerify sigDig]

_ -> return ()
certSent <- usingHState ctx getClientCertSent
when certSent $ do
sigAlg <- getLocalSignatureAlg

mhashSig <- case usedVersion of
TLS12 -> do
Just (_, Just hashSigs, _) <- usingHState ctx getClientCertRequest
-- The values in the "signature_algorithms" extension
-- are in descending order of preference.
-- However here the algorithms are selected according
-- to client preference in 'supportedHashSignatures'.
let suppHashSigs = supportedHashSignatures $ ctxSupported ctx
matchHashSigs = filter (sigAlg `signatureCompatible`) suppHashSigs
hashSigs' = filter (`elem` hashSigs) matchHashSigs

when (null hashSigs') $
throwCore $ Error_Protocol ("no " ++ show sigAlg ++ " hash algorithm in common with the server", True, HandshakeFailure)
return $ Just $ head hashSigs'
_ -> return Nothing

-- Fetch all handshake messages up to now.
msgs <- usingHState ctx $ B.concat <$> getHandshakeMessages
sigDig <- createCertificateVerify ctx usedVersion sigAlg mhashSig msgs
sendPacket ctx $ Handshake [CertVerify sigDig]

getLocalSignatureAlg = do
pk <- usingHState ctx getLocalPrivateKey
@@ -301,7 +298,7 @@ throwMiscErrorOnException msg e =
onServerHello :: Context -> ClientParams -> [ExtensionID] -> Handshake -> IO (RecvState IO)
onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cipher compression exts) = do
when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
case find ((==) rver) (supportedVersions $ ctxSupported ctx) of
case find (== rver) (supportedVersions $ ctxSupported ctx) of
Nothing -> throwCore $ Error_Protocol ("server version " ++ show rver ++ " is not supported", True, ProtocolVersion)
Just _ -> return ()
-- find the compression and cipher methods that the server want to use.
@@ -314,7 +311,7 @@ onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cip

-- intersect sent extensions in client and the received extensions from server.
-- if server returns extensions that we didn't request, fail.
when (not $ null $ filter (not . flip elem sentExts . (\(ExtensionRaw i _) -> i)) exts) $
unless (null $ filter (not . flip elem sentExts . (\(ExtensionRaw i _) -> i)) exts) $
throwCore $ Error_Protocol ("spurious extensions received", True, UnsupportedExtension)

let resumingSession =
@@ -327,11 +324,11 @@ onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cip
setVersion rver
usingHState ctx $ setServerHelloParameters rver serverRan cipherAlg compressAlg

case extensionDecode False `fmap` (extensionLookup extensionID_ApplicationLayerProtocolNegotiation exts) of
case extensionDecode False <$> extensionLookup extensionID_ApplicationLayerProtocolNegotiation exts of
Just (Just (ApplicationLayerProtocolNegotiation [proto])) -> usingState_ ctx $ do
mprotos <- getClientALPNSuggest
case mprotos of
Just protos -> when (elem proto protos) $ do
Just protos -> when (proto `elem` protos) $ do
setExtensionALPN True
setNegotiatedProtocol proto
_ -> return ()
@@ -347,7 +344,7 @@ onServerHello _ _ _ p = unexpected (show p) (Just "server hello")
processCertificate :: ClientParams -> Context -> Handshake -> IO (RecvState IO)
processCertificate cparams ctx (Certificates certs) = do
-- run certificate recv hook
ctxWithHooks ctx (\hooks -> hookRecvCertificates hooks $ certs)
ctxWithHooks ctx (\hooks -> hookRecvCertificates hooks certs)
-- then run certificate validation
usage <- catchException (wrapCertificateChecks <$> checkCert) rejectOnException
case usage of
@@ -394,13 +391,13 @@ processServerKeyExchange ctx (ServerKeyXchg origSkx) = do
doDHESignature dhparams signature signatureType = do
-- FIXME verify if FF group is one of supported groups
verified <- digitallySignDHParamsVerify ctx dhparams signatureType signature
when (not verified) $ throwCore $ Error_Protocol ("bad " ++ show signatureType ++ " signature for dhparams " ++ show dhparams, True, HandshakeFailure)
unless verified $ throwCore $ Error_Protocol ("bad " ++ show signatureType ++ " signature for dhparams " ++ show dhparams, True, HandshakeFailure)
usingHState ctx $ setServerDHParams dhparams

doECDHESignature ecdhparams signature signatureType = do
-- FIXME verify if EC group is one of supported groups
verified <- digitallySignECDHParamsVerify ctx ecdhparams signatureType signature
when (not verified) $ throwCore $ Error_Protocol ("bad " ++ show signatureType ++ " signature for ecdhparams", True, HandshakeFailure)
unless verified $ throwCore $ Error_Protocol ("bad " ++ show signatureType ++ " signature for ecdhparams", True, HandshakeFailure)
usingHState ctx $ setServerECDHParams ecdhparams

processServerKeyExchange ctx p = processCertificateRequest ctx p
8 changes: 4 additions & 4 deletions core/Network/TLS/Handshake/Common.hs
Original file line number Diff line number Diff line change
@@ -44,12 +44,12 @@ errorToAlert :: TLSError -> Packet
errorToAlert (Error_Protocol (_, _, ad)) = Alert [(AlertLevel_Fatal, ad)]
errorToAlert _ = Alert [(AlertLevel_Fatal, InternalError)]

unexpected :: String -> Maybe [Char] -> IO a
unexpected :: String -> Maybe String -> IO a
unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" expected: " ++) expected)

newSession :: Context -> IO Session
newSession ctx
| supportedSession $ ctxSupported ctx = getStateRNG ctx 32 >>= return . Session . Just
| supportedSession $ ctxSupported ctx = Session . Just <$> getStateRNG ctx 32
| otherwise = return $ Session Nothing

-- | when a new handshake is done, wrap up & clean up.
@@ -116,7 +116,7 @@ onRecvStateHandshake ctx (RecvStateHandshake f) (x:xs) = do
onRecvStateHandshake _ _ _ = unexpected "spurious handshake" Nothing

runRecvState :: Context -> RecvState IO -> IO ()
runRecvState _ (RecvStateDone) = return ()
runRecvState _ RecvStateDone = return ()
runRecvState ctx (RecvStateNext f) = recvPacket ctx >>= either throwCore f >>= runRecvState ctx
runRecvState ctx iniState = recvPacketHandshake ctx >>= onRecvStateHandshake ctx iniState >>= runRecvState ctx

@@ -128,7 +128,7 @@ getSessionData ctx = do
tx <- liftIO $ readMVar (ctxTxState ctx)
case mms of
Nothing -> return Nothing
Just ms -> return $ Just $ SessionData
Just ms -> return $ Just SessionData
{ sessionVersion = ver
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher tx
, sessionCompression = compressionID $ stCompression tx
51 changes: 24 additions & 27 deletions core/Network/TLS/Handshake/Server.hs
Original file line number Diff line number Diff line change
@@ -203,7 +203,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS

cred <- case cipherKeyExchange usedCipher of
CipherKeyExchange_RSA -> return $ credentialsFindForDecrypting creds
CipherKeyExchange_DH_Anon -> return $ Nothing
CipherKeyExchange_DH_Anon -> return Nothing
CipherKeyExchange_DHE_RSA -> return $ credentialsFindForSigning RSA signatureCreds
CipherKeyExchange_DHE_DSS -> return $ credentialsFindForSigning DSS signatureCreds
CipherKeyExchange_ECDHE_RSA -> return $ credentialsFindForSigning RSA signatureCreds
@@ -491,31 +491,28 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC

verif <- checkCertificateVerify ctx usedVersion sigAlgExpected msgs dsig

case verif of
True -> do
-- When verification succeeds, commit the
-- client certificate chain to the context.
--
Just certs <- usingHState ctx getClientCertChain
usingState_ ctx $ setClientCertificateChain certs
return ()

False -> do
-- Either verification failed because of an
-- invalid format (with an error message), or
-- the signature is wrong. In either case,
-- ask the application if it wants to
-- proceed, we will do that.
res <- liftIO $ onUnverifiedClientCert (serverHooks sparams)
if res
then do
-- When verification fails, but the
-- application callbacks accepts, we
-- also commit the client certificate
-- chain to the context.
Just certs <- usingHState ctx getClientCertChain
usingState_ ctx $ setClientCertificateChain certs
else throwCore $ Error_Protocol ("verification failed", True, BadCertificate)
if verif then do
-- When verification succeeds, commit the
-- client certificate chain to the context.
--
Just certs <- usingHState ctx getClientCertChain
usingState_ ctx $ setClientCertificateChain certs
return ()
else do
-- Either verification failed because of an
-- invalid format (with an error message), or
-- the signature is wrong. In either case,
-- ask the application if it wants to
-- proceed, we will do that.
res <- liftIO $ onUnverifiedClientCert (serverHooks sparams)
if res then do
-- When verification fails, but the
-- application callbacks accepts, we
-- also commit the client certificate
-- chain to the context.
Just certs <- usingHState ctx getClientCertChain
usingState_ ctx $ setClientCertificateChain certs
else throwCore $ Error_Protocol ("verification failed", True, BadCertificate)
return $ RecvStateNext expectChangeCipher

processCertificateVerify p = do
@@ -535,7 +532,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
_ -> throwCore $ Error_Protocol ("unsupported remote public key type", True, HandshakeFailure)

expectChangeCipher ChangeCipherSpec = do
return $ RecvStateHandshake $ expectFinish
return $ RecvStateHandshake expectFinish

expectChangeCipher p = unexpected (show p) (Just "change cipher")

4 changes: 2 additions & 2 deletions core/Network/TLS/Handshake/Signature.hs
Original file line number Diff line number Diff line change
@@ -129,7 +129,7 @@ signatureCreateWithCertVerifyData :: Context
-> CertVerifyData
-> IO DigitallySigned
signatureCreateWithCertVerifyData ctx malg (sigParam, toSign) = do
cc <- usingState_ ctx $ isClientContext
cc <- usingState_ ctx isClientContext
DigitallySigned malg <$> signPrivate ctx cc sigParam toSign

signatureVerify :: Context -> DigitallySigned -> DigitalSignatureAlg -> ByteString -> IO Bool
@@ -149,7 +149,7 @@ signatureVerifyWithCertVerifyData :: Context
-> CertVerifyData
-> IO Bool
signatureVerifyWithCertVerifyData ctx (DigitallySigned _ bs) (sigParam, toVerify) = do
cc <- usingState_ ctx $ isClientContext
cc <- usingState_ ctx isClientContext
verifyPublic ctx cc sigParam toVerify bs

digitallySignParams :: Context -> ByteString -> DigitalSignatureAlg -> Maybe HashAndSignatureAlgorithm -> IO DigitallySigned
4 changes: 2 additions & 2 deletions core/Network/TLS/Handshake/State.hs
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ data HandshakeKeyState = HandshakeKeyState
} deriving (Show)

data HandshakeState = HandshakeState
{ hstClientVersion :: !(Version)
{ hstClientVersion :: !Version
, hstClientRandom :: !ClientRandom
, hstServerRandom :: !(Maybe ServerRandom)
, hstMasterSecret :: !(Maybe ByteString)
@@ -98,7 +98,7 @@ newtype HandshakeM a = HandshakeM { runHandshakeM :: State HandshakeState a }

instance MonadState HandshakeState HandshakeM where
put x = HandshakeM (put x)
get = HandshakeM (get)
get = HandshakeM get
#if MIN_VERSION_mtl(2,1,0)
state f = HandshakeM (state f)
#endif
10 changes: 5 additions & 5 deletions core/Network/TLS/Hooks.hs
Original file line number Diff line number Diff line change
@@ -28,10 +28,10 @@ data Logging = Logging

defaultLogging :: Logging
defaultLogging = Logging
{ loggingPacketSent = (\_ -> return ())
, loggingPacketRecv = (\_ -> return ())
, loggingIOSent = (\_ -> return ())
, loggingIORecv = (\_ _ -> return ())
{ loggingPacketSent = \_ -> return ()
, loggingPacketRecv = \_ -> return ()
, loggingIOSent = \_ -> return ()
, loggingIORecv = \_ _ -> return ()
}

instance Default Logging where
@@ -49,7 +49,7 @@ data Hooks = Hooks

defaultHooks :: Hooks
defaultHooks = Hooks
{ hookRecvHandshake = \hs -> return hs
{ hookRecvHandshake = return
, hookRecvCertificates = return . const ()
, hookLogging = def
}
4 changes: 2 additions & 2 deletions core/Network/TLS/IO.hs
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ readExact ctx sz = do
return . Left $
if B.null hdrbs
then Error_EOF
else Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ (show $B.length hdrbs))
else Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ show (B.length hdrbs))


-- | recvRecord receive a full TLS record (header + data), from the other side.
@@ -103,7 +103,7 @@ recvPacket ctx = liftIO $ do
pkt <- case pktRecv of
Right (Handshake hss) ->
ctxWithHooks ctx $ \hooks ->
(mapM (hookRecvHandshake hooks) hss) >>= return . Right . Handshake
Right . Handshake <$> mapM (hookRecvHandshake hooks) hss
_ -> return pktRecv
case pkt of
Right p -> withLog ctx $ \logging -> loggingPacketRecv logging $ show p
18 changes: 9 additions & 9 deletions core/Network/TLS/Packet.hs
Original file line number Diff line number Diff line change
@@ -147,7 +147,7 @@ decodeAlert = do
(_, Nothing) -> fail "cannot decode alert description"

decodeAlerts :: ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts = runGetErr "alerts" $ loop
decodeAlerts = runGetErr "alerts" loop
where loop = do
r <- remaining
if r == 0
@@ -212,7 +212,7 @@ decodeClientHello = do
compressions <- getWords8
r <- remaining
exts <- if hasHelloExtensions ver && r > 0
then fmap fromIntegral getWord16 >>= getExtensions
then fromIntegral <$> getWord16 >>= getExtensions
else return []
return $ ClientHello ver random session ciphers compressions exts Nothing

@@ -225,7 +225,7 @@ decodeServerHello = do
compressionid <- getWord8
r <- remaining
exts <- if hasHelloExtensions ver && r > 0
then fmap fromIntegral getWord16 >>= getExtensions
then fromIntegral <$> getWord16 >>= getExtensions
else return []
return $ ServerHello ver random session cipherid compressionid exts

@@ -380,22 +380,22 @@ encodeHandshakeContent (ServerKeyXchg skg) =
SKX_Unparsed bytes -> putBytes bytes
_ -> error ("encodeHandshakeContent: cannot handle: " ++ show skg)

encodeHandshakeContent (HelloRequest) = return ()
encodeHandshakeContent (ServerHelloDone) = return ()
encodeHandshakeContent HelloRequest = return ()
encodeHandshakeContent ServerHelloDone = return ()

encodeHandshakeContent (CertRequest certTypes sigAlgs certAuthorities) = do
putWords8 (map valOfType certTypes)
case sigAlgs of
Nothing -> return ()
Just l -> putWords16 $ map (\(x,y) -> (fromIntegral $ valOfType x) * 256 + (fromIntegral $ valOfType y)) l
Just l -> putWords16 $ map (\(x,y) -> fromIntegral (valOfType x) * 256 + fromIntegral (valOfType y)) l
encodeCertAuthorities certAuthorities
where -- Convert a distinguished name to its DER encoding.
encodeCA dn = return $ encodeASN1' DER (toASN1 dn []) --B.concat $ L.toChunks $ encodeDN dn

-- Encode a list of distinguished names.
encodeCertAuthorities certAuths = do
enc <- mapM encodeCA certAuths
let totLength = sum $ map (((+) 2) . B.length) enc
let totLength = sum $ map ((+) 2 . B.length) enc
putWord16 (fromIntegral totLength)
mapM_ (\ b -> putWord16 (fromIntegral (B.length b)) >> putBytes b) enc

@@ -540,11 +540,11 @@ getPRF :: Version -> Cipher -> PRF
getPRF ver ciph
| ver < TLS12 = prf_MD5SHA1
| maybe True (< TLS12) (cipherMinVer ciph) = prf_SHA256
| otherwise = prf_TLS ver $ maybe SHA256 id $ cipherPRFHash ciph
| otherwise = prf_TLS ver $ fromMaybe SHA256 $ cipherPRFHash ciph

generateMasterSecret_SSL :: ByteArrayAccess preMaster => preMaster -> ClientRandom -> ServerRandom -> ByteString
generateMasterSecret_SSL premasterSecret (ClientRandom c) (ServerRandom s) =
B.concat $ map (computeMD5) ["A","BB","CCC"]
B.concat $ map computeMD5 ["A","BB","CCC"]
where computeMD5 label = hash MD5 $ B.concat [ B.convert premasterSecret, computeSHA1 label ]
computeSHA1 label = hash SHA1 $ B.concat [ label, B.convert premasterSecret, c, s ]

6 changes: 3 additions & 3 deletions core/Network/TLS/Parameters.hs
Original file line number Diff line number Diff line change
@@ -247,10 +247,10 @@ data GroupUsage =

defaultGroupUsage :: DHParams -> DHPublic -> IO GroupUsage
defaultGroupUsage params public
| not $ odd (dhParamsGetP params) = return $ GroupUsageUnsupported "invalid odd prime"
| even $ dhParamsGetP params = return $ GroupUsageUnsupported "invalid odd prime"
| not $ dhValid params (dhParamsGetG params) = return $ GroupUsageUnsupported "invalid generator"
| not $ dhValid params (dhUnwrapPublic public) = return $ GroupUsageInvalidPublic
| otherwise = return $ GroupUsageValid
| not $ dhValid params (dhUnwrapPublic public) = return GroupUsageInvalidPublic
| otherwise = return GroupUsageValid

-- | A set of callbacks run by the clients for various corners of TLS establishment
data ClientHooks = ClientHooks
9 changes: 5 additions & 4 deletions core/Network/TLS/Receiving.hs
Original file line number Diff line number Diff line change
@@ -27,12 +27,13 @@ import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Cipher
import Network.TLS.Util
import Network.TLS.Imports

processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)

processPacket _ (Record ProtocolType_AppData _ fragment) = return $ Right $ AppData $ fragmentGetBytes fragment

processPacket _ (Record ProtocolType_Alert _ fragment) = return (Alert `fmapEither` (decodeAlerts $ fragmentGetBytes fragment))
processPacket _ (Record ProtocolType_Alert _ fragment) = return (Alert `fmapEither` decodeAlerts (fragmentGetBytes fragment))

processPacket ctx (Record ProtocolType_ChangeCipherSpec _ fragment) =
case decodeChangeCipherSpec $ fragmentGetBytes fragment of
@@ -41,7 +42,7 @@ processPacket ctx (Record ProtocolType_ChangeCipherSpec _ fragment) =
return $ Right ChangeCipherSpec

processPacket ctx (Record ProtocolType_Handshake ver fragment) = do
keyxchg <- getHState ctx >>= \hs -> return $ (hs >>= hstPendingCipher >>= Just . cipherKeyExchange)
keyxchg <- getHState ctx >>= \hs -> return (hs >>= hstPendingCipher >>= Just . cipherKeyExchange)
usingState ctx $ do
let currentParams = CurrentParams
{ cParamsVersion = ver
@@ -53,15 +54,15 @@ processPacket ctx (Record ProtocolType_Handshake ver fragment) = do
hss <- parseMany currentParams mCont (fragmentGetBytes fragment)
return $ Handshake hss
where parseMany currentParams mCont bs =
case maybe decodeHandshakeRecord id mCont $ bs of
case fromMaybe decodeHandshakeRecord mCont bs of
GotError err -> throwError err
GotPartial cont -> modify (\st -> st { stHandshakeRecordCont = Just cont }) >> return []
GotSuccess (ty,content) ->
either throwError (return . (:[])) $ decodeHandshake currentParams ty content
GotSuccessRemaining (ty,content) left ->
case decodeHandshake currentParams ty content of
Left err -> throwError err
Right hh -> (hh:) `fmap` parseMany currentParams Nothing left
Right hh -> (hh:) <$> parseMany currentParams Nothing left

processPacket _ (Record ProtocolType_DeprecatedHandshake _ fragment) =
case decodeDeprecatedHandshake $ fragmentGetBytes fragment of
12 changes: 6 additions & 6 deletions core/Network/TLS/Record/Disengage.hs
Original file line number Diff line number Diff line change
@@ -62,7 +62,7 @@ getCipherData (Record pt ver _) cdata = do
Just pad -> do
cver <- getRecordVersion
let b = B.length pad - 1
return (if cver < TLS10 then True else B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad)
return (cver < TLS10 || B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad)

unless (macValid &&! paddingValid) $ do
throwError $ Error_Protocol ("bad record mac", True, BadRecordMac)
@@ -87,7 +87,7 @@ decryptData ver record econtent tst = decryptOf (cstKey cst)
let minContent = (if explicitIV then bulkIVSize bulk else 0) + max (macSize + 1) blockSize

-- check if we have enough bytes to cover the minimum for this cipher
when ((econtentLen `mod` blockSize) /= 0 || econtentLen < minContent) $ sanityCheckError
when ((econtentLen `mod` blockSize) /= 0 || econtentLen < minContent) sanityCheckError

{- update IV -}
(iv, econtent') <- if explicitIV
@@ -99,22 +99,22 @@ decryptData ver record econtent tst = decryptOf (cstKey cst)
let paddinglength = fromIntegral (B.last content') + 1
let contentlen = B.length content' - paddinglength - macSize
(content, mac, padding) <- get3 content' (contentlen, macSize, paddinglength)
getCipherData record $ CipherData
getCipherData record CipherData
{ cipherDataContent = content
, cipherDataMAC = Just mac
, cipherDataPadding = Just padding
}

decryptOf (BulkStateStream (BulkStream decryptF)) = do
-- check if we have enough bytes to cover the minimum for this cipher
when (econtentLen < macSize) $ sanityCheckError
when (econtentLen < macSize) sanityCheckError

let (content', bulkStream') = decryptF econtent
{- update Ctx -}
let contentlen = B.length content' - macSize
(content, mac) <- get2 content' (contentlen, macSize)
modify $ \txs -> txs { stCryptState = cst { cstKey = BulkStateStream bulkStream' } }
getCipherData record $ CipherData
getCipherData record CipherData
{ cipherDataContent = content
, cipherDataMAC = Just mac
, cipherDataPadding = Nothing
@@ -126,7 +126,7 @@ decryptData ver record econtent tst = decryptOf (cstKey cst)
cipherLen = econtentLen - authTagLen - nonceExpLen

-- check if we have enough bytes to cover the minimum for this cipher
when (econtentLen < (authTagLen + nonceExpLen)) $ sanityCheckError
when (econtentLen < (authTagLen + nonceExpLen)) sanityCheckError

(enonce, econtent', authTag) <- get3 econtent (nonceExpLen, cipherLen, authTagLen)
let encodedSeq = encodeWord64 $ msSequence $ stMacState tst
4 changes: 2 additions & 2 deletions core/Network/TLS/Sending.hs
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ makeRecord pkt = do
return $ Record (packetType pkt) ver (fragmentPlaintext $ writePacketContent pkt)
where writePacketContent (Handshake hss) = encodeHandshakes hss
writePacketContent (Alert a) = encodeAlerts a
writePacketContent (ChangeCipherSpec) = encodeChangeCipherSpec
writePacketContent ChangeCipherSpec = encodeChangeCipherSpec
writePacketContent (AppData x) = x

-- | marshall packet data
@@ -69,7 +69,7 @@ prepareRecord :: Context -> RecordM a -> IO (Either TLSError a)
prepareRecord ctx f = do
ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx)
txState <- readMVar $ ctxTxState ctx
let sz = case stCipher $ txState of
let sz = case stCipher txState of
Nothing -> 0
Just cipher -> if hasRecordIV $ bulkF $ cipherBulk cipher
then bulkIVSize $ cipherBulk cipher
4 changes: 2 additions & 2 deletions core/Network/TLS/State.hs
Original file line number Diff line number Diff line change
@@ -173,10 +173,10 @@ setVersionIfUnset ver = modify maybeSet
Just _ -> st

getVersion :: TLSSt Version
getVersion = maybe (error $ "internal error: version hasn't been set yet") id <$> gets stVersion
getVersion = fromMaybe (error "internal error: version hasn't been set yet") <$> gets stVersion

getVersionWithDefault :: Version -> TLSSt Version
getVersionWithDefault defaultVer = maybe defaultVer id <$> gets stVersion
getVersionWithDefault defaultVer = fromMaybe defaultVer <$> gets stVersion

setSecureRenegotiation :: Bool -> TLSSt ()
setSecureRenegotiation b = modify (\st -> st { stSecureRenegotiation = b })
20 changes: 10 additions & 10 deletions core/Network/TLS/Struct.hs
Original file line number Diff line number Diff line change
@@ -326,16 +326,16 @@ packetType ChangeCipherSpec = ProtocolType_ChangeCipherSpec
packetType (AppData _) = ProtocolType_AppData

typeOfHandshake :: Handshake -> HandshakeType
typeOfHandshake (ClientHello {}) = HandshakeType_ClientHello
typeOfHandshake (ServerHello {}) = HandshakeType_ServerHello
typeOfHandshake (Certificates {}) = HandshakeType_Certificate
typeOfHandshake HelloRequest = HandshakeType_HelloRequest
typeOfHandshake (ServerHelloDone) = HandshakeType_ServerHelloDone
typeOfHandshake (ClientKeyXchg {}) = HandshakeType_ClientKeyXchg
typeOfHandshake (ServerKeyXchg {}) = HandshakeType_ServerKeyXchg
typeOfHandshake (CertRequest {}) = HandshakeType_CertRequest
typeOfHandshake (CertVerify {}) = HandshakeType_CertVerify
typeOfHandshake (Finished {}) = HandshakeType_Finished
typeOfHandshake ClientHello{} = HandshakeType_ClientHello
typeOfHandshake ServerHello{} = HandshakeType_ServerHello
typeOfHandshake Certificates{} = HandshakeType_Certificate
typeOfHandshake HelloRequest = HandshakeType_HelloRequest
typeOfHandshake ServerHelloDone = HandshakeType_ServerHelloDone
typeOfHandshake ClientKeyXchg{} = HandshakeType_ClientKeyXchg
typeOfHandshake ServerKeyXchg{} = HandshakeType_ServerKeyXchg
typeOfHandshake CertRequest{} = HandshakeType_CertRequest
typeOfHandshake CertVerify{} = HandshakeType_CertVerify
typeOfHandshake Finished{} = HandshakeType_Finished

numericalVer :: Version -> (Word8, Word8)
numericalVer SSL2 = (2, 0)
6 changes: 3 additions & 3 deletions core/Network/TLS/Wire.hs
Original file line number Diff line number Diff line change
@@ -121,12 +121,12 @@ getInteger16 = os2ip <$> getOpaque16
getBigNum16 :: Get BigNum
getBigNum16 = BigNum <$> getOpaque16

getList :: Int -> (Get (Int, a)) -> Get [a]
getList :: Int -> Get (Int, a) -> Get [a]
getList totalLen getElement = isolate totalLen (getElements totalLen)
where getElements len
| len < 0 = error "list consumed too much data. should never happen with isolate."
| len == 0 = return []
| otherwise = getElement >>= \(elementLen, a) -> liftM ((:) a) (getElements (len - elementLen))
| otherwise = getElement >>= \(elementLen, a) -> (:) a <$> getElements (len - elementLen)

processBytes :: Int -> Get a -> Get a
processBytes i f = isolate i f
@@ -144,7 +144,7 @@ putWord32 = putWord32be

putWords16 :: [Word16] -> Put
putWords16 l = do
putWord16 $ 2 * (fromIntegral $ length l)
putWord16 $ 2 * fromIntegral (length l)
mapM_ putWord16 l

putWord24 :: Int -> Put
6 changes: 3 additions & 3 deletions core/Network/TLS/X509.hs
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ data CertificateUsage =
wrapCertificateChecks :: [FailedReason] -> CertificateUsage
wrapCertificateChecks [] = CertificateUsageAccept
wrapCertificateChecks l
| Expired `elem` l = CertificateUsageReject $ CertificateRejectExpired
| InFuture `elem` l = CertificateUsageReject $ CertificateRejectExpired
| UnknownCA `elem` l = CertificateUsageReject $ CertificateRejectUnknownCA
| Expired `elem` l = CertificateUsageReject CertificateRejectExpired
| InFuture `elem` l = CertificateUsageReject CertificateRejectExpired
| UnknownCA `elem` l = CertificateUsageReject CertificateRejectUnknownCA
| otherwise = CertificateUsageReject $ CertificateRejectOther (show l)