Skip to content

Commit

Permalink
Merge pull request #457 from kosmikus/fix-router-sharing
Browse files Browse the repository at this point in the history
Fix router sharing
  • Loading branch information
kosmikus committed Apr 12, 2016
2 parents d4c6f67 + b1a6d88 commit 15143cc
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 270 deletions.
9 changes: 5 additions & 4 deletions servant-server/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
0.7
---

* The `Router` type has been changed. There are now more situations where
servers will make use of static lookups to efficiently route the request
to the correct endpoint. Functions `layout` and `layoutWithContext` have
been added to visualize the router layout for debugging purposes. Test
* The `Router` type has been changed. Static router tables should now
be properly shared between requests, drastically increasing the
number of situations where servers will be able to route requests
efficiently. Functions `layout` and `layoutWithContext` have been
added to visualize the router layout for debugging purposes. Test
cases for expected router layouts have been added.
* Export `throwError` from module `Servant`
* Add `Handler` type synonym
Expand Down
22 changes: 6 additions & 16 deletions servant-server/src/Servant/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,13 @@ serve p = serveWithContext p EmptyContext

serveWithContext :: (HasServer layout context)
=> Proxy layout -> Context context -> Server layout -> Application
serveWithContext p context server = toApplication (runRouter (route p context d))
where
d = Delayed r r r r (\ _ _ _ -> Route server)
r = return (Route ())
serveWithContext p context server =
toApplication (runRouter (route p context (emptyDelayed (Route server))))

