Skip to content
This repository was archived by the owner on Jun 1, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/routing/include/mysqlrouter/routing.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ extern const int kDefaultDestinationConnectionTimeout;
*/
extern const unsigned long long kDefaultMaxConnectErrors;

/** @brief Timeout then reset counter for connect or handshake errors per host
*
*/
extern const unsigned long long kDefaultMaxConnectErrorsTimeout;

/** @brief Default bind address
*
*/
Expand Down
85 changes: 70 additions & 15 deletions src/routing/src/mysql_routing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include "plugin_config.h"
#include "mysqlrouter/routing.h"

#include "signal.h"

#include <ctime>
#include <algorithm>
#include <array>
#include <cmath>
Expand Down Expand Up @@ -57,13 +60,15 @@ MySQLRouting::MySQLRouting(routing::AccessMode mode, int port, const string &bin
int max_connections,
int destination_connect_timeout,
unsigned long long max_connect_errors,
unsigned long long max_connect_errors_timeout,
unsigned int client_connect_timeout,
unsigned int net_buffer_length)
: name(route_name),
mode_(mode),
max_connections_(set_max_connections(max_connections)),
destination_connect_timeout_(set_destination_connect_timeout(destination_connect_timeout)),
max_connect_errors_(max_connect_errors),
max_connect_errors_timeout_(max_connect_errors_timeout),
client_connect_timeout_(client_connect_timeout),
net_buffer_length_(net_buffer_length),
bind_address_(TCPAddress(bind_address, port)),
Expand All @@ -75,6 +80,21 @@ MySQLRouting::MySQLRouting(routing::AccessMode mode, int port, const string &bin
}
}

/* Catch Signal Handler functio */
void signal_callback_handler(int signum){
log_error("Unexpected error: caught signal SIGPIPE %d",signum);
}

bool check_socket_alive(int fd) {
int error_code;
socklen_t error_code_size = sizeof(error_code);
int res = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error_code, &error_code_size);
if (res == 0 && error_code == 0) {
return true;
};
return false;
}

