Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post-handshake authentication #366

Merged
merged 14 commits into from
Jun 25, 2019
3 changes: 2 additions & 1 deletion core/Network/TLS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ module Network.TLS
-- ** Negotiated
, getNegotiatedProtocol
, getClientSNI
-- ** Updating keys
-- ** Post-handshake actions
, updateKey
, KeyUpdateRequest(..)
, requestCertificate

-- * Raw types
, ProtocolType(..)
Expand Down
8 changes: 8 additions & 0 deletions core/Network/TLS/Context.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ module Network.TLS.Context
import Network.TLS.Backend
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.PostHandshake (postHandshakeAuthClientWith, postHandshakeAuthServerWith)
import Network.TLS.X509
import Network.TLS.RNG

Expand All @@ -92,6 +94,7 @@ class TLSParams a where
getTLSRole :: a -> Role
doHandshake :: a -> Context -> IO ()
doHandshakeWith :: a -> Context -> Handshake -> IO ()
doPostHandshakeAuthWith :: a -> Context -> Handshake13 -> IO ()

instance TLSParams ClientParams where
getTLSCommonParams cparams = ( clientSupported cparams
Expand All @@ -101,6 +104,7 @@ instance TLSParams ClientParams where
getTLSRole _ = ClientRole
doHandshake = handshakeClient
doHandshakeWith = handshakeClientWith
doPostHandshakeAuthWith = postHandshakeAuthClientWith

instance TLSParams ServerParams where
getTLSCommonParams sparams = ( serverSupported sparams
Expand All @@ -110,6 +114,7 @@ instance TLSParams ServerParams where
getTLSRole _ = ServerRole
doHandshake = handshakeServer
doHandshakeWith = handshakeServerWith
doPostHandshakeAuthWith = postHandshakeAuthServerWith

-- | create a new context using the backend and parameters specified.
contextNew :: (MonadIO m, HasBackend backend, TLSParams params)
Expand Down Expand Up @@ -144,6 +149,7 @@ contextNew backend params = liftIO $ do
rx <- newMVar newRecordState
hs <- newMVar Nothing
as <- newIORef []
crs <- newIORef []
lockWrite <- newMVar ()
lockRead <- newMVar ()
lockState <- newMVar ()
Expand All @@ -158,6 +164,7 @@ contextNew backend params = liftIO $ do
, ctxHandshake = hs
, ctxDoHandshake = doHandshake params
, ctxDoHandshakeWith = doHandshakeWith params
, ctxDoPostHandshakeAuthWith = doPostHandshakeAuthWith params
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
Expand All @@ -168,6 +175,7 @@ contextNew backend params = liftIO $ do
, ctxLockRead = lockRead
, ctxLockState = lockState
, ctxPendingActions = as
, ctxCertRequests = crs
, ctxKeyLogger = debugKeyLogger debug
}

Expand Down
28 changes: 28 additions & 0 deletions core/Network/TLS/Context/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ module Network.TLS.Context.Internal
, runRxState
, usingHState
, getHState
, saveHState
, restoreHState
, getStateRNG
, tls13orLater
, addCertRequest13
, getCertRequest13
) where

import Network.TLS.Backend
Expand All @@ -71,6 +75,8 @@ import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Imports
import Network.TLS.Types
import Network.TLS.Util
import qualified Data.ByteString as B

import Control.Concurrent.MVar
Expand Down Expand Up @@ -111,12 +117,14 @@ data Context = Context
, ctxHandshake :: MVar (Maybe HandshakeState) -- ^ optional handshake state
, ctxDoHandshake :: Context -> IO ()
, ctxDoHandshakeWith :: Context -> Handshake -> IO ()
, ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
, ctxHooks :: IORef Hooks -- ^ hooks for this context
, ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state)
, ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state)
, ctxLockState :: MVar () -- ^ lock used during read/write when receiving and sending packet.
-- it is usually nested in a write or read lock.
, ctxPendingActions :: IORef [PendingAction]
, ctxCertRequests :: IORef [Handshake13] -- ^ pending PHA requests
, ctxKeyLogger :: String -> IO ()
}

Expand Down Expand Up @@ -225,6 +233,14 @@ usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst ->
getHState :: MonadIO m => Context -> m (Maybe HandshakeState)
getHState ctx = liftIO $ readMVar (ctxHandshake ctx)

saveHState :: Context -> IO (Saved (Maybe HandshakeState))
saveHState ctx = saveMVar (ctxHandshake ctx)

