Skip to content

Commit

Permalink
Feed data frames (request body) into the Request.
Browse files Browse the repository at this point in the history
The Request data type has a Source for the request body.
Create a Source with a channel linked to it, and feed the channel when we
get data frames.
  • Loading branch information
kolmodin committed Mar 11, 2012
1 parent f194d75 commit 5ce9d4f
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Network/SPDY.hs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ onSynStreamFrame state sId pri nvh = do
putStrLn "Constructed frame:"
print ("syn_reply", sId, nvh')
return (SynReplyControlFrame 0 sId nvhReply) :: IO Frame
enqueueFrame state $ return $ DataFrame sId 1 $ S.concat ("<html><h1>hello from spdy</h1><br/>" : S.concat ([ C8.pack (show b ++ "<br/>") | b <- nvh ]) : "</html>" : [])
enqueueFrame state $ return $ DataFrame 1 sId $ S.concat ("<html><h1>hello from spdy</h1><br/>" : S.concat ([ C8.pack (show b ++ "<br/>") | b <- nvh ]) : "</html>" : [])
where
utf8 (s,t) = (decodeUtf8 s, decodeUtf8 t)

Expand Down
2 changes: 1 addition & 1 deletion Network/SPDY/Frame.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ ourSPDYVersion = 2

data Frame
= DataFrame {
dataFrameStreamID :: Word32,
dataFrameFlags :: Word8,
dataFrameStreamID :: Word32,
dataFramePayload :: B.ByteString }
| SynStreamControlFrame {
controlFrameFlags :: Word8,
Expand Down
93 changes: 74 additions & 19 deletions Network/Wai/Handler/Hope.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE RecordWildCards, OverloadedStrings, DeriveDataTypeable #-}
{-# LANGUAGE RecordWildCards, OverloadedStrings, DeriveDataTypeable, NamedFieldPuns #-}
module Network.Wai.Handler.Hope where

import Network.Wai
Expand All @@ -24,13 +24,17 @@ import qualified Data.ByteString.Char8 as C8 ( pack ) -- Also IsString instance
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L

import Data.Bits

import Network.Socket hiding ( recv, Closed )

import Control.Applicative
import Control.Concurrent
import Control.Concurrent.STM.TVar
import Control.Monad.STM

import Control.Monad ( when )

import Control.Exception ( Exception, throwIO, Handler(..), catches )
import Data.Typeable

Expand Down Expand Up @@ -138,6 +142,7 @@ data StreamState = StreamState
{ streamStateID :: Word32
, streamStatePriority :: Word8
, streamStateReplyThread :: ThreadId
, streamStateBodyChan :: Maybe (Chan (Maybe S.ByteString))
}

data SPDYException
Expand All @@ -159,12 +164,28 @@ frameHandler app sockaddr state frame = do
print frame
case frame of
SynStreamControlFrame flags sId assId pri nvh -> do
state' <- createStream app sockaddr state sId pri nvh
state' <- createStream app sockaddr state flags sId pri nvh
return state'
RstStreamControlFrame flags sId status -> do
putStrLn "RstStream... we're screwed."
-- TODO: remove all knowledge of this stream. empty send buffer.
return state
DataFrame flags sId payload -> do
let flag_fin = testBit flags 0
streamM <- getStreamState state sId
case streamM of
Nothing -> do sendRstStream state sId 2 -- 2 == INVALID_STREAM
return state
Just s -> do
let bodyChan = streamStateBodyChan s
case bodyChan of
Nothing -> do sendRstStream state sId 2 -- which error code?
return state
Just chan -> do writeChan chan (Just payload)
let s' | flag_fin = s { streamStateBodyChan = Nothing }
| otherwise = s
state' = updateStreamState state s'
return state'
PingControlFrame pingId -> do
enqueueFrame state $ return (PingControlFrame pingId)
return state
Expand All @@ -175,14 +196,28 @@ frameHandler app sockaddr state frame = do
NoopControlFrame -> do
return state

getStreamState :: SessionState -> Word32 -> IO (Maybe StreamState)
getStreamState state sId = do
let streamStates = sessionStateStreamStates state
case filter (\s -> streamStateID s == sId) streamStates of
[s] -> return (Just s)
[] -> return Nothing

updateStreamState :: SessionState -> StreamState -> SessionState
updateStreamState state stream =
let streamStates = sessionStateStreamStates state
sId = streamStateID stream
streams = filter (\s -> streamStateID s /= sId) streamStates
in state { sessionStateStreamStates = stream : streams }

enqueueFrame :: SessionState -> IO Frame -> IO ()
enqueueFrame SessionState { sessionStateSendQueue = queue } frame =
atomically $ do
q <- readTVar queue
writeTVar queue (q ++ [frame])

createStream :: Application -> SockAddr -> SessionState -> Word32 -> Word8 -> S.ByteString -> IO SessionState
createStream app sockaddr state@(SessionState { sessionStateNVHReceiveZContext = zInflate }) sId pri nvhBytes = do
createStream :: Application -> SockAddr -> SessionState -> Word8 -> Word32 -> Word8 -> S.ByteString -> IO SessionState
createStream app sockaddr state@(SessionState { sessionStateNVHReceiveZContext = zInflate }) flags sId pri nvhBytes = do
putStrLn $ "Creating stream context, id = " ++ show sId
nvhChunks <- do a <- withInflateInput zInflate nvhBytes popper
b <- flushInflate zInflate
Expand All @@ -192,8 +227,8 @@ createStream app sockaddr state@(SessionState { sessionStateNVHReceiveZContext =
Fail _ _ msg -> throwIO (SPDYNVHException Nothing msg)
Partial _ -> throwIO (SPDYNVHException Nothing "Could not parse NVH block, returned Partial.")
print (sId, pri, nvh)
tId <- onSynStreamFrame app sockaddr state sId pri nvh
let streamState = StreamState sId pri tId
(tId, bodyChan) <- onSynStreamFrame app sockaddr state flags sId pri nvh
let streamState = StreamState sId pri tId bodyChan
return state { sessionStateStreamStates = streamState : sessionStateStreamStates state }
where
feedAll r [] = r
Expand All @@ -212,16 +247,20 @@ sendGoAway :: SessionState -> Word32 -> IO ()
sendGoAway state sId = do
enqueueFrame state $ return $ GoAwayFrame 0 sId

sendRstStream :: SessionState -> Word8 -> Word32 -> Word32 -> IO ()
sendRstStream state flags sId status = do
enqueueFrame state $ return $ RstStreamControlFrame flags sId status
sendRstStream :: SessionState -> Word32 -> Word32 -> IO ()
sendRstStream state sId status = do
enqueueFrame state $ return $ RstStreamControlFrame 0 sId status

onSynStreamFrame :: Application -> SockAddr -> SessionState -> Word32 -> Word8 -> NameValueHeaderBlock -> IO ThreadId
onSynStreamFrame app sockaddr state sId pri nvh = do
req <- case buildReq sockaddr nvh of -- catch errors, return protocol_error on stream
onSynStreamFrame :: Application -> SockAddr -> SessionState -> Word8 -> Word32 -> Word8 -> NameValueHeaderBlock -> IO (ThreadId, Maybe (Chan (Maybe S.ByteString)))
onSynStreamFrame app sockaddr state flags sId pri nvh = do
(bodySource,bodyChan) <- mkChanSource
req <- case buildReq sockaddr bodySource nvh of -- catch errors, return protocol_error on stream
Right req -> return req
Left err -> throwIO (SPDYNVHException (Just sId) err)
forkIO $ runResourceT $ do
let flag_fin = testBit flags 0 -- other side said no more frames from their side
when flag_fin $ do
writeChan bodyChan Nothing
tId <- forkIO $ runResourceT $ do
resp <- app req
let (status, responseHeaders, source) = responseSource resp
headerStatus = ("status", showStatus status)
Expand All @@ -238,10 +277,11 @@ onSynStreamFrame app sockaddr state sId pri nvh = do
print ("syn_reply" :: String, sId, nvh')
return (SynReplyControlFrame 0 sId nvhReply) :: IO Frame
source $$ enqueueFrameSink
return (tId, if flag_fin then Nothing else Just bodyChan)
where
utf8 (s,t) = (decodeUtf8 s, decodeUtf8 t)
showStatus (Status statusCode statusMessage) = S.concat [C8.pack (show statusCode), " ", statusMessage]
mkDataFrame = DataFrame sId 0
mkDataFrame = DataFrame 0 sId
enqueueFrameSink =
sinkState
()
Expand All @@ -250,10 +290,10 @@ onSynStreamFrame app sockaddr state sId pri nvh = do
(Chunk inpBuilder) -> liftIO $ enqueueFrame state $ return $ mkDataFrame (toByteString inpBuilder)
Flush -> return ()
return (StateProcessing ()))
(\_ -> liftIO $ enqueueFrame state $ return $ DataFrame sId 1 "")
(\_ -> liftIO $ enqueueFrame state $ return $ DataFrame 1 sId "")

buildReq :: SockAddr -> NameValueHeaderBlock -> Either String Request
buildReq sockaddr nvh = do
buildReq :: SockAddr -> Source IO S.ByteString -> NameValueHeaderBlock -> Either String Request
buildReq sockaddr bodySource nvh = do
method <- case lookup (decodeUtf8 "method") nvh of
Just m -> return m
Nothing -> Left "no method in NVH block"
Expand Down Expand Up @@ -292,7 +332,7 @@ buildReq sockaddr nvh = do
, isSecure = True
, remoteHost = sockaddr
, queryString = H.parseQuery (encodeUtf8 query)
, requestBody = sourceState () (\_ -> return StateClosed)
, requestBody = bodySource
, vault = V.empty
}
where
Expand All @@ -303,6 +343,21 @@ buildReq sockaddr nvh = do
A.endOfInput
return $ HttpVersion x y)

mkChanSource :: ResourceIO m => IO (Source m S.ByteString, Chan (Maybe S.ByteString))
mkChanSource = do
chan <- newChan
return (chan2source chan, chan)

chan2source :: ResourceIO m => Chan (Maybe S.ByteString) -> Source m S.ByteString
chan2source chan =
sourceIO
(return ())
(\_ -> return ())
(\_ -> do v <- liftIO $ readChan chan
case v of
Nothing -> return IOClosed
Just bs -> return (IOOpen bs))

sender :: TLSCtx a -> TVar [IO Frame] -> IO ()
sender tlsctx queue = go
where
Expand Down Expand Up @@ -334,7 +389,7 @@ sessionHandler handler tlsctx sockaddr = do
SPDYParseException str -> do putStrLn ("Caught this! " ++ show e)
sendGoAway initS 0
SPDYNVHException (Just sId) str -> do putStrLn ("Caught this! " ++ show e)
sendRstStream initS 0 sId 1
sendRstStream initS sId 1
SPDYNVHException Nothing str -> do putStrLn ("Caught this! " ++ show e)
sendGoAway initS 0)
, Handler (\e ->
Expand Down
5 changes: 2 additions & 3 deletions tests/QC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ arbitraryDataFrame = do
sId <- arbitraryWord31be
flags <- arbitrary
payload <- genPayload
return (DataFrame sId flags payload)
return (DataFrame flags sId payload)

arbitrarySynStreamFrame :: Gen Frame
arbitrarySynStreamFrame = do
Expand All @@ -112,10 +112,9 @@ arbitrarySynReplyStreamFrame = do

arbitraryRstStreamFrame :: Gen Frame
arbitraryRstStreamFrame = do
flags <- arbitrary
sId <- arbitraryWord31be
status <- arbitrary
return (RstStreamControlFrame flags sId status)
return (RstStreamControlFrame 0 sId status)

arbitraryGoAwayFrame :: Gen Frame
arbitraryGoAwayFrame = do
Expand Down

0 comments on commit 5ce9d4f

Please sign in to comment.