/** @brief Reads from sender and writes it back to receiver using select
*
* This function reads data from the sender socket and writes it back
Expand Down Expand Up @@ -136,7 +156,9 @@ int copy_mysql_protocol_packets(int sender, int receiver, fd_set *readfds,
// We got error from MySQL Server while handshaking
// We do not consider this a failed handshake
auto server_error = mysql_protocol::ErrorPacket(buffer);
write(receiver, server_error.data(), server_error.size());
if (check_socket_alive(receiver)) {
write(receiver, server_error.data(), server_error.size());
}
// receiver socket closed by caller
*curr_pktnr = 2; // we assume handshaking is done though there was an error
*report_bytes_read = bytes_read;
Expand All @@ -162,6 +184,9 @@ int copy_mysql_protocol_packets(int sender, int receiver, fd_set *readfds,
size_t bytes_to_write = bytes_read;
ssize_t written = 0;
while (bytes_to_write > 0) {
if (!check_socket_alive(receiver)) {
break;
}
if ((written = write(receiver, buffer.data(), bytes_to_write)) < 0) {
log_debug("Write error: %s", strerror(errno));
return -1;
Expand All @@ -176,22 +201,42 @@ int copy_mysql_protocol_packets(int sender, int receiver, fd_set *readfds,
return 0;
}

bool MySQLRouting::check_client_errors_time(const std::array<uint8_t, 16> &client_ip_array) {
size_t timediff;
if (max_connect_errors_timeout_ == 0) {
return false;
}
std::time_t curtime = std::time(nullptr);
timediff = curtime - auth_error_counters_[client_ip_array].last_attempt;
if (timediff > max_connect_errors_timeout_) {
auth_error_counters_[client_ip_array].count = 0;
return true;
};
return false;
}

bool MySQLRouting::block_client_host(const std::array<uint8_t, 16> &client_ip_array,
const string &client_ip_str, int server) {
bool blocked = false;
char *time_str;
std::lock_guard<std::mutex> lock(mutex_auth_errors_);

if (++auth_error_counters_[client_ip_array] >= max_connect_errors_) {
struct tm *curtime = localtime(&auth_error_counters_[client_ip_array].last_attempt);
auth_error_counters_[client_ip_array].last_attempt = std::time(0);
if (++auth_error_counters_[client_ip_array].count >= max_connect_errors_) {
log_warning("[%s] blocking client host %s", name.c_str(), client_ip_str.c_str());
blocked = true;
} else {
log_info("[%s] %d authentication errors for %s (max %d)",
name.c_str(), auth_error_counters_[client_ip_array], client_ip_str.c_str(), max_connect_errors_);
time_str = asctime(curtime);
time_str[strlen(time_str)-1] = '\0';
log_info("[%s] %d authentication errors for %s (max %d). last attempt: %s",
name.c_str(), auth_error_counters_[client_ip_array].count, client_ip_str.c_str(), max_connect_errors_, time_str);
}

if (server >= 0) {
auto fake_response = mysql_protocol::HandshakeResponsePacket(1, {}, "ROUTER", "", "fake_router_login");
write(server, fake_response.data(), fake_response.size());
if (check_socket_alive(server)) {
write(server, fake_response.data(), fake_response.size());
}
}

return blocked;
Expand All @@ -216,7 +261,9 @@ void MySQLRouting::routing_select_thread(int client, const in6_addr client_addr)
os << "Can't connect to MySQL server on ";
os << "'" << bind_address_.addr << "'";
auto server_error = mysql_protocol::ErrorPacket(0, 2003, os.str(), "HY000");
write(client, server_error.data(), server_error.size());
if (check_socket_alive(client)) {
write(client, server_error.data(), server_error.size());
}

shutdown(client, SHUT_RDWR);
shutdown(server, SHUT_RDWR);
Expand Down Expand Up @@ -303,6 +350,7 @@ void MySQLRouting::routing_select_thread(int client, const in6_addr client_addr)
if (!handshake_done) {
auto ip_array = in6_addr_to_array(client_addr);
log_debug("[%s] Routing failed for %s: %s", name.c_str(), c_ip.first.c_str(), extra_msg.c_str());
check_client_errors_time(ip_array);
block_client_host(ip_array, c_ip.first.c_str(), server);
}

Expand All @@ -321,6 +369,7 @@ void MySQLRouting::start() {
socklen_t sin_size = sizeof client_addr;
char client_ip[INET6_ADDRSTRLEN];
int opt_nodelay = 1;
signal(SIGPIPE, signal_callback_handler);

try {
setup_service();
Expand Down Expand Up @@ -348,18 +397,24 @@ void MySQLRouting::start() {
continue;
}

if (auth_error_counters_[in6_addr_to_array(client_addr.sin6_addr)] >= max_connect_errors_) {
std::stringstream os;
os << "Too many connection errors from " << get_peer_name(sock_client).first;
auto server_error = mysql_protocol::ErrorPacket(0, 1129, os.str(), "HY000");
write(sock_client, server_error.data(), server_error.size());
close(sock_client); // no shutdown() before close()
continue;
if (auth_error_counters_[in6_addr_to_array(client_addr.sin6_addr)].count >= max_connect_errors_) {
if (!check_client_errors_time(in6_addr_to_array(client_addr.sin6_addr))) {
std::stringstream os;
os << "Too many connection errors from " << get_peer_name(sock_client).first;
auto server_error = mysql_protocol::ErrorPacket(0, 1129, os.str(), "HY000");
if (check_socket_alive(sock_client)) {
write(sock_client, server_error.data(), server_error.size());
}
close(sock_client); // no shutdown() before close()
continue;
}
}

if (info_active_routes_.load(std::memory_order_relaxed) >= max_connections_) {
auto server_error = mysql_protocol::ErrorPacket(0, 1040, "Too many connections", "HY000");
write(sock_client, server_error.data(), server_error.size());
if (check_socket_alive(sock_client)) {
write(sock_client, server_error.data(), server_error.size());
}
close(sock_client); // no shutdown() before close()
log_warning("[%s] reached max active connections (%d)", name.c_str(), max_connections_);
continue;
Expand Down
14 changes: 13 additions & 1 deletion src/routing/src/mysql_routing.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "plugin_config.h"

#include <atomic>
#include <ctime>
#include <arpa/inet.h>
#include <array>
#include <iostream>
Expand Down Expand Up @@ -80,6 +81,7 @@ using mysqlrouter::URI;
* use 10.0.11.6 to setup the connection routing.
*
*/

