175 changes: 148 additions & 27 deletions src/epgsql_sock.erl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
%%% some conflicting low-level commands (such as `parse', `bind', `execute') are
%%% executed in a wrong order. In this case server and epgsql states become out of
%%% sync and {@link epgsql_cmd_sync} have to be executed in order to recover.
%%%
%%% {@link epgsql_cmd_copy_from_stdin} and {@link epgsql_cmd_start_replication} switches the
%%% "state machine" of connection process to a special "COPY mode" subprotocol.
%%% See [https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-COPY].
%%% @see epgsql_cmd_connect. epgsql_cmd_connect for network connection and authentication setup
%%% @end
%%% Copyright (C) 2009 - Will Glozer. All rights reserved.
Expand All @@ -46,40 +50,44 @@
get_parameter/2,
set_notice_receiver/2,
get_cmd_status/1,
cancel/1]).
cancel/1,
copy_send_rows/3,
standby_status_update/3]).

-export([handle_call/3, handle_cast/2, handle_info/2]).
-export([init/1, code_change/3, terminate/2]).

%% loop callback
-export([on_message/3, on_replication/3]).
-export([on_message/3, on_replication/3, on_copy_from_stdin/3]).

%% Comand's APIs
-export([set_net_socket/3, init_replication_state/1, set_attr/3, get_codec/1,
get_rows/1, get_results/1, notify/2, send/2, send/3, send_multi/2,
get_parameter_internal/2,
get_replication_state/1, set_packet_handler/2]).
get_subproto_state/1, set_packet_handler/2]).

-export_type([transport/0, pg_sock/0, error/0]).

-include("epgsql.hrl").
-include("protocol.hrl").
-include("epgsql_replication.hrl").
-include("epgsql_copy.hrl").

-type transport() :: {call, any()}
| {cast, pid(), reference()}
| {incremental, pid(), reference()}.

-type tcp_socket() :: port(). %gen_tcp:socket() isn't exported prior to erl 18
-type repl_state() :: #repl{}.
-type copy_state() :: #copy{}.

-type error() :: {error, sync_required | closed | sock_closed | sock_error}.

-record(state, {mod :: gen_tcp | ssl | undefined,
sock :: tcp_socket() | ssl:sslsocket() | undefined,
data = <<>>,
backend :: {Pid :: integer(), Key :: integer()} | undefined,
handler = on_message :: on_message | on_replication | undefined,
handler = on_message :: on_message | on_replication | on_copy_from_stdin | undefined,
codec :: epgsql_binary:codec() | undefined,
queue = queue:new() :: queue:queue({epgsql_command:command(), any(), transport()}),
current_cmd :: epgsql_command:command() | undefined,
Expand All @@ -92,11 +100,17 @@
sync_required :: boolean() | undefined,
txstatus :: byte() | undefined, % $I | $T | $E,
complete_status :: atom() | {atom(), integer()} | undefined,
repl :: repl_state() | undefined,
subproto_state :: repl_state() | copy_state() | undefined,
connect_opts :: epgsql:connect_opts() | undefined}).

-opaque pg_sock() :: #state{}.

-ifndef(OTP_RELEASE). % pre-OTP21
-define(WITH_STACKTRACE(T, R, S), T:R -> S = erlang:get_stacktrace(), ).
-else.
-define(WITH_STACKTRACE(T, R, S), T:R:S ->).
-endif.

%% -- client interface --

start_link() ->
Expand Down Expand Up @@ -131,6 +145,12 @@ get_cmd_status(C) ->
cancel(S) ->
gen_server:cast(S, cancel).

copy_send_rows(C, Rows, Timeout) ->
gen_server:call(C, {copy_send_rows, Rows}, Timeout).

standby_status_update(C, FlushedLSN, AppliedLSN) ->
gen_server:call(C, {standby_status_update, FlushedLSN, AppliedLSN}).


%% -- command APIs --

Expand All @@ -145,7 +165,7 @@ set_net_socket(Mod, Socket, State) ->

-spec init_replication_state(pg_sock()) -> pg_sock().
init_replication_state(State) ->
State#state{repl = #repl{}}.
State#state{subproto_state = #repl{}}.

