Skip to content

Commit

Permalink
WIP wai 3.0 support
Browse files Browse the repository at this point in the history
  • Loading branch information
snoyberg committed May 22, 2014
1 parent 0a1cbf5 commit a519123
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 30 deletions.
128 changes: 110 additions & 18 deletions Network/HTTP/ReverseProxy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ import Data.Default.Class (def)
import qualified Network.Wai as WAI
import qualified Network.HTTP.Client as HC
import Network.HTTP.Client (BodyReader, brRead)
#if MIN_VERSION_wai(3,0,0)
import Control.Exception (bracket)
#else
import Control.Exception (bracketOnError)
#endif
import Blaze.ByteString.Builder (fromByteString)
import Data.Word8 (isSpace, _colon, _cr)
import qualified Data.ByteString as S
Expand Down Expand Up @@ -69,6 +73,7 @@ import Data.Monoid (mappend, (<>), mconcat)
import Control.Exception.Lifted (try, SomeException, finally)
import Control.Applicative ((<$>), (<|>))
import Data.Set (Set)
import qualified Data.Conduit.List as CL

-- | Host\/port combination to which we want to proxy.
data ProxyDest = ProxyDest
Expand Down Expand Up @@ -151,7 +156,11 @@ rawProxyTo getDest appdata = do
-- | Sends a simple 502 bad gateway error message with the contents of the
-- exception.
defaultOnExc :: SomeException -> WAI.Application
#if MIN_VERSION_wai(3,0,0)
defaultOnExc exc _ sendResponse = sendResponse $ WAI.responseLBS
#else
defaultOnExc exc _ = return $ WAI.responseLBS
#endif
HT.status502
[("content-type", "text/plain")]
("Error connecting to gateway:\n\n" <> TLE.encodeUtf8 (TL.pack $ show exc))
Expand Down Expand Up @@ -241,6 +250,55 @@ instance Default WaiProxySettings where
(CI.mk <$> lookup "upgrade" (WAI.requestHeaders req)) == Just "websocket"
}

#if MIN_VERSION_wai(2, 1, 0)
renderHeaders :: WAI.Request -> HT.RequestHeaders -> Builder
renderHeaders req headers
= fromByteString (WAI.requestMethod req)
<> fromByteString " "
<> fromByteString (WAI.rawPathInfo req)
<> fromByteString (WAI.rawQueryString req)
<> (if WAI.httpVersion req == HT.http11
then fromByteString " HTTP/1.1"
else fromByteString " HTTP/1.0")
<> mconcat (map goHeader headers)
<> fromByteString "\r\n\r\n"
where
goHeader (x, y)
= fromByteString "\r\n"
<> fromByteString (CI.original x)
<> fromByteString ": "
<> fromByteString y
#endif

