Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream endpoint support for servant #836

Merged
merged 11 commits into from
Dec 3, 2017
1 change: 1 addition & 0 deletions servant-client-core/src/Servant/Client/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ module Servant.Client.Core
, Response(..)
, RunClient(..)
, module Servant.Client.Core.Internal.BaseUrl
, StreamingResponse(..)

-- * Writing HasClient instances
-- | These functions need not be re-exported by backend libraries.
Expand Down
57 changes: 57 additions & 0 deletions servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand All @@ -17,8 +18,12 @@ module Servant.Client.Core.Internal.HasClient where
import Prelude ()
import Prelude.Compat

import Control.Concurrent (newMVar, modifyMVar)
import Control.Monad (when)
import Data.Foldable (toList)
import qualified Data.ByteString.Lazy as BL
import Data.List (foldl')
import Data.Monoid ((<>))
import Data.Proxy (Proxy (Proxy))
import Data.Sequence (fromList)
import Data.String (fromString)
Expand All @@ -29,8 +34,11 @@ import Servant.API ((:<|>) ((:<|>)), (:>),
AuthProtect, BasicAuth,
BasicAuthData,
BuildHeadersTo (..),
BuildFromStream (..),
ByteStringParser (..),
Capture, CaptureAll,
Description, EmptyAPI,
FramingUnrender (..),
Header, Headers (..),
HttpVersion, IsSecure,
MimeRender (mimeRender),
Expand All @@ -40,6 +48,7 @@ import Servant.API ((:<|>) ((:<|>)), (:>),
QueryParams, Raw,
ReflectMethod (..),
RemoteHost, ReqBody,
Stream,
Summary, ToHttpApiData,
Vault, Verb,
WithNamedContext,
Expand Down Expand Up @@ -244,6 +253,54 @@ instance OVERLAPPING_
, getHeadersHList = buildHeadersTo . toList $ responseHeaders response
}

data ResultStream a = ResultStream ((forall b. (IO (Maybe (Either String a)) -> IO b) -> IO b))

instance OVERLAPPABLE_
( RunClient m, MimeUnrender ct a, ReflectMethod method,
FramingUnrender framing a, BuildFromStream a (f a)
) => HasClient m (Stream method framing ct (f a)) where

type Client m (Stream method framing ct (f a)) = m (ResultStream a)

clientWithRoute _pm Proxy req = do
sresp <- streamingRequest req
{ requestAccept = fromList [contentType (Proxy :: Proxy ct)]
, requestMethod = reflectMethod (Proxy :: Proxy method)
}
return $ ResultStream $ \k ->
runStreamingResponse sresp $ \(status,_headers,_httpversion,reader) -> do
when (H.statusCode status /= 200) $ error "bad status" --fixme
let unrender = unrenderFrames (Proxy :: Proxy framing) (Proxy :: Proxy a)
loop bs = do
res <- BL.fromStrict <$> reader
if BL.null res
then return $ parseEOF unrender res
else let sofar = (bs <> res)
in case parseIncremental unrender sofar of
Just x -> return x
Nothing -> loop sofar
(frameParser, remainder) <- loop BL.empty
state <- newMVar remainder
let frameLoop bs = do
res <- BL.fromStrict <$> reader
let addIsEmptyInfo (a, r) = (r, (a, BL.null r && BL.null res))
if BL.null res
then return . addIsEmptyInfo $ parseEOF frameParser res
else let sofar = (bs <> res)
in case parseIncremental frameParser res of
Just x -> return $ addIsEmptyInfo x
Nothing -> frameLoop sofar

go = processResult <$> modifyMVar state frameLoop
processResult (Right bs,isDone) =
if BL.null bs && isDone
then Nothing
else Just $ case mimeUnrender (Proxy :: Proxy ct) bs :: Either String a of
Left err -> Left err
Right x -> Right x
processResult (Left err, _) = Just (Left err)
k go


-- | If you use a 'Header' in one of your endpoints in your API,
-- the corresponding querying function will automatically take
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
Expand All @@ -15,6 +16,7 @@ import Prelude.Compat

import Control.Monad.Catch (Exception)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Semigroup ((<>))
import qualified Data.Sequence as Seq
Expand Down Expand Up @@ -70,6 +72,8 @@ data Response = Response
, responseHttpVersion :: HttpVersion
} deriving (Eq, Show, Generic, Typeable)

data StreamingResponse = StreamingResponse { runStreamingResponse :: forall a. ((Status, Seq.Seq Header, HttpVersion, IO BS.ByteString) -> IO a) -> IO a }

-- A GET request to the top-level path
defaultRequest :: Request
defaultRequest = Request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ import Servant.API (MimeUnrender,
contentTypes,
mimeUnrender)
import Servant.Client.Core.Internal.Request (Request, Response (..),
StreamingResponse (..),
ServantError (..))

class (Monad m) => RunClient m where
-- | How to make a request.
runRequest :: Request -> m Response
streamingRequest :: Request -> m StreamingResponse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is a little bit confusing, as it's the response that's streamed to the client, not the request streamed to the server. Don't you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what name do you suggest?

Copy link
Contributor Author

@gbaz gbaz Nov 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

streamedRequest? that has the same "which direction" problem in reverse imho.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. runRequestStreaming, streamingResponse, runStreaming? They all kind of suffer from the "which direction" problem though. So forget what I said. Unless someone comes up with a better suggestion, I'm fine with that name for now =)

throwServantError :: ServantError -> m a
catchServantError :: m a -> (ServantError -> m a) -> m a

Expand Down
12 changes: 12 additions & 0 deletions servant-client/src/Servant/Client/Internal/HttpClient.hs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ instance Alt ClientM where

instance RunClient ClientM where
runRequest = performRequest
streamingRequest = performStreamingRequest
throwServantError = throwError
catchServantError = catchError

Expand Down Expand Up @@ -115,6 +116,17 @@ performRequest req = do
throwError $ FailureResponse ourResponse
return ourResponse

performStreamingRequest :: Request -> ClientM StreamingResponse
performStreamingRequest req = do
m <- asks manager
burl <- asks baseUrl
let request = requestToClientRequest burl req

return $ StreamingResponse $
\k -> Client.withResponse request m $
\r ->
k (Client.responseStatus r, fromList $ Client.responseHeaders r, Client.responseVersion r, Client.responseBody r)

clientResponseToReponse :: Client.Response BSL.ByteString -> Response
clientResponseToReponse r = Response
{ responseStatusCode = Client.responseStatus r
Expand Down
1 change: 1 addition & 0 deletions servant-server/servant-server.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ library
, containers >= 0.5 && < 0.6
, exceptions >= 0.8 && < 0.9
, http-api-data >= 0.3 && < 0.4
, http-media >= 0.4 && < 0.8
, http-types >= 0.8 && < 0.10
, network-uri >= 2.6 && < 2.7
, monad-control >= 1.0.0.4 && < 1.1
Expand Down
83 changes: 79 additions & 4 deletions servant-server/src/Servant/Server/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

Expand All @@ -25,12 +26,15 @@ module Servant.Server.Internal
, module Servant.Server.Internal.ServantErr
) where

import Control.Monad (when)
import Control.Monad.Trans (liftIO)
import Control.Monad.Trans.Resource (runResourceT)
import qualified Data.ByteString as B
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Char8 as BC8
import qualified Data.ByteString.Lazy as BL
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Maybe (fromMaybe, mapMaybe,
isNothing, maybeToList)
import Data.Either (partitionEithers)
import Data.String (fromString)
import Data.String.Conversions (cs, (<>))
Expand All @@ -40,13 +44,15 @@ import Data.Typeable
import GHC.TypeLits (KnownNat, KnownSymbol, natVal,
symbolVal)
import Network.HTTP.Types hiding (Header, ResponseHeaders)
import qualified Network.HTTP.Media as NHM
import Network.Socket (SockAddr)
import Network.Wai (Application, Request, Response,
httpVersion, isSecure,
lazyRequestBody,
rawQueryString, remoteHost,
requestHeaders, requestMethod,
responseLBS, vault)
responseLBS, responseStream,
vault)
import Prelude ()
import Prelude.Compat
import Web.HttpApiData (FromHttpApiData, parseHeader,
Expand All @@ -60,11 +66,16 @@ import Servant.API ((:<|>) (..), (:>), BasicAuth, Capt
QueryParam, QueryParams, Raw,
RemoteHost, ReqBody, Vault,
WithNamedContext,
Description, Summary)
Description, Summary,
Accept(..),
FramingRender(..), Stream,
StreamGenerator(..), ToStreamGenerator(..),
BoundaryStrategy(..))
import Servant.API.ContentTypes (AcceptHeader (..),
AllCTRender (..),
AllCTUnrender (..),
AllMime,
MimeRender(..),
canHandleAcceptH)
import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders,
getResponse)
Expand Down Expand Up @@ -276,6 +287,70 @@ instance OVERLAPPING_
where method = reflectMethod (Proxy :: Proxy method)
status = toEnum . fromInteger $ natVal (Proxy :: Proxy status)