class MySQLRouting {
public:
/** @brief Default constructor
Expand All @@ -97,9 +99,15 @@ class MySQLRouting {
int max_connections = routing::kDefaultMaxConnections,
int destination_connect_timeout = routing::kDefaultDestinationConnectionTimeout,
unsigned long long max_connect_errors = routing::kDefaultMaxConnectErrors,
unsigned long long max_connect_errors_timeout = routing::kDefaultMaxConnectErrorsTimeout,
unsigned int connect_timeout = routing::kDefaultClientConnectTimeout,
unsigned int net_buffer_length = routing::kDefaultNetBufferLength);

struct AuthErrorCounter {
size_t count;
std::time_t last_attempt;
};

/** @brief Starts the service and accept incoming connections
*
* Starts the connection routing service and start accepting incoming
Expand Down Expand Up @@ -194,6 +202,8 @@ class MySQLRouting {
bool block_client_host(const std::array<uint8_t, 16> &client_ip_array,
const string &client_ip_str, int server = -1);

bool check_client_errors_time(const std::array<uint8_t, 16> &client_ip_array);

/** @brief Returns a copy of the list of blocked client hosts
*
* Returns a copy of the list of the blocked client hosts.
Expand Down Expand Up @@ -253,6 +263,8 @@ class MySQLRouting {
int destination_connect_timeout_;
/** @brief Max connect errors blocking hosts when handshake not completed */
unsigned long long max_connect_errors_;
/** @brief Timeout fot reset counter for connect errors blocking hosts when handshake not completed */
unsigned long long max_connect_errors_timeout_;
/** @brief Timeout waiting for handshake response from client */
unsigned int client_connect_timeout_;
/** @brief Size of buffer to store receiving packets */
Expand All @@ -272,7 +284,7 @@ class MySQLRouting {

/** @brief Authentication error counters for IPv4 or IPv6 hosts */
std::mutex mutex_auth_errors_;
std::map<std::array<uint8_t, 16>, size_t> auth_error_counters_;
std::map<std::array<uint8_t, 16>, AuthErrorCounter> auth_error_counters_;
std::vector<std::array<uint8_t, 16>> blocked_client_hosts_;
};

Expand Down
1 change: 1 addition & 0 deletions src/routing/src/plugin_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ string RoutingPluginConfig::get_default(const string &option) {
{"connect_timeout", to_string(routing::kDefaultDestinationConnectionTimeout)},
{"max_connections", to_string(routing::kDefaultMaxConnections)},
{"max_connect_errors", to_string(routing::kDefaultMaxConnectErrors)},
{"max_connect_errors_timeout", to_string(routing::kDefaultMaxConnectErrorsTimeout)},
{"client_connect_timeout", to_string(routing::kDefaultClientConnectTimeout)},
{"net_buffer_length", to_string(routing::kDefaultNetBufferLength)},
};
Expand Down
3 changes: 3 additions & 0 deletions src/routing/src/plugin_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class RoutingPluginConfig final : public mysqlrouter::BasePluginConfig {
mode(get_option_mode(section, "mode")),
max_connections(get_uint_option<uint16_t>(section, "max_connections", 1)),
max_connect_errors(get_uint_option<uint>(section, "max_connect_errors", 1, UINT32_MAX)),
max_connect_errors_timeout(get_uint_option<uint>(section, "max_connect_errors_timeout", 0, UINT32_MAX)), // 5 minutes for reset error attempts
client_connect_timeout(get_uint_option<uint>(section, "client_connect_timeout", 2, 31536000)),
net_buffer_length(get_uint_option<uint>(section, "net_buffer_length", 1024, 1048576)) { }

Expand All @@ -74,6 +75,8 @@ class RoutingPluginConfig final : public mysqlrouter::BasePluginConfig {
const int max_connections;
/** @brief `max_connect_errors` option read from configuration section */
const unsigned long long max_connect_errors;
/** @brief `max_connect_errors_timeout` option read from configuration section */
const unsigned long long max_connect_errors_timeout;
/** @brief `client_connect_timeout` option read from configuration section */
const unsigned int client_connect_timeout;
/** @brief Size of buffer to receive packets */
Expand Down
1 change: 1 addition & 0 deletions src/routing/src/routing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const int kDefaultDestinationConnectionTimeout = 1;
const string kDefaultBindAddress = "127.0.0.1";
const unsigned int kDefaultNetBufferLength = 16384; // Default defined in latest MySQL Server
const unsigned long long kDefaultMaxConnectErrors = 100; // Similar to MySQL Server
const unsigned long long kDefaultMaxConnectErrorsTimeout = 0; // 5 minutes
const unsigned int kDefaultClientConnectTimeout = 9; // Default connect_timeout MySQL Server minus 1

const std::map<string, AccessMode> kAccessModeNames = {
Expand Down