Skip to content

Commit

Permalink
Merge pull request #227 from enidgjoleka/handle-tls-connection-during…
Browse files Browse the repository at this point in the history
…-cancellation

Handle tls connection during cancellation
  • Loading branch information
seriyps committed Apr 27, 2020
2 parents d8ca6d4 + 6b166f5 commit f0d07e3
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 81 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ epgsql:cancel(connection()) -> ok.

PostgreSQL protocol supports [cancellation](https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.9)
of currently executing command. `cancel/1` sends a cancellation request via the
new temporary TCP connection asynchronously, it doesn't await for the command to
new temporary TCP/TLS_over_TCP connection asynchronously, it doesn't await for the command to
be cancelled. Instead, client should expect to get
`{error, #error{code = <<"57014">>, codename = query_canceled}}` back from
the command that was cancelled. However, normal response can still be received as well.
Expand Down
142 changes: 75 additions & 67 deletions src/commands/epgsql_cmd_connect.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
%%%
-module(epgsql_cmd_connect).
-behaviour(epgsql_command).
-export([hide_password/1, opts_hide_password/1]).
-export([hide_password/1, opts_hide_password/1, open_socket/2]).
-export([init/1, execute/2, handle_message/4]).
-export_type([response/0, connect_error/0]).

Expand Down Expand Up @@ -47,22 +47,59 @@
init(#{host := _, username := _} = Opts) ->
#connect{opts = Opts}.

execute(PgSock, #connect{opts = #{host := Host} = Opts, stage = connect} = State) ->
Timeout = maps:get(timeout, Opts, 5000),
Deadline = deadline(Timeout),
Port = maps:get(port, Opts, 5432),
execute(PgSock, #connect{opts = #{username := Username} = Opts, stage = connect} = State) ->
SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
{ok, Sock} ->
client_handshake(Sock, PgSock, State, Deadline);
FilteredOpts = filter_sensitive_info(Opts),
PgSock1 = epgsql_sock:set_attr(connect_opts, FilteredOpts, PgSock),
case open_socket(SockOpts, Opts) of
{ok, Mode, Sock} ->
PgSock2 = epgsql_sock:set_net_socket(Mode, Sock, PgSock1),
Opts2 = ["user", 0, Username, 0],
Opts3 = case maps:find(database, Opts) of
error -> Opts2;
{ok, Database} -> [Opts2 | ["database", 0, Database, 0]]
end,
{Opts4, PgSock3} =
case Opts of
#{replication := Replication} ->
{[Opts3 | ["replication", 0, Replication, 0]],
epgsql_sock:init_replication_state(PgSock2)};
_ -> {Opts3, PgSock2}
end,
Opts5 = case Opts of
#{application_name := ApplicationName} ->
[Opts4 | ["application_name", 0, ApplicationName, 0]];
_ ->
Opts4
end,
ok = epgsql_sock:send(PgSock3, [<<196608:?int32>>, Opts5, 0]),
PgSock4 = case Opts of
#{async := Async} ->
epgsql_sock:set_attr(async, Async, PgSock3);
_ -> PgSock3
end,
{ok, PgSock4, State#connect{stage = maybe_auth}};
{error, Reason} = Error ->
{stop, Reason, Error, PgSock}
end;
execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
ok = epgsql_sock:send(PgSock, PacketId, Data),
{ok, PgSock, St#connect{auth_send = undefined}}.

client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} = State, Deadline) ->
-spec open_socket([{atom(), any()}], epgsql:connect_opts()) ->
{ok , gen_tcp | ssl, port() | ssl:sslsocket()} | {error, any()}.
open_socket(SockOpts, #{host := Host} = ConnectOpts) ->
Timeout = maps:get(timeout, ConnectOpts, 5000),
Deadline = deadline(Timeout),
Port = maps:get(port, ConnectOpts, 5432),
case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
{ok, Sock} ->
client_handshake(Sock, ConnectOpts, Deadline);
{error, _Reason} = Error ->
Error
end.

client_handshake(Sock, ConnectOpts, Deadline) ->
%% Increase the buffer size. Following the recommendation in the inet man page:
%%
%% It is recommended to have val(buffer) >=
Expand All @@ -71,46 +108,45 @@ client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} =
{ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
inet:getopts(Sock, [recbuf, sndbuf]),
inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
maybe_ssl(Sock, maps:get(ssl, ConnectOpts, false), ConnectOpts, Deadline).

case maybe_ssl(Sock, maps:get(ssl, Opts, false), Opts, PgSock, Deadline) of
maybe_ssl(Sock, false, _ConnectOpts, _Deadline) ->
{ok, gen_tcp, Sock};
maybe_ssl(Sock, Flag, ConnectOpts, Deadline) ->
ok = gen_tcp:send(Sock, <<8:?int32, 80877103:?int32>>),
Timeout0 = timeout(Deadline),
case gen_tcp:recv(Sock, 1, Timeout0) of
{ok, <<$S>>} ->
SslOpts = maps:get(ssl_opts, ConnectOpts, []),
Timeout = timeout(Deadline),
case ssl:connect(Sock, SslOpts, Timeout) of
{ok, Sock2} ->
{ok, ssl, Sock2};
{error, Reason} ->
Err = {ssl_negotiation_failed, Reason},
{error, Err}
end;
{ok, <<$N>>} ->
case Flag of
true ->
{ok, gen_tcp, Sock};
required ->
{error, ssl_not_available}
end;
{error, Reason} ->
{stop, Reason, {error, Reason}, PgSock};
PgSock1 ->
Opts2 = ["user", 0, Username, 0],
Opts3 = case maps:find(database, Opts) of
error -> Opts2;
{ok, Database} -> [Opts2 | ["database", 0, Database, 0]]
end,

{Opts4, PgSock2} =
case Opts of
#{replication := Replication} ->
{[Opts3 | ["replication", 0, Replication, 0]],
epgsql_sock:init_replication_state(PgSock1)};
_ -> {Opts3, PgSock1}
end,
Opts5 = case Opts of
#{application_name := ApplicationName} ->
[Opts3 | ["application_name", 0, ApplicationName, 0]];
_ ->
Opts4
end,
ok = epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts5, 0]),
PgSock3 = case Opts of
#{async := Async} ->
epgsql_sock:set_attr(async, Async, PgSock2);
_ -> PgSock2
end,
{ok, PgSock3, State#connect{stage = maybe_auth}}
{error, Reason}
end.


%% @doc Replace `password' in Opts map with obfuscated one
opts_hide_password(#{password := Password} = Opts) ->
HiddenPassword = hide_password(Password),
Opts#{password => HiddenPassword};
opts_hide_password(Opts) -> Opts.

%% @doc password and username are sensitive data that should not be stored in a
%% permanent state that might crash during code upgrade
filter_sensitive_info(Opts0) ->
maps:without([password, username], Opts0).

%% @doc this function wraps plaintext password to a lambda function, so, if
%% epgsql_sock process crashes when executing `connect' command, password will
Expand All @@ -124,34 +160,6 @@ hide_password(Password) when is_list(Password);
hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
PasswordFun.


maybe_ssl(S, false, _, PgSock, _Deadline) ->
epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
maybe_ssl(S, Flag, Opts, PgSock, Deadline) ->
ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
Timeout0 = timeout(Deadline),
case gen_tcp:recv(S, 1, Timeout0) of
{ok, <<$S>>} ->
SslOpts = maps:get(ssl_opts, Opts, []),
Timeout = timeout(Deadline),
case ssl:connect(S, SslOpts, Timeout) of
{ok, S2} ->
epgsql_sock:set_net_socket(ssl, S2, PgSock);
{error, Reason} ->
Err = {ssl_negotiation_failed, Reason},
{error, Err}
end;
{ok, <<$N>>} ->
case Flag of
true ->
epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
required ->
{error, ssl_not_available}
end;
{error, Reason} ->
{error, Reason}
end.

%% Auth sub-protocol

auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
Expand Down
29 changes: 16 additions & 13 deletions src/epgsql_sock.erl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@
sync_required :: boolean() | undefined,
txstatus :: byte() | undefined, % $I | $T | $E,
complete_status :: atom() | {atom(), integer()} | undefined,
repl :: repl_state() | undefined}).
repl :: repl_state() | undefined,
connect_opts :: epgsql:connect_opts() | undefined}).

