Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix data race in StorageRabbitMQ #48845

Merged
merged 3 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
138 changes: 73 additions & 65 deletions src/Storages/RabbitMQ/RabbitMQConsumer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <memory>
#include <Storages/RabbitMQ/RabbitMQConsumer.h>
#include <Storages/RabbitMQ/RabbitMQHandler.h>
#include <Storages/RabbitMQ/RabbitMQConnection.h>
#include <IO/ReadBufferFromMemory.h>
#include <Common/logger_useful.h>
#include "Poco/Timer.h"
Expand Down Expand Up @@ -34,7 +35,7 @@ RabbitMQConsumer::RabbitMQConsumer(
{
}

void RabbitMQConsumer::shutdown()
void RabbitMQConsumer::stop()
{
stopped = true;
cv.notify_one();
Expand All @@ -53,119 +54,126 @@ void RabbitMQConsumer::subscribe()
consumer_channel->consume(queue_name)
.onSuccess([&](const std::string & /* consumer_tag */)
{
LOG_TRACE(log, "Consumer on channel {} is subscribed to queue {}", channel_id, queue_name);

if (++subscribed == queues.size())
wait_subscription.store(false);
LOG_TRACE(
log, "Consumer on channel {} ({}/{}) is subscribed to queue {}",
channel_id, subscriptions_num, queues.size(), queue_name);
})
.onReceived([&](const AMQP::Message & message, uint64_t delivery_tag, bool redelivered)
{
if (message.bodySize())
{
String message_received = std::string(message.body(), message.body() + message.bodySize());

if (!received.push({message_received, message.hasMessageID() ? message.messageID() : "",
message.hasTimestamp() ? message.timestamp() : 0,
redelivered, AckTracker(delivery_tag, channel_id)}))
MessageData result{
.message = message_received,
.message_id = message.hasMessageID() ? message.messageID() : "",
.timestamp = message.hasTimestamp() ? message.timestamp() : 0,
.redelivered = redelivered,
.delivery_tag = delivery_tag,
.channel_id = channel_id};

if (!received.push(std::move(result)))
throw Exception(ErrorCodes::LOGICAL_ERROR, "Could not push to received queue");

cv.notify_one();
}
})
.onError([&](const char * message)
{
/* End up here either if channel ends up in an error state (then there will be resubscription) or consume call error, which
* arises from queue settings mismatch or queue level error, which should not happen as no one else is supposed to touch them
/* End up here either if channel ends up in an error state (then there will be resubscription)
* or consume call error, which arises from queue settings mismatch or queue level error,
* which should not happen as no one else is supposed to touch them
*/
LOG_ERROR(log, "Consumer failed on channel {}. Reason: {}", channel_id, message);
wait_subscription.store(false);
state = State::ERROR;
});
}
}


bool RabbitMQConsumer::ackMessages()
bool RabbitMQConsumer::ackMessages(const CommitInfo & commit_info)
{
AckTracker record_info = last_inserted_record_info;
if (state != State::OK)
return false;

/* Do not send ack to server if message's channel is not the same as current running channel because delivery tags are scoped per
* channel, so if channel fails, all previous delivery tags become invalid
*/
if (record_info.channel_id == channel_id && record_info.delivery_tag && record_info.delivery_tag > prev_tag)
{
/// Commit all received messages with delivery tags from last committed to last inserted
if (!consumer_channel->ack(record_info.delivery_tag, AMQP::multiple))
{
LOG_ERROR(log, "Failed to commit messages with delivery tags from last committed to {} on channel {}",
record_info.delivery_tag, channel_id);
return false;
}
/// Nothing to ack.
if (!commit_info.delivery_tag)
return false;

prev_tag = record_info.delivery_tag;
LOG_TRACE(log, "Consumer committed messages with deliveryTags up to {} on channel {}", record_info.delivery_tag, channel_id);
}
/// Do not send ack to server if message's channel is not the same as
/// current running channel because delivery tags are scoped per channel,
/// so if channel fails, all previous delivery tags become invalid.
if (commit_info.channel_id != channel_id)
return false;

return true;
}
/// Duplicate ack?
if (commit_info.delivery_tag > last_commited_delivery_tag
&& consumer_channel->ack(commit_info.delivery_tag, AMQP::multiple))
{
last_commited_delivery_tag = commit_info.delivery_tag;

LOG_TRACE(
log, "Consumer committed messages with deliveryTags up to {} on channel {}",
last_commited_delivery_tag, channel_id);

void RabbitMQConsumer::updateAckTracker(AckTracker record_info)
{
if (record_info.delivery_tag && channel_error.load())
return;
return true;
}

if (!record_info.delivery_tag)
prev_tag = 0;
LOG_ERROR(
log,
"Did not commit messages for {}:{}, (current commit point {}:{})",
commit_info.channel_id, commit_info.delivery_tag,
channel_id, last_commited_delivery_tag);

last_inserted_record_info = record_info;
return false;
}


