Skip to content

Commit

Permalink
Fix snap-server vs recent snap-core changes to rqBody
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorycollins committed Apr 21, 2010
1 parent c70890f commit ba5f058
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
39 changes: 26 additions & 13 deletions src/Snap/Internal/Http/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Internal (c2w, w2c)
import qualified Data.ByteString.Nums.Careless.Int as Cvt
import Data.IORef
import Data.List (foldl')
import qualified Data.Map as Map
import Data.Maybe (catMaybes, fromMaybe)
Expand Down Expand Up @@ -248,7 +249,9 @@ httpSession writeEnd handler = do
(req',rsp) <- lift $ handler req

liftIO $ debug "Server.httpSession: handled, skipping request body"
lift $ joinIM $ rqBody req' skipToEof
srqEnum <- liftIO $ readIORef $ rqBody req'
let (SomeEnumerator rqEnum) = srqEnum
lift $ joinIM $ rqEnum skipToEof
liftIO $ debug "Server.httpSession: request body skipped, sending response"

date <- liftIO getDateString
Expand Down Expand Up @@ -277,7 +280,9 @@ receiveRequest = do

case mreq of
(Just ireq) -> do
req <- toRequest ireq >>= setEnumerator >>= parseForm
req' <- toRequest ireq
setEnumerator req'
req <- parseForm req'
checkConnectionClose (rqVersion req) (rqHeaders req)
return $ Just req

Expand All @@ -293,27 +298,30 @@ receiveRequest = do
--
-- if no content-length and no chunked encoding, enumerate the entire
-- socket and close afterwards
setEnumerator :: Request -> ServerMonad Request
setEnumerator :: Request -> ServerMonad ()
setEnumerator req =
if isChunked
then return req { rqBody = readChunkedTransferEncoding }
then liftIO $ writeIORef (rqBody req)
(SomeEnumerator readChunkedTransferEncoding)
else maybe noContentLength hasContentLength mbCL

where
isChunked = maybe False
((== ["chunked"]) . map toCI)
(Map.lookup "transfer-encoding" hdrs)

hasContentLength :: Int -> ServerMonad Request
hasContentLength :: Int -> ServerMonad ()
hasContentLength l = do
return $ req { rqBody = e }
liftIO $ writeIORef (rqBody req)
(SomeEnumerator e)
where
e :: Enumerator IO a
e = return . joinI . I.take l

noContentLength :: ServerMonad Request
noContentLength = do
return $ req { rqBody = return . joinI . I.take 0 }
noContentLength :: ServerMonad ()
noContentLength =
liftIO $ writeIORef (rqBody req)
(SomeEnumerator $ return . joinI . I.take 0 )


hdrs = rqHeaders req
Expand All @@ -328,11 +336,14 @@ receiveRequest = do

getIt :: ServerMonad Request
getIt = do
iter <- liftIO $ rqBody req stream2stream
senum <- liftIO $ readIORef $ rqBody req
let (SomeEnumerator enum) = senum
iter <- liftIO $ enum stream2stream
body <- lift iter
let newParams = parseUrlEncoded $ strictize $ fromWrap body
return $ req { rqBody = return
, rqParams = rqParams req `mappend` newParams }
liftIO $ writeIORef (rqBody req)
(SomeEnumerator $ return . I.joinI . I.take 0)
return $ req { rqParams = rqParams req `mappend` newParams }


toRequest (IRequest method uri version kvps) = do
Expand All @@ -347,6 +358,9 @@ receiveRequest = do
(liftM (parseHost . head)
(Map.lookup "host" hdrs))

-- will override in "setEnumerator"
enum <- liftIO $ newIORef $ SomeEnumerator return


return $ Request serverName
serverPort
Expand Down Expand Up @@ -394,7 +408,6 @@ receiveRequest = do
where
(a,b) = S.break (== (c2w ':')) h

enum = return -- will override in "setEnumerator"
params = parseUrlEncoded queryString

(pathInfo, queryString) = first dropLeadingSlash . second (S.drop 1) $
Expand Down
27 changes: 20 additions & 7 deletions test/suite/Snap/Internal/Http/Server/Tests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ testHttpRequest1 =
iter <- enumBS sampleRequest $
do
r <- liftM fromJust $ rsm receiveRequest
b <- liftM fromWrap $ joinIM $ rqBody r stream2stream
se <- liftIO $ readIORef (rqBody r)
let (SomeEnumerator e) = se
b <- liftM fromWrap $ joinIM $ e stream2stream
return (r,b)

(req,body) <- run iter
Expand Down Expand Up @@ -114,9 +116,13 @@ testMultiRequest =
iter <- (enumBS sampleRequest >. enumBS sampleRequest) $
do
r1 <- liftM fromJust $ rsm receiveRequest
b1 <- liftM fromWrap $ joinIM $ rqBody r1 stream2stream
se1 <- liftIO $ readIORef (rqBody r1)
let (SomeEnumerator e1) = se1
b1 <- liftM fromWrap $ joinIM $ e1 stream2stream
r2 <- liftM fromJust $ rsm receiveRequest
b2 <- liftM fromWrap $ joinIM $ rqBody r2 stream2stream
se2 <- liftIO $ readIORef (rqBody r2)
let (SomeEnumerator e2) = se2
b2 <- liftM fromWrap $ joinIM $ e2 stream2stream
return (r1,b1,r2,b2)

(req1,body1,req2,body2) <- run iter
Expand Down Expand Up @@ -187,7 +193,9 @@ testHttpRequest2 =
iter <- enumBS sampleRequest2 $
do
r <- liftM fromJust $ rsm receiveRequest
b <- liftM fromWrap $ joinIM $ rqBody r stream2stream
se <- liftIO $ readIORef (rqBody r)
let (SomeEnumerator e) = se
b <- liftM fromWrap $ joinIM $ e stream2stream
return (r,b)

(_,body) <- run iter
Expand All @@ -201,7 +209,9 @@ testHttpRequest3 =
iter <- enumBS sampleRequest3 $
do
r <- liftM fromJust $ rsm receiveRequest
b <- liftM fromWrap $ joinIM $ rqBody r stream2stream
se <- liftIO $ readIORef (rqBody r)
let (SomeEnumerator e) = se
b <- liftM fromWrap $ joinIM $ e stream2stream
return (r,b)

(req,body) <- run iter
Expand Down Expand Up @@ -297,10 +307,13 @@ testHttpResponse1 = testCase "HttpResponse1" $ do

echoServer :: Request -> Iteratee IO (Request,Response)
echoServer req = do
let i = joinIM $ rqBody req stream2stream
se <- liftIO $ readIORef (rqBody req)
let (SomeEnumerator enum) = se
let i = joinIM $ enum stream2stream
b <- liftM fromWrap i
let cl = L.length b
return (req {rqBody=return . joinI . take 0}, rsp b cl)
liftIO $ writeIORef (rqBody req) (SomeEnumerator $ return . joinI . take 0)
return (req, rsp b cl)
where
rsp s cl = emptyResponse { rspBody = enumLBS s
, rspContentLength = Just $ fromIntegral cl }
Expand Down

0 comments on commit ba5f058

Please sign in to comment.