-- | The function 'layout' produces a textual description of the internal
-- router layout for debugging purposes. Note that the router layout is
-- determined just by the API, not by the handlers.
--
-- This function makes certain assumptions about the well-behavedness of
-- the 'HasServer' instances of the combinators which should be ok for the
-- core servant constructions, but might not be satisfied for some other
-- combinators provided elsewhere. It is possible that the function may
-- crash for these.
--
-- Example:
--
-- For the following API
Expand All @@ -168,7 +160,7 @@ serveWithContext p context server = toApplication (runRouter (route p context d)
-- > │ └─ e/
-- > │ └─•
-- > ├─ b/
-- > │ └─ <dyn>/
-- > │ └─ <capture>/
-- > │ ├─•
-- > │ ┆
-- > │ └─•
Expand All @@ -185,7 +177,7 @@ serveWithContext p context server = toApplication (runRouter (route p context d)
--
-- [@─•@] Leaves reflect endpoints.
--
-- [@\<dyn\>/@] This is a delayed capture of a path component.
-- [@\<capture\>/@] This is a delayed capture of a path component.
--
-- [@\<raw\>@] This is a part of the API we do not know anything about.
--
Expand All @@ -200,10 +192,8 @@ layout p = layoutWithContext p EmptyContext
-- | Variant of 'layout' that takes an additional 'Context'.
layoutWithContext :: (HasServer layout context)
=> Proxy layout -> Context context -> Text
layoutWithContext p context = routerLayout (route p context d)
where
d = Delayed r r r r (\ _ _ _ -> FailFatal err501)
r = return (Route ())
layoutWithContext p context =
routerLayout (route p context (emptyDelayed (FailFatal err501)))

-- Documentation

Expand Down
18 changes: 11 additions & 7 deletions servant-server/src/Servant/Server/Experimental/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

module Servant.Server.Experimental.Auth where

import Control.Monad.Trans (liftIO)
import Control.Monad.Trans.Except (runExceptT)
import Data.Proxy (Proxy (Proxy))
import Data.Typeable (Typeable)
Expand All @@ -24,10 +25,11 @@ import Servant.Server.Internal (HasContextEntry,
HasServer, ServerT,
getContextEntry,
route)
import Servant.Server.Internal.Router (Router' (WithRequest))
import Servant.Server.Internal.RoutingApplication (RouteResult (FailFatal, Route),
addAuthCheck)
import Servant.Server.Internal.ServantErr (ServantErr, Handler)
import Servant.Server.Internal.RoutingApplication (addAuthCheck,
delayedFailFatal,
DelayedIO,
withRequest)
import Servant.Server.Internal.ServantErr (Handler)

-- * General Auth

Expand Down Expand Up @@ -57,8 +59,10 @@ instance ( HasServer api context
type ServerT (AuthProtect tag :> api) m =
AuthServerData (AuthProtect tag) -> ServerT api m

route Proxy context subserver = WithRequest $ \ request ->
route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request)
route Proxy context subserver =
route (Proxy :: Proxy api) context (subserver `addAuthCheck` withRequest authCheck)
where
authHandler :: Request -> Handler (AuthServerData (AuthProtect tag))
authHandler = unAuthHandler (getContextEntry context)
authCheck = fmap (either FailFatal Route) . runExceptT . authHandler
authCheck :: Request -> DelayedIO (AuthServerData (AuthProtect tag))
authCheck = (>>= either delayedFailFatal return) . liftIO . runExceptT . authHandler
111 changes: 58 additions & 53 deletions servant-server/src/Servant/Server/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ module Servant.Server.Internal
, module Servant.Server.Internal.ServantErr
) where

import Control.Monad.Trans (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC8
import qualified Data.ByteString.Lazy as BL
Expand Down Expand Up @@ -70,7 +71,11 @@ import Servant.Server.Internal.ServantErr
class HasServer layout context where
type ServerT layout (m :: * -> *) :: *

route :: Proxy layout -> Context context -> Delayed (Server layout) -> Router
route ::
Proxy layout
-> Context context
-> Delayed env (Server layout)
-> Router env

type Server layout = ServerT layout Handler

Expand All @@ -92,7 +97,7 @@ instance (HasServer a context, HasServer b context) => HasServer (a :<|> b) cont
type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m

route Proxy context server = choice (route pa context ((\ (a :<|> _) -> a) <$> server))
(route pb context ((\ (_ :<|> b) -> b) <$> server))
(route pb context ((\ (_ :<|> b) -> b) <$> server))
where pa = Proxy :: Proxy a
pb = Proxy :: Proxy b

Expand Down Expand Up @@ -120,12 +125,12 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer sublayout context)
a -> ServerT sublayout m

route Proxy context d =
DynamicRouter $ \ first ->
CaptureRouter $
route (Proxy :: Proxy sublayout)
context
(addCapture d $ case parseUrlPieceMaybe first :: Maybe a of
Nothing -> return $ Fail err400
Just v -> return $ Route v
(addCapture d $ \ txt -> case parseUrlPieceMaybe txt :: Maybe a of
Nothing -> delayedFail err400
Just v -> return v
)

allowedMethodHead :: Method -> Request -> Bool
Expand All @@ -144,41 +149,41 @@ processMethodRouter handleA status method headers request = case handleA of
bdy = if allowedMethodHead method request then "" else body
hdrs = (hContentType, cs contentT) : (fromMaybe [] headers)

methodCheck :: Method -> Request -> IO (RouteResult ())
methodCheck :: Method -> Request -> DelayedIO ()
methodCheck method request
| allowedMethod method request = return $ Route ()
| otherwise = return $ Fail err405
| allowedMethod method request = return ()
| otherwise = delayedFail err405

acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> IO (RouteResult ())
acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> DelayedIO ()
acceptCheck proxy accH
| canHandleAcceptH proxy (AcceptHeader accH) = return $ Route ()
| otherwise = return $ FailFatal err406
| canHandleAcceptH proxy (AcceptHeader accH) = return ()
| otherwise = delayedFailFatal err406

methodRouter :: (AllCTRender ctypes a)
=> Method -> Proxy ctypes -> Status
-> Delayed (Handler a)
-> Router
-> Delayed env (Handler a)
-> Router env
methodRouter method proxy status action = leafRouter route'
where
route' request respond =
route' env request respond =
let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
in runAction (action `addMethodCheck` methodCheck method request
`addAcceptCheck` acceptCheck proxy accH
) respond $ \ output -> do
) env request respond $ \ output -> do
let handleA = handleAcceptH proxy (AcceptHeader accH) output
processMethodRouter handleA status method Nothing request

methodRouterHeaders :: (GetHeaders (Headers h v), AllCTRender ctypes v)
=> Method -> Proxy ctypes -> Status
-> Delayed (Handler (Headers h v))
-> Router
-> Delayed env (Handler (Headers h v))
-> Router env
methodRouterHeaders method proxy status action = leafRouter route'
where
route' request respond =
route' env request respond =
let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
in runAction (action `addMethodCheck` methodCheck method request
`addAcceptCheck` acceptCheck proxy accH
) respond $ \ output -> do
) env request respond $ \ output -> do
let headers = getHeaders output
handleA = handleAcceptH proxy (AcceptHeader accH) (getResponse output)
processMethodRouter handleA status method (Just headers) request
Expand Down Expand Up @@ -230,8 +235,8 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context)
type ServerT (Header sym a :> sublayout) m =
Maybe a -> ServerT sublayout m