#if MIN_VERSION_wai(3, 0, 0)
tryWebSockets :: WaiProxySettings -> ByteString -> Int -> WAI.Request -> (WAI.Response -> IO b) -> IO b -> IO b
tryWebSockets wps host port req sendResponse fallback
| wpsUpgradeToRaw wps req =
sendResponse $ flip WAI.responseRaw backup $ \fromClientBody toClient ->
DCN.runTCPClient settings $ \server ->
let toServer = DCN.appSink server
fromServer = DCN.appSource server
fromClient = do
mapM_ yield $ L.toChunks $ toLazyByteString headers
let loop = do
bs <- liftIO fromClientBody
unless (S.null bs) $ do
yield bs
loop
loop
toClient' = awaitForever $ liftIO . toClient
headers = renderHeaders req $ fixReqHeaders wps req
in void $ concurrently
(fromClient $$ toServer)
(fromServer $$ toClient')
| otherwise = fallback
where
backup = WAI.responseLBS HT.status500 [("Content-Type", "text/plain")]
"http-reverse-proxy detected WebSockets request, but server does not support responseRaw"
settings = DCN.clientSettings port host

#else

tryWebSockets :: WaiProxySettings -> ByteString -> Int -> WAI.Request -> IO WAI.Response -> IO WAI.Response
#if MIN_VERSION_wai(2, 1, 0)
tryWebSockets wps host port req fallback
Expand All @@ -261,28 +319,12 @@ tryWebSockets wps host port req fallback
backup = WAI.responseLBS HT.status500 [("Content-Type", "text/plain")]
"http-reverse-proxy detected WebSockets request, but server does not support responseRaw"
settings = DCN.clientSettings port host

renderHeaders :: WAI.Request -> HT.RequestHeaders -> Builder
renderHeaders req headers
= fromByteString (WAI.requestMethod req)
<> fromByteString " "
<> fromByteString (WAI.rawPathInfo req)
<> fromByteString (WAI.rawQueryString req)
<> (if WAI.httpVersion req == HT.http11
then fromByteString " HTTP/1.1"
else fromByteString " HTTP/1.0")
<> mconcat (map goHeader headers)
<> fromByteString "\r\n\r\n"
where
goHeader (x, y)
= fromByteString "\r\n"
<> fromByteString (CI.original x)
<> fromByteString ": "
<> fromByteString y
#else
tryWebSockets _ _ _ _ = id
#endif

#endif

strippedHeaders :: Set HT.HeaderName
strippedHeaders = Set.fromList
["content-length", "transfer-encoding", "accept-encoding", "content-encoding"]
Expand All @@ -306,6 +348,55 @@ waiProxyToSettings :: (WAI.Request -> IO WaiProxyResponse)
-> WaiProxySettings
-> HC.Manager
-> WAI.Application
#if MIN_VERSION_wai(3,0,0)
waiProxyToSettings getDest wps manager req0 sendResponse = do
edest' <- getDest req0
let edest =
case edest' of
WPRResponse res -> Left res
WPRProxyDest pd -> Right (pd, req0)
WPRModifiedRequest req pd -> Right (pd, req)
case edest of
Left response -> sendResponse response
Right (ProxyDest host port, req) -> tryWebSockets wps host port req sendResponse $ do
let req' = def
{ HC.method = WAI.requestMethod req
, HC.host = host
, HC.port = port
, HC.path = WAI.rawPathInfo req
, HC.queryString = WAI.rawQueryString req
, HC.requestHeaders = fixReqHeaders wps req
, HC.requestBody = body
, HC.redirectCount = 0
, HC.checkStatus = \_ _ _ -> Nothing
, HC.responseTimeout = wpsTimeout wps
}
body =
case WAI.requestBodyLength req of
WAI.KnownLength i -> HC.RequestBodyStream
(fromIntegral i)
($ WAI.requestBody req)
WAI.ChunkedBody -> HC.RequestBodyStreamChunked ($ WAI.requestBody req)
bracket
(try $ HC.responseOpen req' manager)
(either (const $ return ()) HC.responseClose)
$ \ex -> do
case ex of
Left e -> wpsOnExc wps e req sendResponse
Right res -> do
let conduit =
case wpsProcessBody wps $ fmap (const ()) res of
Nothing -> awaitForever (\bs -> yield (Chunk $ fromByteString bs) >> yield Flush)
Just conduit' -> conduit'
src = bodyReaderSource $ HC.responseBody res
sendResponse $ WAI.responseStream
(HC.responseStatus res)
(filter (\(key, _) -> not $ key `Set.member` strippedHeaders) $ HC.responseHeaders res)
(\sendChunk flush -> src $= conduit $$ CL.mapM_ (\mb ->
case mb of
Flush -> flush
Chunk b -> sendChunk b))
#else
waiProxyToSettings getDest wps manager req0 = do
edest' <- getDest req0
let edest =
Expand Down Expand Up @@ -356,6 +447,7 @@ waiProxyToSettings getDest wps manager req0 = do
, filter (\(key, _) -> not $ key `Set.member` strippedHeaders) $ HC.responseHeaders res
, src $= conduit
)
#endif

-- | Get the HTTP headers for the first request on the stream, returning on
-- consumed bytes as leftovers. Has built-in limits on how many bytes it will
Expand Down
2 changes: 1 addition & 1 deletion http-reverse-proxy.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ library
, word8 >= 0.0
, blaze-builder >= 0.3
, http-client >= 0.1
, wai >= 2.0
, wai >= 3.0
, network
, conduit >= 0.5
, conduit-extra
Expand Down
26 changes: 15 additions & 11 deletions test/main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ import Control.Concurrent (forkIO, killThread, newEmptyMVar,
putMVar, takeMVar, threadDelay)
import Control.Exception (IOException, bracket,
onException, try)
import Control.Monad (forever)
import Control.Monad (forever, unless)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Resource (runResourceT)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy.Char8 as L8
import Data.Char (toUpper)
import Data.Conduit (Flush (..), await, yield, ($$),
($$+-), (=$))
($$+-), (=$), awaitForever)
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.List as CL
import Data.Conduit.Network (HostPreference, ServerSettings,
Expand All @@ -32,7 +32,7 @@ import Network.HTTP.ReverseProxy (ProxyDest (..),
import Network.HTTP.Types (status200, status500)
import Network.Socket (sClose)
import Network.Wai (rawPathInfo, responseLBS,
responseSource)
responseStream)
import qualified Network.Wai
import Network.Wai.Handler.Warp (defaultSettings, runSettings,
setBeforeMainLoop, setPort)
Expand Down Expand Up @@ -89,14 +89,14 @@ main = hspec $ do
it "works" $
let content = "mainApp"
in withMan $ \manager ->
withWApp (const $ return $ responseLBS status200 [] content) $ \port1 ->
withWApp (\_ f -> f $ responseLBS status200 [] content) $ \port1 ->
withWApp (waiProxyTo (const $ return $ WPRProxyDest $ ProxyDest "127.0.0.1" port1) defaultOnExc manager) $ \port2 ->
withCApp (rawProxyTo (const $ return $ Right $ ProxyDest "127.0.0.1" port2)) $ \port3 -> do
lbs <- HC.simpleHttp $ "http://127.0.0.1:" ++ show port3
lbs `shouldBe` content
it "modified path" $
let content = "/somepath"
app req = return $ responseLBS status200 [] $ L8.fromChunks [rawPathInfo req]
app req f = f $ responseLBS status200 [] $ L8.fromChunks [rawPathInfo req]
modReq pdest req = return $ WPRModifiedRequest
(req { rawPathInfo = content })
pdest
Expand All @@ -107,9 +107,9 @@ main = hspec $ do
lbs <- HC.simpleHttp $ "http://127.0.0.1:" ++ show port3
S8.concat (L8.toChunks lbs) `shouldBe` content
it "deals with streaming data" $
let app _ = return $ responseSource status200 [] $ forever $ do
yield $ Chunk $ fromByteString "hello"
yield Flush
let app _ f = f $ responseStream status200 [] $ \sendChunk flush -> forever $ do
sendChunk $ fromByteString "hello"
flush
liftIO $ threadDelay 10000000
in withMan $ \manager ->
withWApp app $ \port1 ->
Expand All @@ -120,7 +120,7 @@ main = hspec $ do
HC.responseBody res $$+- await
mbs `shouldBe` Just (Just "hello")
it "passes on body length" $
let app req = return $ responseLBS
let app req f = f $ responseLBS
status200
[("uplength", show' $ Network.Wai.requestBodyLength req)]
""
Expand All @@ -142,8 +142,12 @@ main = hspec $ do
$ fromIntegral
$ S.length body)
it "upgrade to raw" $
let app _ = return $ flip Network.Wai.responseRaw fallback $ \src sink ->
src $$ CL.iterM print =$ CL.map (S8.map toUpper) =$ sink
let app _ f = f $ flip Network.Wai.responseRaw fallback $ \src sink -> do
let src' = do
bs <- liftIO src
unless (S8.null bs) $ yield bs >> src'
sink' = awaitForever $ liftIO . sink
src' $$ CL.iterM print =$ CL.map (S8.map toUpper) =$ sink'
fallback = responseLBS status500 [] "fallback used"
in withMan $ \manager ->
withWApp app $ \port1 ->
Expand Down

0 comments on commit a519123

Please sign in to comment.