-opaque pg_sock() :: #state{}.

Expand Down Expand Up @@ -158,7 +159,9 @@ set_attr(codec, Codec, State) ->
set_attr(sync_required, Value, State) ->
State#state{sync_required = Value};
set_attr(replication_state, Value, State) ->
State#state{repl = Value}.
State#state{repl = Value};
set_attr(connect_opts, ConnectOpts, State) ->
State#state{connect_opts = ConnectOpts}.

%% XXX: be careful!
-spec set_packet_handler(atom(), pg_sock()) -> pg_sock().
Expand Down Expand Up @@ -225,17 +228,17 @@ handle_cast(stop, State) ->
{stop, normal, flush_queue(State, {error, closed})};

handle_cast(cancel, State = #state{backend = {Pid, Key},
sock = TimedOutSock}) ->
{ok, {Addr, Port}} = case State#state.mod of
gen_tcp -> inet:peername(TimedOutSock);
ssl -> ssl:peername(TimedOutSock)
end,
connect_opts = ConnectOpts,
mod = Mode}) ->
SockOpts = [{active, false}, {packet, raw}, binary],
%% TODO timeout
{ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
ok = gen_tcp:send(Sock, Msg),
gen_tcp:close(Sock),
case epgsql_cmd_connect:open_socket(SockOpts, ConnectOpts) of
{ok, Mode, Sock} ->
ok = apply(Mode, send, [Sock, Msg]),
apply(Mode, close, [Sock]);
{error, _Reason} ->
noop
end,
{noreply, State}.

handle_info({Closed, Sock}, #state{sock = Sock} = State)
Expand Down Expand Up @@ -372,8 +375,8 @@ send(#state{mod = Mod, sock = Sock}, Type, Data) ->
-spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}.
send_multi(#state{mod = Mod, sock = Sock}, List) ->
do_send(Mod, Sock, lists:map(fun({Type, Data}) ->
epgsql_wire:encode_command(Type, Data)
end, List)).
epgsql_wire:encode_command(Type, Data)
end, List)).

do_send(gen_tcp, Sock, Bin) ->
%% Why not gen_tcp:send/2?
Expand Down
64 changes: 64 additions & 0 deletions test/epgsql_SUITE.erl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ groups() ->
connect_to_invalid_database,
connect_with_other_error,
connect_with_ssl,
cancel_query_for_connection_with_ssl,
cancel_query_for_connection_with_gen_tcp,
connect_with_client_cert,
connect_with_invalid_client_cert,
connect_to_closed_port,
Expand Down Expand Up @@ -171,6 +173,16 @@ end_per_group(_GroupName, _Config) ->
{routine, _} | _]
}}).

