diff --git a/src/mochiweb_http.erl b/src/mochiweb_http.erl index 3a2a67a..08cf999 100644 --- a/src/mochiweb_http.erl +++ b/src/mochiweb_http.erl @@ -96,22 +96,31 @@ default_body(Req) -> default_body(Req, Req:get(method), Req:get(path)). loop(Socket, Body) -> - mochiweb_socket:setopts(Socket, [{packet, http}]), - request(Socket, Body). + ok = mochiweb_socket:setopts(Socket, [{packet, line}]), + request(Socket, Body, <<>>). -request(Socket, Body) -> - mochiweb_socket:setopts(Socket, [{active, once}]), +request(Socket, Body, Prev) -> + ok = mochiweb_socket:setopts(Socket, [{active, once}]), receive - {Protocol, _, {http_request, Method, Path, Version}} when Protocol == http orelse Protocol == ssl -> - mochiweb_socket:setopts(Socket, [{packet, httph}]), - headers(Socket, {Method, Path, Version}, [], Body, 0); - {Protocol, _, {http_error, "\r\n"}} when Protocol == http orelse Protocol == ssl -> - request(Socket, Body); - {Protocol, _, {http_error, "\n"}} when Protocol == http orelse Protocol == ssl -> - request(Socket, Body); + {Protocol, _, Bin} when Protocol =:= tcp orelse Protocol =:= ssl -> + FullBin = <>, + case erlang:decode_packet(http, FullBin, []) of + {ok, {http_request, Method, Path, Version}, <<>>} -> + collect_headers(Socket, {Method, Path, Version}, Body, + <<>>, false, 0); + {ok, {http_error, "\r\n"}, <<>>} -> + request(Socket, Body, <<>>); + {ok, {http_error, "\n"}, <<>>} -> + request(Socket, Body, <<>>); + {more, _} -> + request(Socket, Body, FullBin) + end; {tcp_closed, _} -> mochiweb_socket:close(Socket), exit(normal); + {ssl_closed, _} -> + mochiweb_socket:close(Socket), + exit(normal); _Other -> handle_invalid_request(Socket) after ?REQUEST_RECV_TIMEOUT -> @@ -124,30 +133,64 @@ reentry(Body) -> ?MODULE:after_response(Body, Req) end. -headers(Socket, Request, Headers, _Body, ?MAX_HEADERS) -> +collect_headers(Socket, Request, _Body, _Collected, _Trunc, ?MAX_HEADERS) -> %% Too many headers sent, bad request. - mochiweb_socket:setopts(Socket, [{packet, raw}]), - handle_invalid_request(Socket, Request, Headers); -headers(Socket, Request, Headers, Body, HeaderCount) -> - mochiweb_socket:setopts(Socket, [{active, once}]), + handle_invalid_request(Socket, Request, []); +collect_headers(Socket, Request, Body, Collected, Trunc, HeaderCount) -> + ok = mochiweb_socket:setopts(Socket, [{active, once}]), receive - {Protocol, _, http_eoh} when Protocol == http orelse Protocol == ssl -> - Req = new_request(Socket, Request, Headers), - call_body(Body, Req), - ?MODULE:after_response(Body, Req); - {Protocol, _, {http_header, _, Name, _, Value}} when Protocol == http orelse Protocol == ssl -> - headers(Socket, Request, [{Name, Value} | Headers], Body, - 1 + HeaderCount); + {Protocol, _, More} when Protocol =:= tcp orelse Protocol =:= ssl -> + case {Trunc, More} of + {false, <<"\n">>} -> + ok = mochiweb_socket:setopts(Socket, [{packet, raw}]), + parse_headers(Socket, Request, Body, + <>, []); + {false, <<"\r\n">>} -> + ok = mochiweb_socket:setopts(Socket, [{packet, raw}]), + parse_headers(Socket, Request, Body, + <>, []); + {_, More} -> + NewBin = <>, + AllButOne= size(More) - 1, + {Truncated, NewHdrCount} = + case More of + <<_:AllButOne/binary, "\n">> -> + {false, 1 + HeaderCount}; + _ -> + {true, HeaderCount} + end, + collect_headers(Socket, Request, Body, NewBin, + Truncated, NewHdrCount) + end; {tcp_closed, _} -> mochiweb_socket:close(Socket), exit(normal); + {ssl_closed, _} -> + mochiweb_socket:close(Socket), + exit(normal); _Other -> - handle_invalid_request(Socket, Request, Headers) + handle_invalid_request(Socket, Request, []) after ?HEADERS_RECV_TIMEOUT -> mochiweb_socket:close(Socket), exit(normal) end. +parse_headers(Socket, Request, Body, <<"\r\n">>, Headers) -> + Req = new_request(Socket, Request, lists:reverse(Headers)), + call_body(Body, Req), + ?MODULE:after_response(Body, Req); +parse_headers(Socket, Request, Body, Bin, Headers) -> + case erlang:decode_packet(httph, Bin, []) of + {ok, {http_header, _, Name, _, Value}, More} -> + parse_headers(Socket, Request, Body, More, + [{Name, Value} | Headers]); + {more, _} -> + handle_invalid_request(Socket, Request, Headers); + {error, _Reason} -> + mochiweb_socket:close(Socket), + exit(normal) + end. + call_body({M, F, A}, Req) -> erlang:apply(M, F, [Req | A]); call_body({M, F}, Req) -> @@ -290,4 +333,55 @@ range_skip_length_test() -> range_skip_length({BodySize, none}, BodySize)), ok. +long_request_line_test() -> + {ok, LS} = gen_tcp:listen(0, [binary, {active, false}]), + {ok, Port} = inet:port(LS), + spawn_link(fun() -> + {ok, S} = gen_tcp:accept(LS), + try + loop(S, {?MODULE, default_body}) + after + gen_tcp:close(S), + gen_tcp:close(LS) + end + end), + {ok, S} = gen_tcp:connect("localhost", Port, [binary, {active, false}]), + try + Req = "GET /" ++ string:chars($X, 8192) ++ " HTTP/1.1\r\n" + ++ "Host: localhost\r\n\r\n", + ok = gen_tcp:send(S, Req), + inet:setopts(S, [{packet, http}]), + ?assertEqual({ok, {http_response, {1,1}, 200, "OK"}}, + gen_tcp:recv(S, 0)), + ok + after + gen_tcp:close(S) + end. + +long_header_test() -> + {ok, LS} = gen_tcp:listen(0, [binary, {active, false}]), + {ok, Port} = inet:port(LS), + spawn_link(fun() -> + {ok, S} = gen_tcp:accept(LS), + try + loop(S, {?MODULE, default_body}) + after + gen_tcp:close(S), + gen_tcp:close(LS) + end + end), + {ok, S} = gen_tcp:connect("localhost", Port, [binary, {active, false}]), + try + Req = "GET / HTTP/1.1\r\n" + ++ "Host: localhost\r\n" + ++ "Link: /" ++ string:chars($X, 8192) ++ "\r\n\r\n", + ok = gen_tcp:send(S, Req), + inet:setopts(S, [{packet, http}]), + ?assertEqual({ok, {http_response, {1,1}, 200, "OK"}}, + gen_tcp:recv(S, 0)), + ok + after + gen_tcp:close(S) + end. + -endif.