Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle tls connection during cancellation #227

Merged
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ epgsql:cancel(connection()) -> ok.

PostgreSQL protocol supports [cancellation](https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.9)
of currently executing command. `cancel/1` sends a cancellation request via the
new temporary TCP connection asynchronously, it doesn't await for the command to
new temporary TCP/TLS_over_TCP connection asynchronously, it doesn't await for the command to
be cancelled. Instead, client should expect to get
`{error, #error{code = <<"57014">>, codename = query_canceled}}` back from
the command that was cancelled. However, normal response can still be received as well.
Expand Down
142 changes: 75 additions & 67 deletions src/commands/epgsql_cmd_connect.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
%%%
-module(epgsql_cmd_connect).
-behaviour(epgsql_command).
-export([hide_password/1, opts_hide_password/1]).
-export([hide_password/1, opts_hide_password/1, open_socket/2]).
-export([init/1, execute/2, handle_message/4]).
-export_type([response/0, connect_error/0]).

Expand Down Expand Up @@ -47,22 +47,59 @@
init(#{host := _, username := _} = Opts) ->
#connect{opts = Opts}.

execute(PgSock, #connect{opts = #{host := Host} = Opts, stage = connect} = State) ->
Timeout = maps:get(timeout, Opts, 5000),
Deadline = deadline(Timeout),
Port = maps:get(port, Opts, 5432),
execute(PgSock, #connect{opts = #{username := Username} = Opts, stage = connect} = State) ->
SockOpts = [{active, false}, {packet, raw}, binary, {nodelay, true}, {keepalive, true}],
case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
{ok, Sock} ->
client_handshake(Sock, PgSock, State, Deadline);
FilteredOpts = filter_sensitive_info(Opts),
PgSock1 = epgsql_sock:set_attr(connect_opts, FilteredOpts, PgSock),
case open_socket(SockOpts, Opts) of
{ok, Mode, Sock} ->
PgSock2 = epgsql_sock:set_net_socket(Mode, Sock, PgSock1),
Opts2 = ["user", 0, Username, 0],
Opts3 = case maps:find(database, Opts) of
error -> Opts2;
{ok, Database} -> [Opts2 | ["database", 0, Database, 0]]
end,
{Opts4, PgSock3} =
case Opts of
#{replication := Replication} ->
{[Opts3 | ["replication", 0, Replication, 0]],
epgsql_sock:init_replication_state(PgSock2)};
_ -> {Opts3, PgSock2}
end,
Opts5 = case Opts of
#{application_name := ApplicationName} ->
[Opts4 | ["application_name", 0, ApplicationName, 0]];
enidgjoleka marked this conversation as resolved.
Show resolved Hide resolved
_ ->
Opts4
end,
ok = epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts5, 0]),
enidgjoleka marked this conversation as resolved.
Show resolved Hide resolved
PgSock4 = case Opts of
#{async := Async} ->
epgsql_sock:set_attr(async, Async, PgSock3);
_ -> PgSock3
end,
{ok, PgSock4, State#connect{stage = maybe_auth}};
{error, Reason} = Error ->
{stop, Reason, Error, PgSock}
end;
execute(PgSock, #connect{stage = auth, auth_send = {PacketId, Data}} = St) ->
ok = epgsql_sock:send(PgSock, PacketId, Data),
{ok, PgSock, St#connect{auth_send = undefined}}.

client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} = State, Deadline) ->
-spec open_socket([{atom(), any()}], epgsql:connect_opts()) ->
{ok , gen_tcp | ssl, port() | ssl:sslsocket()} | {error, any()}.
open_socket(SockOpts, #{host := Host} = ConnectOpts) ->
Timeout = maps:get(timeout, ConnectOpts, 5000),
Deadline = deadline(Timeout),
Port = maps:get(port, ConnectOpts, 5432),
case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
{ok, Sock} ->
client_handshake(Sock, ConnectOpts, Deadline);
{error, _Reason} = Error ->
Error
end.

client_handshake(Sock, ConnectOpts, Deadline) ->
%% Increase the buffer size. Following the recommendation in the inet man page:
%%
%% It is recommended to have val(buffer) >=
Expand All @@ -71,46 +108,45 @@ client_handshake(Sock, PgSock, #connect{opts = #{username := Username} = Opts} =
{ok, [{recbuf, RecBufSize}, {sndbuf, SndBufSize}]} =
inet:getopts(Sock, [recbuf, sndbuf]),
inet:setopts(Sock, [{buffer, max(RecBufSize, SndBufSize)}]),
maybe_ssl(Sock, maps:get(ssl, ConnectOpts, false), ConnectOpts, Deadline).

case maybe_ssl(Sock, maps:get(ssl, Opts, false), Opts, PgSock, Deadline) of
maybe_ssl(Sock, false, _ConnectOpts, _Deadline) ->
{ok, gen_tcp, Sock};
maybe_ssl(Sock, Flag, ConnectOpts, Deadline) ->
ok = gen_tcp:send(Sock, <<8:?int32, 80877103:?int32>>),
Timeout0 = timeout(Deadline),
case gen_tcp:recv(Sock, 1, Timeout0) of
{ok, <<$S>>} ->
SslOpts = maps:get(ssl_opts, ConnectOpts, []),
Timeout = timeout(Deadline),
case ssl:connect(Sock, SslOpts, Timeout) of
{ok, Sock2} ->
{ok, ssl, Sock2};
{error, Reason} ->
Err = {ssl_negotiation_failed, Reason},
{error, Err}
end;
{ok, <<$N>>} ->
case Flag of
true ->
{ok, gen_tcp, Sock};
required ->
{error, ssl_not_available}
end;
{error, Reason} ->
{stop, Reason, {error, Reason}, PgSock};
PgSock1 ->
Opts2 = ["user", 0, Username, 0],
Opts3 = case maps:find(database, Opts) of
error -> Opts2;
{ok, Database} -> [Opts2 | ["database", 0, Database, 0]]
end,

{Opts4, PgSock2} =
case Opts of
#{replication := Replication} ->
{[Opts3 | ["replication", 0, Replication, 0]],
epgsql_sock:init_replication_state(PgSock1)};
_ -> {Opts3, PgSock1}
end,
Opts5 = case Opts of
#{application_name := ApplicationName} ->
[Opts3 | ["application_name", 0, ApplicationName, 0]];
_ ->
Opts4
end,
ok = epgsql_sock:send(PgSock2, [<<196608:?int32>>, Opts5, 0]),
PgSock3 = case Opts of
#{async := Async} ->
epgsql_sock:set_attr(async, Async, PgSock2);
_ -> PgSock2
end,
{ok, PgSock3, State#connect{stage = maybe_auth}}
{error, Reason}
end.


%% @doc Replace `password' in Opts map with obfuscated one
opts_hide_password(#{password := Password} = Opts) ->
HiddenPassword = hide_password(Password),
Opts#{password => HiddenPassword};
opts_hide_password(Opts) -> Opts.

%% @doc password and username are sensitive data that should not be stored in a
%% permanent state that might crash during code upgrade
filter_sensitive_info(Opts0) ->
maps:without([password, username], Opts0).

%% @doc this function wraps plaintext password to a lambda function, so, if
%% epgsql_sock process crashes when executing `connect' command, password will
Expand All @@ -124,34 +160,6 @@ hide_password(Password) when is_list(Password);
hide_password(PasswordFun) when is_function(PasswordFun, 0) ->
PasswordFun.


maybe_ssl(S, false, _, PgSock, _Deadline) ->
epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
maybe_ssl(S, Flag, Opts, PgSock, Deadline) ->
ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>),
Timeout0 = timeout(Deadline),
case gen_tcp:recv(S, 1, Timeout0) of
{ok, <<$S>>} ->
SslOpts = maps:get(ssl_opts, Opts, []),
Timeout = timeout(Deadline),
case ssl:connect(S, SslOpts, Timeout) of
{ok, S2} ->
epgsql_sock:set_net_socket(ssl, S2, PgSock);
{error, Reason} ->
Err = {ssl_negotiation_failed, Reason},
{error, Err}
end;
{ok, <<$N>>} ->
case Flag of
true ->
epgsql_sock:set_net_socket(gen_tcp, S, PgSock);
required ->
{error, ssl_not_available}
end;
{error, Reason} ->
{error, Reason}
end.

%% Auth sub-protocol

auth_init(<<?AUTH_CLEARTEXT:?int32>>, Sock, St) ->
Expand Down
32 changes: 19 additions & 13 deletions src/epgsql_sock.erl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@
sync_required :: boolean() | undefined,
txstatus :: byte() | undefined, % $I | $T | $E,
complete_status :: atom() | {atom(), integer()} | undefined,
repl :: repl_state() | undefined}).
repl :: repl_state() | undefined,
connect_opts :: epgsql:connect_opts() | undefined}).