instance OVERLAPPABLE_
( MimeRender ctype a, ReflectMethod method,
FramingRender framing ctype, ToStreamGenerator f a
) => HasServer (Stream method framing ctype (f a)) context where

type ServerT (Stream method framing ctype (f a)) m = m (f a)
hoistServerWithContext _ _ nt s = nt s

route Proxy _ = streamRouter ([],) method (Proxy :: Proxy framing) (Proxy :: Proxy ctype)
where method = reflectMethod (Proxy :: Proxy method)

instance OVERLAPPING_
( MimeRender ctype a, ReflectMethod method,
FramingRender framing ctype, ToStreamGenerator f a,
GetHeaders (Headers h (f a))
) => HasServer (Stream method framing ctype (Headers h (f a))) context where

type ServerT (Stream method framing ctype (Headers h (f a))) m = m (Headers h (f a))
hoistServerWithContext _ _ nt s = nt s

route Proxy _ = streamRouter (\x -> (getHeaders x, getResponse x)) method (Proxy :: Proxy framing) (Proxy :: Proxy ctype)
where method = reflectMethod (Proxy :: Proxy method)


streamRouter :: (MimeRender ctype a, FramingRender framing ctype, ToStreamGenerator f a) =>
(b -> ([(HeaderName, B.ByteString)], f a))
-> Method
-> Proxy framing
-> Proxy ctype
-> Delayed env (Handler b)
-> Router env
streamRouter splitHeaders method framingproxy ctypeproxy action = leafRouter $ \env request respond ->
let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
cmediatype = NHM.matchAccept [contentType ctypeproxy] accH
accCheck = when (isNothing cmediatype) $ delayedFail err406
contentHeader = (hContentType, NHM.renderHeader . maybeToList $ cmediatype)
in runAction (action `addMethodCheck` methodCheck method request
`addAcceptCheck` accCheck
) env request respond $ \ output ->
let (headers, fa) = splitHeaders output
k = getStreamGenerator . toStreamGenerator $ fa in
Route $ responseStream status200 (contentHeader : headers) $ \write flush -> do
write . BB.lazyByteString $ header framingproxy ctypeproxy
case boundary framingproxy ctypeproxy of
BoundaryStrategyBracket f ->
let go x = let bs = mimeRender ctypeproxy $ x
(before, after) = f bs
in write ( BB.lazyByteString before
<> BB.lazyByteString bs
<> BB.lazyByteString after) >> flush
in k go go
BoundaryStrategyIntersperse sep -> k
(\x -> do
write . BB.lazyByteString . mimeRender ctypeproxy $ x
flush)
(\x -> do
write . (BB.lazyByteString sep <>) . BB.lazyByteString . mimeRender ctypeproxy $ x
flush)
BoundaryStrategyGeneral f ->
let go = (>> flush) . write . BB.lazyByteString . f . mimeRender ctypeproxy
in k go go
write . BB.lazyByteString $ trailer framingproxy ctypeproxy

