Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Expose low-level asynchronous versions of accept, connect, read, writ…

…e and shutdown.

Ignore-this: 7e7fa4c6730a45bbafec997a2032d1ec

darcs-hash:20110730123554-3a530-fd665ebb2e6af900a4c3c6a2efe2c0f03a9e8f50.gz
  • Loading branch information...
commit a2a77f87691a3ea15572eb033550882be46d6179 1 parent 98e9b49
@mvv mvv authored
Showing with 105 additions and 61 deletions.
  1. +105 −61 OpenSSL/Session.hsc
View
166 OpenSSL/Session.hsc
@@ -1,4 +1,8 @@
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DeriveFoldable #-}
+{-# LANGUAGE DeriveTraversable #-}
+{-# LANGUAGE NamedFieldPuns #-}
-- | Functions for handling SSL connections. These functions use GHC specific
-- calls to cooperative the with the scheduler so that 'blocking' functions
-- only actually block the Haskell thread, not a whole OS thread.
@@ -21,15 +25,21 @@ module OpenSSL.Session
-- * SSL connections
, SSL
+ , SSLResult(..)
, connection
, fdConnection
, accept
+ , tryAccept
, connect
+ , tryConnect
, read
+ , tryRead
, write
+ , tryWrite
, lazyRead
, lazyWrite
, shutdown
+ , tryShutdown
, ShutdownType(..)
, getPeerCertificate
, getVerifyResult
@@ -50,13 +60,16 @@ module OpenSSL.Session
#include "openssl/ssl.h"
-import Prelude hiding (catch, read, ioError)
+import Prelude hiding (catch, read, ioError, mapM)
import Control.Concurrent (threadWaitWrite, threadWaitRead)
import Control.Concurrent.QSem
import Control.Exception
-import Control.Monad
+import Control.Applicative ((<$>), (<$))
+import Control.Monad (void, unless)
import Data.Typeable
-import Foreign
+import Data.Foldable (Foldable)
+import Data.Traversable (Traversable, forM)
+import Foreign hiding (void)
import Foreign.C
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
@@ -253,7 +266,11 @@ data SSL_
-- times. Thus multiple OS threads can be 'blocked' inside IO in the same SSL
-- object at a time, because they aren't really in the SSL object, they are
-- waiting for the RTS to wake the Haskell thread.
-newtype SSL = SSL (QSem, ForeignPtr SSL_, Fd, Maybe Socket)
+data SSL = SSL { sslSem :: QSem
+ , sslPtr :: ForeignPtr SSL_
+ , sslFd :: Fd
+ , sslSocket :: Maybe Socket
+ }
foreign import ccall unsafe "SSL_new" _ssl_new :: Ptr SSLContext_ -> IO (Ptr SSL_)
foreign import ccall unsafe "&SSL_free" _ssl_free :: FunPtr (Ptr SSL_ -> IO ())
@@ -267,7 +284,11 @@ connection' context fd@(Fd fdInt) sock = do
_ssl_set_fd ssl fdInt
return ssl
fpssl <- newForeignPtr _ssl_free ssl
- return $ SSL (sem, fpssl, fd, sock)
+ return $ SSL { sslSem = sem
+ , sslPtr = fpssl
+ , sslFd = fd
+ , sslSocket = sock
+ }
-- | Wrap a Socket in an SSL connection. Reading and writing to the Socket
-- after this will cause weird errors in the SSL code. The SSL object
@@ -282,9 +303,9 @@ fdConnection :: SSLContext -> Fd -> IO SSL
fdConnection context fd = connection' context fd Nothing
withSSL :: SSL -> (Ptr SSL_ -> IO a) -> IO a
-withSSL (SSL (sem, ssl, _, _)) action = do
- waitQSem sem
- finally (withForeignPtr ssl action) $ signalQSem sem
+withSSL (SSL {sslSem, sslPtr}) action = do
+ waitQSem sslSem
+ finally (withForeignPtr sslPtr action) $ signalQSem sslSem
foreign import ccall "SSL_accept" _ssl_accept :: Ptr SSL_ -> IO CInt
foreign import ccall "SSL_connect" _ssl_connect :: Ptr SSL_ -> IO CInt
@@ -302,40 +323,54 @@ throwSSLException x = throw (UnknownError (fromIntegral x))
-- | This is the type of an SSL IO operation. EOF and termination are handled
-- by exceptions while everything else is one of these. Note that reading
-- from an SSL socket can result in WantWrite and vice versa.
-data SSLIOResult = Done CInt -- ^ successfully mananged *n* bytes
+data SSLResult a = SSLDone a -- ^ operation finished successfully
| WantRead -- ^ needs more data from the network
| WantWrite -- ^ needs more outgoing buffer space
- deriving (Eq)
+ deriving (Eq, Show, Functor, Foldable, Traversable)
+-- | Block until the operation is finished.
+sslBlock :: (SSL -> IO (SSLResult a)) -> SSL -> IO a
+sslBlock action ssl = do
+ result <- action ssl
+ case result of
+ SSLDone r -> return r
+ WantRead -> threadWaitRead (sslFd ssl) >> sslBlock action ssl
+ WantWrite -> threadWaitWrite (sslFd ssl) >> sslBlock action ssl
-- | Perform an SSL operation which can return non-blocking error codes, thus
-- requesting that the operation be performed when data or buffer space is
-- availible.
-sslDoHandshake :: (Ptr SSL_ -> IO CInt) -> SSL -> IO CInt
-sslDoHandshake action ssl@(SSL (_, _, fd, _)) = do
- let f ssl = do
- n <- action ssl
- case n of
- n | n >= 0 -> return $ Done n
- _ -> do
- err <- _ssl_get_error ssl n
- case err of
- (#const SSL_ERROR_WANT_READ) -> return WantRead
- (#const SSL_ERROR_WANT_WRITE) -> return WantWrite
- _ -> throwSSLException err
- result <- withSSL ssl f
- case result of
- Done n -> return n
- WantRead -> threadWaitRead fd >> sslDoHandshake action ssl
- WantWrite -> threadWaitWrite fd >> sslDoHandshake action ssl
+sslTryHandshake :: (Ptr SSL_ -> IO CInt) -> SSL -> IO (SSLResult CInt)
+sslTryHandshake action = flip withSSL $ \pSsl -> do
+ n <- action pSsl
+ if n >= 0
+ then return $ SSLDone n
+ else do
+ err <- _ssl_get_error pSsl n
+ case err of
+ (#const SSL_ERROR_WANT_READ) -> return WantRead
+ (#const SSL_ERROR_WANT_WRITE) -> return WantWrite
+ _ -> throwSSLException err
-- | Perform an SSL server handshake
accept :: SSL -> IO ()
-accept ssl = sslDoHandshake _ssl_accept ssl >>= failIf_ (/= 1)
+accept = sslBlock tryAccept
+
+-- | Try to perform an SSL server handshake without blocking
+tryAccept :: SSL -> IO (SSLResult ())
+tryAccept ssl = do
+ result <- sslTryHandshake _ssl_accept ssl
+ forM result $ failIf_ (/= 1)
-- | Perform an SSL client handshake
connect :: SSL -> IO ()
-connect ssl = sslDoHandshake _ssl_connect ssl >>= failIf_ (/= 1)
+connect = sslBlock tryConnect
+
+-- | Try to perform an SSL client handshake without blocking
+tryConnect :: SSL -> IO (SSLResult ())
+tryConnect ssl = do
+ result <- sslTryHandshake _ssl_connect ssl
+ forM result $ failIf_ (/= 1)
foreign import ccall "SSL_read" _ssl_read :: Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt
foreign import ccall unsafe "SSL_get_shutdown" _ssl_get_shutdown :: Ptr SSL_ -> IO CInt
@@ -353,11 +388,11 @@ sslIOInner :: (Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt) -- ^ the SSL IO functi
-> Ptr CChar -- ^ the buffer to pass
-> Int -- ^ the length to pass
-> Ptr SSL_
- -> IO SSLIOResult
+ -> IO (SSLResult CInt)
sslIOInner f ptr nbytes ssl = do
n <- f ssl (castPtr ptr) $ fromIntegral nbytes
case n of
- n | n > 0 -> return $ Done $ fromIntegral n
+ n | n > 0 -> return $ SSLDone $ fromIntegral n
| n == 0 -> do
shutdown <- _ssl_get_shutdown ssl
if shutdown .&. (#const SSL_RECEIVED_SHUTDOWN) == 0
@@ -370,37 +405,42 @@ sslIOInner f ptr nbytes ssl = do
(#const SSL_ERROR_WANT_WRITE) -> return WantWrite
_ -> throwSSLException err
--- | Try the read the given number of bytes from an SSL connection. On EOF an
+catchEOF :: a -> IO a -> IO a
+catchEOF x m = m `catch` \e -> if isEOFError e then return x else throwIO e
+
+-- | Try to read the given number of bytes from an SSL connection. On EOF an
-- empty ByteString is returned. If the connection dies without a graceful
-- SSL shutdown, an exception is raised.
read :: SSL -> Int -> IO B.ByteString
-read ssl@(SSL (_, _, fd, _)) nbytes = B.createAndTrim nbytes $ f ssl
- where
- f ssl ptr
- = do result <- withSSL ssl $ sslIOInner _ssl_read (castPtr ptr) nbytes
- case result of
- Done n -> return $ fromIntegral n
- WantRead -> threadWaitRead fd >> f ssl ptr
- WantWrite -> threadWaitWrite fd >> f ssl ptr
- `catch`
- \ ioe ->
- if isEOFError ioe then
- return 0
- else
- ioError ioe -- rethrow
+read ssl nbytes = B.createAndTrim nbytes $ \ptr -> do
+ let doRead = withSSL ssl $ sslIOInner _ssl_read (castPtr ptr) nbytes
+ catchEOF 0 $ fromIntegral <$> sslBlock (const doRead) ssl
+
+-- | Try to read the given number of bytes from an SSL connection
+-- without blocking.
+tryRead :: SSL -> Int -> IO (SSLResult B.ByteString)
+tryRead ssl nbytes = do
+ (bs, result) <- B.createAndTrim' nbytes $ \ptr -> do
+ result <- catchEOF (SSLDone 0) $ withSSL ssl $
+ sslIOInner _ssl_read (castPtr ptr) nbytes
+ return $ case result of
+ SSLDone n -> (0, fromIntegral n, SSLDone ())
+ WantRead -> (0, 0, WantRead)
+ WantWrite -> (0, 0, WantWrite)
+ return $ bs <$ result
foreign import ccall "SSL_write" _ssl_write :: Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt
-- | Write a given ByteString to the SSL connection. Either all the data is
-- written or an exception is raised because of an error
write :: SSL -> B.ByteString -> IO ()
-write ssl@(SSL (_, _, fd, _)) bs = B.unsafeUseAsCStringLen bs $ f ssl where
- f ssl (ptr, len) = do
- result <- withSSL ssl $ sslIOInner _ssl_write ptr len
- case result of
- Done _ -> return ()
- WantRead -> threadWaitRead fd >> f ssl (ptr, len)
- WantWrite -> threadWaitWrite fd >> f ssl (ptr, len)
+write ssl bs = void $ sslBlock (`tryWrite` bs) ssl
+
+-- | Try to write a given ByteString to the SSL connection without blocking.
+tryWrite :: SSL -> B.ByteString -> IO (SSLResult ())
+tryWrite ssl bs =
+ B.unsafeUseAsCStringLen bs $ \(ptr, len) ->
+ ((() <$) <$>) $ withSSL ssl $ sslIOInner _ssl_write ptr len
-- | Lazily read all data until reaching EOF. If the connection dies
-- without a graceful SSL shutdown, an exception is raised.
@@ -439,12 +479,18 @@ data ShutdownType = Bidirectional -- ^ wait for the peer to also shutdown
-- This can either just send a shutdown, or can send and wait for the peer's
-- shutdown message.
shutdown :: SSL -> ShutdownType -> IO ()
-shutdown ssl ty = do
- n <- sslDoHandshake _ssl_shutdown ssl
- case ty of
- Unidirectional -> return ()
- Bidirectional -> unless (n == 1)
- $ shutdown ssl ty
+shutdown ssl ty = sslBlock (`tryShutdown` ty) ssl
+
+-- | Try to cleanly shutdown an SSL connection without blocking.
+tryShutdown :: SSL -> ShutdownType -> IO (SSLResult ())
+tryShutdown ssl ty = do
+ result <- sslTryHandshake _ssl_shutdown ssl
+ case result of
+ SSLDone n -> case ty of
+ Bidirectional | n /= 1 -> tryShutdown ssl ty
+ _ -> return $ SSLDone ()
+ WantRead -> return WantRead
+ WantWrite -> return WantWrite
foreign import ccall "SSL_get_peer_certificate" _ssl_get_peer_cert :: Ptr SSL_ -> IO (Ptr X509_)
@@ -477,11 +523,9 @@ getVerifyResult ssl =
-- | Get the socket underlying an SSL connection
sslSocket :: SSL -> Maybe Socket
-sslSocket (SSL (_, _, _, socket)) = socket
-- | Get the underlying socket Fd
sslFd :: SSL -> Fd
-sslFd (SSL (_, _, fd, _)) = fd
-- | The root exception type for all SSL exceptions.
data SomeSSLException
Please sign in to comment.
Something went wrong with that request. Please try again.