Skip to content

Commit

Permalink
Add support of CORS
Browse files Browse the repository at this point in the history
  • Loading branch information
manifest committed Feb 23, 2016
1 parent dbb6360 commit ade828e
Show file tree
Hide file tree
Showing 3 changed files with 415 additions and 0 deletions.
142 changes: 142 additions & 0 deletions src/cowboy_req.erl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
-export([has_resp_header/2]).
-export([has_resp_body/1]).
-export([delete_resp_header/2]).
-export([set_cors_headers/2]).
-export([set_cors_preflight_headers/2]).
-export([reply/2]).
-export([reply/3]).
-export([reply/4]).
Expand All @@ -86,6 +88,11 @@
-export([lock/1]).
-export([to_list/1]).

-type cors_allowed_origins() :: [binary()] | binary().
-type cors_allowed_methods() :: [binary()].
-type cors_allowed_headers() :: [binary()].
-type cors_max_age() :: non_neg_integer() | max.

-type cookie_opts() :: cow_cookie:cookie_opts().
-export_type([cookie_opts/0]).

Expand Down Expand Up @@ -666,6 +673,122 @@ delete_resp_header(Name, Req=#http_req{resp_headers=RespHeaders}) ->
RespHeaders2 = lists:keydelete(Name, 1, RespHeaders),
Req#http_req{resp_headers=RespHeaders2}.

-spec set_cors_headers(map(), Req) -> Req when Req :: req().
set_cors_headers(M, Req) ->
try
Origin =
match_cors_origin(
header(<<"origin">>, Req),
maps:get(origins, M, [])),

Req2 = set_cors_allow_credentials(maps:get(credentials, M, false), Origin, Req),
set_cors_exposed_headers(maps:get(exposed_headers, M, []), Req2)
catch throw:_Reason ->
Req
end.

-spec set_cors_preflight_headers(map(), Req) -> Req when Req :: req().
set_cors_preflight_headers(M, Req) ->
try
Origin =
match_cors_origin(
header(<<"origin">>, Req),
maps:get(origins, M, [])),
Method =
match_cors_method(
header(<<"access-control-request-method">>, Req),
maps:get(methods, M, [])),
Headers =
match_cors_headers(
header(<<"access-control-request-headers">>, Req),
maps:get(headers, M, [])),

Req2 = set_cors_allow_credentials(maps:get(credentials, M, false), Origin, Req),
Req3 = set_cors_max_age(maps:get(max_age, M, undefined), Req2),
Req4 = set_cors_allowed_methods([Method], Req3),
set_cors_allowed_headers(Headers, Req4)
catch throw:_Reason ->
Req
end.

-spec set_cors_allow_credentials(boolean(), binary(), Req) -> Req when Req :: req().
set_cors_allow_credentials(Credentials, Origin, Req) ->
case match_cors_credentials(Credentials, Origin) of
true ->
Req2 = set_resp_header(<<"access-control-allow-origin">>, Origin, Req),
set_resp_header(<<"access-control-allow-credentials">>, <<"true">>, Req2);
_ ->
set_resp_header(<<"access-control-allow-origin">>, Origin, Req)
end.

-spec set_cors_max_age(cors_max_age(), Req) -> Req when Req :: req().
set_cors_max_age(undefined, Req) ->
Req;
set_cors_max_age(max, Req) ->
set_resp_header(<<"access-control-max-age">>, <<"1728000">>, Req);
set_cors_max_age(Val, Req) ->
set_resp_header(<<"access-control-max-age">>, integer_to_binary(Val), Req).

-spec set_cors_allowed_methods(cors_allowed_methods(), Req) -> Req when Req :: req().
%% NOTE: just to make dialyzer happy. We would need this statement
%% if we decided to return an entire list of allowed methods
%% instead of single one passed with the particular request.
%% set_cors_allowed_methods([], Req) ->
%% Req;
set_cors_allowed_methods(Val, Req) ->
set_resp_header(<<"access-control-allow-methods">>, binary_join(Val, <<$,>>), Req).

-spec set_cors_allowed_headers(cors_allowed_headers(), Req) -> Req when Req :: req().
set_cors_allowed_headers([], Req) ->
Req;
set_cors_allowed_headers(Val, Req) ->
set_resp_header(<<"access-control-allow-headers">>, binary_join(Val, <<$,>>), Req).

-spec set_cors_exposed_headers(cors_allowed_headers(), Req) -> Req when Req :: req().
set_cors_exposed_headers([], Req) ->
Req;
set_cors_exposed_headers(L, Req) ->
set_resp_header(<<"access-control-expose-headers">>, binary_join(L, <<$,>>), Req).

-spec match_cors_origin(binary() | undefined, cors_allowed_origins()) -> binary().
match_cors_origin(undefined, Origins) ->
throw({bad_origin, undefined, Origins});
match_cors_origin(Val, Val) ->
Val;
match_cors_origin(Val, <<$*>>) ->
Val;
match_cors_origin(Val, Origins) when is_list(Origins) ->
case lists:member(Val, Origins) of
true -> Val;
_ -> throw({nomatch_origin, Val, Origins})
end;
match_cors_origin(Val, Origins) ->
throw({nomatch_origin, Val, Origins}).

-spec match_cors_method(binary() | undefined, cors_allowed_methods()) -> binary().
match_cors_method(undefined, Methods) ->
throw({bad_method, undefined, Methods});
match_cors_method(Val, Methods) ->
case lists:member(Val, Methods) of
true -> Val;
_ -> throw({nomatch_method, Val, Methods})
end.

-spec match_cors_headers(binary() | undefined, cors_allowed_headers()) -> cors_allowed_headers().
match_cors_headers(undefined, _) ->
[];
match_cors_headers(Val, Headers) ->
[case lists:member(Header, Headers) of
false -> throw({nomatch_header, Header, Headers});
_ -> Header
end || Header <- binary:split(Val, [<<$,>>, <<", ">>], [global, trim_all])].

-spec match_cors_credentials(boolean(), binary()) -> boolean().
match_cors_credentials(true, <<$*>>) ->
throw({bad_credentials, true, <<$*>>});
match_cors_credentials(Val, _) ->
Val.

-spec reply(cowboy:http_status(), Req) -> Req when Req::req().
reply(Status, Req=#http_req{resp_body=Body}) ->
reply(Status, [], Body, Req).
Expand Down Expand Up @@ -1244,6 +1367,15 @@ filter_constraints(Tail, Map, Key, Value, Constraints) ->
filter(Tail, Map#{Key => Value2})
end.

-spec binary_join(binary() | [binary()], binary()) -> binary().
binary_join([H|T], Sep) ->
lists:foldl(
fun(Val, Acc) ->
<<Acc/binary, Sep/binary, Val/binary>>
end, H, T).
%%binary_join([], _) -> <<>>;
%%binary_join(L, _) -> L.

%% Tests.

-ifdef(TEST).
Expand Down Expand Up @@ -1298,4 +1430,14 @@ merge_headers_test_() ->
{<<"server">>,<<"Cowboy">>}]}
],
[fun() -> Res = merge_headers(L,R) end || {L, R, Res} <- Tests].

binary_join_test_() ->
Sep = <<$,>>,
Test =
[%%{<<$b>>, <<"b">>},
%%{[], <<>>},
{[<<$a>>], <<$a>>},
{[<<$a>>, <<$b>>], <<"a,b">>}],
[fun() -> Output = binary_join(Input, Sep) end || {Input, Output} <- Test].

-endif.

0 comments on commit ade828e

Please sign in to comment.