-opaque pg_sock() :: #state{}.

Expand Down Expand Up @@ -158,7 +159,9 @@ set_attr(codec, Codec, State) ->
set_attr(sync_required, Value, State) ->
State#state{sync_required = Value};
set_attr(replication_state, Value, State) ->
State#state{repl = Value}.
State#state{repl = Value};
set_attr(connect_opts, ConnectOpts, State) ->
State#state{connect_opts = ConnectOpts}.

%% XXX: be careful!
-spec set_packet_handler(atom(), pg_sock()) -> pg_sock().
Expand Down Expand Up @@ -225,17 +228,20 @@ handle_cast(stop, State) ->
{stop, normal, flush_queue(State, {error, closed})};

handle_cast(cancel, State = #state{backend = {Pid, Key},
sock = TimedOutSock}) ->
{ok, {Addr, Port}} = case State#state.mod of
gen_tcp -> inet:peername(TimedOutSock);
ssl -> ssl:peername(TimedOutSock)
end,
connect_opts = ConnectOpts,
mod = Mode}) ->
SockOpts = [{active, false}, {packet, raw}, binary],
%% TODO timeout
{ok, Sock} = gen_tcp:connect(Addr, Port, SockOpts),
Msg = <<16:?int32, 80877102:?int32, Pid:?int32, Key:?int32>>,
ok = gen_tcp:send(Sock, Msg),
gen_tcp:close(Sock),
case epgsql_cmd_connect:open_socket(SockOpts, ConnectOpts) of
{ok, Mode, Sock} when Mode == gen_tcp ->
ok = gen_tcp:send(Sock, Msg),
gen_tcp:close(Sock);
{ok, Mode, Sock} when Mode == ssl ->
ok = ssl:send(Sock, Msg),
ssl:close(Sock);
enidgjoleka marked this conversation as resolved.
Show resolved Hide resolved
{error, _Reason} ->
noop
end,
{noreply, State}.