void RabbitMQConsumer::setupChannel()
void RabbitMQConsumer::updateChannel(RabbitMQConnection & connection)
{
if (!consumer_channel)
return;

wait_subscription.store(true);
state = State::INITIALIZING;
last_commited_delivery_tag = 0;

consumer_channel = connection.createChannel();
consumer_channel->onReady([&]()
{
/* First number indicates current consumer buffer; second number indicates serial number of created channel for current buffer,
* i.e. if channel fails - another one is created and its serial number is incremented; channel_base is to guarantee that
* channel_id is unique for each table
*/
channel_id = std::to_string(channel_id_base) + "_" + std::to_string(channel_id_counter++) + "_" + channel_base;
LOG_TRACE(log, "Channel {} is created", channel_id);

subscribed = 0;
subscribe();
channel_error.store(false);
try
{
/// 1. channel_id_base - indicates current consumer buffer.
/// 2. channel_id_couner - indicates serial number of created channel for current buffer
/// (incremented on each channel update).
/// 3. channel_base is to guarantee that channel_id is unique for each table.
channel_id = fmt::format("{}_{}_{}", channel_id_base, channel_id_counter++, channel_base);

LOG_TRACE(log, "Channel {} is successfully created", channel_id);

subscriptions_num = 0;
subscribe();

state = State::OK;
}
catch (...)
{
state = State::ERROR;
tryLogCurrentException(__PRETTY_FUNCTION__);
}
});

consumer_channel->onError([&](const char * message)
{
LOG_ERROR(log, "Channel {} error: {}", channel_id, message);

channel_error.store(true);
wait_subscription.store(false);
LOG_ERROR(log, "Channel {} in an error state: {}", channel_id, message);
state = State::ERROR;
});
}


bool RabbitMQConsumer::needChannelUpdate()
{
if (wait_subscription)
return false;

return channel_error || !consumer_channel || !consumer_channel->usable();
chassert(consumer_channel);
return state == State::ERROR;
}


void RabbitMQConsumer::iterateEventLoop()
{
event_handler.iterateLoop();
}

ReadBufferPtr RabbitMQConsumer::consume()
{
if (stopped || !received.tryPop(current))
Expand Down
70 changes: 36 additions & 34 deletions src/Storages/RabbitMQ/RabbitMQConsumer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace DB
{

class RabbitMQHandler;
class RabbitMQConnection;
using ChannelPtr = std::unique_ptr<AMQP::TcpChannel>;
static constexpr auto SANITY_TIMEOUT = 1000 * 60 * 10; /// 10min.

Expand All @@ -27,54 +28,43 @@ class RabbitMQConsumer

public:
RabbitMQConsumer(
RabbitMQHandler & event_handler_,
std::vector<String> & queues_,
size_t channel_id_base_,
const String & channel_base_,
Poco::Logger * log_,
uint32_t queue_size_);

struct AckTracker
RabbitMQHandler & event_handler_,
std::vector<String> & queues_,
size_t channel_id_base_,
const String & channel_base_,
Poco::Logger * log_,
uint32_t queue_size_);

struct CommitInfo
{
UInt64 delivery_tag;
UInt64 delivery_tag = 0;
String channel_id;

AckTracker() = default;
AckTracker(UInt64 tag, String id) : delivery_tag(tag), channel_id(id) {}
};

struct MessageData
{
String message;
String message_id;
uint64_t timestamp = 0;
UInt64 timestamp = 0;
bool redelivered = false;
AckTracker track{};
UInt64 delivery_tag = 0;
String channel_id;
};
const MessageData & currentMessage() { return current; }

/// Return read buffer containing next available message
/// or nullptr if there are no messages to process.
ReadBufferPtr consume();

ChannelPtr & getChannel() { return consumer_channel; }
void setupChannel();
bool needChannelUpdate();
void shutdown();

void updateQueues(std::vector<String> & queues_) { queues = queues_; }
size_t queuesCount() { return queues.size(); }
void updateChannel(RabbitMQConnection & connection);

void stop();
bool isConsumerStopped() const { return stopped.load(); }
bool ackMessages();
void updateAckTracker(AckTracker record = AckTracker());

bool hasPendingMessages() { return !received.empty(); }
bool ackMessages(const CommitInfo & commit_info);

auto getChannelID() const { return current.track.channel_id; }
auto getDeliveryTag() const { return current.track.delivery_tag; }
auto getRedelivered() const { return current.redelivered; }
auto getMessageID() const { return current.message_id; }
auto getTimestamp() const { return current.timestamp; }
bool hasPendingMessages() { return !received.empty(); }

void waitForMessages(std::optional<uint64_t> timeout_ms = std::nullopt)
{
Expand All @@ -88,24 +78,36 @@ class RabbitMQConsumer

private:
void subscribe();
void iterateEventLoop();
bool isChannelUsable();
void updateCommitInfo(CommitInfo record);

ChannelPtr consumer_channel;
RabbitMQHandler & event_handler; /// Used concurrently, but is thread safe.
std::vector<String> queues;

const std::vector<String> queues;
const String channel_base;
const size_t channel_id_base;

Poco::Logger * log;
std::atomic<bool> stopped;

String channel_id;
std::atomic<bool> channel_error = true, wait_subscription = false;
UInt64 channel_id_counter = 0;

enum class State
{
NONE,
INITIALIZING,
OK,
ERROR,
};
std::atomic<State> state = State::NONE;
size_t subscriptions_num = 0;

ConcurrentBoundedQueue<MessageData> received;
MessageData current;
size_t subscribed = 0;

AckTracker last_inserted_record_info;
UInt64 prev_tag = 0, channel_id_counter = 0;
UInt64 last_commited_delivery_tag;

std::condition_variable cv;
std::mutex mutex;
Expand Down