diff --git a/src/cowboy_req.erl b/src/cowboy_req.erl index 8f0a04b52..f25817474 100644 --- a/src/cowboy_req.erl +++ b/src/cowboy_req.erl @@ -67,6 +67,8 @@ -export([has_resp_header/2]). -export([has_resp_body/1]). -export([delete_resp_header/2]). +-export([set_cors_headers/2]). +-export([set_cors_preflight_headers/2]). -export([reply/2]). -export([reply/3]). -export([reply/4]). @@ -86,6 +88,11 @@ -export([lock/1]). -export([to_list/1]). +-type cors_allowed_origins() :: [binary()] | binary(). +-type cors_allowed_methods() :: [binary()]. +-type cors_allowed_headers() :: [binary()]. +-type cors_max_age() :: non_neg_integer() | max. + -type cookie_opts() :: cow_cookie:cookie_opts(). -export_type([cookie_opts/0]). @@ -666,6 +673,122 @@ delete_resp_header(Name, Req=#http_req{resp_headers=RespHeaders}) -> RespHeaders2 = lists:keydelete(Name, 1, RespHeaders), Req#http_req{resp_headers=RespHeaders2}. +-spec set_cors_headers(map(), Req) -> Req when Req :: req(). +set_cors_headers(M, Req) -> + try + Origin = + match_cors_origin( + header(<<"origin">>, Req), + maps:get(origins, M, [])), + + Req2 = set_cors_allow_credentials(maps:get(credentials, M, false), Origin, Req), + set_cors_exposed_headers(maps:get(exposed_headers, M, []), Req2) + catch throw:_Reason -> + Req + end. + +-spec set_cors_preflight_headers(map(), Req) -> Req when Req :: req(). +set_cors_preflight_headers(M, Req) -> + try + Origin = + match_cors_origin( + header(<<"origin">>, Req), + maps:get(origins, M, [])), + Method = + match_cors_method( + header(<<"access-control-request-method">>, Req), + maps:get(methods, M, [])), + Headers = + match_cors_headers( + header(<<"access-control-request-headers">>, Req), + maps:get(headers, M, [])), + + Req2 = set_cors_allow_credentials(maps:get(credentials, M, false), Origin, Req), + Req3 = set_cors_max_age(maps:get(max_age, M, undefined), Req2), + Req4 = set_cors_allowed_methods([Method], Req3), + set_cors_allowed_headers(Headers, Req4) + catch throw:_Reason -> + Req + end. + +-spec set_cors_allow_credentials(boolean(), binary(), Req) -> Req when Req :: req(). +set_cors_allow_credentials(Credentials, Origin, Req) -> + case match_cors_credentials(Credentials, Origin) of + true -> + Req2 = set_resp_header(<<"access-control-allow-origin">>, Origin, Req), + set_resp_header(<<"access-control-allow-credentials">>, <<"true">>, Req2); + _ -> + set_resp_header(<<"access-control-allow-origin">>, Origin, Req) + end. + +-spec set_cors_max_age(cors_max_age(), Req) -> Req when Req :: req(). +set_cors_max_age(undefined, Req) -> + Req; +set_cors_max_age(max, Req) -> + set_resp_header(<<"access-control-max-age">>, <<"1728000">>, Req); +set_cors_max_age(Val, Req) -> + set_resp_header(<<"access-control-max-age">>, integer_to_binary(Val), Req). + +-spec set_cors_allowed_methods(cors_allowed_methods(), Req) -> Req when Req :: req(). +%% NOTE: just to make dialyzer happy. We would need this statement +%% if we decided to return an entire list of allowed methods +%% instead of single one passed with the particular request. +%% set_cors_allowed_methods([], Req) -> +%% Req; +set_cors_allowed_methods(Val, Req) -> + set_resp_header(<<"access-control-allow-methods">>, binary_join(Val, <<$,>>), Req). + +-spec set_cors_allowed_headers(cors_allowed_headers(), Req) -> Req when Req :: req(). +set_cors_allowed_headers([], Req) -> + Req; +set_cors_allowed_headers(Val, Req) -> + set_resp_header(<<"access-control-allow-headers">>, binary_join(Val, <<$,>>), Req). + +-spec set_cors_exposed_headers(cors_allowed_headers(), Req) -> Req when Req :: req(). +set_cors_exposed_headers([], Req) -> + Req; +set_cors_exposed_headers(L, Req) -> + set_resp_header(<<"access-control-expose-headers">>, binary_join(L, <<$,>>), Req). + +-spec match_cors_origin(binary() | undefined, cors_allowed_origins()) -> binary(). +match_cors_origin(undefined, Origins) -> + throw({bad_origin, undefined, Origins}); +match_cors_origin(Val, Val) -> + Val; +match_cors_origin(Val, <<$*>>) -> + Val; +match_cors_origin(Val, Origins) when is_list(Origins) -> + case lists:member(Val, Origins) of + true -> Val; + _ -> throw({nomatch_origin, Val, Origins}) + end; +match_cors_origin(Val, Origins) -> + throw({nomatch_origin, Val, Origins}). + +-spec match_cors_method(binary() | undefined, cors_allowed_methods()) -> binary(). +match_cors_method(undefined, Methods) -> + throw({bad_method, undefined, Methods}); +match_cors_method(Val, Methods) -> + case lists:member(Val, Methods) of + true -> Val; + _ -> throw({nomatch_method, Val, Methods}) + end. + +-spec match_cors_headers(binary() | undefined, cors_allowed_headers()) -> cors_allowed_headers(). +match_cors_headers(undefined, _) -> + []; +match_cors_headers(Val, Headers) -> + [case lists:member(Header, Headers) of + false -> throw({nomatch_header, Header, Headers}); + _ -> Header + end || Header <- binary:split(Val, [<<$,>>, <<", ">>], [global, trim_all])]. + +-spec match_cors_credentials(boolean(), binary()) -> boolean(). +match_cors_credentials(true, <<$*>>) -> + throw({bad_credentials, true, <<$*>>}); +match_cors_credentials(Val, _) -> + Val. + -spec reply(cowboy:http_status(), Req) -> Req when Req::req(). reply(Status, Req=#http_req{resp_body=Body}) -> reply(Status, [], Body, Req). @@ -1244,6 +1367,15 @@ filter_constraints(Tail, Map, Key, Value, Constraints) -> filter(Tail, Map#{Key => Value2}) end. +-spec binary_join(binary() | [binary()], binary()) -> binary(). +binary_join([H|T], Sep) -> + lists:foldl( + fun(Val, Acc) -> + <> + end, H, T). +%%binary_join([], _) -> <<>>; +%%binary_join(L, _) -> L. + %% Tests. -ifdef(TEST). @@ -1298,4 +1430,14 @@ merge_headers_test_() -> {<<"server">>,<<"Cowboy">>}]} ], [fun() -> Res = merge_headers(L,R) end || {L, R, Res} <- Tests]. + +binary_join_test_() -> + Sep = <<$,>>, + Test = + [%%{<<$b>>, <<"b">>}, + %%{[], <<>>}, + {[<<$a>>], <<$a>>}, + {[<<$a>>, <<$b>>], <<"a,b">>}], + [fun() -> Output = binary_join(Input, Sep) end || {Input, Output} <- Test]. + -endif. diff --git a/test/cors_SUITE.erl b/test/cors_SUITE.erl new file mode 100644 index 000000000..79eefb806 --- /dev/null +++ b/test/cors_SUITE.erl @@ -0,0 +1,257 @@ +%% Copyright (c) 2016, Andrei Nesterov +%% +%% Permission to use, copy, modify, and/or distribute this software for any +%% purpose with or without fee is hereby granted, provided that the above +%% copyright notice and this permission notice appear in all copies. +%% +%% THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +%% WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +%% MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +%% ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +%% WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +%% ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +%% OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +-module(cors_SUITE). +-compile(export_all). + +-import(ct_helper, [config/2]). +-import(cowboy_test, [gun_open/1]). +-import(cowboy_test, [gun_open/2]). +-import(cowboy_test, [gun_down/1]). + +%% Definitions. +-define(ORIGIN_URI, <<"http://example.org">>). +-define(REQUEST_METHOD, <<"PUT">>). + +%% ct. + +all() -> + [ + {group, http}, + {group, https} + ]. + +groups() -> + Tests = ct_helper:all(?MODULE), + [ + {http, [parallel], Tests}, + {https, [parallel], Tests} + ]. + +init_per_group(Name = http, Config) -> + cowboy_test:init_http(Name, [ + {env, [{dispatch, init_dispatch(Config)}]} + ], Config); +init_per_group(Name = https, Config) -> + cowboy_test:init_https(Name, [ + {env, [{dispatch, init_dispatch(Config)}]} + ], Config). + +end_per_group(Name, _) -> + ok = cowboy:stop_listener(Name). + +%% Dispatch configuration. + +init_dispatch(_Config) -> + OriginsVal = ?ORIGIN_URI, + OriginsAny = <<$*>>, + OriginsList = + [<<"http://example.com">>, + <<"http://example.org:80">>, + <<"httpx://example.org">>, + ?ORIGIN_URI], + Methods = [<<"GET">>, <<"PUT">>], + ExposedHeaders = Headers = [<<"H1">>, <<"H2">>, <<"H3">>], + MaxAge = 0, + + cowboy_router:compile([ + {"localhost", [ + {"/origins/val", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/origins/any", cors_echo, + [{hs, #{origins => OriginsAny}}, + {phs, #{origins => OriginsAny, methods => Methods}}]}, + {"/origins/list", cors_echo, + [{hs, #{origins => OriginsList}}, + {phs, #{origins => OriginsList, methods => Methods}}]}, + {"/credentials/false", cors_echo, + [{hs, #{origins => OriginsVal, credentials => false}}, + {phs, #{origins => OriginsVal, credentials => false, methods => Methods}}]}, + {"/credentials/true", cors_echo, + [{hs, #{origins => OriginsVal, credentials => true}}, + {phs, #{origins => OriginsVal, credentials => true, methods => Methods}}]}, + {"/credentials/true/origins/any", cors_echo, + [{hs, #{origins => OriginsAny, credentials => true}}, + {phs, #{origins => OriginsAny, credentials => true, methods => Methods}}]}, + {"/exposed_headers/undef", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/exposed_headers/list", cors_echo, + [{hs, #{origins => OriginsVal, exposed_headers => ExposedHeaders}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/max_age/undef", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/max_age/val", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, max_age => MaxAge, methods => Methods}}]}, + {"/methods/list", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, methods => Methods}}]}, + {"/headers/list", cors_echo, + [{hs, #{origins => OriginsVal}}, + {phs, #{origins => OriginsVal, headers => Headers, methods => Methods}}]} + ]} + ]). + +%% Convenience functions. + +do_request(Path, Headers, Config) -> + do_request(?REQUEST_METHOD, Path, Headers, Config). + +do_request(Method, Path, Headers, Config) -> + ConnPid = gun_open(Config), + Ref = gun:request(ConnPid, Method, Path, Headers), + {response, fin, 200, RespHeaders} = gun:await(ConnPid, Ref), + RespHeaders. + +do_preflight_request(Path, Headers, Config) -> + do_preflight_request(?REQUEST_METHOD, Path, Headers, Config). + +do_preflight_request(Method, Path, Headers, Config) -> + Headers2 = [{<<"access-control-request-method">>, Method}|Headers], + do_request(<<"OPTIONS">>, Path, Headers2, Config). + +do_find_header(Key, Headers) -> + case lists:keyfind(Key, 1, Headers) of + false -> error; + {_, Val} -> {ok, Val} + end. + +%% Tests. + +origins(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + OriginNoMatchH = {<<"origin">>, <<>>}, + Tests = + [%% Origin isn't presented + {"/origins/val", [], error}, + %% Origin isn't allowed + {"/origins/val", [OriginNoMatchH], error}, + %% Single origin value is allowed + {"/origins/val", [OriginH], {ok, ?ORIGIN_URI}}, + %% Any origin is allowed + {"/origins/any", [OriginH], {ok, ?ORIGIN_URI}}, + %% Origin is presented in the allowed origins list + {"/origins/list", [OriginH], {ok, ?ORIGIN_URI}}, + %% Origin isn't presented in the allowed origins list + {"/origins/list", [OriginNoMatchH], error}], + + %% cors requests + [begin + RespHeaders = do_request(Path, Headers, Config), + Output = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, Output} <- Tests], + + %% cors preflight requests + [begin + RespHeaders = do_preflight_request(Path, Headers, Config), + MaybeOrigin = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeOrigin} <- Tests]. + +credentials(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + Tests = + [%% Credentials aren't supported + {"/credentials/false", [OriginH], error}, + %% Credentials are supported for this particular origin + {"/credentials/true", [OriginH], {ok, <<"true">>}}, + %% Credentials are supported for any origin + {"/credentials/true/origins/any", [OriginH], {ok, <<"true">>}}], + + %% cors requests + [begin + RespHeaders = do_request(Path, Headers, Config), + MaybeCredentials = do_find_header(<<"access-control-allow-credentials">>, RespHeaders), + {ok, ?ORIGIN_URI} = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeCredentials} <- Tests], + + %% cors preflight requests + [begin + RespHeaders = do_preflight_request(Path, Headers, Config), + MaybeCredentials = do_find_header(<<"access-control-allow-credentials">>, RespHeaders), + {ok, ?ORIGIN_URI} = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeCredentials} <- Tests]. + +exposed_headers(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + Tests = + [%% Exposed headers isn't set + {"/exposed_headers/undef", [OriginH], error}, + %% Exposed headers is set + {"/exposed_headers/list", [OriginH], {ok, <<"H1,H2,H3">>}}], + + %% cors requests + [begin + RespHeaders = do_request(Path, Headers, Config), + MaybeExposedHeaders = do_find_header(<<"access-control-expose-headers">>, RespHeaders), + {ok, ?ORIGIN_URI} = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeExposedHeaders} <- Tests]. + +max_age(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + Tests = + [%% Max age isn't set + {"/max_age/undef", [OriginH], error}, + %% Max age is set + {"/max_age/val", [OriginH], {ok, <<"0">>}}], + + %% cors preflight requests + [begin + RespHeaders = do_preflight_request(Path, Headers, Config), + MaybeMaxAge = do_find_header(<<"access-control-max-age">>, RespHeaders), + {ok, ?ORIGIN_URI} = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeMaxAge} <- Tests]. + +methods(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + MethodH = fun(Val) -> {<<"access-control-request-method">>, Val} end, + Tests = + [%% Method isn't presented + {"/methods/list", [OriginH], error, error}, + %% Method isn't allowed + {"/methods/list", [OriginH, MethodH(<<"PATCH">>)], error, error}, + %% Method is allowed + {"/methods/list", [OriginH, MethodH(?REQUEST_METHOD)], {ok, ?REQUEST_METHOD}, {ok, ?ORIGIN_URI}}], + + %% cors preflight requests + [begin + RespHeaders = do_request(<<"OPTIONS">>, Path, Headers, Config), + MaybeMethods = do_find_header(<<"access-control-allow-methods">>, RespHeaders), + MaybeOrigin = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeMethods, MaybeOrigin} <- Tests]. + +headers(Config) -> + OriginH = {<<"origin">>, ?ORIGIN_URI}, + HeadersH = fun(Val) -> {<<"access-control-request-headers">>, Val} end, + Tests = + [%% Headers aren't presented + {"/headers/list", [OriginH], error, {ok, ?ORIGIN_URI}}, + %% Headers arent't allowed + {"/headers/list", [OriginH, HeadersH(<<"H8">>)], error, error}, + {"/headers/list", [OriginH, HeadersH(<<"H8,H9">>)], error, error}, + {"/headers/list", [OriginH, HeadersH(<<"H1,H9">>)], error, error}, + %% Headers are allowed + {"/headers/list", [OriginH, HeadersH(<<>>)], error, {ok, ?ORIGIN_URI}}, + {"/headers/list", [OriginH, HeadersH(<<"H1">>)], {ok, <<"H1">>}, {ok, ?ORIGIN_URI}}, + {"/headers/list", [OriginH, HeadersH(<<"H1,H2">>)], {ok, <<"H1,H2">>}, {ok, ?ORIGIN_URI}}], + + %% cors preflight requests + [begin + RespHeaders = do_preflight_request(Path, Headers, Config), + MaybeHeaders = do_find_header(<<"access-control-allow-headers">>, RespHeaders), + MaybeOrigin = do_find_header(<<"access-control-allow-origin">>, RespHeaders) + end || {Path, Headers, MaybeHeaders, MaybeOrigin} <- Tests]. + diff --git a/test/cors_SUITE_data/cors_echo.erl b/test/cors_SUITE_data/cors_echo.erl new file mode 100644 index 000000000..541c8c165 --- /dev/null +++ b/test/cors_SUITE_data/cors_echo.erl @@ -0,0 +1,16 @@ +%% Feel free to use, reuse and abuse the code in this file. + +-module(cors_echo). + +-export([init/2]). + +init(Req, Opts) -> + {_, Hs} = lists:keyfind(hs, 1, Opts), + {_, PHs} = lists:keyfind(phs, 1, Opts), + Req2 = + case cowboy_req:method(Req) of + <<"OPTIONS">> -> cowboy_req:set_cors_preflight_headers(PHs, Req); + _ -> cowboy_req:set_cors_headers(Hs, Req) + end, + {ok, cowboy_req:reply(200, Req2), Opts}. +