handle_info({Closed, Sock}, #state{sock = Sock} = State)
Expand Down Expand Up @@ -372,8 +378,8 @@ send(#state{mod = Mod, sock = Sock}, Type, Data) ->
-spec send_multi(pg_sock(), [{byte(), iodata()}]) -> ok | {error, any()}.
send_multi(#state{mod = Mod, sock = Sock}, List) ->
do_send(Mod, Sock, lists:map(fun({Type, Data}) ->
epgsql_wire:encode_command(Type, Data)
end, List)).
epgsql_wire:encode_command(Type, Data)
end, List)).

do_send(gen_tcp, Sock, Bin) ->
%% Why not gen_tcp:send/2?
Expand Down
73 changes: 73 additions & 0 deletions test/epgsql_SUITE.erl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ groups() ->
connect_to_invalid_database,
connect_with_other_error,
connect_with_ssl,
cancel_query_for_connection_with_ssl,
cancel_query_for_connection_with_gen_tcp,
connect_with_client_cert,
connect_with_invalid_client_cert,
connect_to_closed_port,
Expand Down Expand Up @@ -171,6 +173,16 @@ end_per_group(_GroupName, _Config) ->
{routine, _} | _]
}}).

-define(QUERY_CANCELED, {error, #error{
severity = error,
code = <<"57014">>,
codename = query_canceled,
message = <<"canceling statement due to user request">>,
extra = [{file, <<"postgres.c">>},
{line, _},
{routine, _} | _]
}}).

%% From uuid.erl in http://gitorious.org/avtobiff/erlang-uuid
uuid_to_bin_string(<<U0:32, U1:16, U2:16, U3:16, U4:48>>) ->
iolist_to_binary(io_lib:format(
Expand Down Expand Up @@ -284,6 +296,67 @@ connect_with_ssl(Config) ->
"epgsql_test",
[{ssl, true}]).

cancel_query_for_connection_with_ssl(Config) ->
Module = ?config(module, Config),
{Host, Port} = epgsql_ct:connection_data(Config),
Module = ?config(module, Config),
Args2 = [ {port, Port}, {database, "epgsql_test_db1"}
| [ {ssl, true}
, {timeout, 1000} ]
],
{ok, C} = Module:connect(Host, "epgsql_test", Args2),
?assertMatch({ok, _Cols, [{true}]},
Module:equery(C, "select ssl_is_used()")),
process_flag(trap_exit, true),
enidgjoleka marked this conversation as resolved.
Show resolved Hide resolved
Self = self(),
spawn_link(fun() ->
?assertMatch(?QUERY_CANCELED, Module:equery(C, "SELECT pg_sleep(5)")),
Self ! done
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option is to send the result to Self and assert in receiver:

Self ! {done, Module:equery(C, "..")}
end,
<..>
receive {done, Result} ->
    ?assertMatch(?QUERY_CANCELLED, Result)
after ....

end),
%% this will never match but introduces 1 second latency needed
%% for the test not to be flaky
receive none ->
noop
after 1000 ->
enidgjoleka marked this conversation as resolved.
Show resolved Hide resolved
epgsql:cancel(C),
receive done ->
?assert(true)
after 5000 ->
epgsql:close(C),
?assert(false)
end
end,
epgsql_ct:flush().

cancel_query_for_connection_with_gen_tcp(Config) ->
Module = ?config(module, Config),
{Host, Port} = epgsql_ct:connection_data(Config),
Module = ?config(module, Config),
Args2 = [ {port, Port}, {database, "epgsql_test_db1"}
| [ {timeout, 1000} ]
],
{ok, C} = Module:connect(Host, "epgsql_test", Args2),
process_flag(trap_exit, true),
Self = self(),
spawn_link(fun() ->
?assertMatch(?QUERY_CANCELED, Module:equery(C, "SELECT pg_sleep(5)")),
Self ! done
end),
%% this is will never match but it introduces the 1 second latency needed
%% for the test not to be flaky
receive none ->
noop
after 1000 ->
epgsql:cancel(C),
receive done ->
?assert(true)
after 5000 ->
epgsql:close(C),
?assert(false)
end
end,
epgsql_ct:flush().

connect_with_client_cert(Config) ->
Module = ?config(module, Config),
Dir = filename:join(code:lib_dir(epgsql), ?TEST_DATA_DIR),
Expand Down