Skip to content

Commit

Permalink
fix(server): Avoid blocking call to mria_lb when probing core node
Browse files Browse the repository at this point in the history
  • Loading branch information
ieQu1 committed Jun 19, 2023
1 parent 064bf23 commit a57958f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 56 deletions.
2 changes: 1 addition & 1 deletion src/mria_config.erl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ set_replay_batch_size(N) ->

-spec lb_timeout() -> timeout().
lb_timeout() ->
application:get_env(mria, rlog_lb_update_timeout, 300).
application:get_env(mria, rlog_lb_update_timeout, 3000).

-spec lb_poll_interval() -> non_neg_integer().
lb_poll_interval() ->
Expand Down
82 changes: 36 additions & 46 deletions src/mria_lb.erl
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,9 @@
%% Type declarations
%%================================================================================

-type core_protocol_versions() :: #{node() => integer()}.

-record(s,
{ core_protocol_versions :: core_protocol_versions()
, core_nodes :: [node()]
{ core_nodes :: [node()]
, node_info :: #{node() => node_info()}
}).

-type node_info() ::
Expand Down Expand Up @@ -108,15 +106,15 @@ init(_) ->
logger:set_process_metadata(#{domain => [mria, rlog, lb]}),
start_timer(0),
mria_membership:monitor(membership, self(), true),
State = #s{ core_protocol_versions = #{}
State = #s{ node_info = #{}
, core_nodes = []
},
{ok, State}.

handle_info(?update, St0) ->
T0 = erlang:system_time(millisecond),
T0 = erlang:monotonic_time(millisecond),
St = do_update(St0),
T1 = erlang:system_time(millisecond),
T1 = erlang:monotonic_time(millisecond),
start_timer(T1 - T0),
{noreply, St};
handle_info({membership, Event}, St) ->
Expand All @@ -134,27 +132,6 @@ handle_cast(Cast, St) ->
?unexpected_event_tp(#{cast => Cast, state => St}),
{noreply, St}.

handle_call({probe, Node, Shard}, _From, St0 = #s{core_protocol_versions = ProtoVSNs}) ->
LastVSNChecked = maps:get(Node, ProtoVSNs, undefined),
MyVersion = mria_rlog:get_protocol_version(),
ProbeResult = mria_lib:rpc_call_nothrow({Node, Shard}, mria_rlog_server, do_probe, [Shard]),
{Reply, ServerVersion} =
case ProbeResult of
{true, MyVersion} ->
{true, MyVersion};
{true, CurrentVersion} when CurrentVersion =/= LastVSNChecked ->
?tp(warning, "Different Mria version on the core node",
#{ my_version => MyVersion
, server_version => CurrentVersion
, last_version => LastVSNChecked
, node => Node
}),
{false, CurrentVersion};
_ ->
{false, LastVSNChecked}
end,
St = St0#s{core_protocol_versions = ProtoVSNs#{Node => ServerVersion}},
{reply, Reply, St};
handle_call(core_nodes, _From, St = #s{core_nodes = CoreNodes}) ->
{reply, CoreNodes, St};
handle_call(Call, From, St) ->
Expand All @@ -172,7 +149,7 @@ terminate(_Reason, St) ->
%% Internal functions
%%================================================================================

do_update(State = #s{core_nodes = OldCoreNodes}) ->
do_update(State = #s{core_nodes = OldCoreNodes, node_info = OldNodeInfo}) ->
DiscoveredNodes = discover_nodes(),
%% Get information about core nodes:
{NodeInfo0, _BadNodes} = rpc:multicall( DiscoveredNodes
Expand All @@ -181,6 +158,7 @@ do_update(State = #s{core_nodes = OldCoreNodes}) ->
),
NodeInfo1 = [I || I = {_, #{whoami := core, running := true}} <- NodeInfo0],
NodeInfo = maps:from_list(NodeInfo1),
maybe_report_changes(OldNodeInfo, NodeInfo),
%% Find partitions of the core cluster, and if the core cluster is
%% partitioned choose the best partition to connect to:
Clusters = find_clusters(NodeInfo),
Expand All @@ -204,7 +182,7 @@ do_update(State = #s{core_nodes = OldCoreNodes}) ->
}),
DiscoveredReplicants = discover_replicants(NewCoreNodes),
ping_new_nodes(NewCoreNodes, DiscoveredReplicants),
State#s{core_nodes = NewCoreNodes}.
State#s{core_nodes = NewCoreNodes, node_info = NodeInfo}.

%% Find fully connected clusters (i.e. cliques of nodes)
-spec find_clusters(#{node() => node_info()}) -> [[node()]].
Expand Down Expand Up @@ -233,8 +211,10 @@ find_clusters([Node|Rest], NodeInfo, Acc) ->
-spec shard_badness(#{node() => node_info()}) -> #{mria_rlog:shard() => {node(), Badness}}
when Badness :: float().
shard_badness(NodeInfo) ->
MyProtoVersion = mria_rlog:get_protocol_version(),
maps:fold(
fun(Node, #{shard_badness := Shards}, Acc) ->
fun(Node, #{shard_badness := Shards, protocol_version := ProtoVsn}, Acc)
when ProtoVsn =:= MyProtoVersion ->
lists:foldl(
fun({Shard, Badness}, Acc1) ->
maps:update_with(Shard,
Expand All @@ -247,7 +227,9 @@ shard_badness(NodeInfo) ->
Acc1)
end,
Acc,
Shards)
Shards);
(_Node, _NodeInfo, Acc) ->
Acc
end,
#{},
NodeInfo).
Expand Down Expand Up @@ -432,6 +414,12 @@ ping_nodes(Nodes) ->
mria_membership:ping(Node, LocalMember)
end, Nodes).

-spec maybe_report_changes(A, A) -> ok
when A :: #{node() => node_info()}.
maybe_report_changes(_Old, _New) ->
%% TODO
ok.

%%================================================================================
%% Unit tests
%%================================================================================
Expand All @@ -440,33 +428,35 @@ ping_nodes(Nodes) ->
-include_lib("eunit/include/eunit.hrl").

find_clusters_test_() ->
Vsn = mria_rlog:get_protocol_version(),
[ ?_assertMatch( [[1, 2, 3]]
, lists:sort(find_clusters(#{ 1 => #{db_nodes => [1, 2, 3]}
, 2 => #{db_nodes => [2, 1, 3]}
, 3 => #{db_nodes => [2, 3, 1]}
, lists:sort(find_clusters(#{ 1 => #{db_nodes => [1, 2, 3], protocol_version => Vsn}
, 2 => #{db_nodes => [2, 1, 3], protocol_version => Vsn}
, 3 => #{db_nodes => [2, 3, 1], protocol_version => Vsn}
}))
)
, ?_assertMatch( [[1], [2, 3]]
, lists:sort(find_clusters(#{ 1 => #{db_nodes => [1, 2, 3]}
, 2 => #{db_nodes => [2, 3]}
, 3 => #{db_nodes => [3, 2]}
, lists:sort(find_clusters(#{ 1 => #{db_nodes => [1, 2, 3], protocol_version => Vsn}
, 2 => #{db_nodes => [2, 3], protocol_version => Vsn}
, 3 => #{db_nodes => [3, 2], protocol_version => Vsn}
}))
)
, ?_assertMatch( [[1, 2, 3], [4, 5], [6]]
, lists:sort(find_clusters(#{ 1 => #{db_nodes => [1, 2, 3]}
, 2 => #{db_nodes => [1, 2, 3]}
, 3 => #{db_nodes => [3, 2, 1]}
, 4 => #{db_nodes => [4, 5]}
, 5 => #{db_nodes => [4, 5]}
, 6 => #{db_nodes => [6, 4, 5]}
, lists:sort(find_clusters(#{ 1 => #{db_nodes => [1, 2, 3], protocol_version => Vsn}
, 2 => #{db_nodes => [1, 2, 3], protocol_version => Vsn}
, 3 => #{db_nodes => [3, 2, 1], protocol_version => Vsn}
, 4 => #{db_nodes => [4, 5], protocol_version => Vsn}
, 5 => #{db_nodes => [4, 5], protocol_version => Vsn}
, 6 => #{db_nodes => [6, 4, 5], protocol_version => Vsn}
}))
)
].

shard_badness_test_() ->
Vsn = mria_rlog:get_protocol_version(),
[ ?_assertMatch( #{foo := {n1, 1}, bar := {n2, 2}}
, shard_badness(#{ n1 => #{shard_badness => [{foo, 1}]}
, n2 => #{shard_badness => [{foo, 2}, {bar, 2}]}
, shard_badness(#{ n1 => #{shard_badness => [{foo, 1}], protocol_version => Vsn}
, n2 => #{shard_badness => [{foo, 2}, {bar, 2}], protocol_version => Vsn}
})
)
].
Expand Down
8 changes: 6 additions & 2 deletions src/mria_rlog_server.erl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ start_link(Parent, Shard) ->
%% server is lost or delayed due to network congestion.
-spec probe(node(), mria_rlog:shard()) -> boolean().
probe(Node, Shard) ->
mria_lb:probe(Node, Shard).
Vsn = mria_rlog:get_protocol_version(),
case mria_lib:rpc_call_nothrow({Node, Shard}, ?MODULE, do_probe, [Shard]) of
{true, Vsn} -> true;
_ -> false
end.

-spec subscribe(mria_rlog:shard(), mria_lib:subscriber(), checkpoint()) ->
{ ok
Expand Down Expand Up @@ -292,4 +296,4 @@ do_bootstrap(Shard, Subscriber) ->

-spec do_probe(mria_rlog:shard()) -> {true, integer()}.
do_probe(Shard) ->
{gen_server:call(Shard, probe, 1000), mria_rlog:get_protocol_version()}.
{gen_server:call(Shard, probe, infinity), mria_rlog:get_protocol_version()}.
14 changes: 7 additions & 7 deletions test/mria_lb_SUITE.erl
Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,25 @@ t_probe(_Config) ->
ok = rpc:call(N1, meck, expect, [mria_rlog, get_protocol_version,
fun() -> ExpectedVersion + 1 end]),
?tp(call_probe, #{}),
false = rpc:call(N2, mria_lb, probe, [N1, test_shard]),
false = rpc:call(N2, mria_rlog_server, probe, [N1, test_shard]),
%% 2. last version is cached; should not log
?tp(call_probe, #{}),
false = rpc:call(N2, mria_lb, probe, [N1, test_shard]),
false = rpc:call(N2, mria_rlog_server, probe, [N1, test_shard]),
%% 3. probing a new node for the first time; should log
ok = rpc:call(N3, meck, expect, [mria_rlog, get_protocol_version,
fun() -> ExpectedVersion + 1 end]),
?tp(call_probe, #{}),
false = rpc:call(N2, mria_lb, probe, [N3, test_shard]),
false = rpc:call(N2, mria_rlog_server, probe, [N3, test_shard]),
%% 4. change of versions; should log
ok = rpc:call(N1, meck, expect, [mria_rlog, get_protocol_version,
fun() -> ExpectedVersion + 2 end]),
?tp(call_probe, #{}),
false = rpc:call(N2, mria_lb, probe, [N1, test_shard]),
false = rpc:call(N2, mria_rlog_server, probe, [N1, test_shard]),
%% 5. correct version; should not log
ok = rpc:call(N1, meck, expect, [mria_rlog, get_protocol_version,
fun() -> ExpectedVersion end]),
?tp(call_probe, #{}),
true = rpc:call(N2, mria_lb, probe, [N1, test_shard]),
true = rpc:call(N2, mria_rlog_server, probe, [N1, test_shard]),
?tp(test_end, #{}),
{ExpectedVersion, [N1, N2, N3]}
after
Expand Down Expand Up @@ -125,9 +125,9 @@ t_probe_pure_mnesia(_Config) ->
#{timetrap => 30000},
try
[N1, N2, N3] = mria_ct:start_cluster(mria, Cluster),
?assert(erpc:call(N3, mria_lb, probe, [N1, test_shard])),
?assert(erpc:call(N3, mria_rlog_server, probe, [N1, test_shard])),
%% should return false, since it's a pure mnesia node
?assertNot(erpc:call(N3, mria_lb, probe, [N2, test_shard])),
?assertNot(erpc:call(N3, mria_rlog_server, probe, [N2, test_shard])),
ok
after
mria_ct:teardown_cluster(Cluster)
Expand Down

0 comments on commit a57958f

Please sign in to comment.