route Proxy context subserver = WithRequest $ \ request ->
let mheader = parseHeaderMaybe =<< lookup str (requestHeaders request)
route Proxy context subserver =
let mheader req = parseHeaderMaybe =<< lookup str (requestHeaders req)
in route (Proxy :: Proxy sublayout) context (passToServer subserver mheader)
where str = fromString $ symbolVal (Proxy :: Proxy sym)

Expand Down Expand Up @@ -262,10 +267,10 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context)
type ServerT (QueryParam sym a :> sublayout) m =
Maybe a -> ServerT sublayout m

route Proxy context subserver = WithRequest $ \ request ->
let querytext = parseQueryText $ rawQueryString request
param =
case lookup paramname querytext of
route Proxy context subserver =
let querytext r = parseQueryText $ rawQueryString r
param r =
case lookup paramname (querytext r) of
Nothing -> Nothing -- param absent from the query string
Just Nothing -> Nothing -- param present with no value -> Nothing
Just (Just v) -> parseQueryParamMaybe v -- if present, we try to convert to
Expand Down Expand Up @@ -298,13 +303,13 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout context)
type ServerT (QueryParams sym a :> sublayout) m =
[a] -> ServerT sublayout m

route Proxy context subserver = WithRequest $ \ request ->
let querytext = parseQueryText $ rawQueryString request
route Proxy context subserver =
let querytext r = parseQueryText $ rawQueryString r
-- if sym is "foo", we look for query string parameters
-- named "foo" or "foo[]" and call parseQueryParam on the
-- corresponding values
parameters = filter looksLikeParam querytext
values = mapMaybe (convert . snd) parameters
parameters r = filter looksLikeParam (querytext r)
values r = mapMaybe (convert . snd) (parameters r)
in route (Proxy :: Proxy sublayout) context (passToServer subserver values)
where paramname = cs $ symbolVal (Proxy :: Proxy sym)
looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]")
Expand All @@ -329,9 +334,9 @@ instance (KnownSymbol sym, HasServer sublayout context)
type ServerT (QueryFlag sym :> sublayout) m =
Bool -> ServerT sublayout m

route Proxy context subserver = WithRequest $ \ request ->
let querytext = parseQueryText $ rawQueryString request
param = case lookup paramname querytext of
route Proxy context subserver =
let querytext r = parseQueryText $ rawQueryString r
param r = case lookup paramname (querytext r) of
Just Nothing -> True -- param is there, with no value
Just (Just v) -> examine v -- param with a value
Nothing -> False -- param not in the query string
Expand All @@ -352,8 +357,8 @@ instance HasServer Raw context where

type ServerT Raw m = Application

