Skip to content

Commit

Permalink
Merge pull request #836 from gbaz/gb-streaming
Browse files Browse the repository at this point in the history
Stream endpoint support for servant
  • Loading branch information
phadej committed Dec 3, 2017
2 parents 1398642 + db13077 commit cbd3862
Show file tree
Hide file tree
Showing 14 changed files with 410 additions and 8 deletions.
1 change: 1 addition & 0 deletions servant-client-core/src/Servant/Client/Core.hs
Expand Up @@ -43,6 +43,7 @@ module Servant.Client.Core
, Response(..)
, RunClient(..)
, module Servant.Client.Core.Internal.BaseUrl
, StreamingResponse(..)

-- * Writing HasClient instances
-- | These functions need not be re-exported by backend libraries.
Expand Down
55 changes: 55 additions & 0 deletions servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs
Expand Up @@ -17,8 +17,11 @@ module Servant.Client.Core.Internal.HasClient where
import Prelude ()
import Prelude.Compat

import Control.Concurrent (newMVar, modifyMVar)
import Data.Foldable (toList)
import qualified Data.ByteString.Lazy as BL
import Data.List (foldl')
import Data.Monoid ((<>))
import Data.Proxy (Proxy (Proxy))
import Data.Sequence (fromList)
import Data.String (fromString)
Expand All @@ -29,8 +32,11 @@ import Servant.API ((:<|>) ((:<|>)), (:>),
AuthProtect, BasicAuth,
BasicAuthData,
BuildHeadersTo (..),
BuildFromStream (..),
ByteStringParser (..),
Capture, CaptureAll,
Description, EmptyAPI,
FramingUnrender (..),
Header, Headers (..),
HttpVersion, IsSecure,
MimeRender (mimeRender),
Expand All @@ -40,6 +46,8 @@ import Servant.API ((:<|>) ((:<|>)), (:>),
QueryParams, Raw,
ReflectMethod (..),
RemoteHost, ReqBody,
ResultStream(..),
Stream,
Summary, ToHttpApiData,
Vault, Verb,
WithNamedContext,
Expand Down Expand Up @@ -244,6 +252,53 @@ instance OVERLAPPING_
, getHeadersHList = buildHeadersTo . toList $ responseHeaders response
}

instance OVERLAPPABLE_
( RunClient m, MimeUnrender ct a, ReflectMethod method,
FramingUnrender framing a, BuildFromStream a (f a)
) => HasClient m (Stream method framing ct (f a)) where

type Client m (Stream method framing ct (f a)) = m (f a)

clientWithRoute _pm Proxy req = do
sresp <- streamingRequest req
{ requestAccept = fromList [contentType (Proxy :: Proxy ct)]
, requestMethod = reflectMethod (Proxy :: Proxy method)
}
return . buildFromStream $ ResultStream $ \k ->
runStreamingResponse sresp $ \(_status,_headers,_httpversion,reader) -> do
let unrender = unrenderFrames (Proxy :: Proxy framing) (Proxy :: Proxy a)
loop bs = do
res <- BL.fromStrict <$> reader
if BL.null res
then return $ parseEOF unrender res
else let sofar = (bs <> res)
in case parseIncremental unrender sofar of
Just x -> return x
Nothing -> loop sofar
(frameParser, remainder) <- loop BL.empty
state <- newMVar remainder
let frameLoop bs = do
res <- BL.fromStrict <$> reader
let addIsEmptyInfo (a, r) = (r, (a, BL.null r && BL.null res))
if BL.null res
then if BL.null bs
then return ("", (Right "", True))
else return . addIsEmptyInfo $ parseEOF frameParser bs
else let sofar = (bs <> res)
in case parseIncremental frameParser sofar of
Just x -> return $ addIsEmptyInfo x
Nothing -> frameLoop sofar

go = processResult <$> modifyMVar state frameLoop
processResult (Right bs,isDone) =
if BL.null bs && isDone
then Nothing
else Just $ case mimeUnrender (Proxy :: Proxy ct) bs :: Either String a of
Left err -> Left err
Right x -> Right x
processResult (Left err, _) = Just (Left err)
k go


-- | If you use a 'Header' in one of your endpoints in your API,
-- the corresponding querying function will automatically take
Expand Down
Expand Up @@ -4,6 +4,7 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
Expand All @@ -15,6 +16,7 @@ import Prelude.Compat

import Control.Monad.Catch (Exception)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Semigroup ((<>))
import qualified Data.Sequence as Seq
Expand Down Expand Up @@ -70,6 +72,8 @@ data Response = Response
, responseHttpVersion :: HttpVersion
} deriving (Eq, Show, Generic, Typeable)

data StreamingResponse = StreamingResponse { runStreamingResponse :: forall a. ((Status, Seq.Seq Header, HttpVersion, IO BS.ByteString) -> IO a) -> IO a }

-- A GET request to the top-level path
defaultRequest :: Request
defaultRequest = Request
Expand Down
Expand Up @@ -19,11 +19,13 @@ import Servant.API (MimeUnrender,
contentTypes,
mimeUnrender)
import Servant.Client.Core.Internal.Request (Request, Response (..),
StreamingResponse (..),
ServantError (..))

class (Monad m) => RunClient m where
-- | How to make a request.
runRequest :: Request -> m Response
streamingRequest :: Request -> m StreamingResponse
throwServantError :: ServantError -> m a
catchServantError :: m a -> (ServantError -> m a) -> m a

Expand Down
1 change: 1 addition & 0 deletions servant-client/servant-client.cabal
Expand Up @@ -72,6 +72,7 @@ test-suite spec
hspec-discover:hspec-discover
other-modules:
Servant.ClientSpec
Servant.StreamSpec
build-depends:
base == 4.*
, aeson
Expand Down
22 changes: 19 additions & 3 deletions servant-client/src/Servant/Client/Internal/HttpClient.hs
Expand Up @@ -89,6 +89,7 @@ instance Alt ClientM where

instance RunClient ClientM where
runRequest = performRequest
streamingRequest = performStreamingRequest
throwServantError = throwError
catchServantError = catchError

Expand All @@ -111,13 +112,28 @@ performRequest req = do
Right response -> do
let status = Client.responseStatus response
status_code = statusCode status
ourResponse = clientResponseToReponse response
ourResponse = clientResponseToResponse response
unless (status_code >= 200 && status_code < 300) $
throwError $ FailureResponse ourResponse
return ourResponse

clientResponseToReponse :: Client.Response BSL.ByteString -> Response
clientResponseToReponse r = Response
performStreamingRequest :: Request -> ClientM StreamingResponse
performStreamingRequest req = do
m <- asks manager
burl <- asks baseUrl
let request = requestToClientRequest burl req
return $ StreamingResponse $
\k -> Client.withResponse request m $
\r -> do
let status = Client.responseStatus r
status_code = statusCode status
unless (status_code >= 200 && status_code < 300) $ do
b <- BSL.fromChunks <$> Client.brConsume (Client.responseBody r)
throw $ FailureResponse $ Response status b (fromList $ Client.responseHeaders r) (Client.responseVersion r)
k (status, fromList $ Client.responseHeaders r, Client.responseVersion r, Client.responseBody r)

clientResponseToResponse :: Client.Response BSL.ByteString -> Response
clientResponseToResponse r = Response
{ responseStatusCode = Client.responseStatus r
, responseBody = Client.responseBody r
, responseHeaders = fromList $ Client.responseHeaders r
Expand Down
2 changes: 1 addition & 1 deletion servant-client/test/Servant/ClientSpec.hs
Expand Up @@ -24,7 +24,7 @@
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}

#include "overlapping-compat.h"
module Servant.ClientSpec (spec) where
module Servant.ClientSpec (spec, Person(..), startWaiApp, endWaiApp) where

import Prelude ()
import Prelude.Compat
Expand Down
113 changes: 113 additions & 0 deletions servant-client/test/Servant/StreamSpec.hs
@@ -0,0 +1,113 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
#if __GLASGOW_HASKELL__ >= 800
{-# OPTIONS_GHC -freduction-depth=100 #-}
#else
{-# OPTIONS_GHC -fcontext-stack=100 #-}
#endif
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}

#include "overlapping-compat.h"
module Servant.StreamSpec (spec) where

import Prelude ()
import Prelude.Compat
import Data.Proxy
import qualified Network.HTTP.Client as C
import System.IO.Unsafe (unsafePerformIO)
import Test.Hspec

import Servant.API ((:<|>) ((:<|>)),
(:>),
EmptyAPI, JSON,
StreamGet,
NewlineFraming,
NetstringFraming,
ResultStream(..),
StreamGenerator(..))
import Servant.Client
import Servant.Server
import qualified Servant.ClientSpec as CS
import Servant.ClientSpec (Person(..))


spec :: Spec
spec = describe "Servant.Stream" $ do
streamSpec

type StreamApi f =
"streamGetNewline" :> StreamGet NewlineFraming JSON (f Person)
:<|> "streamGetNetstring" :> StreamGet NetstringFraming JSON (f Person)
:<|> EmptyAPI


capi :: Proxy (StreamApi ResultStream)
capi = Proxy

sapi :: Proxy (StreamApi StreamGenerator)
sapi = Proxy


getGetNL :<|> getGetNS :<|> EmptyClient = client capi


getGetNL :: ClientM (ResultStream Person)
getGetNS :: ClientM (ResultStream Person)

alice :: Person
alice = Person "Alice" 42

bob :: Person
bob = Person "Bob" 25

server :: Application
server = serve sapi (
(return (StreamGenerator (\f r -> f alice >> r bob >> r alice))
:: Handler (StreamGenerator Person))
:<|>
(return (StreamGenerator (\f r -> f alice >> r bob >> r alice))
:: Handler (StreamGenerator Person))
:<|>
emptyServer)


{-# NOINLINE manager' #-}
manager' :: C.Manager
manager' = unsafePerformIO $ C.newManager C.defaultManagerSettings

runClient :: ClientM a -> BaseUrl -> IO (Either ServantError a)
runClient x baseUrl' = runClientM x (ClientEnv manager' baseUrl')

runResultStream :: ResultStream a -> IO (Maybe (Either String a), Maybe (Either String a), Maybe (Either String a), Maybe (Either String a))
runResultStream (ResultStream k) = k $ \act -> (,,,) <$> act <*> act <*> act <*> act

streamSpec :: Spec
streamSpec = beforeAll (CS.startWaiApp server) $ afterAll CS.endWaiApp $ do

it "Servant.API.StreamGet.Newline" $ \(_, baseUrl) -> do
Right res <- runClient getGetNL baseUrl
let jra = Just (Right alice)
jrb = Just (Right bob)
runResultStream res `shouldReturn` (jra, jrb, jra, Nothing)

it "Servant.API.StreamGet.Netstring" $ \(_, baseUrl) -> do
Right res <- runClient getGetNS baseUrl
let jra = Just (Right alice)
jrb = Just (Right bob)
runResultStream res `shouldReturn` (jra, jrb, jra, Nothing)
1 change: 1 addition & 0 deletions servant-server/servant-server.cabal
Expand Up @@ -64,6 +64,7 @@ library
, containers >= 0.5 && < 0.6
, exceptions >= 0.8 && < 0.9
, http-api-data >= 0.3 && < 0.4
, http-media >= 0.4 && < 0.8
, http-types >= 0.8 && < 0.11
, network-uri >= 2.6 && < 2.7
, monad-control >= 1.0.0.4 && < 1.1
Expand Down

0 comments on commit cbd3862

Please sign in to comment.