-- | If you use 'Header' in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a function
-- that takes an argument of the type specified by 'Header'.
Expand Down Expand Up @@ -318,7 +393,7 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
<> fromString headerName
<> " failed: " <> e
}
Right header -> return $ Just header
Right hdr -> return $ Just hdr

-- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a function
Expand Down
1 change: 1 addition & 0 deletions servant/servant.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ library
Servant.API.IsSecure
Servant.API.QueryParam
Servant.API.Raw
Servant.API.Stream
Servant.API.RemoteHost
Servant.API.ReqBody
Servant.API.ResponseHeaders
Expand Down
9 changes: 9 additions & 0 deletions servant/src/Servant/API.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ module Servant.API (
-- * Actual endpoints, distinguished by HTTP method
module Servant.API.Verbs,

-- * Streaming endpoints, distinguished by HTTP method
module Servant.API.Stream,

-- * Authentication
module Servant.API.BasicAuth,

Expand Down Expand Up @@ -80,6 +83,12 @@ import Servant.API.IsSecure (IsSecure (..))
import Servant.API.QueryParam (QueryFlag, QueryParam,
QueryParams)
import Servant.API.Raw (Raw)
import Servant.API.Stream (Stream, StreamGenerator (..),
ToStreamGenerator (..), BuildFromStream (..),
ByteStringParser (..),
FramingRender (..), BoundaryStrategy (..),
FramingUnrender (..),
NewlineFraming)
import Servant.API.RemoteHost (RemoteHost)
import Servant.API.ReqBody (ReqBody)
import Servant.API.ResponseHeaders (AddHeader, addHeader, noHeader,
Expand Down
Loading