restoreHState :: Context
-> Saved (Maybe HandshakeState)
-> IO (Saved (Maybe HandshakeState))
restoreHState ctx = restoreMVar (ctxHandshake ctx)

runTxState :: Context -> RecordM a -> IO (Either TLSError a)
runTxState ctx f = do
ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx)
Expand Down Expand Up @@ -270,3 +286,15 @@ tls13orLater ctx = do
return $ case ev of
Left _ -> False
Right v -> v >= TLS13

addCertRequest13 :: Context -> Handshake13 -> IO ()
addCertRequest13 ctx certReq = modifyIORef (ctxCertRequests ctx) (certReq:)

getCertRequest13 :: Context -> CertReqContext -> IO (Maybe Handshake13)
getCertRequest13 ctx context = do
let ref = ctxCertRequests ctx
l <- readIORef ref
let (matched, others) = partition (\(CertRequest13 c _) -> context == c) l
case matched of
[] -> return Nothing
(certReq:_) -> writeIORef ref others >> return (Just certReq)
6 changes: 6 additions & 0 deletions core/Network/TLS/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ module Network.TLS.Core
, recvData'
, updateKey
, KeyUpdateRequest(..)
, requestCertificate
) where

import Network.TLS.Cipher
Expand All @@ -46,6 +47,7 @@ import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Process
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.PostHandshake
import Network.TLS.KeySchedule
import Network.TLS.Types (Role(..), HostName)
import Network.TLS.Util (catchException, mapChunks_)
Expand Down Expand Up @@ -221,6 +223,10 @@ recvData13 ctx = do
else do
let reason = "received key update before established"
terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason
loopHandshake13 (h@CertRequest13{}:hs) =
postHandshakeAuthWith ctx h >> loopHandshake13 hs
loopHandshake13 (h@Certificate13{}:hs) =
postHandshakeAuthWith ctx h >> loopHandshake13 hs
loopHandshake13 (h:hs) = do
mPendingAction <- popPendingAction ctx
case mPendingAction of
Expand Down
11 changes: 11 additions & 0 deletions core/Network/TLS/Extension.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ module Network.TLS.Extension
, KeyShare(..)
, KeyShareEntry(..)
, MessageType(..)
, PostHandshakeAuth(..)
, PskKexMode(..)
, PskKeyExchangeModes(..)
, PskIdentity(..)
Expand Down Expand Up @@ -466,6 +467,16 @@ decodeSignatureAlgorithms = runGetMaybe $ do

------------------------------------------------------------

data PostHandshakeAuth = PostHandshakeAuth deriving (Show,Eq)

instance Extension PostHandshakeAuth where
extensionID _ = extensionID_PostHandshakeAuth
extensionEncode _ = B.empty
extensionDecode MsgTClientHello = runGetMaybe (pure PostHandshakeAuth)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that return is safer than pure from the backward compatibility point of view.

extensionDecode _ = error "extensionDecode: PostHandshakeAuth"

------------------------------------------------------------

newtype SignatureAlgorithmsCert = SignatureAlgorithmsCert [HashAndSignatureAlgorithm] deriving (Show,Eq)

instance Extension SignatureAlgorithmsCert where
Expand Down
20 changes: 0 additions & 20 deletions core/Network/TLS/Handshake.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@ module Network.TLS.Handshake

import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.IO
import Network.TLS.Util (catchException)
import Network.TLS.Imports

import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Client
import Network.TLS.Handshake.Server

import Control.Monad.State.Strict
import Control.Exception (IOException, handle, fromException)

-- | Handshake for a new TLS connection
-- This is to be called at the beginning of a connection, and during renegotiation
Expand All @@ -41,18 +36,3 @@ handshake ctx =
handshakeWith :: MonadIO m => Context -> Handshake -> m ()
handshakeWith ctx hs =
liftIO $ withWriteLock ctx $ handleException ctx $ ctxDoHandshakeWith ctx ctx hs

handleException :: Context -> IO () -> IO ()
handleException ctx f = catchException f $ \exception -> do
let tlserror = fromMaybe (Error_Misc $ show exception) $ fromException exception
setEstablished ctx NotEstablished
handle ignoreIOErr $ do
tls13 <- tls13orLater ctx
if tls13 then
sendPacket13 ctx $ Alert13 $ errorToAlert tlserror
else
sendPacket ctx $ Alert $ errorToAlert tlserror
handshakeFailed tlserror
where
ignoreIOErr :: IOException -> IO ()
ignoreIOErr _ = return ()
Loading