-spec set_attr(atom(), any(), pg_sock()) -> pg_sock().
set_attr(backend, {_Pid, _Key} = Backend, State) ->
Expand All @@ -158,8 +178,8 @@ set_attr(codec, Codec, State) ->
State#state{codec = Codec};
set_attr(sync_required, Value, State) ->
State#state{sync_required = Value};
set_attr(replication_state, Value, State) ->
State#state{repl = Value};
set_attr(subproto_state, Value, State) ->
State#state{subproto_state = Value};
set_attr(connect_opts, ConnectOpts, State) ->
State#state{connect_opts = ConnectOpts}.

Expand All @@ -172,9 +192,9 @@ set_packet_handler(Handler, State) ->
get_codec(#state{codec = Codec}) ->
Codec.

-spec get_replication_state(pg_sock()) -> repl_state().
get_replication_state(#state{repl = Repl}) ->
Repl.
-spec get_subproto_state(pg_sock()) -> repl_state() | copy_state() | undefined.
get_subproto_state(#state{subproto_state = SubState}) ->
SubState.

-spec get_rows(pg_sock()) -> [tuple()].
get_rows(#state{rows = Rows}) ->
Expand All @@ -197,6 +217,10 @@ get_parameter_internal(Name, #state{parameters = Parameters}) ->
init([]) ->
{ok, #state{}}.

handle_call({command, Command, Args}, From, State) ->
Transport = {call, From},
command_new(Transport, Command, Args, State);

handle_call({get_parameter, Name}, _From, State) ->
{reply, {ok, get_parameter_internal(Name, State)}, State};

Expand All @@ -208,14 +232,16 @@ handle_call(get_cmd_status, _From, #state{complete_status = Status} = State) ->

handle_call({standby_status_update, FlushedLSN, AppliedLSN}, _From,
#state{handler = on_replication,
repl = #repl{last_received_lsn = ReceivedLSN} = Repl} = State) ->
subproto_state = #repl{last_received_lsn = ReceivedLSN} = Repl} = State) ->
send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN)),
Repl1 = Repl#repl{last_flushed_lsn = FlushedLSN,
last_applied_lsn = AppliedLSN},
{reply, ok, State#state{repl = Repl1}};
handle_call({command, Command, Args}, From, State) ->
Transport = {call, From},
command_new(Transport, Command, Args, State).
{reply, ok, State#state{subproto_state = Repl1}};

handle_call({copy_send_rows, Rows}, _From,
#state{handler = Handler, subproto_state = CopyState} = State) ->
Response = handle_copy_send_rows(Rows, Handler, CopyState, State),
{reply, Response, State}.

handle_cast({{Method, From, Ref} = Transport, Command, Args}, State)
when ((Method == cast) or (Method == incremental)),
Expand All @@ -241,6 +267,10 @@ handle_cast(cancel, State = #state{backend = {Pid, Key},
end,
{noreply, State}.

handle_info({DataTag, Sock, Data2}, #state{data = Data, sock = Sock} = State)
when DataTag == tcp; DataTag == ssl ->
loop(State#state{data = <<Data/binary, Data2/binary>>});

handle_info({Closed, Sock}, #state{sock = Sock} = State)
when Closed == tcp_closed; Closed == ssl_closed ->
{stop, sock_closed, flush_queue(State#state{sock = undefined}, {error, sock_closed})};
Expand All @@ -256,8 +286,10 @@ handle_info({inet_reply, _, ok}, State) ->
handle_info({inet_reply, _, Status}, State) ->
{stop, Status, flush_queue(State, {error, Status})};

handle_info({_, Sock, Data2}, #state{data = Data, sock = Sock} = State) ->
loop(State#state{data = <<Data/binary, Data2/binary>>}).
handle_info({io_request, From, ReplyAs, Request}, State) ->
Response = handle_io_request(Request, State),
io_reply(Response, From, ReplyAs),
{noreply, State}.

terminate(_Reason, #state{sock = undefined}) -> ok;
terminate(_Reason, #state{mod = gen_tcp, sock = Sock}) -> gen_tcp:close(Sock);
Expand Down Expand Up @@ -398,7 +430,7 @@ do_send(gen_tcp, Sock, Bin) ->
do_send(ssl, Sock, Bin) ->
ssl:send(Sock, Bin).

loop(#state{data = Data, handler = Handler, repl = Repl} = State) ->
loop(#state{data = Data, handler = Handler, subproto_state = Repl} = State) ->
case epgsql_wire:decode_message(Data) of
{Type, Payload, Tail} ->
case ?MODULE:Handler(Type, Payload, State#state{data = Tail}) of
Expand All @@ -409,14 +441,16 @@ loop(#state{data = Data, handler = Handler, repl = Repl} = State) ->
end;
_ ->
%% in replication mode send feedback after each batch of messages
case (Repl =/= undefined) andalso (Repl#repl.feedback_required) of
case Handler == on_replication
andalso (Repl =/= undefined)
andalso (Repl#repl.feedback_required) of
true ->
#repl{last_received_lsn = LastReceivedLSN,
last_flushed_lsn = LastFlushedLSN,
last_applied_lsn = LastAppliedLSN} = Repl,
send(State, ?COPY_DATA, epgsql_wire:encode_standby_status_update(
LastReceivedLSN, LastFlushedLSN, LastAppliedLSN)),
{noreply, State#state{repl = Repl#repl{feedback_required = false}}};
{noreply, State#state{subproto_state = Repl#repl{feedback_required = false}}};
_ ->
{noreply, State}
end
Expand Down Expand Up @@ -486,6 +520,74 @@ flush_queue(#state{current_cmd = undefined} = State, _) ->
flush_queue(State, Error) ->
flush_queue(finish(State, Error), Error).

%% @doc Handler for IO protocol version of COPY FROM STDIN
%%
%% COPY FROM STDIN is implemented as Erlang
%% <a href="https://erlang.org/doc/apps/stdlib/io_protocol.html">io protocol</a>.
handle_io_request(_, #state{handler = Handler}) when Handler =/= on_copy_from_stdin ->
%% Received IO request when `epgsql_cmd_copy_from_stdin' haven't yet been called or it was
%% terminated with error and already sent `ReadyForQuery'
{error, not_in_copy_mode};
handle_io_request(_, #state{subproto_state = #copy{last_error = Err}}) when Err =/= undefined ->
{error, Err};
handle_io_request({put_chars, Encoding, Chars}, State) ->
send(State, ?COPY_DATA, encode_chars(Encoding, Chars));
handle_io_request({put_chars, Encoding, Mod, Fun, Args}, State) ->
try apply(Mod, Fun, Args) of
Chars when is_binary(Chars);
is_list(Chars) ->
handle_io_request({put_chars, Encoding, Chars}, State);
Other ->
{error, {fun_return_not_characters, Other}}
catch ?WITH_STACKTRACE(T, R, S)
{error, {fun_exception, {T, R, S}}}
end;
handle_io_request({setopts, _}, _State) ->
{error, request};
handle_io_request(getopts, _State) ->
{error, request};
handle_io_request({requests, Requests}, State) ->
try_requests(Requests, State, ok).

try_requests([Req | Requests], State, _) ->
case handle_io_request(Req, State) of
{error, _} = Err ->
Err;
Other ->
try_requests(Requests, State, Other)
end;
try_requests([], _, LastRes) ->
LastRes.

io_reply(Result, From, ReplyAs) ->
From ! {io_reply, ReplyAs, Result}.

%% @doc Handler for `copy_send_rows' API
%%
%% Only supports binary protocol right now.
%% But, in theory, can be used for text / csv formats as well, but we would need to add
%% some more callbacks to `epgsql_type' behaviour (eg, `encode_text')
handle_copy_send_rows(_Rows, Handler, _CopyState, _State) when Handler =/= on_copy_from_stdin ->
{error, not_in_copy_mode};
handle_copy_send_rows(_, _, #copy{format = Format}, _) when Format =/= binary ->
%% copy_send_rows only supports "binary" format
{error, not_binary_format};
handle_copy_send_rows(_, _, #copy{last_error = LastError}, _) when LastError =/= undefined ->
%% server already reported error in data stream asynchronously
{error, LastError};
handle_copy_send_rows(Rows, _, #copy{binary_types = Types}, State) ->
Data = [epgsql_wire:encode_copy_row(Values, Types, get_codec(State))
|| Values <- Rows],
ok = send(State, ?COPY_DATA, Data).

encode_chars(_, Bin) when is_binary(Bin) ->
Bin;
encode_chars(unicode, Chars) when is_list(Chars) ->
unicode:characters_to_binary(Chars);
encode_chars(latin1, Chars) when is_list(Chars) ->
unicode:characters_to_binary(Chars, latin1).


to_binary(B) when is_binary(B) -> B;
to_binary(L) when is_list(L) -> list_to_binary(L).

Expand Down Expand Up @@ -547,12 +649,31 @@ on_message(?NOTIFICATION, <<Pid:?int32, Strings/binary>>, State) ->
on_message(Msg, Payload, State) ->
command_handle_message(Msg, Payload, State).

%% @doc Handle "copy subprotocol" for COPY .. FROM STDIN
%%
%% Activated by `epgsql_cmd_copy_from_stdin', deactivated by `epgsql_cmd_copy_done' or error
on_copy_from_stdin(?READY_FOR_QUERY, <<Status:8>>,
#state{subproto_state = #copy{last_error = Err,
initiator = Pid}} = State) when Err =/= undefined ->
%% Reporting error from here and not from ?ERROR so it's easier to be in sync state
Pid ! {epgsql, self(), {error, Err}},
{noreply, State#state{subproto_state = undefined,
handler = on_message,
txstatus = Status}};
on_copy_from_stdin(?ERROR, Err, #state{subproto_state = SubState} = State) ->
Reason = epgsql_wire:decode_error(Err),
{noreply, State#state{subproto_state = SubState#copy{last_error = Reason}}};
on_copy_from_stdin(M, Data, Sock) when M == ?NOTICE;
M == ?NOTIFICATION;
M == ?PARAMETER_STATUS ->
on_message(M, Data, Sock).


%% CopyData for Replication mode
on_replication(?COPY_DATA, <<?PRIMARY_KEEPALIVE_MESSAGE:8, LSN:?int64, _Timestamp:?int64, ReplyRequired:8>>,
#state{repl = #repl{last_flushed_lsn = LastFlushedLSN,
last_applied_lsn = LastAppliedLSN,
align_lsn = AlignLsn} = Repl} = State) ->
#state{subproto_state = #repl{last_flushed_lsn = LastFlushedLSN,
last_applied_lsn = LastAppliedLSN,
align_lsn = AlignLsn} = Repl} = State) ->
Repl1 =
case ReplyRequired of
1 when AlignLsn ->
Expand All @@ -569,14 +690,14 @@ on_replication(?COPY_DATA, <<?PRIMARY_KEEPALIVE_MESSAGE:8, LSN:?int64, _Timestam
Repl#repl{feedback_required = true,
last_received_lsn = LSN}
end,
{noreply, State#state{repl = Repl1}};
{noreply, State#state{subproto_state = Repl1}};

%% CopyData for Replication mode
on_replication(?COPY_DATA, <<?X_LOG_DATA, StartLSN:?int64, EndLSN:?int64,
_Timestamp:?int64, WALRecord/binary>>,
#state{repl = Repl} = State) ->
#state{subproto_state = Repl} = State) ->
Repl1 = handle_xlog_data(StartLSN, EndLSN, WALRecord, Repl),
{noreply, State#state{repl = Repl1}};
{noreply, State#state{subproto_state = Repl1}};
on_replication(?ERROR, Err, State) ->
Reason = epgsql_wire:decode_error(Err),
{stop, {error, Reason}, State};
Expand Down
50 changes: 48 additions & 2 deletions src/epgsql_wire.erl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
encode_formats/1,
format/2,
encode_parameters/2,
encode_standby_status_update/3]).
encode_standby_status_update/3,
encode_copy_header/0,
encode_copy_row/3,
encode_copy_trailer/0]).
%% Encoders for Client -> Server packets
-export([encode_query/1,
encode_parse/3,
encode_describe/2,
encode_bind/4,
encode_copy_done/0,
encode_execute/2,
encode_close/2,
encode_flush/0,
Expand Down Expand Up @@ -213,6 +217,7 @@ decode_complete(Bin) ->
["DELETE", Rows] -> {delete, list_to_integer(Rows)};
["MOVE", Rows] -> {move, list_to_integer(Rows)};
["FETCH", Rows] -> {fetch, list_to_integer(Rows)};
["COPY", Rows] -> {copy, list_to_integer(Rows)};
[Type | _Rest] -> lower_atom(Type)
end.

Expand Down Expand Up @@ -251,7 +256,8 @@ format(#column{oid = Oid}, Codec) ->
end.

%% @doc encode parameters for 'Bind'
-spec encode_parameters([], epgsql_binary:codec()) -> iolist().
-spec encode_parameters([{epgsql:epgsql_type(), epgsql:bind_param()}],
epgsql_binary:codec()) -> iolist().
encode_parameters(Parameters, Codec) ->
encode_parameters(Parameters, 0, <<>>, [], Codec).

Expand Down Expand Up @@ -310,6 +316,41 @@ encode_standby_status_update(ReceivedLSN, FlushedLSN, AppliedLSN) ->
Timestamp = ((MegaSecs * 1000000 + Secs) * 1000000 + MicroSecs) - 946684800*1000000,
<<$r:8, ReceivedLSN:?int64, FlushedLSN:?int64, AppliedLSN:?int64, Timestamp:?int64, 0:8>>.

%% @doc encode binary copy data file header
%%
%% See [https://www.postgresql.org/docs/current/sql-copy.html#id-1.9.3.55.9.4.5]
encode_copy_header() ->
<<
"PGCOPY\n", 8#377, "\r\n", 0, % "signature"
0:?int32, % flags
0:?int32 % length of the extensions area
>>.

%% @doc encode binary copy data file row / tuple
%%
%% See [https://www.postgresql.org/docs/current/sql-copy.html#id-1.9.3.55.9.4.6]
encode_copy_row(ValuesTuple, Types, Codec) when is_tuple(ValuesTuple) ->
encode_copy_row(tuple_to_list(ValuesTuple), Types, Codec);
encode_copy_row(Values, Types, Codec) ->
NumCols = length(Types),
[<<NumCols:?int16>>
| lists:zipwith(
fun(Type, Value) ->
case epgsql_binary:is_null(Value, Codec) of
true ->
<<-1:?int32>>;
false ->
epgsql_binary:encode(Type, Value, Codec)
end
end, Types, Values)
].

%% @doc encode binary copy data file header
%%
%% See [https://www.postgresql.org/docs/current/sql-copy.html#id-1.9.3.55.9.4.7]
encode_copy_trailer() ->
<<-1:?int16>>.

%%
%% Encoders for various PostgreSQL protocol client-side packets
%% See https://www.postgresql.org/docs/current/protocol-message-formats.html
Expand Down Expand Up @@ -390,5 +431,10 @@ encode_flush() ->
encode_sync() ->
{?SYNC, []}.

%% @doc encodes `CopyDone' packet.
-spec encode_copy_done() -> {packet_type(), iodata()}.
encode_copy_done() ->
{?COPY_DONE, []}.

obj_atom_to_byte(statement) -> ?PREPARED_STATEMENT;
obj_atom_to_byte(portal) -> ?PORTAL.
347 changes: 347 additions & 0 deletions test/epgsql_copy_SUITE.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
-module(epgsql_copy_SUITE).
-include_lib("common_test/include/ct.hrl").
-include_lib("stdlib/include/assert.hrl").
-include("epgsql.hrl").

-export([
init_per_suite/1,
all/0,
end_per_suite/1,

from_stdin_text/1,
from_stdin_csv/1,
from_stdin_binary/1,
from_stdin_io_apis/1,
from_stdin_with_terminator/1,
from_stdin_corrupt_data/1
]).

init_per_suite(Config) ->
[{module, epgsql}|Config].

end_per_suite(_Config) ->
ok.

all() ->
[
from_stdin_text,
from_stdin_csv,
from_stdin_binary,
from_stdin_io_apis,
from_stdin_with_terminator,
from_stdin_corrupt_data
].

%% @doc Test that COPY in text format works
from_stdin_text(Config) ->
Module = ?config(module, Config),
epgsql_ct:with_connection(
Config,
fun(C) ->
?assertEqual(
{ok, [text, text]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
?assertEqual(
ok,
io:put_chars(C,
"10\thello world\n"
"11\t\\N\n"
"12\tline 12\n")),
?assertEqual(
ok,
io:put_chars(C, "13\tline 13\n")),
?assertEqual(
ok,
io:put_chars(C, "14\tli")),
?assertEqual(
ok,
io:put_chars(C, "ne 14\n")),
?assertEqual(
{ok, 5},
Module:copy_done(C)),
?assertMatch(
{ok, _, [{10, <<"hello world">>},
{11, null},
{12, <<"line 12">>},
{13, <<"line 13">>},
{14, <<"line 14">>}]},
Module:equery(C,
"SELECT id, value FROM test_table1"
" WHERE id IN (10, 11, 12, 13, 14) ORDER BY id"))
end).

%% @doc Test that COPY in CSV format works
from_stdin_csv(Config) ->
Module = ?config(module, Config),
epgsql_ct:with_connection(
Config,
fun(C) ->
?assertEqual(
{ok, [text, text]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT csv, QUOTE '''')")),
?assertEqual(
ok,
io:put_chars(C,
"20,'hello world'\n"
"21,\n"
"22,line 22\n")),
?assertEqual(
ok,
io:put_chars(C, "23,'line 23'\n")),
?assertEqual(
ok,
io:put_chars(C, "24,'li")),
?assertEqual(
ok,
io:put_chars(C, "ne 24'\n")),
?assertEqual(
{ok, 5},
Module:copy_done(C)),
?assertMatch(
{ok, _, [{20, <<"hello world">>},
{21, null},
{22, <<"line 22">>},
{23, <<"line 23">>},
{24, <<"line 24">>}]},
Module:equery(C,
"SELECT id, value FROM test_table1"
" WHERE id IN (20, 21, 22, 23, 24) ORDER BY id"))
end).

%% @doc Test that COPY in binary format works
from_stdin_binary(Config) ->
Module = ?config(module, Config),
epgsql_ct:with_connection(
Config,
fun(C) ->
?assertEqual(
{ok, [binary, binary]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT binary)",
{binary, [int4, text]})),
%% Batch of rows
?assertEqual(
ok,
Module:copy_send_rows(
C,
[{60, <<"hello world">>},
{61, null},
{62, "line 62"}],
5000)),
%% Single row
?assertEqual(
ok,
Module:copy_send_rows(
C,
[{63, <<"line 63">>}],
1000)),
%% Rows as lists
?assertEqual(
ok,
Module:copy_send_rows(
C,
[
[64, <<"line 64">>],
[65, <<"line 65">>]
],
infinity)),
?assertEqual({ok, 6}, Module:copy_done(C)),
?assertMatch(
{ok, _, [{60, <<"hello world">>},
{61, null},
{62, <<"line 62">>},
{63, <<"line 63">>},
{64, <<"line 64">>},
{65, <<"line 65">>}]},
Module:equery(C,
"SELECT id, value FROM test_table1"
" WHERE id IN (60, 61, 62, 63, 64, 65) ORDER BY id"))
end).

%% @doc Tests that different IO-protocol APIs work
from_stdin_io_apis(Config) ->
Module = ?config(module, Config),
epgsql_ct:with_connection(
Config,
fun(C) ->
?assertEqual(
{ok, [text, text]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
?assertEqual(ok, io:format(C, "30\thello world\n", [])),
?assertEqual(ok, io:format(C, "~b\t~s\n", [31, "line 31"])),
%% Output "32\thello\n" in multiple calls
?assertEqual(ok, io:write(C, 32)),
?assertEqual(ok, io:put_chars(C, "\t")),
?assertEqual(ok, io:write(C, hello)),
?assertEqual(ok, io:nl(C)),
%% Using `file` API
?assertEqual(ok, file:write(C, "33\tline 33\n34\tline 34\n")),
%% Binary
?assertEqual(ok, io:put_chars(C, <<"35\tline 35\n">>)),
?assertEqual(ok, file:write(C, <<"36\tline 36\n">>)),
%% IoData
?assertEqual(ok, io:put_chars(C, [<<"37">>, $\t, <<"line 37">>, <<$\n>>])),
?assertEqual(ok, file:write(C, [["38", <<$\t>>], [<<"line 38">>, $\n]])),
%% Raw IO-protocol message-passing
Ref = erlang:make_ref(),
C ! {io_request, self(), Ref, {put_chars, unicode, "39\tline 39\n"}},
?assertEqual(ok, receive {io_reply, Ref, Resp} -> Resp
after 5000 ->
timeout
end),
%% Not documented!
?assertEqual(ok, io:requests(
C,
[{put_chars, unicode, "40\tline 40\n"},
{put_chars, latin1, "41\tline 41\n"},
{format, "~w\t~s", [42, "line 42"]},
nl])),
?assertEqual(
{ok, 13},
Module:copy_done(C)),
?assertMatch(
{ok, _, [{30, <<"hello world">>},
{31, <<"line 31">>},
{32, <<"hello">>},
{33, <<"line 33">>},
{34, <<"line 34">>},
{35, <<"line 35">>},
{36, <<"line 36">>},
{37, <<"line 37">>},
{38, <<"line 38">>},
{39, <<"line 39">>},
{40, <<"line 40">>},
{41, <<"line 41">>},
{42, <<"line 42">>}
]},
Module:equery(
C,
"SELECT id, value FROM test_table1"
" WHERE id IN (30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42)"
" ORDER BY id"))
end).

%% @doc Tests that "end-of-data" terminator is successfully ignored
from_stdin_with_terminator(Config) ->
Module = ?config(module, Config),
epgsql_ct:with_connection(
Config,
fun(C) ->
%% TEXT
?assertEqual(
{ok, [text, text]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
?assertEqual(ok, io:put_chars(
C,
"50\tline 50\n"
"51\tline 51\n"
"\\.\n")),
?assertEqual({ok, 2}, Module:copy_done(C)),
%% CSV
?assertEqual(
{ok, [text, text]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT csv)")),
?assertEqual(ok, io:put_chars(
C,
"52,line 52\n"
"53,line 53\n"
"\\.\n")),
?assertEqual({ok, 2}, Module:copy_done(C)),
?assertMatch(
{ok, _, [{50, <<"line 50">>},
{51, <<"line 51">>},
{52, <<"line 52">>},
{53, <<"line 53">>}
]},
Module:equery(C,
"SELECT id, value FROM test_table1"
" WHERE id IN (50, 51, 52, 53) ORDER BY id"))
end).

from_stdin_corrupt_data(Config) ->
Module = ?config(module, Config),
epgsql_ct:with_connection(
Config,
fun(C) ->
?assertEqual(
{ok, [text, text]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
%% Wrong number of arguments to io:format
Fmt = "~w\t~s\n",
?assertMatch({error, {fun_exception, {error, badarg, _Stack}}},
io:request(C, {format, Fmt, []})),
?assertError(badarg, io:format(C, Fmt, [])),
%% Wrong return value from IO function
?assertEqual({error, {fun_return_not_characters, node()}},
io:request(C, {put_chars, unicode, erlang, node, []})),
?assertEqual({ok, 0}, Module:copy_done(C)),
%%
%% Corrupt text format
?assertEqual(
{ok, [text, text]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT text)")),
?assertEqual(ok, io:put_chars(
C,
"42\n43\nwasd\n")),
?assertMatch(
#error{codename = bad_copy_file_format,
severity = error},
receive
{epgsql, C, {error, Err}} ->
Err
after 5000 ->
timeout
end),
?assertEqual({error, not_in_copy_mode},
io:request(C, {put_chars, unicode, "queque\n"})),
?assertError(badarg, io:format(C, "~w\n~s\n", [60, "wasd"])),
%%
%% Corrupt CSV format
?assertEqual(
{ok, [text, text]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT csv)")),
?assertEqual(ok, io:put_chars(
C,
"42\n43\nwasd\n")),
?assertMatch(
#error{codename = bad_copy_file_format,
severity = error},
receive
{epgsql, C, {error, Err}} ->
Err
after 5000 ->
timeout
end),
%%
%% Corrupt binary format
?assertEqual(
{ok, [binary, binary]},
Module:copy_from_stdin(
C, "COPY test_table1 (id, value) FROM STDIN WITH (FORMAT binary)",
{binary, [int4, text]})),
?assertEqual(
ok,
Module:copy_send_rows(C, [{44, <<"line 44">>}], 1000)),
?assertEqual(ok, io:put_chars(C, "45\tThis is not ok!\n")),
?assertMatch(
#error{codename = bad_copy_file_format,
severity = error},
receive
{epgsql, C, {error, Err}} ->
Err
after 5000 ->
timeout
end),
%% Connection is still usable
?assertMatch(
{ok, _, [{1}]},
Module:equery(C, "SELECT 1", []))
end).