route Proxy _ rawApplication = RawRouter $ \ request respond -> do
r <- runDelayed rawApplication
route Proxy _ rawApplication = RawRouter $ \ env request respond -> do
r <- runDelayed rawApplication env request
case r of
Route app -> app request (respond . Route)
Fail a -> respond $ Fail a
Expand Down Expand Up @@ -386,22 +391,22 @@ instance ( AllCTUnrender list a, HasServer sublayout context
type ServerT (ReqBody list a :> sublayout) m =
a -> ServerT sublayout m

route Proxy context subserver = WithRequest $ \ request ->
route (Proxy :: Proxy sublayout) context (addBodyCheck subserver (bodyCheck request))
route Proxy context subserver =
route (Proxy :: Proxy sublayout) context (addBodyCheck subserver bodyCheck)
where
bodyCheck request = do
bodyCheck = withRequest $ \ request -> do
-- See HTTP RFC 2616, section 7.2.1
-- http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1
-- See also "W3C Internet Media Type registration, consistency of use"
-- http://www.w3.org/2001/tag/2002/0129-mime
let contentTypeH = fromMaybe "application/octet-stream"
$ lookup hContentType $ requestHeaders request
mrqbody <- handleCTypeH (Proxy :: Proxy list) (cs contentTypeH)
<$> lazyRequestBody request
<$> liftIO (lazyRequestBody request)
case mrqbody of
Nothing -> return $ FailFatal err415
Just (Left e) -> return $ FailFatal err400 { errBody = cs e }
Just (Right v) -> return $ Route v
Nothing -> delayedFailFatal err415
Just (Left e) -> delayedFailFatal err400 { errBody = cs e }
Just (Right v) -> return v

-- | Make sure the incoming request starts with @"/path"@, strip it and
-- pass the rest of the request path to @sublayout@.
Expand All @@ -418,28 +423,28 @@ instance (KnownSymbol path, HasServer sublayout context) => HasServer (path :> s
instance HasServer api context => HasServer (RemoteHost :> api) context where
type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m

route Proxy context subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) context (passToServer subserver $ remoteHost req)
route Proxy context subserver =
route (Proxy :: Proxy api) context (passToServer subserver remoteHost)

instance HasServer api context => HasServer (IsSecure :> api) context where
type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m

route Proxy context subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) context (passToServer subserver $ secure req)
route Proxy context subserver =
route (Proxy :: Proxy api) context (passToServer subserver secure)

where secure req = if isSecure req then Secure else NotSecure

instance HasServer api context => HasServer (Vault :> api) context where
type ServerT (Vault :> api) m = Vault -> ServerT api m

route Proxy context subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) context (passToServer subserver $ vault req)
route Proxy context subserver =
route (Proxy :: Proxy api) context (passToServer subserver vault)

instance HasServer api context => HasServer (HttpVersion :> api) context where
type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m

route Proxy context subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) context (passToServer subserver $ httpVersion req)
route Proxy context subserver =
route (Proxy :: Proxy api) context (passToServer subserver httpVersion)

-- | Basic Authentication
instance ( KnownSymbol realm
Expand All @@ -450,12 +455,12 @@ instance ( KnownSymbol realm

type ServerT (BasicAuth realm usr :> api) m = usr -> ServerT api m

route Proxy context subserver = WithRequest $ \ request ->
route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request)
route Proxy context subserver =
route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck)
where
realm = BC8.pack $ symbolVal (Proxy :: Proxy realm)
basicAuthContext = getContextEntry context
authCheck req = runBasicAuth req realm basicAuthContext
authCheck = withRequest $ \ req -> runBasicAuth req realm basicAuthContext

-- * helpers

Expand Down
11 changes: 6 additions & 5 deletions servant-server/src/Servant/Server/Internal/BasicAuth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
module Servant.Server.Internal.BasicAuth where

import Control.Monad (guard)
import Control.Monad.Trans (liftIO)
import qualified Data.ByteString as BS
import Data.ByteString.Base64 (decodeLenient)
import Data.Monoid ((<>))
Expand Down Expand Up @@ -57,13 +58,13 @@ decodeBAHdr req = do

-- | Run and check basic authentication, returning the appropriate http error per
-- the spec.
runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> IO (RouteResult usr)
runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> DelayedIO usr
runBasicAuth req realm (BasicAuthCheck ba) =
case decodeBAHdr req of
Nothing -> plzAuthenticate
Just e -> ba e >>= \res -> case res of
Just e -> liftIO (ba e) >>= \res -> case res of
BadPassword -> plzAuthenticate
NoSuchUser -> plzAuthenticate
Unauthorized -> return $ FailFatal err403
Authorized usr -> return $ Route usr
where plzAuthenticate = return $ FailFatal err401 { errHeaders = [mkBAChallengerHdr realm] }
Unauthorized -> delayedFailFatal err403
Authorized usr -> return usr
where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] }

0 comments on commit 15143cc

Please sign in to comment.