-define(QUERY_CANCELED, {error, #error{
severity = error,
code = <<"57014">>,
codename = query_canceled,
message = <<"canceling statement due to user request">>,
extra = [{file, <<"postgres.c">>},
{line, _},
{routine, _} | _]
}}).

%% From uuid.erl in http://gitorious.org/avtobiff/erlang-uuid
uuid_to_bin_string(<<U0:32, U1:16, U2:16, U3:16, U4:48>>) ->
iolist_to_binary(io_lib:format(
Expand Down Expand Up @@ -284,6 +296,58 @@ connect_with_ssl(Config) ->
"epgsql_test",
[{ssl, true}]).

cancel_query_for_connection_with_ssl(Config) ->
Module = ?config(module, Config),
{Host, Port} = epgsql_ct:connection_data(Config),
Module = ?config(module, Config),
Args2 = [ {port, Port}, {database, "epgsql_test_db1"}
| [ {ssl, true}
, {timeout, 1000} ]
],
{ok, C} = Module:connect(Host, "epgsql_test", Args2),
?assertMatch({ok, _Cols, [{true}]},
Module:equery(C, "select ssl_is_used()")),
Self = self(),
spawn_link(fun() ->
?assertMatch(?QUERY_CANCELED, Module:equery(C, "SELECT pg_sleep(5)")),
Self ! done
end),
%% this timer is needed for the test not to be flaky
timer:sleep(1000),
epgsql:cancel(C),
receive done ->
?assert(true)
after 5000 ->
epgsql:close(C),
?assert(false)
end,
epgsql_ct:flush().

cancel_query_for_connection_with_gen_tcp(Config) ->
Module = ?config(module, Config),
{Host, Port} = epgsql_ct:connection_data(Config),
Module = ?config(module, Config),
Args2 = [ {port, Port}, {database, "epgsql_test_db1"}
| [ {timeout, 1000} ]
],
{ok, C} = Module:connect(Host, "epgsql_test", Args2),
process_flag(trap_exit, true),
Self = self(),
spawn_link(fun() ->
?assertMatch(?QUERY_CANCELED, Module:equery(C, "SELECT pg_sleep(5)")),
Self ! done
end),
%% this timer is needed for the test not to be flaky
timer:sleep(1000),
epgsql:cancel(C),
receive done ->
?assert(true)
after 5000 ->
epgsql:close(C),
?assert(false)
end,
epgsql_ct:flush().

connect_with_client_cert(Config) ->
Module = ?config(module, Config),
Dir = filename:join(code:lib_dir(epgsql), ?TEST_DATA_DIR),
Expand Down

0 comments on commit f0d07e3

Please sign in to comment.