diff --git a/servant-client-core/src/Servant/Client/Core.hs b/servant-client-core/src/Servant/Client/Core.hs index a926c169e..73160abfa 100644 --- a/servant-client-core/src/Servant/Client/Core.hs +++ b/servant-client-core/src/Servant/Client/Core.hs @@ -43,6 +43,7 @@ module Servant.Client.Core , Response(..) , RunClient(..) , module Servant.Client.Core.Internal.BaseUrl + , StreamingResponse(..) -- * Writing HasClient instances -- | These functions need not be re-exported by backend libraries. diff --git a/servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs b/servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs index 42d61d589..ef5bdce4b 100644 --- a/servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs +++ b/servant-client-core/src/Servant/Client/Core/Internal/HasClient.hs @@ -17,8 +17,11 @@ module Servant.Client.Core.Internal.HasClient where import Prelude () import Prelude.Compat +import Control.Concurrent (newMVar, modifyMVar) import Data.Foldable (toList) +import qualified Data.ByteString.Lazy as BL import Data.List (foldl') +import Data.Monoid ((<>)) import Data.Proxy (Proxy (Proxy)) import Data.Sequence (fromList) import Data.String (fromString) @@ -29,8 +32,11 @@ import Servant.API ((:<|>) ((:<|>)), (:>), AuthProtect, BasicAuth, BasicAuthData, BuildHeadersTo (..), + BuildFromStream (..), + ByteStringParser (..), Capture, CaptureAll, Description, EmptyAPI, + FramingUnrender (..), Header, Headers (..), HttpVersion, IsSecure, MimeRender (mimeRender), @@ -40,6 +46,8 @@ import Servant.API ((:<|>) ((:<|>)), (:>), QueryParams, Raw, ReflectMethod (..), RemoteHost, ReqBody, + ResultStream(..), + Stream, Summary, ToHttpApiData, Vault, Verb, WithNamedContext, @@ -244,6 +252,53 @@ instance OVERLAPPING_ , getHeadersHList = buildHeadersTo . toList $ responseHeaders response } +instance OVERLAPPABLE_ + ( RunClient m, MimeUnrender ct a, ReflectMethod method, + FramingUnrender framing a, BuildFromStream a (f a) + ) => HasClient m (Stream method framing ct (f a)) where + + type Client m (Stream method framing ct (f a)) = m (f a) + + clientWithRoute _pm Proxy req = do + sresp <- streamingRequest req + { requestAccept = fromList [contentType (Proxy :: Proxy ct)] + , requestMethod = reflectMethod (Proxy :: Proxy method) + } + return . buildFromStream $ ResultStream $ \k -> + runStreamingResponse sresp $ \(_status,_headers,_httpversion,reader) -> do + let unrender = unrenderFrames (Proxy :: Proxy framing) (Proxy :: Proxy a) + loop bs = do + res <- BL.fromStrict <$> reader + if BL.null res + then return $ parseEOF unrender res + else let sofar = (bs <> res) + in case parseIncremental unrender sofar of + Just x -> return x + Nothing -> loop sofar + (frameParser, remainder) <- loop BL.empty + state <- newMVar remainder + let frameLoop bs = do + res <- BL.fromStrict <$> reader + let addIsEmptyInfo (a, r) = (r, (a, BL.null r && BL.null res)) + if BL.null res + then if BL.null bs + then return ("", (Right "", True)) + else return . addIsEmptyInfo $ parseEOF frameParser bs + else let sofar = (bs <> res) + in case parseIncremental frameParser sofar of + Just x -> return $ addIsEmptyInfo x + Nothing -> frameLoop sofar + + go = processResult <$> modifyMVar state frameLoop + processResult (Right bs,isDone) = + if BL.null bs && isDone + then Nothing + else Just $ case mimeUnrender (Proxy :: Proxy ct) bs :: Either String a of + Left err -> Left err + Right x -> Right x + processResult (Left err, _) = Just (Left err) + k go + -- | If you use a 'Header' in one of your endpoints in your API, -- the corresponding querying function will automatically take diff --git a/servant-client-core/src/Servant/Client/Core/Internal/Request.hs b/servant-client-core/src/Servant/Client/Core/Internal/Request.hs index 458219b93..b120c7f77 100644 --- a/servant-client-core/src/Servant/Client/Core/Internal/Request.hs +++ b/servant-client-core/src/Servant/Client/Core/Internal/Request.hs @@ -4,6 +4,7 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} @@ -15,6 +16,7 @@ import Prelude.Compat import Control.Monad.Catch (Exception) import qualified Data.ByteString.Builder as Builder +import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import Data.Semigroup ((<>)) import qualified Data.Sequence as Seq @@ -70,6 +72,8 @@ data Response = Response , responseHttpVersion :: HttpVersion } deriving (Eq, Show, Generic, Typeable) +data StreamingResponse = StreamingResponse { runStreamingResponse :: forall a. ((Status, Seq.Seq Header, HttpVersion, IO BS.ByteString) -> IO a) -> IO a } + -- A GET request to the top-level path defaultRequest :: Request defaultRequest = Request diff --git a/servant-client-core/src/Servant/Client/Core/Internal/RunClient.hs b/servant-client-core/src/Servant/Client/Core/Internal/RunClient.hs index 564cbb39d..88b39a041 100644 --- a/servant-client-core/src/Servant/Client/Core/Internal/RunClient.hs +++ b/servant-client-core/src/Servant/Client/Core/Internal/RunClient.hs @@ -19,11 +19,13 @@ import Servant.API (MimeUnrender, contentTypes, mimeUnrender) import Servant.Client.Core.Internal.Request (Request, Response (..), + StreamingResponse (..), ServantError (..)) class (Monad m) => RunClient m where -- | How to make a request. runRequest :: Request -> m Response + streamingRequest :: Request -> m StreamingResponse throwServantError :: ServantError -> m a catchServantError :: m a -> (ServantError -> m a) -> m a diff --git a/servant-client/servant-client.cabal b/servant-client/servant-client.cabal index fd83a7809..695c28664 100644 --- a/servant-client/servant-client.cabal +++ b/servant-client/servant-client.cabal @@ -72,6 +72,7 @@ test-suite spec hspec-discover:hspec-discover other-modules: Servant.ClientSpec + Servant.StreamSpec build-depends: base == 4.* , aeson diff --git a/servant-client/src/Servant/Client/Internal/HttpClient.hs b/servant-client/src/Servant/Client/Internal/HttpClient.hs index 9a479c42b..b52d83ca7 100644 --- a/servant-client/src/Servant/Client/Internal/HttpClient.hs +++ b/servant-client/src/Servant/Client/Internal/HttpClient.hs @@ -89,6 +89,7 @@ instance Alt ClientM where instance RunClient ClientM where runRequest = performRequest + streamingRequest = performStreamingRequest throwServantError = throwError catchServantError = catchError @@ -111,13 +112,28 @@ performRequest req = do Right response -> do let status = Client.responseStatus response status_code = statusCode status - ourResponse = clientResponseToReponse response + ourResponse = clientResponseToResponse response unless (status_code >= 200 && status_code < 300) $ throwError $ FailureResponse ourResponse return ourResponse -clientResponseToReponse :: Client.Response BSL.ByteString -> Response -clientResponseToReponse r = Response +performStreamingRequest :: Request -> ClientM StreamingResponse +performStreamingRequest req = do + m <- asks manager + burl <- asks baseUrl + let request = requestToClientRequest burl req + return $ StreamingResponse $ + \k -> Client.withResponse request m $ + \r -> do + let status = Client.responseStatus r + status_code = statusCode status + unless (status_code >= 200 && status_code < 300) $ do + b <- BSL.fromChunks <$> Client.brConsume (Client.responseBody r) + throw $ FailureResponse $ Response status b (fromList $ Client.responseHeaders r) (Client.responseVersion r) + k (status, fromList $ Client.responseHeaders r, Client.responseVersion r, Client.responseBody r) + +clientResponseToResponse :: Client.Response BSL.ByteString -> Response +clientResponseToResponse r = Response { responseStatusCode = Client.responseStatus r , responseBody = Client.responseBody r , responseHeaders = fromList $ Client.responseHeaders r diff --git a/servant-client/test/Servant/ClientSpec.hs b/servant-client/test/Servant/ClientSpec.hs index 7cced886c..342593e25 100644 --- a/servant-client/test/Servant/ClientSpec.hs +++ b/servant-client/test/Servant/ClientSpec.hs @@ -24,7 +24,7 @@ {-# OPTIONS_GHC -fno-warn-name-shadowing #-} #include "overlapping-compat.h" -module Servant.ClientSpec (spec) where +module Servant.ClientSpec (spec, Person(..), startWaiApp, endWaiApp) where import Prelude () import Prelude.Compat diff --git a/servant-client/test/Servant/StreamSpec.hs b/servant-client/test/Servant/StreamSpec.hs new file mode 100644 index 000000000..df9003abf --- /dev/null +++ b/servant-client/test/Servant/StreamSpec.hs @@ -0,0 +1,113 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +#if __GLASGOW_HASKELL__ >= 800 +{-# OPTIONS_GHC -freduction-depth=100 #-} +#else +{-# OPTIONS_GHC -fcontext-stack=100 #-} +#endif +{-# OPTIONS_GHC -fno-warn-orphans #-} +{-# OPTIONS_GHC -fno-warn-name-shadowing #-} + +#include "overlapping-compat.h" +module Servant.StreamSpec (spec) where + +import Prelude () +import Prelude.Compat +import Data.Proxy +import qualified Network.HTTP.Client as C +import System.IO.Unsafe (unsafePerformIO) +import Test.Hspec + +import Servant.API ((:<|>) ((:<|>)), + (:>), + EmptyAPI, JSON, + StreamGet, + NewlineFraming, + NetstringFraming, + ResultStream(..), + StreamGenerator(..)) +import Servant.Client +import Servant.Server +import qualified Servant.ClientSpec as CS +import Servant.ClientSpec (Person(..)) + + +spec :: Spec +spec = describe "Servant.Stream" $ do + streamSpec + +type StreamApi f = + "streamGetNewline" :> StreamGet NewlineFraming JSON (f Person) + :<|> "streamGetNetstring" :> StreamGet NetstringFraming JSON (f Person) + :<|> EmptyAPI + + +capi :: Proxy (StreamApi ResultStream) +capi = Proxy + +sapi :: Proxy (StreamApi StreamGenerator) +sapi = Proxy + + +getGetNL :<|> getGetNS :<|> EmptyClient = client capi + + +getGetNL :: ClientM (ResultStream Person) +getGetNS :: ClientM (ResultStream Person) + +alice :: Person +alice = Person "Alice" 42 + +bob :: Person +bob = Person "Bob" 25 + +server :: Application +server = serve sapi ( + (return (StreamGenerator (\f r -> f alice >> r bob >> r alice)) + :: Handler (StreamGenerator Person)) + :<|> + (return (StreamGenerator (\f r -> f alice >> r bob >> r alice)) + :: Handler (StreamGenerator Person)) + :<|> + emptyServer) + + +{-# NOINLINE manager' #-} +manager' :: C.Manager +manager' = unsafePerformIO $ C.newManager C.defaultManagerSettings + +runClient :: ClientM a -> BaseUrl -> IO (Either ServantError a) +runClient x baseUrl' = runClientM x (ClientEnv manager' baseUrl') + +runResultStream :: ResultStream a -> IO (Maybe (Either String a), Maybe (Either String a), Maybe (Either String a), Maybe (Either String a)) +runResultStream (ResultStream k) = k $ \act -> (,,,) <$> act <*> act <*> act <*> act + +streamSpec :: Spec +streamSpec = beforeAll (CS.startWaiApp server) $ afterAll CS.endWaiApp $ do + + it "Servant.API.StreamGet.Newline" $ \(_, baseUrl) -> do + Right res <- runClient getGetNL baseUrl + let jra = Just (Right alice) + jrb = Just (Right bob) + runResultStream res `shouldReturn` (jra, jrb, jra, Nothing) + + it "Servant.API.StreamGet.Netstring" $ \(_, baseUrl) -> do + Right res <- runClient getGetNS baseUrl + let jra = Just (Right alice) + jrb = Just (Right bob) + runResultStream res `shouldReturn` (jra, jrb, jra, Nothing) diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index 94d1352fd..27590b85d 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -64,6 +64,7 @@ library , containers >= 0.5 && < 0.6 , exceptions >= 0.8 && < 0.9 , http-api-data >= 0.3 && < 0.4 + , http-media >= 0.4 && < 0.8 , http-types >= 0.8 && < 0.11 , network-uri >= 2.6 && < 2.7 , monad-control >= 1.0.0.4 && < 1.1 diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index d63c4c544..cf454fe3b 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -26,12 +26,15 @@ module Servant.Server.Internal , module Servant.Server.Internal.ServantErr ) where +import Control.Monad (when) import Control.Monad.Trans (liftIO) import Control.Monad.Trans.Resource (runResourceT) import qualified Data.ByteString as B +import qualified Data.ByteString.Builder as BB import qualified Data.ByteString.Char8 as BC8 import qualified Data.ByteString.Lazy as BL -import Data.Maybe (fromMaybe, mapMaybe) +import Data.Maybe (fromMaybe, mapMaybe, + isNothing, maybeToList) import Data.Either (partitionEithers) import Data.String (fromString) import Data.String.Conversions (cs, (<>)) @@ -41,13 +44,15 @@ import Data.Typeable import GHC.TypeLits (KnownNat, KnownSymbol, natVal, symbolVal) import Network.HTTP.Types hiding (Header, ResponseHeaders) +import qualified Network.HTTP.Media as NHM import Network.Socket (SockAddr) import Network.Wai (Application, Request, httpVersion, isSecure, lazyRequestBody, rawQueryString, remoteHost, requestHeaders, requestMethod, - responseLBS, vault) + responseLBS, responseStream, + vault) import Prelude () import Prelude.Compat import Web.HttpApiData (FromHttpApiData, parseHeader, @@ -61,11 +66,16 @@ import Servant.API ((:<|>) (..), (:>), BasicAuth, Capt QueryParam, QueryParams, Raw, RemoteHost, ReqBody, Vault, WithNamedContext, - Description, Summary) + Description, Summary, + Accept(..), + FramingRender(..), Stream, + StreamGenerator(..), ToStreamGenerator(..), + BoundaryStrategy(..)) import Servant.API.ContentTypes (AcceptHeader (..), AllCTRender (..), AllCTUnrender (..), AllMime, + MimeRender(..), canHandleAcceptH) import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders, getResponse) @@ -257,6 +267,70 @@ instance OVERLAPPING_ where method = reflectMethod (Proxy :: Proxy method) status = toEnum . fromInteger $ natVal (Proxy :: Proxy status) + +instance OVERLAPPABLE_ + ( MimeRender ctype a, ReflectMethod method, + FramingRender framing ctype, ToStreamGenerator f a + ) => HasServer (Stream method framing ctype (f a)) context where + + type ServerT (Stream method framing ctype (f a)) m = m (f a) + hoistServerWithContext _ _ nt s = nt s + + route Proxy _ = streamRouter ([],) method (Proxy :: Proxy framing) (Proxy :: Proxy ctype) + where method = reflectMethod (Proxy :: Proxy method) + +instance OVERLAPPING_ + ( MimeRender ctype a, ReflectMethod method, + FramingRender framing ctype, ToStreamGenerator f a, + GetHeaders (Headers h (f a)) + ) => HasServer (Stream method framing ctype (Headers h (f a))) context where + + type ServerT (Stream method framing ctype (Headers h (f a))) m = m (Headers h (f a)) + hoistServerWithContext _ _ nt s = nt s + + route Proxy _ = streamRouter (\x -> (getHeaders x, getResponse x)) method (Proxy :: Proxy framing) (Proxy :: Proxy ctype) + where method = reflectMethod (Proxy :: Proxy method) + + +streamRouter :: (MimeRender ctype a, FramingRender framing ctype, ToStreamGenerator f a) => + (b -> ([(HeaderName, B.ByteString)], f a)) + -> Method + -> Proxy framing + -> Proxy ctype + -> Delayed env (Handler b) + -> Router env +streamRouter splitHeaders method framingproxy ctypeproxy action = leafRouter $ \env request respond -> + let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request + cmediatype = NHM.matchAccept [contentType ctypeproxy] accH + accCheck = when (isNothing cmediatype) $ delayedFail err406 + contentHeader = (hContentType, NHM.renderHeader . maybeToList $ cmediatype) + in runAction (action `addMethodCheck` methodCheck method request + `addAcceptCheck` accCheck + ) env request respond $ \ output -> + let (headers, fa) = splitHeaders output + k = getStreamGenerator . toStreamGenerator $ fa in + Route $ responseStream status200 (contentHeader : headers) $ \write flush -> do + write . BB.lazyByteString $ header framingproxy ctypeproxy + case boundary framingproxy ctypeproxy of + BoundaryStrategyBracket f -> + let go x = let bs = mimeRender ctypeproxy $ x + (before, after) = f bs + in write ( BB.lazyByteString before + <> BB.lazyByteString bs + <> BB.lazyByteString after) >> flush + in k go go + BoundaryStrategyIntersperse sep -> k + (\x -> do + write . BB.lazyByteString . mimeRender ctypeproxy $ x + flush) + (\x -> do + write . (BB.lazyByteString sep <>) . BB.lazyByteString . mimeRender ctypeproxy $ x + flush) + BoundaryStrategyGeneral f -> + let go = (>> flush) . write . BB.lazyByteString . f . mimeRender ctypeproxy + in k go go + write . BB.lazyByteString $ trailer framingproxy ctypeproxy + -- | If you use 'Header' in one of the endpoints for your API, -- this automatically requires your server-side handler to be a function -- that takes an argument of the type specified by 'Header'. @@ -299,7 +373,7 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context) <> fromString headerName <> " failed: " <> e } - Right header -> return $ Just header + Right hdr -> return $ Just hdr -- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API, -- this automatically requires your server-side handler to be a function diff --git a/servant/servant.cabal b/servant/servant.cabal index df0f1cca3..bc1aa0133 100644 --- a/servant/servant.cabal +++ b/servant/servant.cabal @@ -51,6 +51,7 @@ library Servant.API.IsSecure Servant.API.QueryParam Servant.API.Raw + Servant.API.Stream Servant.API.RemoteHost Servant.API.ReqBody Servant.API.ResponseHeaders diff --git a/servant/src/Servant/API.hs b/servant/src/Servant/API.hs index 88e4d9348..84f3d861d 100644 --- a/servant/src/Servant/API.hs +++ b/servant/src/Servant/API.hs @@ -31,6 +31,9 @@ module Servant.API ( -- * Actual endpoints, distinguished by HTTP method module Servant.API.Verbs, + -- * Streaming endpoints, distinguished by HTTP method + module Servant.API.Stream, + -- * Authentication module Servant.API.BasicAuth, @@ -80,6 +83,15 @@ import Servant.API.IsSecure (IsSecure (..)) import Servant.API.QueryParam (QueryFlag, QueryParam, QueryParams) import Servant.API.Raw (Raw) +import Servant.API.Stream (Stream, StreamGet, StreamPost, + StreamGenerator (..), + ToStreamGenerator (..), + ResultStream(..), BuildFromStream (..), + ByteStringParser (..), + FramingRender (..), BoundaryStrategy (..), + FramingUnrender (..), + NewlineFraming, + NetstringFraming) import Servant.API.RemoteHost (RemoteHost) import Servant.API.ReqBody (ReqBody) import Servant.API.ResponseHeaders (AddHeader, addHeader, noHeader, diff --git a/servant/src/Servant/API/Stream.hs b/servant/src/Servant/API/Stream.hs new file mode 100644 index 000000000..073e0ce1a --- /dev/null +++ b/servant/src/Servant/API/Stream.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TupleSections #-} +{-# OPTIONS_HADDOCK not-home #-} + +module Servant.API.Stream where + +import Data.ByteString.Lazy (ByteString, empty) +import qualified Data.ByteString.Lazy.Char8 as LB +import Data.Monoid ((<>)) +import Data.Proxy (Proxy) +import Data.Typeable (Typeable) +import GHC.Generics (Generic) +import Text.Read (readMaybe) +import Control.Arrow (first) +import Network.HTTP.Types.Method (StdMethod (..)) + +-- | A Stream endpoint for a given method emits a stream of encoded values at a given Content-Type, delimited by a framing strategy. Steam endpoints always return response code 200 on success. Type synonyms are provided for standard methods. +data Stream (method :: k1) (framing :: *) (contentType :: *) a + deriving (Typeable, Generic) + +type StreamGet = Stream 'GET +type StreamPost = Stream 'POST + +-- | Stream endpoints may be implemented as producing a @StreamGenerator@ -- a function that itself takes two emit functions -- the first to be used on the first value the stream emits, and the second to be used on all subsequent values (to allow interspersed framing strategies such as comma separation). +newtype StreamGenerator a = StreamGenerator {getStreamGenerator :: (a -> IO ()) -> (a -> IO ()) -> IO ()} + +-- | ToStreamGenerator is intended to be implemented for types such as Conduit, Pipe, etc. By implementing this class, all such streaming abstractions can be used directly as endpoints. +class ToStreamGenerator f a where + toStreamGenerator :: f a -> StreamGenerator a + +instance ToStreamGenerator StreamGenerator a + where toStreamGenerator x = x + +-- | Clients reading from streaming endpoints can be implemented as producing a @ResultStream@ that captures the setup, takedown, and incremental logic for a read, being an IO continuation that takes a producer of Just either values or errors that terminates with a Nothing. +data ResultStream a = ResultStream ((forall b. (IO (Maybe (Either String a)) -> IO b) -> IO b)) + +-- | BuildFromStream is intended to be implemented for types such as Conduit, Pipe, etc. By implementing this class, all such streaming abstractions can be used directly on the client side for talking to streaming endpoints. +class BuildFromStream a b where + buildFromStream :: ResultStream a -> b + +instance BuildFromStream a (ResultStream a) + where buildFromStream x = x + +-- | The FramingRender class provides the logic for emitting a framing strategy. The strategy emits a header, followed by boundary-delimited data, and finally a termination character. For many strategies, some of these will just be empty bytestrings. +class FramingRender strategy a where + header :: Proxy strategy -> Proxy a -> ByteString + boundary :: Proxy strategy -> Proxy a -> BoundaryStrategy + trailer :: Proxy strategy -> Proxy a -> ByteString + +-- | The bracketing strategy generates things to precede and follow the content, as with netstrings. +-- The intersperse strategy inserts seperators between things, as with newline framing. +-- Finally, the general strategy performs an arbitrary rewrite on the content, to allow escaping rules and such. +data BoundaryStrategy = BoundaryStrategyBracket (ByteString -> (ByteString,ByteString)) + | BoundaryStrategyIntersperse ByteString + | BoundaryStrategyGeneral (ByteString -> ByteString) + +-- | A type of parser that can never fail, and has different parsing strategies (incremental, or EOF) depending if more input can be sent. The incremental parser should return `Nothing` if it would like to be sent a longer ByteString. If it returns a value, it also returns the remainder following that value. +data ByteStringParser a = ByteStringParser { + parseIncremental :: ByteString -> Maybe (a, ByteString), + parseEOF :: ByteString -> (a, ByteString) +} + +-- | The FramingUnrender class provides the logic for parsing a framing strategy. The outer @ByteStringParser@ strips the header from a stream of bytes, and yields a parser that can handle the remainder, stepwise. Each frame may be a ByteString, or a String indicating the error state for that frame. Such states are per-frame, so that protocols that can resume after errors are able to do so. Eventually this returns an empty ByteString to indicate termination. +class FramingUnrender strategy a where + unrenderFrames :: Proxy strategy -> Proxy a -> ByteStringParser (ByteStringParser (Either String ByteString)) + + +-- | A simple framing strategy that has no header or termination, and inserts a newline character between each frame. +-- This assumes that it is used with a Content-Type that encodes without newlines (e.g. JSON). +data NewlineFraming + +instance FramingRender NewlineFraming a where + header _ _ = empty + boundary _ _ = BoundaryStrategyIntersperse "\n" + trailer _ _ = empty + +instance FramingUnrender NewlineFraming a where + unrenderFrames _ _ = ByteStringParser (Just . (go,)) (go,) + where go = ByteStringParser + (\x -> case LB.break (== '\n') x of + (h,r) -> if not (LB.null r) then Just (Right h, LB.drop 1 r) else Nothing + ) + (\x -> case LB.break (== '\n') x of + (h,r) -> (Right h, LB.drop 1 r) + ) +-- | The netstring framing strategy as defined by djb: +data NetstringFraming + +instance FramingRender NetstringFraming a where + header _ _ = empty + boundary _ _ = BoundaryStrategyBracket $ \b -> ((<> ":") . LB.pack . show . LB.length $ b, ",") + trailer _ _ = empty + + +instance FramingUnrender NetstringFraming a where + unrenderFrames _ _ = ByteStringParser (Just . (go,)) (go,) + where go = ByteStringParser + (\b -> let (i,r) = LB.break (==':') b + in case readMaybe (LB.unpack i) of + Just len -> if LB.length r > len + then Just . first Right . fmap (LB.drop 1) $ LB.splitAt len . LB.drop 1 $ r + else Nothing + Nothing -> Just (Left ("Bad netstring frame, couldn't parse value as integer value: " ++ LB.unpack i), LB.drop 1 . LB.dropWhile (/= ',') $ r)) + (\b -> let (i,r) = LB.break (==':') b + in case readMaybe (LB.unpack i) of + Just len -> if LB.length r > len + then first Right . fmap (LB.drop 1) $ LB.splitAt len . LB.drop 1 $ r + else (Right $ LB.take len r, LB.empty) + Nothing -> (Left ("Bad netstring frame, couldn't parse value as integer value: " ++ LB.unpack i), LB.drop 1 . LB.dropWhile (/= ',') $ r)) diff --git a/servant/src/Servant/Utils/Links.hs b/servant/src/Servant/Utils/Links.hs index d39e4a61e..25eb2ad1b 100644 --- a/servant/src/Servant/Utils/Links.hs +++ b/servant/src/Servant/Utils/Links.hs @@ -119,6 +119,7 @@ import Servant.API.RemoteHost ( RemoteHost ) import Servant.API.Verbs ( Verb ) import Servant.API.Sub ( type (:>) ) import Servant.API.Raw ( Raw ) +import Servant.API.Stream ( Stream ) import Servant.API.TypeLevel import Servant.API.Experimental.Auth ( AuthProtect ) @@ -337,6 +338,10 @@ instance HasLink Raw where type MkLink Raw = Link toLink _ = id +instance HasLink (Stream m fr ct a) where + type MkLink (Stream m fr ct a) = Link + toLink _ = id + -- AuthProtext instances instance HasLink sub => HasLink (AuthProtect tag :> sub) where type MkLink (AuthProtect tag :> sub) = MkLink sub