From de93318569920a801974a2b5269b37048c15c138 Mon Sep 17 00:00:00 2001 From: Michael Penick Date: Tue, 13 Sep 2016 15:06:08 -0700 Subject: [PATCH 1/7] Renamed 'Handler' to 'RequestCallback' --- src/auth_requests.cpp | 4 +- src/auth_requests.hpp | 4 +- src/batch_request.cpp | 14 +- src/batch_request.hpp | 2 +- src/connection.cpp | 190 +++++++++--------- src/connection.hpp | 34 ++-- src/control_connection.cpp | 112 +++++------ src/control_connection.hpp | 25 ++- src/execute_request.cpp | 30 +-- src/execute_request.hpp | 8 +- src/multiple_request_handler.cpp | 66 ------ src/multiple_request_handler.hpp | 83 -------- src/options_request.hpp | 2 +- src/pool.cpp | 71 ++++++- src/pool.hpp | 4 +- src/prepare_handler.cpp | 82 -------- src/prepare_handler.hpp | 48 ----- src/prepare_request.cpp | 2 +- src/prepare_request.hpp | 2 +- src/query_request.cpp | 32 +-- src/query_request.hpp | 8 +- src/register_request.cpp | 2 +- src/register_request.hpp | 2 +- src/request.hpp | 4 +- src/{handler.cpp => request_callback.cpp} | 51 ++++- src/{handler.hpp => request_callback.hpp} | 62 +++++- src/request_handler.cpp | 81 +++++++- src/request_handler.hpp | 6 +- ...handler.cpp => schema_change_callback.cpp} | 32 +-- ...handler.hpp => schema_change_callback.hpp} | 9 +- src/set_keyspace_handler.cpp | 78 ------- src/set_keyspace_handler.hpp | 49 ----- src/startup_request.cpp | 2 +- src/startup_request.hpp | 2 +- src/statement.cpp | 6 +- src/statement.hpp | 6 +- 36 files changed, 518 insertions(+), 697 deletions(-) delete mode 100644 src/multiple_request_handler.cpp delete mode 100644 src/multiple_request_handler.hpp delete mode 100644 src/prepare_handler.cpp delete mode 100644 src/prepare_handler.hpp rename src/{handler.cpp => request_callback.cpp} (69%) rename src/{handler.hpp => request_callback.hpp} (62%) rename src/{schema_change_handler.cpp => schema_change_callback.cpp} (81%) rename src/{schema_change_handler.hpp => schema_change_callback.hpp} (84%) delete mode 100644 src/set_keyspace_handler.cpp delete mode 100644 src/set_keyspace_handler.hpp diff --git a/src/auth_requests.cpp b/src/auth_requests.cpp index d459e4c7f..731c07ab5 100644 --- a/src/auth_requests.cpp +++ b/src/auth_requests.cpp @@ -18,7 +18,7 @@ namespace cass { -int CredentialsRequest::encode(int version, Handler* handler, BufferVec* bufs) const { +int CredentialsRequest::encode(int version, RequestCallback* callback, BufferVec* bufs) const { if (version != 1) { return -1; } @@ -39,7 +39,7 @@ int CredentialsRequest::encode(int version, Handler* handler, BufferVec* bufs) c return length; } -int AuthResponseRequest::encode(int version, Handler* handler, BufferVec* bufs) const { +int AuthResponseRequest::encode(int version, RequestCallback* callback, BufferVec* bufs) const { if (version < 2) { return -1; } diff --git a/src/auth_requests.hpp b/src/auth_requests.hpp index a87bfabf2..e797e4281 100644 --- a/src/auth_requests.hpp +++ b/src/auth_requests.hpp @@ -31,7 +31,7 @@ class CredentialsRequest : public Request { , credentials_(credentials) { } private: - int encode(int version, Handler* handler, BufferVec* bufs) const; + int encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: V1Authenticator::Credentials credentials_; @@ -48,7 +48,7 @@ class AuthResponseRequest : public Request { const SharedRefPtr& auth() const { return auth_; } private: - int encode(int version, Handler* handler, BufferVec* bufs) const; + int encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: std::string token_; diff --git a/src/batch_request.cpp b/src/batch_request.cpp index 3373f8b60..bf66db3f0 100644 --- a/src/batch_request.cpp +++ b/src/batch_request.cpp @@ -79,7 +79,7 @@ CassError cass_batch_add_statement(CassBatch* batch, CassStatement* statement) { namespace cass { -int BatchRequest::encode(int version, Handler* handler, BufferVec* bufs) const { +int BatchRequest::encode(int version, RequestCallback* callback, BufferVec* bufs) const { int length = 0; uint8_t flags = 0; @@ -104,11 +104,11 @@ int BatchRequest::encode(int version, Handler* handler, BufferVec* bufs) const { end = statements_.end(); i != end; ++i) { const SharedRefPtr& statement(*i); if (statement->has_names_for_values()) { - handler->on_error(CASS_ERROR_LIB_BAD_PARAMS, + callback->on_error(CASS_ERROR_LIB_BAD_PARAMS, "Batches cannot contain queries with named values"); return ENCODE_ERROR_BATCH_WITH_NAMED_VALUES; } - int32_t result = (*i)->encode_batch(version, bufs, handler); + int32_t result = (*i)->encode_batch(version, bufs, callback); if (result < 0) { return result; } @@ -127,7 +127,7 @@ int BatchRequest::encode(int version, Handler* handler, BufferVec* bufs) const { flags |= CASS_QUERY_FLAG_SERIAL_CONSISTENCY; } - if (handler->timestamp() != CASS_INT64_MIN) { + if (callback->timestamp() != CASS_INT64_MIN) { buf_size += sizeof(int64_t); // [long] flags |= CASS_QUERY_FLAG_DEFAULT_TIMESTAMP; } @@ -135,7 +135,7 @@ int BatchRequest::encode(int version, Handler* handler, BufferVec* bufs) const { Buffer buf(buf_size); - size_t pos = buf.encode_uint16(0, handler->consistency()); + size_t pos = buf.encode_uint16(0, callback->consistency()); if (version >= 3) { pos = buf.encode_byte(pos, flags); @@ -143,8 +143,8 @@ int BatchRequest::encode(int version, Handler* handler, BufferVec* bufs) const { pos = buf.encode_uint16(pos, serial_consistency()); } - if (handler->timestamp() != CASS_INT64_MIN) { - pos = buf.encode_int64(pos, handler->timestamp()); + if (callback->timestamp() != CASS_INT64_MIN) { + pos = buf.encode_int64(pos, callback->timestamp()); } } diff --git a/src/batch_request.hpp b/src/batch_request.hpp index 5a4960744..7b0355052 100644 --- a/src/batch_request.hpp +++ b/src/batch_request.hpp @@ -50,7 +50,7 @@ class BatchRequest : public RoutableRequest { virtual bool get_routing_key(std::string* routing_key, EncodingCache* cache) const; private: - int encode(int version, Handler* handler, BufferVec* bufs) const; + int encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: typedef std::map PreparedMap; diff --git a/src/connection.cpp b/src/connection.cpp index 1d91c55b8..9d9e7f324 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -59,20 +59,20 @@ namespace cass { -static void cleanup_pending_handlers(List* pending) { +static void cleanup_pending_callbacks(List* pending) { while (!pending->is_empty()) { - Handler* handler = pending->front(); - pending->remove(handler); - if (handler->state() == Handler::REQUEST_STATE_WRITING || - handler->state() == Handler::REQUEST_STATE_READING) { - handler->on_timeout(); - handler->stop_timer(); + RequestCallback* callback = pending->front(); + pending->remove(callback); + if (callback->state() == RequestCallback::REQUEST_STATE_WRITING || + callback->state() == RequestCallback::REQUEST_STATE_READING) { + callback->on_timeout(); + callback->stop_timer(); } - handler->dec_ref(); + callback->dec_ref(); } } -void Connection::StartupHandler::on_set(ResponseMessage* response) { +void Connection::StartupCallback::on_set(ResponseMessage* response) { switch (response->opcode()) { case CQL_OPCODE_SUPPORTED: connection_->on_supported(response); @@ -129,7 +129,7 @@ void Connection::StartupHandler::on_set(ResponseMessage* response) { } } -void Connection::StartupHandler::on_error(CassError code, +void Connection::StartupCallback::on_error(CassError code, const std::string& message) { std::ostringstream ss; ss << "Error: '" << message @@ -137,13 +137,13 @@ void Connection::StartupHandler::on_error(CassError code, connection_->notify_error(ss.str()); } -void Connection::StartupHandler::on_timeout() { +void Connection::StartupCallback::on_timeout() { if (!connection_->is_closing()) { connection_->notify_error("Timed out", CONNECTION_ERROR_TIMEOUT); } } -void Connection::StartupHandler::on_result_response(ResponseMessage* response) { +void Connection::StartupCallback::on_result_response(ResponseMessage* response) { ResultResponse* result = static_cast(response->response_body().get()); switch (result->kind()) { @@ -156,25 +156,25 @@ void Connection::StartupHandler::on_result_response(ResponseMessage* response) { } } -Connection::HeartbeatHandler::HeartbeatHandler(Connection* connection) - : Handler(new OptionsRequest()) { +Connection::HeartbeatCallback::HeartbeatCallback(Connection* connection) + : RequestCallback(new OptionsRequest()) { set_connection(connection); } -void Connection::HeartbeatHandler::on_set(ResponseMessage* response) { +void Connection::HeartbeatCallback::on_set(ResponseMessage* response) { LOG_TRACE("Heartbeat completed on host %s", connection_->address_string().c_str()); connection_->heartbeat_outstanding_ = false; } -void Connection::HeartbeatHandler::on_error(CassError code, const std::string& message) { +void Connection::HeartbeatCallback::on_error(CassError code, const std::string& message) { LOG_WARN("An error occurred on host %s during a heartbeat request: %s", connection_->address_string().c_str(), message.c_str()); connection_->heartbeat_outstanding_ = false; } -void Connection::HeartbeatHandler::on_timeout() { +void Connection::HeartbeatCallback::on_timeout() { LOG_WARN("Heartbeat request timed out on host %s", connection_->address_string().c_str()); connection_->heartbeat_outstanding_ = false; @@ -240,23 +240,23 @@ void Connection::connect() { } } -bool Connection::write(Handler* handler, bool flush_immediately) { - bool result = internal_write(handler, flush_immediately); +bool Connection::write(RequestCallback* callback, bool flush_immediately) { + bool result = internal_write(callback, flush_immediately); if (result) { restart_heartbeat_timer(); } return result; } -bool Connection::internal_write(Handler* handler, bool flush_immediately) { - int stream = stream_manager_.acquire(handler); +bool Connection::internal_write(RequestCallback* callback, bool flush_immediately) { + int stream = stream_manager_.acquire(callback); if (stream < 0) { return false; } - handler->inc_ref(); // Connection reference - handler->set_connection(this); - handler->set_stream(stream); + callback->inc_ref(); // Connection reference + callback->set_connection(this); + callback->set_stream(stream); if (pending_writes_.is_empty() || pending_writes_.back()->is_flushed()) { if (ssl_session_) { @@ -268,7 +268,7 @@ bool Connection::internal_write(Handler* handler, bool flush_immediately) { PendingWriteBase *pending_write = pending_writes_.back(); - int32_t request_size = pending_write->write(handler); + int32_t request_size = pending_write->write(callback); if (request_size < 0) { stream_manager_.release(stream); switch (request_size) { @@ -278,11 +278,11 @@ bool Connection::internal_write(Handler* handler, bool flush_immediately) { break; default: - handler->on_error(CASS_ERROR_LIB_MESSAGE_ENCODE, + callback->on_error(CASS_ERROR_LIB_MESSAGE_ENCODE, "Operation unsupported by this protocol version"); break; } - handler->dec_ref(); + callback->dec_ref(); return true; // Don't retry } @@ -297,14 +297,14 @@ bool Connection::internal_write(Handler* handler, bool flush_immediately) { } LOG_TRACE("Sending message type %s with stream %d", - opcode_to_string(handler->request()->opcode()).c_str(), stream); + opcode_to_string(callback->request()->opcode()).c_str(), stream); - handler->set_state(Handler::REQUEST_STATE_WRITING); - uint64_t request_timeout_ms = handler->request_timeout_ms(config_); + callback->set_state(RequestCallback::REQUEST_STATE_WRITING); + uint64_t request_timeout_ms = callback->request_timeout_ms(config_); if (request_timeout_ms > 0) { // 0 means no timeout - handler->start_timer(loop_, + callback->start_timer(loop_, request_timeout_ms, - handler, + callback, Connection::on_timeout); } @@ -321,8 +321,8 @@ void Connection::flush() { pending_writes_.back()->flush(); } -void Connection::schedule_schema_agreement(const SharedRefPtr& handler, uint64_t wait) { - PendingSchemaAgreement* pending_schema_agreement = new PendingSchemaAgreement(handler); +void Connection::schedule_schema_agreement(const SharedRefPtr& callback, uint64_t wait) { + PendingSchemaAgreement* pending_schema_agreement = new PendingSchemaAgreement(callback); pending_schema_agreements_.add_to_back(pending_schema_agreement); pending_schema_agreement->timer.start(loop_, wait, @@ -462,36 +462,36 @@ void Connection::consume(char* input, size_t size) { opcode_to_string(response->opcode())); } } else { - Handler* handler = NULL; - if (stream_manager_.get_pending_and_release(response->stream(), handler)) { - switch (handler->state()) { - case Handler::REQUEST_STATE_READING: + RequestCallback* callback = NULL; + if (stream_manager_.get_pending_and_release(response->stream(), callback)) { + switch (callback->state()) { + case RequestCallback::REQUEST_STATE_READING: maybe_set_keyspace(response.get()); - pending_reads_.remove(handler); - handler->stop_timer(); - handler->set_state(Handler::REQUEST_STATE_DONE); - handler->on_set(response.get()); - handler->dec_ref(); + pending_reads_.remove(callback); + callback->stop_timer(); + callback->set_state(RequestCallback::REQUEST_STATE_DONE); + callback->on_set(response.get()); + callback->dec_ref(); break; - case Handler::REQUEST_STATE_WRITING: + case RequestCallback::REQUEST_STATE_WRITING: // There are cases when the read callback will happen // before the write callback. If this happens we have // to allow the write callback to cleanup. maybe_set_keyspace(response.get()); - handler->set_state(Handler::REQUEST_STATE_READ_BEFORE_WRITE); - handler->on_set(response.get()); + callback->set_state(RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE); + callback->on_set(response.get()); break; - case Handler::REQUEST_STATE_TIMEOUT: - pending_reads_.remove(handler); - handler->set_state(Handler::REQUEST_STATE_DONE); - handler->dec_ref(); + case RequestCallback::REQUEST_STATE_TIMEOUT: + pending_reads_.remove(callback); + callback->set_state(RequestCallback::REQUEST_STATE_DONE); + callback->dec_ref(); break; - case Handler::REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING: + case RequestCallback::REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING: // We must wait for the write callback before we can do the cleanup - handler->set_state(Handler::REQUEST_STATE_READ_BEFORE_WRITE); + callback->set_state(RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE); break; default: @@ -577,7 +577,7 @@ void Connection::on_close(uv_handle_t* handle) { static_cast(connection), connection->host_->address_string().c_str()); - cleanup_pending_handlers(&connection->pending_reads_); + cleanup_pending_callbacks(&connection->pending_reads_); while (!connection->pending_writes_.is_empty()) { PendingWriteBase* pending_write @@ -591,7 +591,7 @@ void Connection::on_close(uv_handle_t* handle) { = connection->pending_schema_agreements_.front(); connection->pending_schema_agreements_.remove(pending_schema_aggreement); pending_schema_aggreement->stop_timer(); - pending_schema_aggreement->handler->on_closing(); + pending_schema_aggreement->callback->on_closing(); delete pending_schema_aggreement; } @@ -726,22 +726,22 @@ void Connection::on_read_ssl(uv_stream_t* client, ssize_t nread, const uv_buf_t* } void Connection::on_timeout(Timer* timer) { - Handler* handler = static_cast(timer->data()); - Connection* connection = handler->connection(); + RequestCallback* callback = static_cast(timer->data()); + Connection* connection = callback->connection(); LOG_INFO("Request timed out to host %s on connection(%p)", connection->host_->address_string().c_str(), static_cast(connection)); // TODO (mpenick): We need to handle the case where we have too many // timeout requests and we run out of stream ids. The java-driver // uses a threshold to defunct the connection. - handler->set_state(Handler::REQUEST_STATE_TIMEOUT); - handler->on_timeout(); + callback->set_state(RequestCallback::REQUEST_STATE_TIMEOUT); + callback->on_timeout(); connection->metrics_->request_timeouts.inc(); } void Connection::on_connected() { - internal_write(new StartupHandler(this, new OptionsRequest())); + internal_write(new StartupCallback(this, new OptionsRequest())); } void Connection::on_authenticate(const std::string& class_name) { @@ -761,7 +761,7 @@ void Connection::on_auth_challenge(const AuthResponseRequest* request, } AuthResponseRequest* auth_response = new AuthResponseRequest(response, request->auth()); - internal_write(new StartupHandler(this, auth_response)); + internal_write(new StartupCallback(this, auth_response)); } void Connection::on_auth_success(const AuthResponseRequest* request, @@ -776,14 +776,14 @@ void Connection::on_auth_success(const AuthResponseRequest* request, void Connection::on_ready() { if (state_ == CONNECTION_STATE_CONNECTED && listener_->event_types() != 0) { set_state(CONNECTION_STATE_REGISTERING_EVENTS); - internal_write(new StartupHandler(this, new RegisterRequest(listener_->event_types()))); + internal_write(new StartupCallback(this, new RegisterRequest(listener_->event_types()))); return; } if (keyspace_.empty()) { notify_ready(); } else { - internal_write(new StartupHandler(this, new QueryRequest("USE \"" + keyspace_ + "\""))); + internal_write(new StartupCallback(this, new QueryRequest("USE \"" + keyspace_ + "\""))); } } @@ -798,15 +798,15 @@ void Connection::on_supported(ResponseMessage* response) { // TODO(mstump) do something with the supported info (void)supported; - internal_write(new StartupHandler(this, new StartupRequest())); + internal_write(new StartupCallback(this, new StartupRequest())); } void Connection::on_pending_schema_agreement(Timer* timer) { PendingSchemaAgreement* pending_schema_agreement = static_cast(timer->data()); - Connection* connection = pending_schema_agreement->handler->connection(); + Connection* connection = pending_schema_agreement->callback->connection(); connection->pending_schema_agreements_.remove(pending_schema_agreement); - pending_schema_agreement->handler->execute(); + pending_schema_agreement->callback->execute(); delete pending_schema_agreement; } @@ -865,7 +865,7 @@ void Connection::send_credentials(const std::string& class_name) { if (v1_auth) { V1Authenticator::Credentials credentials; v1_auth->get_credentials(&credentials); - internal_write(new StartupHandler(this, new CredentialsRequest(credentials))); + internal_write(new StartupCallback(this, new CredentialsRequest(credentials))); } else { send_initial_auth_response(class_name); } @@ -882,7 +882,7 @@ void Connection::send_initial_auth_response(const std::string& class_name) { return; } AuthResponseRequest* auth_response = new AuthResponseRequest(response, auth); - internal_write(new StartupHandler(this, auth_response)); + internal_write(new StartupCallback(this, auth_response)); } } @@ -898,7 +898,7 @@ void Connection::on_heartbeat(Timer* timer) { Connection* connection = static_cast(timer->data()); if (!connection->heartbeat_outstanding_) { - if (!connection->internal_write(new HeartbeatHandler(connection))) { + if (!connection->internal_write(new HeartbeatCallback(connection))) { // Recycling only this connection with a timeout error. This is unlikely and // it means the connection ran out of stream IDs as a result of requests // that never returned and as a result timed out. @@ -936,19 +936,19 @@ void Connection::PendingSchemaAgreement::stop_timer() { } Connection::PendingWriteBase::~PendingWriteBase() { - cleanup_pending_handlers(&handlers_); + cleanup_pending_callbacks(&callbacks_); } -int32_t Connection::PendingWriteBase::write(Handler* handler) { +int32_t Connection::PendingWriteBase::write(RequestCallback* callback) { size_t last_buffer_size = buffers_.size(); - int32_t request_size = handler->encode(connection_->protocol_version_, 0x00, &buffers_); + int32_t request_size = callback->encode(connection_->protocol_version_, 0x00, &buffers_); if (request_size < 0) { buffers_.resize(last_buffer_size); // rollback return request_size; } size_ += request_size; - handlers_.add_to_back(handler); + callbacks_.add_to_back(callback); return request_size; } @@ -958,16 +958,16 @@ void Connection::PendingWriteBase::on_write(uv_write_t* req, int status) { Connection* connection = static_cast(pending_write->connection_); - while (!pending_write->handlers_.is_empty()) { - Handler* handler = pending_write->handlers_.front(); + while (!pending_write->callbacks_.is_empty()) { + RequestCallback* callback = pending_write->callbacks_.front(); - pending_write->handlers_.remove(handler); + pending_write->callbacks_.remove(callback); - switch (handler->state()) { - case Handler::REQUEST_STATE_WRITING: + switch (callback->state()) { + case RequestCallback::REQUEST_STATE_WRITING: if (status == 0) { - handler->set_state(Handler::REQUEST_STATE_READING); - connection->pending_reads_.add_to_back(handler); + callback->set_state(RequestCallback::REQUEST_STATE_READING); + connection->pending_reads_.add_to_back(callback); } else { if (!connection->is_closing()) { connection->notify_error("Write error '" + @@ -976,33 +976,33 @@ void Connection::PendingWriteBase::on_write(uv_write_t* req, int status) { connection->defunct(); } - connection->stream_manager_.release(handler->stream()); - handler->stop_timer(); - handler->set_state(Handler::REQUEST_STATE_DONE); - handler->on_error(CASS_ERROR_LIB_WRITE_ERROR, + connection->stream_manager_.release(callback->stream()); + callback->stop_timer(); + callback->set_state(RequestCallback::REQUEST_STATE_DONE); + callback->on_error(CASS_ERROR_LIB_WRITE_ERROR, "Unable to write to socket"); - handler->dec_ref(); + callback->dec_ref(); } break; - case Handler::REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING: + case RequestCallback::REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING: // The read may still come back, handle cleanup there - handler->set_state(Handler::REQUEST_STATE_TIMEOUT); - connection->pending_reads_.add_to_back(handler); + callback->set_state(RequestCallback::REQUEST_STATE_TIMEOUT); + connection->pending_reads_.add_to_back(callback); break; - case Handler::REQUEST_STATE_READ_BEFORE_WRITE: + case RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE: // The read callback happened before the write callback // returned. This is now responsible for cleanup. - handler->stop_timer(); - handler->set_state(Handler::REQUEST_STATE_DONE); - handler->dec_ref(); + callback->stop_timer(); + callback->set_state(RequestCallback::REQUEST_STATE_DONE); + callback->dec_ref(); break; - case Handler::REQUEST_STATE_RETRY_WRITE_OUTSTANDING: - handler->stop_timer(); - handler->retry(); - handler->dec_ref(); + case RequestCallback::REQUEST_STATE_RETRY_WRITE_OUTSTANDING: + callback->stop_timer(); + callback->retry(); + callback->dec_ref(); break; default: diff --git a/src/connection.hpp b/src/connection.hpp index 6bdc2984a..74f092008 100644 --- a/src/connection.hpp +++ b/src/connection.hpp @@ -19,7 +19,7 @@ #include "buffer.hpp" #include "cassandra.h" -#include "handler.hpp" +#include "request_callback.hpp" #include "hash.hpp" #include "host.hpp" #include "list.hpp" @@ -28,7 +28,7 @@ #include "ref_counted.hpp" #include "request.hpp" #include "response.hpp" -#include "schema_change_handler.hpp" +#include "schema_change_callback.hpp" #include "scoped_ptr.hpp" #include "ssl.hpp" #include "stream_manager.hpp" @@ -104,10 +104,10 @@ class Connection { void connect(); - bool write(Handler* request, bool flush_immediately = true); + bool write(RequestCallback* request, bool flush_immediately = true); void flush(); - void schedule_schema_agreement(const SharedRefPtr& handler, uint64_t wait); + void schedule_schema_agreement(const SharedRefPtr& callback, uint64_t wait); const Config& config() const { return config_; } Metrics* metrics() { return metrics_; } @@ -163,10 +163,10 @@ class Connection { char buf_[MAX_BUFFER_SIZE]; }; - class StartupHandler : public Handler { + class StartupCallback : public RequestCallback { public: - StartupHandler(Connection* connection, Request* request) - : Handler(request) { + StartupCallback(Connection* connection, Request* request) + : RequestCallback(request) { set_connection(connection); } @@ -178,9 +178,9 @@ class Connection { void on_result_response(ResponseMessage* response); }; - class HeartbeatHandler : public Handler { + class HeartbeatCallback : public RequestCallback { public: - HeartbeatHandler(Connection* connection); + HeartbeatCallback(Connection* connection); virtual void on_set(ResponseMessage* response); virtual void on_error(CassError code, const std::string& message); @@ -206,7 +206,7 @@ class Connection { return size_; } - int32_t write(Handler* handler); + int32_t write(RequestCallback* callback); virtual void flush() = 0; @@ -218,7 +218,7 @@ class Connection { bool is_flushed_; size_t size_; BufferVec buffers_; - List handlers_; + List callbacks_; }; class PendingWrite : public PendingWriteBase { @@ -245,16 +245,16 @@ class Connection { struct PendingSchemaAgreement : public List::Node { - PendingSchemaAgreement(const SharedRefPtr& handler) - : handler(handler) { } + PendingSchemaAgreement(const SharedRefPtr& callback) + : callback(callback) { } void stop_timer(); - SharedRefPtr handler; + SharedRefPtr callback; Timer timer; }; - bool internal_write(Handler* request, bool flush_immediately = true); + bool internal_write(RequestCallback* request, bool flush_immediately = true); void internal_close(ConnectionState close_state); void set_state(ConnectionState state); void consume(char* input, size_t size); @@ -309,7 +309,7 @@ class Connection { size_t pending_writes_size_; List pending_writes_; - List pending_reads_; + List pending_reads_; List pending_schema_agreements_; uv_loop_t* loop_; @@ -321,7 +321,7 @@ class Connection { Listener* listener_; ScopedPtr response_; - StreamManager stream_manager_; + StreamManager stream_manager_; uv_tcp_t socket_; Timer connect_timer_; diff --git a/src/control_connection.cpp b/src/control_connection.cpp index bbddffa8f..1a14736c3 100644 --- a/src/control_connection.cpp +++ b/src/control_connection.cpp @@ -406,18 +406,18 @@ void ControlConnection::on_event(EventResponse* response) { } void ControlConnection::query_meta_hosts() { - ScopedRefPtr > handler( - new ControlMultipleRequestHandler(this, ControlConnection::on_query_hosts, UnusedData())); + ScopedRefPtr > callback( + new ControlMultipleRequestCallback(this, ControlConnection::on_query_hosts, UnusedData())); // This needs to happen before other schema metadata queries so that we have // a valid Cassandra version because this version determines which follow up // schema metadata queries are executed. - handler->execute_query("local", token_aware_routing_ ? SELECT_LOCAL_TOKENS : SELECT_LOCAL); - handler->execute_query("peers", token_aware_routing_ ? SELECT_PEERS_TOKENS : SELECT_PEERS); + callback->execute_query("local", token_aware_routing_ ? SELECT_LOCAL_TOKENS : SELECT_LOCAL); + callback->execute_query("peers", token_aware_routing_ ? SELECT_PEERS_TOKENS : SELECT_PEERS); } void ControlConnection::on_query_hosts(ControlConnection* control_connection, const UnusedData& data, - const MultipleRequestHandler::ResponseMap& responses) { + const MultipleRequestCallback::ResponseMap& responses) { Connection* connection = control_connection->connection_; if (connection == NULL) { return; @@ -443,7 +443,7 @@ void ControlConnection::on_query_hosts(ControlConnection* control_connection, host->set_mark(session->current_host_mark_); ResultResponse* local_result; - if (MultipleRequestHandler::get_result_response(responses, "local", &local_result) && + if (MultipleRequestCallback::get_result_response(responses, "local", &local_result) && local_result->row_count() > 0) { control_connection->update_node_info(host, &local_result->first_row(), ADD_HOST); control_connection->cassandra_version_ = host->cassandra_version(); @@ -463,7 +463,7 @@ void ControlConnection::on_query_hosts(ControlConnection* control_connection, { ResultResponse* peers_result; - if (MultipleRequestHandler::get_result_response(responses, "peers", &peers_result)) { + if (MultipleRequestCallback::get_result_response(responses, "peers", &peers_result)) { ResultIterator rows(peers_result); while (rows.next()) { Address address; @@ -509,35 +509,35 @@ void ControlConnection::on_query_hosts(ControlConnection* control_connection, //TODO: query and callbacks should be in Metadata // punting for now because of tight coupling of Session and CC state void ControlConnection::query_meta_schema() { - ScopedRefPtr > handler( - new ControlMultipleRequestHandler(this, ControlConnection::on_query_meta_schema, UnusedData())); + ScopedRefPtr > callback( + new ControlMultipleRequestCallback(this, ControlConnection::on_query_meta_schema, UnusedData())); if (cassandra_version_ >= VersionNumber(3, 0, 0)) { if (use_schema_ || token_aware_routing_) { - handler->execute_query("keyspaces", SELECT_KEYSPACES_30); + callback->execute_query("keyspaces", SELECT_KEYSPACES_30); } if (use_schema_) { - handler->execute_query("tables", SELECT_TABLES_30); - handler->execute_query("views", SELECT_VIEWS_30); - handler->execute_query("columns", SELECT_COLUMNS_30); - handler->execute_query("indexes", SELECT_INDEXES_30); - handler->execute_query("user_types", SELECT_USERTYPES_30); - handler->execute_query("functions", SELECT_FUNCTIONS_30); - handler->execute_query("aggregates", SELECT_AGGREGATES_30); + callback->execute_query("tables", SELECT_TABLES_30); + callback->execute_query("views", SELECT_VIEWS_30); + callback->execute_query("columns", SELECT_COLUMNS_30); + callback->execute_query("indexes", SELECT_INDEXES_30); + callback->execute_query("user_types", SELECT_USERTYPES_30); + callback->execute_query("functions", SELECT_FUNCTIONS_30); + callback->execute_query("aggregates", SELECT_AGGREGATES_30); } } else { if (use_schema_ || token_aware_routing_) { - handler->execute_query("keyspaces", SELECT_KEYSPACES_20); + callback->execute_query("keyspaces", SELECT_KEYSPACES_20); } if (use_schema_) { - handler->execute_query("tables", SELECT_COLUMN_FAMILIES_20); - handler->execute_query("columns", SELECT_COLUMNS_20); + callback->execute_query("tables", SELECT_COLUMN_FAMILIES_20); + callback->execute_query("columns", SELECT_COLUMNS_20); if (cassandra_version_ >= VersionNumber(2, 1, 0)) { - handler->execute_query("user_types", SELECT_USERTYPES_21); + callback->execute_query("user_types", SELECT_USERTYPES_21); } if (cassandra_version_ >= VersionNumber(2, 2, 0)) { - handler->execute_query("functions", SELECT_FUNCTIONS_22); - handler->execute_query("aggregates", SELECT_AGGREGATES_22); + callback->execute_query("functions", SELECT_FUNCTIONS_22); + callback->execute_query("aggregates", SELECT_AGGREGATES_22); } } } @@ -545,7 +545,7 @@ void ControlConnection::query_meta_schema() { void ControlConnection::on_query_meta_schema(ControlConnection* control_connection, const UnusedData& unused, - const MultipleRequestHandler::ResponseMap& responses) { + const MultipleRequestCallback::ResponseMap& responses) { Connection* connection = control_connection->connection_; if (connection == NULL) { return; @@ -559,7 +559,7 @@ void ControlConnection::on_query_meta_schema(ControlConnection* control_connecti if (session->token_map_) { ResultResponse* keyspaces_result; - if (MultipleRequestHandler::get_result_response(responses, "keyspaces", &keyspaces_result)) { + if (MultipleRequestCallback::get_result_response(responses, "keyspaces", &keyspaces_result)) { session->token_map_->clear_replicas_and_strategies(); // Only clear replicas once we have the new keyspaces session->token_map_->add_keyspaces(cassandra_version, keyspaces_result); } @@ -570,42 +570,42 @@ void ControlConnection::on_query_meta_schema(ControlConnection* control_connecti session->metadata().clear_and_update_back(cassandra_version); ResultResponse* keyspaces_result; - if (MultipleRequestHandler::get_result_response(responses, "keyspaces", &keyspaces_result)) { + if (MultipleRequestCallback::get_result_response(responses, "keyspaces", &keyspaces_result)) { session->metadata().update_keyspaces(protocol_version, cassandra_version, keyspaces_result); } ResultResponse* tables_result; - if (MultipleRequestHandler::get_result_response(responses, "tables", &tables_result)) { + if (MultipleRequestCallback::get_result_response(responses, "tables", &tables_result)) { session->metadata().update_tables(protocol_version, cassandra_version, tables_result); } ResultResponse* views_result; - if (MultipleRequestHandler::get_result_response(responses, "views", &views_result)) { + if (MultipleRequestCallback::get_result_response(responses, "views", &views_result)) { session->metadata().update_views(protocol_version, cassandra_version, views_result); } ResultResponse* columns_result = NULL; - if (MultipleRequestHandler::get_result_response(responses, "columns", &columns_result)) { + if (MultipleRequestCallback::get_result_response(responses, "columns", &columns_result)) { session->metadata().update_columns(protocol_version, cassandra_version, columns_result); } ResultResponse* indexes_result; - if (MultipleRequestHandler::get_result_response(responses, "indexes", &indexes_result)) { + if (MultipleRequestCallback::get_result_response(responses, "indexes", &indexes_result)) { session->metadata().update_indexes(protocol_version, cassandra_version, indexes_result); } ResultResponse* user_types_result; - if (MultipleRequestHandler::get_result_response(responses, "user_types", &user_types_result)) { + if (MultipleRequestCallback::get_result_response(responses, "user_types", &user_types_result)) { session->metadata().update_user_types(protocol_version, cassandra_version, user_types_result); } ResultResponse* functions_result; - if (MultipleRequestHandler::get_result_response(responses, "functions", &functions_result)) { + if (MultipleRequestCallback::get_result_response(responses, "functions", &functions_result)) { session->metadata().update_functions(protocol_version, cassandra_version, functions_result); } ResultResponse* aggregates_result; - if (MultipleRequestHandler::get_result_response(responses, "aggregates", &aggregates_result)) { + if (MultipleRequestCallback::get_result_response(responses, "aggregates", &aggregates_result)) { session->metadata().update_aggregates(protocol_version, cassandra_version, aggregates_result); } @@ -631,7 +631,7 @@ void ControlConnection::refresh_node_info(SharedRefPtr host, bool is_connected_host = host->address().compare(connection_->address()) == 0; std::string query; - ControlHandler::ResponseCallback response_callback; + ControlCallback::ResponseCallback response_callback; bool token_query = token_aware_routing_ && (host->was_just_added() || query_tokens); if (is_connected_host || !host->listen_address().empty()) { @@ -652,12 +652,12 @@ void ControlConnection::refresh_node_info(SharedRefPtr host, LOG_DEBUG("refresh_node_info: %s", query.c_str()); RefreshNodeData data(host, is_new_node); - ScopedRefPtr > handler( - new ControlHandler(new QueryRequest(query), + ScopedRefPtr > callback( + new ControlCallback(new QueryRequest(query), this, response_callback, data)); - if (!connection_->write(handler.get())) { + if (!connection_->write(callback.get())) { LOG_ERROR("No more stream available while attempting to refresh node info"); } } @@ -807,7 +807,7 @@ void ControlConnection::refresh_keyspace(const StringRef& keyspace_name) { LOG_DEBUG("Refreshing keyspace %s", query.c_str()); connection_->write( - new ControlHandler(new QueryRequest(query), + new ControlCallback(new QueryRequest(query), this, ControlConnection::on_refresh_keyspace, keyspace_name.to_string())); @@ -874,31 +874,31 @@ void ControlConnection::refresh_table_or_view(const StringRef& keyspace_name, LOG_DEBUG("Refreshing table %s; %s", table_query.c_str(), column_query.c_str()); } - ScopedRefPtr > handler( - new ControlMultipleRequestHandler(this, + ScopedRefPtr > callback( + new ControlMultipleRequestCallback(this, ControlConnection::on_refresh_table_or_view, RefreshTableData(keyspace_name.to_string(), table_or_view_name.to_string()))); - handler->execute_query("tables", table_query); + callback->execute_query("tables", table_query); if (!view_query.empty()) { - handler->execute_query("views", view_query); + callback->execute_query("views", view_query); } - handler->execute_query("columns", column_query); + callback->execute_query("columns", column_query); if (!index_query.empty()) { - handler->execute_query("indexes", index_query); + callback->execute_query("indexes", index_query); } } void ControlConnection::on_refresh_table_or_view(ControlConnection* control_connection, const RefreshTableData& data, - const MultipleRequestHandler::ResponseMap& responses) { + const MultipleRequestCallback::ResponseMap& responses) { ResultResponse* tables_result; Session* session = control_connection->session_; int protocol_version = control_connection->protocol_version_; const VersionNumber& cassandra_version = control_connection->cassandra_version_; - if (!MultipleRequestHandler::get_result_response(responses, "tables", &tables_result) || + if (!MultipleRequestCallback::get_result_response(responses, "tables", &tables_result) || tables_result->row_count() == 0) { ResultResponse* views_result; - if (!MultipleRequestHandler::get_result_response(responses, "views", &views_result) || + if (!MultipleRequestCallback::get_result_response(responses, "views", &views_result) || views_result->row_count() == 0) { LOG_ERROR("No row found for table (or view) %s.%s in system schema tables.", data.keyspace_name.c_str(), data.table_or_view_name.c_str()); @@ -910,12 +910,12 @@ void ControlConnection::on_refresh_table_or_view(ControlConnection* control_conn } ResultResponse* columns_result; - if (MultipleRequestHandler::get_result_response(responses, "columns", &columns_result)) { + if (MultipleRequestCallback::get_result_response(responses, "columns", &columns_result)) { session->metadata().update_columns(protocol_version, cassandra_version, columns_result); } ResultResponse* indexes_result; - if (MultipleRequestHandler::get_result_response(responses, "indexes", &indexes_result)) { + if (MultipleRequestCallback::get_result_response(responses, "indexes", &indexes_result)) { session->metadata().update_indexes(protocol_version, cassandra_version, indexes_result); } } @@ -937,7 +937,7 @@ void ControlConnection::refresh_type(const StringRef& keyspace_name, LOG_DEBUG("Refreshing type %s", query.c_str()); connection_->write( - new ControlHandler >(new QueryRequest(query), + new ControlCallback >(new QueryRequest(query), this, ControlConnection::on_refresh_type, std::make_pair(keyspace_name.to_string(), type_name.to_string()))); @@ -1003,7 +1003,7 @@ void ControlConnection::refresh_function(const StringRef& keyspace_name, request->set(2, signature.get()); connection_->write( - new ControlHandler(request.get(), + new ControlCallback(request.get(), this, ControlConnection::on_refresh_function, RefreshFunctionData(keyspace_name, function_name, arg_types, is_aggregate))); @@ -1091,21 +1091,21 @@ void ControlConnection::on_reconnect(Timer* timer) { } template -void ControlConnection::ControlMultipleRequestHandler::execute_query( +void ControlConnection::ControlMultipleRequestCallback::execute_query( const std::string& index, const std::string& query) { // We need to update the loop time to prevent new requests from timing out // in cases where a callback took a long time to execute. In the future, // we might improve this by executing the these long running callbacks // on a seperate thread. uv_update_time(control_connection_->session_->loop()); - MultipleRequestHandler::execute_query(index, query); + MultipleRequestCallback::execute_query(index, query); } template -void ControlConnection::ControlMultipleRequestHandler::on_set( - const MultipleRequestHandler::ResponseMap& responses) { +void ControlConnection::ControlMultipleRequestCallback::on_set( + const MultipleRequestCallback::ResponseMap& responses) { bool has_error = false; - for (MultipleRequestHandler::ResponseMap::const_iterator it = responses.begin(), + for (MultipleRequestCallback::ResponseMap::const_iterator it = responses.begin(), end = responses.end(); it != end; ++it) { if (control_connection_->handle_query_invalid_response(it->second.get())) { has_error = true; diff --git a/src/control_connection.hpp b/src/control_connection.hpp index 011e345e9..fd4b2f0e8 100644 --- a/src/control_connection.hpp +++ b/src/control_connection.hpp @@ -19,11 +19,10 @@ #include "address.hpp" #include "connection.hpp" -#include "handler.hpp" +#include "request_callback.hpp" #include "host.hpp" #include "load_balancing.hpp" #include "macros.hpp" -#include "multiple_request_handler.hpp" #include "response.hpp" #include "scoped_ptr.hpp" #include "token_map.hpp" @@ -75,21 +74,21 @@ class ControlConnection : public Connection::Listener { private: template - class ControlMultipleRequestHandler : public MultipleRequestHandler { + class ControlMultipleRequestCallback : public MultipleRequestCallback { public: - typedef void (*ResponseCallback)(ControlConnection*, const T&, const MultipleRequestHandler::ResponseMap&); + typedef void (*ResponseCallback)(ControlConnection*, const T&, const MultipleRequestCallback::ResponseMap&); - ControlMultipleRequestHandler(ControlConnection* control_connection, + ControlMultipleRequestCallback(ControlConnection* control_connection, ResponseCallback response_callback, const T& data) - : MultipleRequestHandler(control_connection->connection_) + : MultipleRequestCallback(control_connection->connection_) , control_connection_(control_connection) , response_callback_(response_callback) , data_(data) {} void execute_query(const std::string& index, const std::string& query); - virtual void on_set(const MultipleRequestHandler::ResponseMap& responses); + virtual void on_set(const MultipleRequestCallback::ResponseMap& responses); virtual void on_error(CassError code, const std::string& message) { control_connection_->handle_query_failure(code, message); @@ -117,15 +116,15 @@ class ControlConnection : public Connection::Listener { struct UnusedData {}; template - class ControlHandler : public Handler { + class ControlCallback : public RequestCallback { public: typedef void (*ResponseCallback)(ControlConnection*, const T&, Response*); - ControlHandler(const Request* request, + ControlCallback(const Request* request, ControlConnection* control_connection, ResponseCallback response_callback, const T& data) - : Handler(request) + : RequestCallback(request) , control_connection_(control_connection) , response_callback_(response_callback) , data_(data) {} @@ -202,12 +201,12 @@ class ControlConnection : public Connection::Listener { void query_meta_hosts(); static void on_query_hosts(ControlConnection* control_connection, const UnusedData& data, - const MultipleRequestHandler::ResponseMap& responses); + const MultipleRequestCallback::ResponseMap& responses); void query_meta_schema(); static void on_query_meta_schema(ControlConnection* control_connection, const UnusedData& data, - const MultipleRequestHandler::ResponseMap& responses); + const MultipleRequestCallback::ResponseMap& responses); void refresh_node_info(SharedRefPtr host, bool is_new_node, @@ -228,7 +227,7 @@ class ControlConnection : public Connection::Listener { const StringRef& table_name); static void on_refresh_table_or_view(ControlConnection* control_connection, const RefreshTableData& data, - const MultipleRequestHandler::ResponseMap& responses); + const MultipleRequestCallback::ResponseMap& responses); void refresh_type(const StringRef& keyspace_name, const StringRef& type_name); diff --git a/src/execute_request.cpp b/src/execute_request.cpp index 12296a6bd..343183c25 100644 --- a/src/execute_request.cpp +++ b/src/execute_request.cpp @@ -17,11 +17,11 @@ #include "execute_request.hpp" #include "constants.hpp" -#include "handler.hpp" +#include "request_callback.hpp" namespace cass { -int32_t ExecuteRequest::encode_batch(int version, BufferVec* bufs, Handler* handler) const { +int32_t ExecuteRequest::encode_batch(int version, BufferVec* bufs, RequestCallback* callback) const { int32_t length = 0; const std::string& id(prepared_->id()); @@ -37,7 +37,7 @@ int32_t ExecuteRequest::encode_batch(int version, BufferVec* bufs, Handler* hand buf.encode_uint16(pos, elements_count()); if (elements_count() > 0) { - int32_t result = copy_buffers(version, bufs, handler); + int32_t result = copy_buffers(version, bufs, callback); if (result < 0) return result; length += result; } @@ -45,15 +45,15 @@ int32_t ExecuteRequest::encode_batch(int version, BufferVec* bufs, Handler* hand return length; } -int ExecuteRequest::encode(int version, Handler* handler, BufferVec* bufs) const { +int ExecuteRequest::encode(int version, RequestCallback* callback, BufferVec* bufs) const { if (version == 1) { - return internal_encode_v1(handler, bufs); + return internal_encode_v1(callback, bufs); } else { - return internal_encode(version, handler, bufs); + return internal_encode(version, callback, bufs); } } -int ExecuteRequest::internal_encode_v1(Handler* handler, BufferVec* bufs) const { +int ExecuteRequest::internal_encode_v1(RequestCallback* callback, BufferVec* bufs) const { size_t length = 0; const int version = 1; @@ -73,7 +73,7 @@ int ExecuteRequest::internal_encode_v1(Handler* handler, BufferVec* bufs) const prepared_id.size()); buf.encode_uint16(pos, elements_count()); // ... - int32_t result = copy_buffers(version, bufs, handler); + int32_t result = copy_buffers(version, bufs, callback); if (result < 0) return result; length += result; } @@ -83,7 +83,7 @@ int ExecuteRequest::internal_encode_v1(Handler* handler, BufferVec* bufs) const size_t buf_size = sizeof(uint16_t); Buffer buf(buf_size); - buf.encode_uint16(0, handler->consistency()); + buf.encode_uint16(0, callback->consistency()); bufs->push_back(buf); length += buf_size; } @@ -91,7 +91,7 @@ int ExecuteRequest::internal_encode_v1(Handler* handler, BufferVec* bufs) const return length; } -int ExecuteRequest::internal_encode(int version, Handler* handler, BufferVec* bufs) const { +int ExecuteRequest::internal_encode(int version, RequestCallback* callback, BufferVec* bufs) const { int length = 0; uint8_t flags = this->flags(); @@ -122,7 +122,7 @@ int ExecuteRequest::internal_encode(int version, Handler* handler, BufferVec* bu flags |= CASS_QUERY_FLAG_SERIAL_CONSISTENCY; } - if (version >= 3 && handler->timestamp() != CASS_INT64_MIN) { + if (version >= 3 && callback->timestamp() != CASS_INT64_MIN) { paging_buf_size += sizeof(int64_t); // [long] flags |= CASS_QUERY_FLAG_DEFAULT_TIMESTAMP; } @@ -135,12 +135,12 @@ int ExecuteRequest::internal_encode(int version, Handler* handler, BufferVec* bu size_t pos = buf.encode_string(0, prepared_id.data(), prepared_id.size()); - pos = buf.encode_uint16(pos, handler->consistency()); + pos = buf.encode_uint16(pos, callback->consistency()); pos = buf.encode_byte(pos, flags); if (elements_count() > 0) { buf.encode_uint16(pos, elements_count()); - int32_t result = copy_buffers(version, bufs, handler); + int32_t result = copy_buffers(version, bufs, callback); if (result < 0) return result; length += result; } @@ -165,8 +165,8 @@ int ExecuteRequest::internal_encode(int version, Handler* handler, BufferVec* bu pos = buf.encode_uint16(pos, serial_consistency()); } - if (version >= 3 && handler->timestamp() != CASS_INT64_MIN) { - pos = buf.encode_int64(pos, handler->timestamp()); + if (version >= 3 && callback->timestamp() != CASS_INT64_MIN) { + pos = buf.encode_int64(pos, callback->timestamp()); } } diff --git a/src/execute_request.hpp b/src/execute_request.hpp index 164b79a65..f35f4b2a2 100644 --- a/src/execute_request.hpp +++ b/src/execute_request.hpp @@ -54,12 +54,12 @@ class ExecuteRequest : public Statement { return metadata_->get_column_definition(index).data_type; } - virtual int32_t encode_batch(int version, BufferVec* bufs, Handler* handler) const; + virtual int32_t encode_batch(int version, BufferVec* bufs, RequestCallback* callback) const; private: - int encode(int version, Handler* handler, BufferVec* bufs) const; - int internal_encode_v1(Handler* handler, BufferVec* bufs) const; - int internal_encode(int version, Handler* handler, BufferVec* bufs) const; + int encode(int version, RequestCallback* callback, BufferVec* bufs) const; + int internal_encode_v1(RequestCallback* callback, BufferVec* bufs) const; + int internal_encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: SharedRefPtr prepared_; diff --git a/src/multiple_request_handler.cpp b/src/multiple_request_handler.cpp deleted file mode 100644 index 065ef30aa..000000000 --- a/src/multiple_request_handler.cpp +++ /dev/null @@ -1,66 +0,0 @@ -/* - Copyright (c) 2014-2016 DataStax - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#include "multiple_request_handler.hpp" - -#include "connection.hpp" -#include "query_request.hpp" - -namespace cass { - -bool MultipleRequestHandler::get_result_response(const ResponseMap& responses, - const std::string& index, - ResultResponse** response) { - ResponseMap::const_iterator it = responses.find(index); - if (it == responses.end() || it->second->opcode() != CQL_OPCODE_RESULT) { - return false; - } - *response = static_cast(it->second.get()); - return true; -} - -void MultipleRequestHandler::execute_query(const std::string& index, const std::string& query) { - if (has_errors_or_timeouts_) return; - responses_[index] = SharedRefPtr(); - SharedRefPtr handler(new InternalHandler(this, new QueryRequest(query), index)); - remaining_++; - if (!connection_->write(handler.get())) { - on_error(CASS_ERROR_LIB_NO_STREAMS, "No more streams available"); - } -} - -void MultipleRequestHandler::InternalHandler::on_set(ResponseMessage* response) { - parent_->responses_[index_] = response->response_body(); - if (--parent_->remaining_ == 0 && !parent_->has_errors_or_timeouts_) { - parent_->on_set(parent_->responses_); - } -} - -void MultipleRequestHandler::InternalHandler::on_error(CassError code, const std::string& message) { - if (!parent_->has_errors_or_timeouts_) { - parent_->on_error(code, message); - } - parent_->has_errors_or_timeouts_ = true; -} - -void MultipleRequestHandler::InternalHandler::on_timeout() { - if (!parent_->has_errors_or_timeouts_) { - parent_->on_timeout(); - } - parent_->has_errors_or_timeouts_ = true; -} - -} // namespace cass diff --git a/src/multiple_request_handler.hpp b/src/multiple_request_handler.hpp deleted file mode 100644 index f796a9d1a..000000000 --- a/src/multiple_request_handler.hpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - Copyright (c) 2014-2016 DataStax - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#ifndef __CASS_MULTIPLE_REQUEST_HANDLER_HPP_INCLUDED__ -#define __CASS_MULTIPLE_REQUEST_HANDLER_HPP_INCLUDED__ - -#include "handler.hpp" -#include "ref_counted.hpp" -#include "request.hpp" - -#include -#include - -namespace cass { - -class Connection; -class Response; -class ResultResponse; - -class MultipleRequestHandler : public RefCounted { -public: - typedef std::map > ResponseMap; - - MultipleRequestHandler(Connection* connection) - : connection_(connection) - , has_errors_or_timeouts_(false) - , remaining_(0) { } - - virtual ~MultipleRequestHandler() { } - - static bool get_result_response(const ResponseMap& responses, - const std::string& index, - ResultResponse** response); - - void execute_query(const std::string& index, const std::string& query); - - virtual void on_set(const ResponseMap& responses) = 0; - virtual void on_error(CassError code, const std::string& message) = 0; - virtual void on_timeout() = 0; - - Connection* connection() { - return connection_; - } - -private: - class InternalHandler : public Handler { - public: - InternalHandler(MultipleRequestHandler* parent, const Request* request, const std::string& index) - : Handler(request) - , parent_(parent) - , index_(index) { } - - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); - - private: - ScopedRefPtr parent_; - std::string index_; - }; - - Connection* connection_; - bool has_errors_or_timeouts_; - int remaining_; - ResponseMap responses_; -}; - -} // namespace cass - -#endif diff --git a/src/options_request.hpp b/src/options_request.hpp index 6dab80f07..f182180b6 100644 --- a/src/options_request.hpp +++ b/src/options_request.hpp @@ -28,7 +28,7 @@ class OptionsRequest : public Request { : Request(CQL_OPCODE_OPTIONS) {} private: - int encode(int version, Handler* handler, BufferVec* bufs) const { return 0; } + int encode(int version, RequestCallback* callback, BufferVec* bufs) const { return 0; } }; } // namespace cass diff --git a/src/pool.cpp b/src/pool.cpp index 205aa41b5..95976a446 100644 --- a/src/pool.cpp +++ b/src/pool.cpp @@ -20,9 +20,8 @@ #include "error_response.hpp" #include "io_worker.hpp" #include "logger.hpp" -#include "prepare_handler.hpp" +#include "query_request.hpp" #include "session.hpp" -#include "set_keyspace_handler.hpp" #include "request_handler.hpp" #include "result_response.hpp" #include "timer.hpp" @@ -35,6 +34,72 @@ static bool least_busy_comp(Connection* a, Connection* b) { return a->pending_request_count() < b->pending_request_count(); } +class SetKeyspaceCallback : public RequestCallback { +public: + SetKeyspaceCallback(Connection* connection, + const std::string& keyspace, + RequestHandler* request_handler); + + virtual void on_set(ResponseMessage* response); + virtual void on_error(CassError code, const std::string& message); + virtual void on_timeout(); + +private: + void on_result_response(ResponseMessage* response); + +private: + ScopedRefPtr request_handler_; +}; + +SetKeyspaceCallback::SetKeyspaceCallback(Connection* connection, + const std::string& keyspace, + RequestHandler* request_handler) + : RequestCallback(new QueryRequest("USE \"" + keyspace + "\"")) + , request_handler_(request_handler) { + set_connection(connection); +} + +void SetKeyspaceCallback::on_set(ResponseMessage* response) { + switch (response->opcode()) { + case CQL_OPCODE_RESULT: + on_result_response(response); + break; + case CQL_OPCODE_ERROR: + connection_->defunct(); + request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, + "Unable to set keyspace"); + break; + default: + break; + } +} + +void SetKeyspaceCallback::on_error(CassError code, const std::string& message) { + connection_->defunct(); + request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, + "Unable to set keyspace"); +} + +void SetKeyspaceCallback::on_timeout() { + // TODO(mpenick): What to do here? + request_handler_->on_timeout(); +} + +void SetKeyspaceCallback::on_result_response(ResponseMessage* response) { + ResultResponse* result = + static_cast(response->response_body().get()); + if (result->kind() == CASS_RESULT_KIND_SET_KEYSPACE) { + if (!connection_->write(request_handler_.get())) { + // Try on the same host but a different connection + request_handler_->retry(); + } + } else { + connection_->defunct(); + request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, + "Unable to set keyspace"); + } +} + Pool::Pool(IOWorker* io_worker, const Host::ConstPtr& host, bool is_initial_connection) @@ -210,7 +275,7 @@ bool Pool::write(Connection* connection, RequestHandler* request_handler) { io_worker_->keyspace()->c_str(), static_cast(connection), static_cast(this)); - if (!connection->write(new SetKeyspaceHandler(connection, *io_worker_->keyspace(), + if (!connection->write(new SetKeyspaceCallback(connection, *io_worker_->keyspace(), request_handler), false)) { return false; } diff --git a/src/pool.hpp b/src/pool.hpp index 5aee1fd3a..1331c7fed 100644 --- a/src/pool.hpp +++ b/src/pool.hpp @@ -23,7 +23,7 @@ #include "metrics.hpp" #include "ref_counted.hpp" #include "request.hpp" -#include "request_handler.hpp" +#include "request_callback.hpp" #include "scoped_ptr.hpp" #include "timer.hpp" @@ -113,7 +113,7 @@ class Pool : public RefCounted Connection::ConnectionError error_code_; ConnectionVec connections_; ConnectionVec pending_connections_; - List pending_requests_; + List pending_requests_; int available_connection_count_; bool is_available_; bool is_initial_connection_; diff --git a/src/prepare_handler.cpp b/src/prepare_handler.cpp deleted file mode 100644 index f6b80beb1..000000000 --- a/src/prepare_handler.cpp +++ /dev/null @@ -1,82 +0,0 @@ -/* - Copyright (c) 2014-2016 DataStax - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#include "prepare_handler.hpp" - -#include "batch_request.hpp" -#include "constants.hpp" -#include "error_response.hpp" -#include "execute_request.hpp" -#include "prepare_request.hpp" -#include "request_handler.hpp" -#include "response.hpp" -#include "result_response.hpp" - -namespace cass { - -bool PrepareHandler::init(const std::string& prepared_id) { - PrepareRequest* prepare = - static_cast(new PrepareRequest()); - request_.reset(prepare); - if (request_handler_->request()->opcode() == CQL_OPCODE_EXECUTE) { - const ExecuteRequest* execute = static_cast( - request_handler_->request()); - prepare->set_query(execute->prepared()->statement()); - return true; - } else if (request_handler_->request()->opcode() == CQL_OPCODE_BATCH) { - const BatchRequest* batch = static_cast( - request_handler_->request()); - std::string prepared_statement; - if (batch->prepared_statement(prepared_id, &prepared_statement)) { - prepare->set_query(prepared_statement); - return true; - } - } - return false; // Invalid request type -} - -void PrepareHandler::on_set(ResponseMessage* response) { - switch (response->opcode()) { - case CQL_OPCODE_RESULT: { - ResultResponse* result = - static_cast(response->response_body().get()); - if (result->kind() == CASS_RESULT_KIND_PREPARED) { - request_handler_->retry(); - } else { - request_handler_->next_host(); - request_handler_->retry(); - } - } break; - case CQL_OPCODE_ERROR: - request_handler_->next_host(); - request_handler_->retry(); - break; - default: - break; - } -} - -void PrepareHandler::on_error(CassError code, const std::string& message) { - request_handler_->next_host(); - request_handler_->retry(); -} - -void PrepareHandler::on_timeout() { - request_handler_->next_host(); - request_handler_->retry(); -} - -} // namespace cass diff --git a/src/prepare_handler.hpp b/src/prepare_handler.hpp deleted file mode 100644 index 5f5ef5d02..000000000 --- a/src/prepare_handler.hpp +++ /dev/null @@ -1,48 +0,0 @@ -/* - Copyright (c) 2014-2016 DataStax - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#ifndef __CASS_PREPARE_HANDLER_HPP_INCLUDED__ -#define __CASS_PREPARE_HANDLER_HPP_INCLUDED__ - -#include "handler.hpp" -#include "scoped_ptr.hpp" -#include "ref_counted.hpp" -#include "request_handler.hpp" - -namespace cass { - -class ResponseMessage; -class Request; - -class PrepareHandler : public Handler { -public: - PrepareHandler(RequestHandler* request_handler) - : Handler(NULL) - , request_handler_(request_handler) {} - - bool init(const std::string& prepared_id); - - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); - -private: - ScopedRefPtr request_handler_; -}; - -} // namespace cass - -#endif diff --git a/src/prepare_request.cpp b/src/prepare_request.cpp index c0f573079..a603c59e6 100644 --- a/src/prepare_request.cpp +++ b/src/prepare_request.cpp @@ -20,7 +20,7 @@ namespace cass { -int PrepareRequest::encode(int version, Handler* handler, BufferVec* bufs) const { +int PrepareRequest::encode(int version, RequestCallback* callback, BufferVec* bufs) const { // [long string] size_t length = sizeof(int32_t) + query_.size(); bufs->push_back(Buffer(length)); diff --git a/src/prepare_request.hpp b/src/prepare_request.hpp index 2e7565d39..3141ad7ed 100644 --- a/src/prepare_request.hpp +++ b/src/prepare_request.hpp @@ -39,7 +39,7 @@ class PrepareRequest : public Request { } private: - int encode(int version, Handler* handler, BufferVec* bufs) const; + int encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: std::string query_; diff --git a/src/query_request.cpp b/src/query_request.cpp index 7da18c8f5..1514eb780 100644 --- a/src/query_request.cpp +++ b/src/query_request.cpp @@ -17,13 +17,13 @@ #include "query_request.hpp" #include "constants.hpp" -#include "handler.hpp" +#include "request_callback.hpp" #include "logger.hpp" #include "serialization.hpp" namespace cass { -int32_t QueryRequest::encode_batch(int version, BufferVec* bufs, Handler* handler) const { +int32_t QueryRequest::encode_batch(int version, BufferVec* bufs, RequestCallback* callback) const { int32_t length = 0; const std::string& query(query_); @@ -43,11 +43,11 @@ int32_t QueryRequest::encode_batch(int version, BufferVec* bufs, Handler* handle return ENCODE_ERROR_UNSUPPORTED_PROTOCOL; } buf.encode_uint16(pos, value_names_.size()); - length += copy_buffers_with_names(version, bufs, handler->encoding_cache()); + length += copy_buffers_with_names(version, bufs, callback->encoding_cache()); } else { buf.encode_uint16(pos, elements_count()); if (elements_count() > 0) { - int32_t result = copy_buffers(version, bufs, handler); + int32_t result = copy_buffers(version, bufs, callback); if (result < 0) return result; length += result; } @@ -89,27 +89,27 @@ int32_t QueryRequest::copy_buffers_with_names(int version, return size; } -int QueryRequest::encode(int version, Handler* handler, BufferVec* bufs) const { +int QueryRequest::encode(int version, RequestCallback* callback, BufferVec* bufs) const { if (version == 1) { - return internal_encode_v1(handler, bufs); + return internal_encode_v1(callback, bufs); } else { - return internal_encode(version, handler, bufs); + return internal_encode(version, callback, bufs); } } -int QueryRequest::internal_encode_v1(Handler* handler, BufferVec* bufs) const { +int QueryRequest::internal_encode_v1(RequestCallback* callback, BufferVec* bufs) const { // [long string] + [short] size_t length = sizeof(int32_t) + query_.size() + sizeof(uint16_t); Buffer buf(length); size_t pos = buf.encode_long_string(0, query_.data(), query_.size()); - buf.encode_uint16(pos, handler->consistency()); + buf.encode_uint16(pos, callback->consistency()); bufs->push_back(buf); return length; } -int QueryRequest::internal_encode(int version, Handler* handler, BufferVec* bufs) const { +int QueryRequest::internal_encode(int version, RequestCallback* callback, BufferVec* bufs) const { int length = 0; uint8_t flags = this->flags(); @@ -138,7 +138,7 @@ int QueryRequest::internal_encode(int version, Handler* handler, BufferVec* bufs flags |= CASS_QUERY_FLAG_SERIAL_CONSISTENCY; } - if (version >= 3 && handler->timestamp() != CASS_INT64_MIN) { + if (version >= 3 && callback->timestamp() != CASS_INT64_MIN) { paging_buf_size += sizeof(int64_t); // [long] flags |= CASS_QUERY_FLAG_DEFAULT_TIMESTAMP; } @@ -149,7 +149,7 @@ int QueryRequest::internal_encode(int version, Handler* handler, BufferVec* bufs Buffer& buf = bufs->back(); size_t pos = buf.encode_long_string(0, query_.data(), query_.size()); - pos = buf.encode_uint16(pos, handler->consistency()); + pos = buf.encode_uint16(pos, callback->consistency()); pos = buf.encode_byte(pos, flags); if (has_names_for_values()) { @@ -158,10 +158,10 @@ int QueryRequest::internal_encode(int version, Handler* handler, BufferVec* bufs return ENCODE_ERROR_UNSUPPORTED_PROTOCOL; } buf.encode_uint16(pos, value_names_.size()); - length += copy_buffers_with_names(version, bufs, handler->encoding_cache()); + length += copy_buffers_with_names(version, bufs, callback->encoding_cache()); } else if (elements_count() > 0) { buf.encode_uint16(pos, elements_count()); - int32_t result = copy_buffers(version, bufs, handler); + int32_t result = copy_buffers(version, bufs, callback); if (result < 0) return result; length += result; } @@ -186,8 +186,8 @@ int QueryRequest::internal_encode(int version, Handler* handler, BufferVec* bufs pos = buf.encode_uint16(pos, serial_consistency()); } - if (version >= 3 && handler->timestamp() != CASS_INT64_MIN) { - pos = buf.encode_int64(pos, handler->timestamp()); + if (version >= 3 && callback->timestamp() != CASS_INT64_MIN) { + pos = buf.encode_int64(pos, callback->timestamp()); } } diff --git a/src/query_request.hpp b/src/query_request.hpp index e60882f59..03ed03c9b 100644 --- a/src/query_request.hpp +++ b/src/query_request.hpp @@ -60,7 +60,7 @@ class QueryRequest : public Statement { , query_(query, query_length) , value_names_(value_count) { } - virtual int32_t encode_batch(int version, BufferVec* bufs, Handler* handler) const; + virtual int32_t encode_batch(int version, BufferVec* bufs, RequestCallback* callback) const; private: virtual size_t get_indices(StringRef name, @@ -73,9 +73,9 @@ class QueryRequest : public Statement { private: int32_t copy_buffers_with_names(int version, BufferVec* bufs, EncodingCache* cache) const; - int encode(int version, Handler* handler, BufferVec* bufs) const; - int internal_encode_v1(Handler* handler, BufferVec* bufs) const; - int internal_encode(int version, Handler* handler, BufferVec* bufs) const; + int encode(int version, RequestCallback* callback, BufferVec* bufs) const; + int internal_encode_v1(RequestCallback* callback, BufferVec* bufs) const; + int internal_encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: std::string query_; diff --git a/src/register_request.cpp b/src/register_request.cpp index 3b81589ca..fa93edaa0 100644 --- a/src/register_request.cpp +++ b/src/register_request.cpp @@ -20,7 +20,7 @@ namespace cass { -int RegisterRequest::encode(int version, Handler* handler, BufferVec* bufs) const { +int RegisterRequest::encode(int version, RequestCallback* callback, BufferVec* bufs) const { // [string list] size_t length = sizeof(uint16_t); std::vector events; diff --git a/src/register_request.hpp b/src/register_request.hpp index 07c8a8014..174cc6ec3 100644 --- a/src/register_request.hpp +++ b/src/register_request.hpp @@ -29,7 +29,7 @@ class RegisterRequest : public Request { , event_types_(event_types) {} private: - int encode(int version, Handler* handler, BufferVec* bufs) const; + int encode(int version, RequestCallback* callback, BufferVec* bufs) const; int event_types_; }; diff --git a/src/request.hpp b/src/request.hpp index 128f83dda..cd4b98ea7 100644 --- a/src/request.hpp +++ b/src/request.hpp @@ -31,7 +31,7 @@ namespace cass { -class Handler; +class RequestCallback; class RequestMessage; class CustomPayload : public RefCounted { @@ -111,7 +111,7 @@ class Request : public RefCounted { custom_payload_.reset(payload); } - virtual int encode(int version, Handler* handler, BufferVec* bufs) const = 0; + virtual int encode(int version, RequestCallback* callback, BufferVec* bufs) const = 0; private: uint8_t opcode_; diff --git a/src/handler.cpp b/src/request_callback.cpp similarity index 69% rename from src/handler.cpp rename to src/request_callback.cpp index 6c8fe0fb5..de4714057 100644 --- a/src/handler.cpp +++ b/src/request_callback.cpp @@ -14,19 +14,20 @@ limitations under the License. */ -#include "handler.hpp" +#include "request_callback.hpp" #include "config.hpp" #include "connection.hpp" #include "constants.hpp" #include "logger.hpp" +#include "query_request.hpp" #include "request.hpp" #include "result_response.hpp" #include "serialization.hpp" namespace cass { -int32_t Handler::encode(int version, int flags, BufferVec* bufs) { +int32_t RequestCallback::encode(int version, int flags, BufferVec* bufs) { if (version < 1 || version > 4) { return Request::ENCODE_ERROR_UNSUPPORTED_PROTOCOL; } @@ -67,7 +68,7 @@ int32_t Handler::encode(int version, int flags, BufferVec* bufs) { return length + header_size; } -void Handler::set_state(Handler::State next_state) { +void RequestCallback::set_state(RequestCallback::State next_state) { switch (state_) { case REQUEST_STATE_NEW: if (next_state == REQUEST_STATE_NEW) { @@ -143,7 +144,7 @@ void Handler::set_state(Handler::State next_state) { } -uint64_t Handler::request_timeout_ms(const Config& config) const { +uint64_t RequestCallback::request_timeout_ms(const Config& config) const { uint64_t request_timeout_ms = request_->request_timeout_ms(); if (request_timeout_ms == CASS_UINT64_MAX) { return config.request_timeout_ms(); @@ -151,4 +152,46 @@ uint64_t Handler::request_timeout_ms(const Config& config) const { return request_timeout_ms; } +bool MultipleRequestCallback::get_result_response(const ResponseMap& responses, + const std::string& index, + ResultResponse** response) { + ResponseMap::const_iterator it = responses.find(index); + if (it == responses.end() || it->second->opcode() != CQL_OPCODE_RESULT) { + return false; + } + *response = static_cast(it->second.get()); + return true; +} + +void MultipleRequestCallback::execute_query(const std::string& index, const std::string& query) { + if (has_errors_or_timeouts_) return; + responses_[index] = SharedRefPtr(); + SharedRefPtr callback(new InternalCallback(this, new QueryRequest(query), index)); + remaining_++; + if (!connection_->write(callback.get())) { + on_error(CASS_ERROR_LIB_NO_STREAMS, "No more streams available"); + } +} + +void MultipleRequestCallback::InternalCallback::on_set(ResponseMessage* response) { + parent_->responses_[index_] = response->response_body(); + if (--parent_->remaining_ == 0 && !parent_->has_errors_or_timeouts_) { + parent_->on_set(parent_->responses_); + } +} + +void MultipleRequestCallback::InternalCallback::on_error(CassError code, const std::string& message) { + if (!parent_->has_errors_or_timeouts_) { + parent_->on_error(code, message); + } + parent_->has_errors_or_timeouts_ = true; +} + +void MultipleRequestCallback::InternalCallback::on_timeout() { + if (!parent_->has_errors_or_timeouts_) { + parent_->on_timeout(); + } + parent_->has_errors_or_timeouts_ = true; +} + } // namespace cass diff --git a/src/handler.hpp b/src/request_callback.hpp similarity index 62% rename from src/handler.hpp rename to src/request_callback.hpp index 2449c8b88..3aecfd050 100644 --- a/src/handler.hpp +++ b/src/request_callback.hpp @@ -14,8 +14,8 @@ limitations under the License. */ -#ifndef __CASS_HANDLER_HPP_INCLUDED__ -#define __CASS_HANDLER_HPP_INCLUDED__ +#ifndef __CASS_REQUEST_CALLBACK_HPP_INCLUDED__ +#define __CASS_REQUEST_CALLBACK_HPP_INCLUDED__ #include "buffer.hpp" #include "constants.hpp" @@ -33,11 +33,13 @@ namespace cass { class Config; class Connection; +class Response; class ResponseMessage; +class ResultResponse; typedef std::vector UvBufVec; -class Handler : public RefCounted, public List::Node { +class RequestCallback : public RefCounted, public List::Node { public: enum State { REQUEST_STATE_NEW, @@ -50,7 +52,7 @@ class Handler : public RefCounted, public List::Node { REQUEST_STATE_DONE }; - Handler(const Request* request) + RequestCallback(const Request* request) : request_(request) , connection_(NULL) , stream_(-1) @@ -59,7 +61,7 @@ class Handler : public RefCounted, public List::Node { , timestamp_(CASS_INT64_MIN) , start_time_ns_(0) { } - virtual ~Handler() {} + virtual ~RequestCallback() {} int32_t encode(int version, int flags, BufferVec* bufs); @@ -130,7 +132,55 @@ class Handler : public RefCounted, public List::Node { Request::EncodingCache encoding_cache_; private: - DISALLOW_COPY_AND_ASSIGN(Handler); + DISALLOW_COPY_AND_ASSIGN(RequestCallback); +}; + +class MultipleRequestCallback : public RefCounted { +public: + typedef std::map > ResponseMap; + + MultipleRequestCallback(Connection* connection) + : connection_(connection) + , has_errors_or_timeouts_(false) + , remaining_(0) { } + + virtual ~MultipleRequestCallback() { } + + static bool get_result_response(const ResponseMap& responses, + const std::string& index, + ResultResponse** response); + + void execute_query(const std::string& index, const std::string& query); + + virtual void on_set(const ResponseMap& responses) = 0; + virtual void on_error(CassError code, const std::string& message) = 0; + virtual void on_timeout() = 0; + + Connection* connection() { + return connection_; + } + +private: + class InternalCallback : public RequestCallback { + public: + InternalCallback(MultipleRequestCallback* parent, const Request* request, const std::string& index) + : RequestCallback(request) + , parent_(parent) + , index_(index) { } + + virtual void on_set(ResponseMessage* response); + virtual void on_error(CassError code, const std::string& message); + virtual void on_timeout(); + + private: + ScopedRefPtr parent_; + std::string index_; + }; + + Connection* connection_; + bool has_errors_or_timeouts_; + int remaining_; + ResponseMap responses_; }; } // namespace cass diff --git a/src/request_handler.cpp b/src/request_handler.cpp index 842f5120f..8cb6b971c 100644 --- a/src/request_handler.cpp +++ b/src/request_handler.cpp @@ -16,21 +16,92 @@ #include "request_handler.hpp" +#include "batch_request.hpp" #include "connection.hpp" #include "constants.hpp" +#include "error_response.hpp" #include "execute_request.hpp" #include "io_worker.hpp" #include "pool.hpp" -#include "prepare_handler.hpp" +#include "prepare_request.hpp" +#include "response.hpp" #include "result_response.hpp" #include "row.hpp" -#include "schema_change_handler.hpp" +#include "schema_change_callback.hpp" #include "session.hpp" #include namespace cass { +class PrepareCallback : public RequestCallback { +public: + PrepareCallback(RequestHandler* request_handler) + : RequestCallback(NULL) + , request_handler_(request_handler) {} + + bool init(const std::string& prepared_id); + + virtual void on_set(ResponseMessage* response); + virtual void on_error(CassError code, const std::string& message); + virtual void on_timeout(); + +private: + ScopedRefPtr request_handler_; +}; + +bool PrepareCallback::init(const std::string& prepared_id) { + PrepareRequest* prepare = + static_cast(new PrepareRequest()); + request_.reset(prepare); + if (request_handler_->request()->opcode() == CQL_OPCODE_EXECUTE) { + const ExecuteRequest* execute = static_cast( + request_handler_->request()); + prepare->set_query(execute->prepared()->statement()); + return true; + } else if (request_handler_->request()->opcode() == CQL_OPCODE_BATCH) { + const BatchRequest* batch = static_cast( + request_handler_->request()); + std::string prepared_statement; + if (batch->prepared_statement(prepared_id, &prepared_statement)) { + prepare->set_query(prepared_statement); + return true; + } + } + return false; // Invalid request type +} + +void PrepareCallback::on_set(ResponseMessage* response) { + switch (response->opcode()) { + case CQL_OPCODE_RESULT: { + ResultResponse* result = + static_cast(response->response_body().get()); + if (result->kind() == CASS_RESULT_KIND_PREPARED) { + request_handler_->retry(); + } else { + request_handler_->next_host(); + request_handler_->retry(); + } + } break; + case CQL_OPCODE_ERROR: + request_handler_->next_host(); + request_handler_->retry(); + break; + default: + break; + } +} + +void PrepareCallback::on_error(CassError code, const std::string& message) { + request_handler_->next_host(); + request_handler_->retry(); +} + +void PrepareCallback::on_timeout() { + request_handler_->next_host(); + request_handler_->retry(); +} + void RequestHandler::on_set(ResponseMessage* response) { assert(connection_ != NULL); assert(!is_query_plan_exhausted_ && "Tried to set on a non-existent host"); @@ -150,8 +221,8 @@ void RequestHandler::on_result_response(ResponseMessage* response) { break; case CASS_RESULT_KIND_SCHEMA_CHANGE: { - SharedRefPtr schema_change_handler( - new SchemaChangeHandler(connection_, + SharedRefPtr schema_change_handler( + new SchemaChangeCallback(connection_, this, response->response_body())); schema_change_handler->execute(); @@ -214,7 +285,7 @@ void RequestHandler::on_error_response(ResponseMessage* response) { } void RequestHandler::on_error_unprepared(ErrorResponse* error) { - ScopedRefPtr prepare_handler(new PrepareHandler(this)); + ScopedRefPtr prepare_handler(new PrepareCallback(this)); if (prepare_handler->init(error->prepared_id().to_string())) { if (!connection_->write(prepare_handler.get())) { // Try to prepare on the same host but on a different connection diff --git a/src/request_handler.hpp b/src/request_handler.hpp index d63bf65dc..2f110ba9a 100644 --- a/src/request_handler.hpp +++ b/src/request_handler.hpp @@ -20,7 +20,7 @@ #include "constants.hpp" #include "error_response.hpp" #include "future.hpp" -#include "handler.hpp" +#include "request_callback.hpp" #include "host.hpp" #include "load_balancing.hpp" #include "metadata.hpp" @@ -87,12 +87,12 @@ class ResponseFuture : public Future { }; -class RequestHandler : public Handler { +class RequestHandler : public RequestCallback { public: RequestHandler(const Request* request, ResponseFuture* future, RetryPolicy* retry_policy) - : Handler(request) + : RequestCallback(request) , future_(future) , retry_policy_(retry_policy) , num_retries_(0) diff --git a/src/schema_change_handler.cpp b/src/schema_change_callback.cpp similarity index 81% rename from src/schema_change_handler.cpp rename to src/schema_change_callback.cpp index 972af7caa..a0cdc1e6a 100644 --- a/src/schema_change_handler.cpp +++ b/src/schema_change_callback.cpp @@ -14,7 +14,7 @@ limitations under the License. */ -#include "schema_change_handler.hpp" +#include "schema_change_callback.hpp" #include "address.hpp" #include "connection.hpp" @@ -34,27 +34,27 @@ namespace cass { -SchemaChangeHandler::SchemaChangeHandler(Connection* connection, +SchemaChangeCallback::SchemaChangeCallback(Connection* connection, RequestHandler* request_handler, const SharedRefPtr& response, uint64_t elapsed) - : MultipleRequestHandler(connection) + : MultipleRequestCallback(connection) , request_handler_(request_handler) , request_response_(response) , start_ms_(get_time_since_epoch_ms()) , elapsed_ms_(elapsed) {} -void SchemaChangeHandler::execute() { +void SchemaChangeCallback::execute() { execute_query("local", "SELECT schema_version FROM system.local WHERE key='local'"); execute_query("peers", "SELECT peer, rpc_address, schema_version FROM system.peers"); } -bool SchemaChangeHandler::has_schema_agreement(const ResponseMap& responses) { +bool SchemaChangeCallback::has_schema_agreement(const ResponseMap& responses) { StringRef current_version; ResultResponse* local_result; - if (MultipleRequestHandler::get_result_response(responses, "local", &local_result) && + if (MultipleRequestCallback::get_result_response(responses, "local", &local_result) && local_result->row_count() > 0) { const Row* row = &local_result->first_row(); @@ -68,7 +68,7 @@ bool SchemaChangeHandler::has_schema_agreement(const ResponseMap& responses) { } ResultResponse* peers_result; - if (MultipleRequestHandler::get_result_response(responses, "peers", &peers_result)) { + if (MultipleRequestCallback::get_result_response(responses, "peers", &peers_result)) { ResultIterator rows(peers_result); while (rows.next()) { const Row* row = rows.row(); @@ -95,13 +95,13 @@ bool SchemaChangeHandler::has_schema_agreement(const ResponseMap& responses) { return true; } -void SchemaChangeHandler::on_set(const ResponseMap& responses) { +void SchemaChangeCallback::on_set(const ResponseMap& responses) { elapsed_ms_ += get_time_since_epoch_ms() - start_ms_; bool has_error = false; - for (MultipleRequestHandler::ResponseMap::const_iterator it = responses.begin(), + for (MultipleRequestCallback::ResponseMap::const_iterator it = responses.begin(), end = responses.end(); it != end; ++it) { - if (check_error_or_invalid_response("SchemaChangeHandler", CQL_OPCODE_RESULT, it->second.get())) { + if (check_error_or_invalid_response("SchemaChangeCallback", CQL_OPCODE_RESULT, it->second.get())) { has_error = true; } } @@ -123,16 +123,16 @@ void SchemaChangeHandler::on_set(const ResponseMap& responses) { "Trying again in %d ms", RETRY_SCHEMA_AGREEMENT_WAIT_MS); // Try again - SharedRefPtr handler( - new SchemaChangeHandler(connection(), + SharedRefPtr callback( + new SchemaChangeCallback(connection(), request_handler_.get(), request_response_, elapsed_ms_)); - connection()->schedule_schema_agreement(handler, + connection()->schedule_schema_agreement(callback, RETRY_SCHEMA_AGREEMENT_WAIT_MS); } -void SchemaChangeHandler::on_error(CassError code, const std::string& message) { +void SchemaChangeCallback::on_error(CassError code, const std::string& message) { std::ostringstream ss; ss << "An error occurred waiting for schema agreement: '" << message << "' (0x" << std::hex << std::uppercase << std::setw(8) << std::setfill('0') << code << ")"; @@ -140,12 +140,12 @@ void SchemaChangeHandler::on_error(CassError code, const std::string& message) { request_handler_->set_response(request_response_); } -void SchemaChangeHandler::on_timeout() { +void SchemaChangeCallback::on_timeout() { LOG_ERROR("A timeout occurred waiting for schema agreement"); request_handler_->set_response(request_response_); } -void SchemaChangeHandler::on_closing() { +void SchemaChangeCallback::on_closing() { LOG_WARN("Connection closed while waiting for schema agreement"); request_handler_->set_response(request_response_); } diff --git a/src/schema_change_handler.hpp b/src/schema_change_callback.hpp similarity index 84% rename from src/schema_change_handler.hpp rename to src/schema_change_callback.hpp index 70856a663..9566c881d 100644 --- a/src/schema_change_handler.hpp +++ b/src/schema_change_callback.hpp @@ -14,10 +14,9 @@ limitations under the License. */ -#ifndef __CASS_SCHEMA_CHANGE_HANDLER_HPP_INCLUDED__ -#define __CASS_SCHEMA_CHANGE_HANDLER_HPP_INCLUDED__ +#ifndef __CASS_SCHEMA_CHANGE_CALLBACK_HPP_INCLUDED__ +#define __CASS_SCHEMA_CHANGE_CALLBACK_HPP_INCLUDED__ -#include "multiple_request_handler.hpp" #include "ref_counted.hpp" #include "request_handler.hpp" #include "scoped_ptr.hpp" @@ -30,9 +29,9 @@ namespace cass { class Connection; class Response; -class SchemaChangeHandler : public MultipleRequestHandler { +class SchemaChangeCallback : public MultipleRequestCallback { public: - SchemaChangeHandler(Connection* connection, + SchemaChangeCallback(Connection* connection, RequestHandler* request_handler, const SharedRefPtr& response, uint64_t elapsed = 0); diff --git a/src/set_keyspace_handler.cpp b/src/set_keyspace_handler.cpp deleted file mode 100644 index a5bdc5a6b..000000000 --- a/src/set_keyspace_handler.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/* - Copyright (c) 2014-2016 DataStax - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#include "set_keyspace_handler.hpp" - -#include "connection.hpp" -#include "io_worker.hpp" -#include "prepare_request.hpp" -#include "query_request.hpp" -#include "request_handler.hpp" -#include "response.hpp" -#include "result_response.hpp" - -namespace cass { - -SetKeyspaceHandler::SetKeyspaceHandler(Connection* connection, - const std::string& keyspace, - RequestHandler* request_handler) - : Handler(new QueryRequest("USE \"" + keyspace + "\"")) - , request_handler_(request_handler) { - set_connection(connection); -} - -void SetKeyspaceHandler::on_set(ResponseMessage* response) { - switch (response->opcode()) { - case CQL_OPCODE_RESULT: - on_result_response(response); - break; - case CQL_OPCODE_ERROR: - connection_->defunct(); - request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, - "Unable to set keyspace"); - break; - default: - break; - } -} - -void SetKeyspaceHandler::on_error(CassError code, const std::string& message) { - connection_->defunct(); - request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, - "Unable to set keyspace"); -} - -void SetKeyspaceHandler::on_timeout() { - // TODO(mpenick): What to do here? - request_handler_->on_timeout(); -} - -void SetKeyspaceHandler::on_result_response(ResponseMessage* response) { - ResultResponse* result = - static_cast(response->response_body().get()); - if (result->kind() == CASS_RESULT_KIND_SET_KEYSPACE) { - if (!connection_->write(request_handler_.get())) { - // Try on the same host but a different connection - request_handler_->retry(); - } - } else { - connection_->defunct(); - request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, - "Unable to set keyspace"); - } -} - -} // namespace cass diff --git a/src/set_keyspace_handler.hpp b/src/set_keyspace_handler.hpp deleted file mode 100644 index 6d5a31d41..000000000 --- a/src/set_keyspace_handler.hpp +++ /dev/null @@ -1,49 +0,0 @@ -/* - Copyright (c) 2014-2016 DataStax - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#ifndef __CASS_SET_KEYSPACE_HANDLER_HPP_INCLUDED__ -#define __CASS_SET_KEYSPACE_HANDLER_HPP_INCLUDED__ - -#include "query_request.hpp" -#include "ref_counted.hpp" -#include "request_handler.hpp" -#include "scoped_ptr.hpp" - -namespace cass { - -class ResponseMessage; -class Connection; - -class SetKeyspaceHandler : public Handler { -public: - SetKeyspaceHandler(Connection* connection, - const std::string& keyspace, - RequestHandler* request_handler); - - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); - -private: - void on_result_response(ResponseMessage* response); - -private: - ScopedRefPtr request_handler_; -}; - -} // namespace cass - -#endif diff --git a/src/startup_request.cpp b/src/startup_request.cpp index 8180ac008..47370b7ae 100644 --- a/src/startup_request.cpp +++ b/src/startup_request.cpp @@ -20,7 +20,7 @@ namespace cass { -int StartupRequest::encode(int version, Handler* handler, BufferVec* bufs) const { +int StartupRequest::encode(int version, RequestCallback* callback, BufferVec* bufs) const { // [string map] size_t length = sizeof(uint16_t); diff --git a/src/startup_request.hpp b/src/startup_request.hpp index b10b5c969..26e5abff9 100644 --- a/src/startup_request.hpp +++ b/src/startup_request.hpp @@ -39,7 +39,7 @@ class StartupRequest : public Request { const std::string compression() const { return compression_; } private: - int encode(int version, Handler* handler, BufferVec* bufs) const; + int encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: typedef std::map OptionsMap; diff --git a/src/statement.cpp b/src/statement.cpp index 2329003c8..9020057b1 100644 --- a/src/statement.cpp +++ b/src/statement.cpp @@ -240,19 +240,19 @@ CassError cass_statement_bind_custom_by_name_n(CassStatement* statement, namespace cass { -int32_t Statement::copy_buffers(int version, BufferVec* bufs, Handler* handler) const { +int32_t Statement::copy_buffers(int version, BufferVec* bufs, RequestCallback* callback) const { int32_t size = 0; for (size_t i = 0; i < elements().size(); ++i) { const Element& element = elements()[i]; if (!element.is_unset()) { - bufs->push_back(element.get_buffer_cached(version, handler->encoding_cache(), false)); + bufs->push_back(element.get_buffer_cached(version, callback->encoding_cache(), false)); } else { if (version >= 4) { bufs->push_back(cass::encode_with_length(CassUnset())); } else { std::stringstream ss; ss << "Query parameter at index " << i << " was not set"; - handler->on_error(CASS_ERROR_LIB_PARAMETER_UNSET, ss.str()); + callback->on_error(CASS_ERROR_LIB_PARAMETER_UNSET, ss.str()); return Request::ENCODE_ERROR_PARAMETER_UNSET; } } diff --git a/src/statement.hpp b/src/statement.hpp index b435dabc0..6ee936efc 100644 --- a/src/statement.hpp +++ b/src/statement.hpp @@ -30,7 +30,7 @@ namespace cass { -class Handler; +class RequestCallback; class Statement : public RoutableRequest, public AbstractData { public: @@ -95,10 +95,10 @@ class Statement : public RoutableRequest, public AbstractData { virtual bool get_routing_key(std::string* routing_key, EncodingCache* cache) const; - virtual int32_t encode_batch(int version, BufferVec* bufs, Handler* handler) const = 0; + virtual int32_t encode_batch(int version, BufferVec* bufs, RequestCallback* callback) const = 0; protected: - int32_t copy_buffers(int version, BufferVec* bufs, Handler* handler) const; + int32_t copy_buffers(int version, BufferVec* bufs, RequestCallback* callback) const; private: uint8_t flags_; From 1cfbe00b4e19984e7783bdc3b03ba51aad405084 Mon Sep 17 00:00:00 2001 From: Michael Penick Date: Thu, 22 Sep 2016 15:41:16 -0700 Subject: [PATCH 2/7] Improvement: SharedRefPtr<> cleanup --- src/auth.hpp | 29 ++++++++--- src/auth_requests.hpp | 6 +-- src/batch_request.cpp | 4 +- src/batch_request.hpp | 4 +- src/blacklist_dc_policy.cpp | 2 +- src/blacklist_dc_policy.hpp | 2 +- src/blacklist_policy.cpp | 2 +- src/blacklist_policy.hpp | 2 +- src/cluster.cpp | 4 +- src/collection.cpp | 2 +- src/config.hpp | 14 ++--- src/connection.cpp | 35 ++++++++----- src/connection.hpp | 8 +-- src/control_connection.cpp | 52 +++++++++---------- src/control_connection.hpp | 12 ++--- src/data_type.cpp | 24 ++++----- src/data_type.hpp | 25 ++++----- src/data_type_parser.cpp | 18 +++---- src/data_type_parser.hpp | 5 +- src/dc_aware_policy.cpp | 28 +++++----- src/dc_aware_policy.hpp | 18 +++---- src/execute_request.hpp | 6 +-- src/future.cpp | 16 +++--- src/future.hpp | 1 + src/host.cpp | 4 +- src/host.hpp | 18 +++---- src/io_worker.cpp | 8 +-- src/io_worker.hpp | 8 +-- src/latency_aware_policy.cpp | 18 +++---- src/latency_aware_policy.hpp | 14 ++--- src/list_policy.cpp | 14 ++--- src/list_policy.hpp | 14 ++--- src/load_balancing.hpp | 24 +++++---- src/metadata.cpp | 54 +++++++++---------- src/metadata.hpp | 26 +++++----- src/periodic_task.hpp | 19 ++++--- src/pool.cpp | 8 +-- src/pool.hpp | 2 + src/prepared.cpp | 5 +- src/prepared.hpp | 8 +-- src/ref_counted.hpp | 86 +++++++++---------------------- src/request.hpp | 10 ++-- src/request_callback.cpp | 6 ++- src/request_callback.hpp | 13 ++--- src/request_handler.cpp | 19 ++++--- src/request_handler.hpp | 24 +++++---- src/response.hpp | 12 +++-- src/result_metadata.hpp | 2 + src/result_response.cpp | 26 +++++----- src/result_response.hpp | 12 +++-- src/retry_policy.hpp | 6 ++- src/round_robin_policy.cpp | 18 +++---- src/round_robin_policy.hpp | 14 ++--- src/schema_change_callback.cpp | 14 ++--- src/schema_change_callback.hpp | 10 ++-- src/session.cpp | 94 +++++++++++++++++++--------------- src/session.hpp | 38 +++++++------- src/ssl.cpp | 6 +-- src/ssl.hpp | 4 +- src/ssl/ssl_no_impl.cpp | 4 +- src/ssl/ssl_no_impl.hpp | 2 +- src/ssl/ssl_openssl_impl.cpp | 4 +- src/ssl/ssl_openssl_impl.hpp | 2 +- src/statement.hpp | 2 + src/token_aware_policy.cpp | 10 ++-- src/token_aware_policy.hpp | 4 +- src/tuple.cpp | 2 +- src/tuple.hpp | 4 +- src/user_type_value.cpp | 2 +- src/whitelist_dc_policy.cpp | 2 +- src/whitelist_dc_policy.hpp | 2 +- src/whitelist_policy.cpp | 2 +- src/whitelist_policy.hpp | 2 +- 73 files changed, 526 insertions(+), 495 deletions(-) diff --git a/src/auth.hpp b/src/auth.hpp index d5e4baaea..4209589a0 100644 --- a/src/auth.hpp +++ b/src/auth.hpp @@ -41,6 +41,8 @@ class V1Authenticator { class Authenticator : public RefCounted { public: + typedef SharedRefPtr Ptr; + Authenticator() { } virtual ~Authenticator() { } @@ -78,13 +80,22 @@ class PlainTextAuthenticator : public V1Authenticator, public Authenticator { class AuthProvider : public RefCounted { public: + typedef SharedRefPtr Ptr; + AuthProvider() : RefCounted() { } virtual ~AuthProvider() { } - virtual V1Authenticator* new_authenticator_v1(const Host::ConstPtr& host, const std::string& class_name) const { return NULL; } - virtual Authenticator* new_authenticator(const Host::ConstPtr& host, const std::string& class_name) const { return NULL; } + virtual V1Authenticator* new_authenticator_v1(const Host::ConstPtr& host, + const std::string& class_name) const { + return NULL; + } + + virtual Authenticator::Ptr new_authenticator(const Host::ConstPtr& host, + const std::string& class_name) const { + return Authenticator::Ptr(); + } private: DISALLOW_COPY_AND_ASSIGN(AuthProvider); @@ -136,9 +147,9 @@ class ExternalAuthProvider : public AuthProvider { } } - virtual V1Authenticator* new_authenticator_v1(const Host::ConstPtr& host, const std::string& class_name) const { return NULL; } - virtual Authenticator* new_authenticator(const Host::ConstPtr& host, const std::string& class_name) const { - return new ExternalAuthenticator(host, class_name, &exchange_callbacks_, data_); + virtual Authenticator::Ptr new_authenticator(const Host::ConstPtr& host, + const std::string& class_name) const { + return Authenticator::Ptr(new ExternalAuthenticator(host, class_name, &exchange_callbacks_, data_)); } private: @@ -154,12 +165,14 @@ class PlainTextAuthProvider : public AuthProvider { : username_(username) , password_(password) { } - virtual V1Authenticator* new_authenticator_v1(const Host::ConstPtr& host, const std::string& class_name) const { + virtual V1Authenticator* new_authenticator_v1(const Host::ConstPtr& host, + const std::string& class_name) const { return new PlainTextAuthenticator(username_, password_); } - virtual Authenticator* new_authenticator(const Host::ConstPtr& host, const std::string& class_name) const { - return new PlainTextAuthenticator(username_, password_); + virtual Authenticator::Ptr new_authenticator(const Host::ConstPtr& host, + const std::string& class_name) const { + return Authenticator::Ptr(new PlainTextAuthenticator(username_, password_)); } private: diff --git a/src/auth_requests.hpp b/src/auth_requests.hpp index e797e4281..4824c38bc 100644 --- a/src/auth_requests.hpp +++ b/src/auth_requests.hpp @@ -40,19 +40,19 @@ class CredentialsRequest : public Request { class AuthResponseRequest : public Request { public: AuthResponseRequest(const std::string& token, - const SharedRefPtr& auth) + const Authenticator::Ptr& auth) : Request(CQL_OPCODE_AUTH_RESPONSE) , token_(token) , auth_(auth) { } - const SharedRefPtr& auth() const { return auth_; } + const Authenticator::Ptr& auth() const { return auth_; } private: int encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: std::string token_; - SharedRefPtr auth_; + Authenticator::Ptr auth_; }; } // namespace cass diff --git a/src/batch_request.cpp b/src/batch_request.cpp index bf66db3f0..9717dd40f 100644 --- a/src/batch_request.cpp +++ b/src/batch_request.cpp @@ -102,7 +102,7 @@ int BatchRequest::encode(int version, RequestCallback* callback, BufferVec* bufs for (BatchRequest::StatementList::const_iterator i = statements_.begin(), end = statements_.end(); i != end; ++i) { - const SharedRefPtr& statement(*i); + const Statement::Ptr& statement(*i); if (statement->has_names_for_values()) { callback->on_error(CASS_ERROR_LIB_BAD_PARAMS, "Batches cannot contain queries with named values"); @@ -160,7 +160,7 @@ void BatchRequest::add_statement(Statement* statement) { ExecuteRequest* execute_request = static_cast(statement); prepared_statements_[execute_request->prepared()->id()] = execute_request; } - statements_.push_back(SharedRefPtr(statement)); + statements_.push_back(Statement::Ptr(statement)); } bool BatchRequest::prepared_statement(const std::string& id, diff --git a/src/batch_request.hpp b/src/batch_request.hpp index 7b0355052..b82bb3781 100644 --- a/src/batch_request.hpp +++ b/src/batch_request.hpp @@ -21,6 +21,7 @@ #include "constants.hpp" #include "request.hpp" #include "ref_counted.hpp" +#include "statement.hpp" #include #include @@ -28,12 +29,11 @@ namespace cass { -class Statement; class ExecuteRequest; class BatchRequest : public RoutableRequest { public: - typedef std::list > StatementList; + typedef std::list StatementList; BatchRequest(uint8_t type_) : RoutableRequest(CQL_OPCODE_BATCH) diff --git a/src/blacklist_dc_policy.cpp b/src/blacklist_dc_policy.cpp index 75ef61334..454099709 100644 --- a/src/blacklist_dc_policy.cpp +++ b/src/blacklist_dc_policy.cpp @@ -18,7 +18,7 @@ namespace cass { -bool BlacklistDCPolicy::is_valid_host(const SharedRefPtr& host) const { +bool BlacklistDCPolicy::is_valid_host(const Host::Ptr& host) const { const std::string& host_dc = host->dc(); for (DcList::const_iterator it = dcs_.begin(), end = dcs_.end(); it != end; ++it) { diff --git a/src/blacklist_dc_policy.hpp b/src/blacklist_dc_policy.hpp index efa6d8853..1b54fae71 100644 --- a/src/blacklist_dc_policy.hpp +++ b/src/blacklist_dc_policy.hpp @@ -38,7 +38,7 @@ class BlacklistDCPolicy : public ListPolicy { } private: - bool is_valid_host(const SharedRefPtr& host) const; + bool is_valid_host(const Host::Ptr& host) const; DcList dcs_; diff --git a/src/blacklist_policy.cpp b/src/blacklist_policy.cpp index ef30c86fc..8fb801cdd 100644 --- a/src/blacklist_policy.cpp +++ b/src/blacklist_policy.cpp @@ -18,7 +18,7 @@ namespace cass { -bool BlacklistPolicy::is_valid_host(const SharedRefPtr& host) const { +bool BlacklistPolicy::is_valid_host(const Host::Ptr& host) const { const std::string& host_address = host->address().to_string(false); for (ContactPointList::const_iterator it = hosts_.begin(), end = hosts_.end(); diff --git a/src/blacklist_policy.hpp b/src/blacklist_policy.hpp index b1368e0d8..b4ba28d93 100644 --- a/src/blacklist_policy.hpp +++ b/src/blacklist_policy.hpp @@ -38,7 +38,7 @@ class BlacklistPolicy : public ListPolicy { } private: - bool is_valid_host(const SharedRefPtr& host) const; + bool is_valid_host(const Host::Ptr& host) const; ContactPointList hosts_; diff --git a/src/cluster.cpp b/src/cluster.cpp index 217cfb86b..031407687 100644 --- a/src/cluster.cpp +++ b/src/cluster.cpp @@ -376,7 +376,9 @@ CassError cass_cluster_set_authenticator_callbacks(CassCluster* cluster, const CassAuthenticatorCallbacks* exchange_callbacks, CassAuthenticatorDataCleanupCallback cleanup_callback, void* data) { - cluster->config().set_auth_provider(new cass::ExternalAuthProvider(exchange_callbacks, cleanup_callback, data)); + cluster->config().set_auth_provider(cass::AuthProvider::Ptr( + new cass::ExternalAuthProvider(exchange_callbacks, + cleanup_callback, data))); return CASS_OK; } diff --git a/src/collection.cpp b/src/collection.cpp index dc075aab8..883d7b7f8 100644 --- a/src/collection.cpp +++ b/src/collection.cpp @@ -38,7 +38,7 @@ CassCollection* cass_collection_new_from_data_type(const CassDataType* data_type return NULL; } cass::Collection* collection - = new cass::Collection(cass::SharedRefPtr(data_type), + = new cass::Collection(cass::DataType::ConstPtr(data_type), item_count); collection->inc_ref(); return CassCollection::to(collection); diff --git a/src/config.hpp b/src/config.hpp index e3538ccb7..5cf20c5d4 100644 --- a/src/config.hpp +++ b/src/config.hpp @@ -228,10 +228,10 @@ class Config { log_data_ = data; } - const SharedRefPtr& auth_provider() const { return auth_provider_; } + const AuthProvider::Ptr& auth_provider() const { return auth_provider_; } - void set_auth_provider(AuthProvider* auth_provider) { - auth_provider_.reset(auth_provider == NULL ? new AuthProvider() : auth_provider); + void set_auth_provider(const AuthProvider::Ptr& auth_provider) { + auth_provider_ = (!auth_provider ? AuthProvider::Ptr(new AuthProvider()) : auth_provider); } void set_credentials(const std::string& username, const std::string& password) { @@ -389,9 +389,9 @@ class Config { CassLogLevel log_level_; CassLogCallback log_callback_; void* log_data_; - SharedRefPtr auth_provider_; - SharedRefPtr load_balancing_policy_; - SharedRefPtr ssl_context_; + AuthProvider::Ptr auth_provider_; + LoadBalancingPolicy::Ptr load_balancing_policy_; + SslContext::Ptr ssl_context_; bool token_aware_routing_; bool latency_aware_routing_; LatencyAwarePolicy::Settings latency_aware_routing_settings_; @@ -405,7 +405,7 @@ class Config { unsigned connection_idle_timeout_secs_; unsigned connection_heartbeat_interval_secs_; SharedRefPtr timestamp_gen_; - SharedRefPtr retry_policy_; + RetryPolicy::Ptr retry_policy_; bool use_schema_; bool use_hostname_resolution_; bool use_randomized_contact_points_; diff --git a/src/connection.cpp b/src/connection.cpp index 9d9e7f324..89954754c 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -157,7 +157,7 @@ void Connection::StartupCallback::on_result_response(ResponseMessage* response) } Connection::HeartbeatCallback::HeartbeatCallback(Connection* connection) - : RequestCallback(new OptionsRequest()) { + : RequestCallback(Request::ConstPtr(new OptionsRequest())) { set_connection(connection); } @@ -321,7 +321,7 @@ void Connection::flush() { pending_writes_.back()->flush(); } -void Connection::schedule_schema_agreement(const SharedRefPtr& callback, uint64_t wait) { +void Connection::schedule_schema_agreement(const SchemaChangeCallback::Ptr& callback, uint64_t wait) { PendingSchemaAgreement* pending_schema_agreement = new PendingSchemaAgreement(callback); pending_schema_agreements_.add_to_back(pending_schema_agreement); pending_schema_agreement->timer.start(loop_, @@ -741,7 +741,7 @@ void Connection::on_timeout(Timer* timer) { } void Connection::on_connected() { - internal_write(new StartupCallback(this, new OptionsRequest())); + internal_write(new StartupCallback(this, Request::ConstPtr(new OptionsRequest()))); } void Connection::on_authenticate(const std::string& class_name) { @@ -759,9 +759,10 @@ void Connection::on_auth_challenge(const AuthResponseRequest* request, notify_error("Failed evaluating challenge token: " + request->auth()->error(), CONNECTION_ERROR_AUTH); return; } - AuthResponseRequest* auth_response = new AuthResponseRequest(response, - request->auth()); - internal_write(new StartupCallback(this, auth_response)); + internal_write(new StartupCallback(this, + Request::ConstPtr( + new AuthResponseRequest(response, + request->auth())))); } void Connection::on_auth_success(const AuthResponseRequest* request, @@ -776,14 +777,18 @@ void Connection::on_auth_success(const AuthResponseRequest* request, void Connection::on_ready() { if (state_ == CONNECTION_STATE_CONNECTED && listener_->event_types() != 0) { set_state(CONNECTION_STATE_REGISTERING_EVENTS); - internal_write(new StartupCallback(this, new RegisterRequest(listener_->event_types()))); + internal_write(new StartupCallback(this, + Request::ConstPtr( + new RegisterRequest(listener_->event_types())))); return; } if (keyspace_.empty()) { notify_ready(); } else { - internal_write(new StartupCallback(this, new QueryRequest("USE \"" + keyspace_ + "\""))); + internal_write(new StartupCallback(this, + Request::ConstPtr( + new QueryRequest("USE \"" + keyspace_ + "\"")))); } } @@ -798,7 +803,8 @@ void Connection::on_supported(ResponseMessage* response) { // TODO(mstump) do something with the supported info (void)supported; - internal_write(new StartupCallback(this, new StartupRequest())); + internal_write(new StartupCallback(this, + Request::ConstPtr(new StartupRequest()))); } void Connection::on_pending_schema_agreement(Timer* timer) { @@ -865,14 +871,16 @@ void Connection::send_credentials(const std::string& class_name) { if (v1_auth) { V1Authenticator::Credentials credentials; v1_auth->get_credentials(&credentials); - internal_write(new StartupCallback(this, new CredentialsRequest(credentials))); + internal_write(new StartupCallback(this, + Request::ConstPtr( + new CredentialsRequest(credentials)))); } else { send_initial_auth_response(class_name); } } void Connection::send_initial_auth_response(const std::string& class_name) { - SharedRefPtr auth(config_.auth_provider()->new_authenticator(host_, class_name)); + Authenticator::Ptr auth(config_.auth_provider()->new_authenticator(host_, class_name)); if (!auth) { notify_error("Authentication required but no auth provider set", CONNECTION_ERROR_AUTH); } else { @@ -881,8 +889,9 @@ void Connection::send_initial_auth_response(const std::string& class_name) { notify_error("Failed creating initial response token: " + auth->error(), CONNECTION_ERROR_AUTH); return; } - AuthResponseRequest* auth_response = new AuthResponseRequest(response, auth); - internal_write(new StartupCallback(this, auth_response)); + internal_write(new StartupCallback(this, + Request::ConstPtr( + new AuthResponseRequest(response, auth)))); } } diff --git a/src/connection.hpp b/src/connection.hpp index 74f092008..017340153 100644 --- a/src/connection.hpp +++ b/src/connection.hpp @@ -107,7 +107,7 @@ class Connection { bool write(RequestCallback* request, bool flush_immediately = true); void flush(); - void schedule_schema_agreement(const SharedRefPtr& callback, uint64_t wait); + void schedule_schema_agreement(const SchemaChangeCallback::Ptr& callback, uint64_t wait); const Config& config() const { return config_; } Metrics* metrics() { return metrics_; } @@ -165,7 +165,7 @@ class Connection { class StartupCallback : public RequestCallback { public: - StartupCallback(Connection* connection, Request* request) + StartupCallback(Connection* connection, const Request::ConstPtr& request) : RequestCallback(request) { set_connection(connection); } @@ -245,12 +245,12 @@ class Connection { struct PendingSchemaAgreement : public List::Node { - PendingSchemaAgreement(const SharedRefPtr& callback) + PendingSchemaAgreement(const SchemaChangeCallback::Ptr& callback) : callback(callback) { } void stop_timer(); - SharedRefPtr callback; + SchemaChangeCallback::Ptr callback; Timer timer; }; diff --git a/src/control_connection.cpp b/src/control_connection.cpp index 1a14736c3..828c6905e 100644 --- a/src/control_connection.cpp +++ b/src/control_connection.cpp @@ -67,9 +67,9 @@ class ControlStartupQueryPlan : public QueryPlan { std::transform(hosts.begin(), hosts.end(), std::back_inserter(hosts_), GetHost()); } - virtual SharedRefPtr compute_next() { + virtual Host::Ptr compute_next() { const size_t size = hosts_.size(); - if (count_ >= size) return SharedRefPtr(); + if (count_ >= size) return Host::Ptr(); size_t index = (index_ + count_) % size; ++count_; return hosts_[index]; @@ -122,7 +122,7 @@ ControlConnection::ControlConnection() , use_schema_(false) , token_aware_routing_(false) { } -const SharedRefPtr& ControlConnection::connected_host() const { +const Host::Ptr& ControlConnection::connected_host() const { return current_host_; } @@ -279,7 +279,7 @@ void ControlConnection::on_event(EventResponse* response) { switch (response->topology_change()) { case EventResponse::NEW_NODE: { LOG_INFO("New node %s added", address_str.c_str()); - SharedRefPtr host = session_->get_host(response->affected_node()); + Host::Ptr host = session_->get_host(response->affected_node()); if (!host) { host = session_->add_host(response->affected_node()); refresh_node_info(host, true, true); @@ -289,7 +289,7 @@ void ControlConnection::on_event(EventResponse* response) { case EventResponse::REMOVED_NODE: { LOG_INFO("Node %s removed", address_str.c_str()); - SharedRefPtr host = session_->get_host(response->affected_node()); + Host::Ptr host = session_->get_host(response->affected_node()); if (host) { session_->on_remove(host); if (session_->token_map_) { @@ -303,7 +303,7 @@ void ControlConnection::on_event(EventResponse* response) { case EventResponse::MOVED_NODE: LOG_INFO("Node %s moved", address_str.c_str()); - SharedRefPtr host = session_->get_host(response->affected_node()); + Host::Ptr host = session_->get_host(response->affected_node()); if (host) { refresh_node_info(host, false, true); } else { @@ -406,7 +406,7 @@ void ControlConnection::on_event(EventResponse* response) { } void ControlConnection::query_meta_hosts() { - ScopedRefPtr > callback( + SharedRefPtr > callback( new ControlMultipleRequestCallback(this, ControlConnection::on_query_hosts, UnusedData())); // This needs to happen before other schema metadata queries so that we have // a valid Cassandra version because this version determines which follow up @@ -438,7 +438,7 @@ void ControlConnection::on_query_hosts(ControlConnection* control_connection, // versions of Cassandra. If this happens we defunct the connection and move // to the next node in the query plan. { - SharedRefPtr host = session->get_host(connection->address()); + Host::Ptr host = session->get_host(connection->address()); if (host) { host->set_mark(session->current_host_mark_); @@ -475,7 +475,7 @@ void ControlConnection::on_query_hosts(ControlConnection* control_connection, continue; } - SharedRefPtr host = session->get_host(address); + Host::Ptr host = session->get_host(address); bool is_new = false; if (!host) { is_new = true; @@ -509,7 +509,7 @@ void ControlConnection::on_query_hosts(ControlConnection* control_connection, //TODO: query and callbacks should be in Metadata // punting for now because of tight coupling of Session and CC state void ControlConnection::query_meta_schema() { - ScopedRefPtr > callback( + SharedRefPtr > callback( new ControlMultipleRequestCallback(this, ControlConnection::on_query_meta_schema, UnusedData())); if (cassandra_version_ >= VersionNumber(3, 0, 0)) { @@ -621,7 +621,7 @@ void ControlConnection::on_query_meta_schema(ControlConnection* control_connecti } } -void ControlConnection::refresh_node_info(SharedRefPtr host, +void ControlConnection::refresh_node_info(Host::Ptr host, bool is_new_node, bool query_tokens) { if (connection_ == NULL) { @@ -652,11 +652,11 @@ void ControlConnection::refresh_node_info(SharedRefPtr host, LOG_DEBUG("refresh_node_info: %s", query.c_str()); RefreshNodeData data(host, is_new_node); - ScopedRefPtr > callback( - new ControlCallback(new QueryRequest(query), - this, - response_callback, - data)); + SharedRefPtr > callback( + new ControlCallback(Request::ConstPtr(new QueryRequest(query)), + this, + response_callback, + data)); if (!connection_->write(callback.get())) { LOG_ERROR("No more stream available while attempting to refresh node info"); } @@ -729,7 +729,7 @@ void ControlConnection::on_refresh_node_info_all(ControlConnection* control_conn } } -void ControlConnection::update_node_info(SharedRefPtr host, const Row* row, UpdateHostType type) { +void ControlConnection::update_node_info(Host::Ptr host, const Row* row, UpdateHostType type) { const Value* v; std::string rack; @@ -807,7 +807,7 @@ void ControlConnection::refresh_keyspace(const StringRef& keyspace_name) { LOG_DEBUG("Refreshing keyspace %s", query.c_str()); connection_->write( - new ControlCallback(new QueryRequest(query), + new ControlCallback(Request::ConstPtr(new QueryRequest(query)), this, ControlConnection::on_refresh_keyspace, keyspace_name.to_string())); @@ -874,10 +874,10 @@ void ControlConnection::refresh_table_or_view(const StringRef& keyspace_name, LOG_DEBUG("Refreshing table %s; %s", table_query.c_str(), column_query.c_str()); } - ScopedRefPtr > callback( + SharedRefPtr > callback( new ControlMultipleRequestCallback(this, - ControlConnection::on_refresh_table_or_view, - RefreshTableData(keyspace_name.to_string(), table_or_view_name.to_string()))); + ControlConnection::on_refresh_table_or_view, + RefreshTableData(keyspace_name.to_string(), table_or_view_name.to_string()))); callback->execute_query("tables", table_query); if (!view_query.empty()) { callback->execute_query("views", view_query); @@ -937,7 +937,7 @@ void ControlConnection::refresh_type(const StringRef& keyspace_name, LOG_DEBUG("Refreshing type %s", query.c_str()); connection_->write( - new ControlCallback >(new QueryRequest(query), + new ControlCallback >(Request::ConstPtr(new QueryRequest(query)), this, ControlConnection::on_refresh_type, std::make_pair(keyspace_name.to_string(), type_name.to_string()))); @@ -1000,10 +1000,10 @@ void ControlConnection::refresh_function(const StringRef& keyspace_name, request->set(0, CassString(keyspace_name.data(), keyspace_name.size())); request->set(1, CassString(function_name.data(), function_name.size())); - request->set(2, signature.get()); + request->set(2, signature); connection_->write( - new ControlCallback(request.get(), + new ControlCallback(request, this, ControlConnection::on_refresh_function, RefreshFunctionData(keyspace_name, function_name, arg_types, is_aggregate))); @@ -1057,7 +1057,7 @@ void ControlConnection::handle_query_timeout() { } void ControlConnection::on_up(const Address& address) { - SharedRefPtr host = session_->get_host(address); + Host::Ptr host = session_->get_host(address); if (host) { if (host->is_up()) return; @@ -1074,7 +1074,7 @@ void ControlConnection::on_up(const Address& address) { } void ControlConnection::on_down(const Address& address) { - SharedRefPtr host = session_->get_host(address); + Host::Ptr host = session_->get_host(address); if (host) { if (host->is_down()) return; diff --git a/src/control_connection.hpp b/src/control_connection.hpp index fd4b2f0e8..c41b7a301 100644 --- a/src/control_connection.hpp +++ b/src/control_connection.hpp @@ -62,7 +62,7 @@ class ControlConnection : public Connection::Listener { return cassandra_version_; } - const SharedRefPtr& connected_host() const; + const Host::Ptr& connected_host() const; void clear(); @@ -120,7 +120,7 @@ class ControlConnection : public Connection::Listener { public: typedef void (*ResponseCallback)(ControlConnection*, const T&, Response*); - ControlCallback(const Request* request, + ControlCallback(const Request::ConstPtr& request, ControlConnection* control_connection, ResponseCallback response_callback, const T& data) @@ -152,11 +152,11 @@ class ControlConnection : public Connection::Listener { }; struct RefreshNodeData { - RefreshNodeData(const SharedRefPtr& host, + RefreshNodeData(const Host::Ptr& host, bool is_new_node) : host(host) , is_new_node(is_new_node) {} - SharedRefPtr host; + Host::Ptr host; bool is_new_node; }; @@ -208,7 +208,7 @@ class ControlConnection : public Connection::Listener { const UnusedData& data, const MultipleRequestCallback::ResponseMap& responses); - void refresh_node_info(SharedRefPtr host, + void refresh_node_info(Host::Ptr host, bool is_new_node, bool query_tokens = false); static void on_refresh_node_info(ControlConnection* control_connection, @@ -218,7 +218,7 @@ class ControlConnection : public Connection::Listener { const RefreshNodeData& data, Response* response); - void update_node_info(SharedRefPtr host, const Row* row, UpdateHostType type); + void update_node_info(Host::Ptr host, const Row* row, UpdateHostType type); void refresh_keyspace(const StringRef& keyspace_name); static void on_refresh_keyspace(ControlConnection* control_connection, const std::string& keyspace_name, Response* response); diff --git a/src/data_type.cpp b/src/data_type.cpp index 679628cc1..eef6d7d5b 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -62,9 +62,9 @@ CassDataType* cass_data_type_new(CassValueType type) { } CassDataType* cass_data_type_new_from_existing(const CassDataType* data_type) { - cass::DataType* copy = data_type->copy(); + cass::DataType::Ptr copy = data_type->copy(); copy->inc_ref(); - return CassDataType::to(copy); + return CassDataType::to(copy.get()); } CassDataType* cass_data_type_new_tuple(size_t item_count) { @@ -75,9 +75,9 @@ CassDataType* cass_data_type_new_tuple(size_t item_count) { } CassDataType* cass_data_type_new_udt(size_t field_count) { - cass::UserType* user_type = new cass::UserType(field_count); - user_type->inc_ref(); - return CassDataType::to(user_type); + cass::DataType* data_type = new cass::UserType(field_count); + data_type->inc_ref(); + return CassDataType::to(data_type); } const CassDataType* cass_data_type_sub_data_type(const CassDataType* data_type, @@ -297,18 +297,18 @@ CassError cass_data_type_add_sub_type(CassDataType* data_type, if (composite_type->types().size() >= 1) { return CASS_ERROR_LIB_BAD_PARAMS; } - composite_type->types().push_back(cass::SharedRefPtr(sub_data_type)); + composite_type->types().push_back(cass::DataType::ConstPtr(sub_data_type)); break; case CASS_VALUE_TYPE_MAP: if (composite_type->types().size() >= 2) { return CASS_ERROR_LIB_BAD_PARAMS; } - composite_type->types().push_back(cass::SharedRefPtr(sub_data_type)); + composite_type->types().push_back(cass::DataType::ConstPtr(sub_data_type)); break; case CASS_VALUE_TYPE_TUPLE: - composite_type->types().push_back(cass::SharedRefPtr(sub_data_type)); + composite_type->types().push_back(cass::DataType::ConstPtr(sub_data_type)); break; default: @@ -339,7 +339,7 @@ CassError cass_data_type_add_sub_type_by_name_n(CassDataType* data_type, = static_cast(data_type->from()); user_type->add_field(std::string(name, name_length), - cass::SharedRefPtr(sub_data_type)); + cass::DataType::ConstPtr(sub_data_type)); return CASS_OK; @@ -347,7 +347,7 @@ CassError cass_data_type_add_sub_type_by_name_n(CassDataType* data_type, CassError cass_data_type_add_sub_value_type(CassDataType* data_type, CassValueType sub_value_type) { - cass::SharedRefPtr sub_data_type( + cass::DataType::ConstPtr sub_data_type( new cass::DataType(sub_value_type)); return cass_data_type_add_sub_type(data_type, CassDataType::to(sub_data_type.get())); @@ -357,7 +357,7 @@ CassError cass_data_type_add_sub_value_type(CassDataType* data_type, CassError cass_data_type_add_sub_value_type_by_name(CassDataType* data_type, const char* name, CassValueType sub_value_type) { - cass::SharedRefPtr sub_data_type( + cass::DataType::ConstPtr sub_data_type( new cass::DataType(sub_value_type)); return cass_data_type_add_sub_type_by_name(data_type, name, CassDataType::to(sub_data_type.get())); @@ -367,7 +367,7 @@ CassError cass_data_type_add_sub_value_type_by_name_n(CassDataType* data_type, const char* name, size_t name_length, CassValueType sub_value_type) { - cass::SharedRefPtr sub_data_type( + cass::DataType::ConstPtr sub_data_type( new cass::DataType(sub_value_type)); return cass_data_type_add_sub_type_by_name_n(data_type, name, name_length, CassDataType::to(sub_data_type.get())); diff --git a/src/data_type.hpp b/src/data_type.hpp index 3e6350e1c..ba131d207 100644 --- a/src/data_type.hpp +++ b/src/data_type.hpp @@ -65,6 +65,7 @@ inline bool equals_both_not_empty(const std::string& s1, class DataType : public RefCounted { public: + typedef SharedRefPtr Ptr; typedef SharedRefPtr ConstPtr; typedef std::vector Vec; @@ -103,8 +104,8 @@ class DataType : public RefCounted { } } - virtual DataType* copy() const { - return new DataType(value_type_); + virtual DataType::Ptr copy() const { + return Ptr(new DataType(value_type_)); } virtual std::string to_string() const { @@ -168,12 +169,12 @@ class CustomType : public DataType { if (data_type->value_type() != CASS_VALUE_TYPE_CUSTOM) { return false; } - const SharedRefPtr& custom_type(data_type); + const ConstPtr& custom_type(data_type); return equals_both_not_empty(class_name_, custom_type->class_name_); } - virtual DataType* copy() const { - return new CustomType(class_name_); + virtual DataType::Ptr copy() const { + return DataType::Ptr(new CustomType(class_name_)); } virtual std::string to_string() const { @@ -265,8 +266,8 @@ class CollectionType : public CompositeType { return true; } - virtual DataType* copy() const { - return new CollectionType(value_type(), types_, is_frozen()); + virtual DataType::Ptr copy() const { + return DataType::Ptr(new CollectionType(value_type(), types_, is_frozen())); } public: @@ -312,7 +313,7 @@ class TupleType : public CompositeType { return false; } - const SharedRefPtr& tuple_type(data_type); + const ConstPtr& tuple_type(data_type); // Only compare sub-types if both have sub-types if(!types_.empty() && !tuple_type->types_.empty()) { @@ -329,8 +330,8 @@ class TupleType : public CompositeType { return true; } - virtual DataType* copy() const { - return new TupleType(types_, is_frozen()); + virtual DataType::Ptr copy() const { + return DataType::Ptr(new TupleType(types_, is_frozen())); } }; @@ -431,8 +432,8 @@ class UserType : public DataType { return true; } - virtual DataType* copy() const { - return new UserType(keyspace_, type_name_, fields_.entries(), is_frozen()); + virtual DataType::Ptr copy() const { + return DataType::Ptr(new UserType(keyspace_, type_name_, fields_.entries(), is_frozen())); } virtual std::string to_string() const { diff --git a/src/data_type_parser.cpp b/src/data_type_parser.cpp index 09d037677..bf380edaf 100644 --- a/src/data_type_parser.cpp +++ b/src/data_type_parser.cpp @@ -405,7 +405,7 @@ DataType::ConstPtr DataTypeClassNameParser::parse_one(const std::string& type, c return DataType::ConstPtr(new CustomType(next)); } -SharedRefPtr DataTypeClassNameParser::parse_with_composite(const std::string& type, const NativeDataTypes& native_types) { +ParseResult::Ptr DataTypeClassNameParser::parse_with_composite(const std::string& type, const NativeDataTypes& native_types) { Parser parser(type, 0); std::string next; @@ -414,20 +414,20 @@ SharedRefPtr DataTypeClassNameParser::parse_with_composite(const st if (!is_composite(next)) { DataType::ConstPtr data_type = parse_one(type, native_types); if (!data_type) { - return SharedRefPtr(); + return ParseResult::Ptr(); } - return SharedRefPtr(new ParseResult(data_type, is_reversed(next))); + return ParseResult::Ptr(new ParseResult(data_type, is_reversed(next))); } TypeParamsVec sub_class_names; if (!parser.get_type_params(&sub_class_names)) { - return SharedRefPtr(); + return ParseResult::Ptr(); } if (sub_class_names.empty()) { LOG_ERROR("Expected at least one subclass type for a composite type"); - return SharedRefPtr(); + return ParseResult::Ptr(); } ParseResult::CollectionMap collections; @@ -440,14 +440,14 @@ SharedRefPtr DataTypeClassNameParser::parse_with_composite(const st collection_parser.get_next_name(); NameAndTypeParamsVec params; if (!collection_parser.get_collection_params(¶ms)) { - return SharedRefPtr(); + return ParseResult::Ptr(); } for (NameAndTypeParamsVec::const_iterator i = params.begin(), end = params.end(); i != end; ++i) { DataType::ConstPtr data_type = parse_one(i->second, native_types); if (!data_type) { - return SharedRefPtr(); + return ParseResult::Ptr(); } collections[i->first] = data_type; } @@ -458,13 +458,13 @@ SharedRefPtr DataTypeClassNameParser::parse_with_composite(const st for (size_t i = 0; i < count; ++i) { DataType::ConstPtr data_type = parse_one(sub_class_names[i], native_types); if (!data_type) { - return SharedRefPtr(); + return ParseResult::Ptr(); } types.push_back(data_type); reversed.push_back(is_reversed(sub_class_names[i])); } - return SharedRefPtr(new ParseResult(true, types, reversed, collections)); + return ParseResult::Ptr(new ParseResult(true, types, reversed, collections)); } bool DataTypeClassNameParser::get_nested_class_name(const std::string& type, std::string* class_name) { diff --git a/src/data_type_parser.hpp b/src/data_type_parser.hpp index b7873d001..740efc933 100644 --- a/src/data_type_parser.hpp +++ b/src/data_type_parser.hpp @@ -106,8 +106,9 @@ class DataTypeCqlNameParser { class ParseResult : public RefCounted { public: + typedef SharedRefPtr Ptr; typedef std::vector ReversedVec; - typedef std::map CollectionMap; + typedef std::map CollectionMap; ParseResult(DataType::ConstPtr type, bool reversed) : is_composite_(false) { @@ -147,7 +148,7 @@ class DataTypeClassNameParser { static bool is_tuple_type(const std::string& type); static DataType::ConstPtr parse_one(const std::string& type, const NativeDataTypes& native_types); - static SharedRefPtr parse_with_composite(const std::string& type, const NativeDataTypes& native_types); + static ParseResult::Ptr parse_with_composite(const std::string& type, const NativeDataTypes& native_types); private: static bool get_nested_class_name(const std::string& type, std::string* class_name); diff --git a/src/dc_aware_policy.cpp b/src/dc_aware_policy.cpp index 036975c8e..42451932c 100644 --- a/src/dc_aware_policy.cpp +++ b/src/dc_aware_policy.cpp @@ -26,7 +26,7 @@ namespace cass { static const CopyOnWriteHostVec NO_HOSTS(new HostVec()); -void DCAwarePolicy::init(const SharedRefPtr& connected_host, +void DCAwarePolicy::init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random) { if (local_dc_.empty() && connected_host && !connected_host->dc().empty()) { @@ -45,7 +45,7 @@ void DCAwarePolicy::init(const SharedRefPtr& connected_host, } } -CassHostDistance DCAwarePolicy::distance(const SharedRefPtr& host) const { +CassHostDistance DCAwarePolicy::distance(const Host::Ptr& host) const { if (local_dc_.empty() || host->dc() == local_dc_) { return CASS_HOST_DISTANCE_LOCAL; } @@ -69,7 +69,7 @@ QueryPlan* DCAwarePolicy::new_query_plan(const std::string& connected_keyspace, return new DCAwareQueryPlan(this, cl, index_++); } -void DCAwarePolicy::on_add(const SharedRefPtr& host) { +void DCAwarePolicy::on_add(const Host::Ptr& host) { const std::string& dc = host->dc(); if (local_dc_.empty() && !dc.empty()) { LOG_INFO("Using '%s' for local data center " @@ -85,7 +85,7 @@ void DCAwarePolicy::on_add(const SharedRefPtr& host) { } } -void DCAwarePolicy::on_remove(const SharedRefPtr& host) { +void DCAwarePolicy::on_remove(const Host::Ptr& host) { const std::string& dc = host->dc(); if (dc == local_dc_) { remove_host(local_dc_live_hosts_, host); @@ -94,15 +94,15 @@ void DCAwarePolicy::on_remove(const SharedRefPtr& host) { } } -void DCAwarePolicy::on_up(const SharedRefPtr& host) { +void DCAwarePolicy::on_up(const Host::Ptr& host) { on_add(host); } -void DCAwarePolicy::on_down(const SharedRefPtr& host) { +void DCAwarePolicy::on_down(const Host::Ptr& host) { on_remove(host); } -void DCAwarePolicy::PerDCHostMap::add_host_to_dc(const std::string& dc, const SharedRefPtr& host) { +void DCAwarePolicy::PerDCHostMap::add_host_to_dc(const std::string& dc, const Host::Ptr& host) { ScopedWriteLock wl(&rwlock_); Map::iterator i = map_.find(dc); if (i == map_.end()) { @@ -114,7 +114,7 @@ void DCAwarePolicy::PerDCHostMap::add_host_to_dc(const std::string& dc, const Sh } } -void DCAwarePolicy::PerDCHostMap::remove_host_from_dc(const std::string& dc, const SharedRefPtr& host) { +void DCAwarePolicy::PerDCHostMap::remove_host_from_dc(const std::string& dc, const Host::Ptr& host) { ScopedWriteLock wl(&rwlock_); Map::iterator i = map_.find(dc); if (i != map_.end()) { @@ -138,7 +138,7 @@ void DCAwarePolicy::PerDCHostMap::copy_dcs(KeySet* dcs) const { } // Helper method to prevent copy (Notice: "const CopyOnWriteHostVec&") -static const SharedRefPtr& get_next_host(const CopyOnWriteHostVec& hosts, size_t index) { +static const Host::Ptr& get_next_host(const CopyOnWriteHostVec& hosts, size_t index) { return (*hosts)[index % hosts->size()]; } @@ -157,17 +157,17 @@ DCAwarePolicy::DCAwareQueryPlan::DCAwareQueryPlan(const DCAwarePolicy* policy, , remote_remaining_(0) , index_(start_index) {} -SharedRefPtr DCAwarePolicy::DCAwareQueryPlan::compute_next() { +Host::Ptr DCAwarePolicy::DCAwareQueryPlan::compute_next() { while (local_remaining_ > 0) { --local_remaining_; - const SharedRefPtr& host(get_next_host(hosts_, index_++)); + const Host::Ptr& host(get_next_host(hosts_, index_++)); if (host->is_up()) { return host; } } if (policy_->skip_remote_dcs_for_local_cl_ && is_dc_local(cl_)) { - return SharedRefPtr(); + return Host::Ptr(); } if (!remote_dcs_) { @@ -178,7 +178,7 @@ SharedRefPtr DCAwarePolicy::DCAwareQueryPlan::compute_next() { while (true) { while (remote_remaining_ > 0) { --remote_remaining_; - const SharedRefPtr& host(get_next_host(hosts_, index_++)); + const Host::Ptr& host(get_next_host(hosts_, index_++)); if (host->is_up()) { return host; } @@ -194,7 +194,7 @@ SharedRefPtr DCAwarePolicy::DCAwareQueryPlan::compute_next() { remote_dcs_->erase(i); } - return SharedRefPtr(); + return Host::Ptr(); } } // namespace cass diff --git a/src/dc_aware_policy.hpp b/src/dc_aware_policy.hpp index 21b001540..214b8aeea 100644 --- a/src/dc_aware_policy.hpp +++ b/src/dc_aware_policy.hpp @@ -46,22 +46,22 @@ class DCAwarePolicy : public LoadBalancingPolicy { , local_dc_live_hosts_(new HostVec) , index_(0) {} - virtual void init(const SharedRefPtr& connected_host, const HostMap& hosts, Random* random); + virtual void init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random); - virtual CassHostDistance distance(const SharedRefPtr& host) const; + virtual CassHostDistance distance(const Host::Ptr& host) const; virtual QueryPlan* new_query_plan(const std::string& connected_keyspace, const Request* request, const TokenMap* token_map, Request::EncodingCache* cache); - virtual void on_add(const SharedRefPtr& host); + virtual void on_add(const Host::Ptr& host); - virtual void on_remove(const SharedRefPtr& host); + virtual void on_remove(const Host::Ptr& host); - virtual void on_up(const SharedRefPtr& host); + virtual void on_up(const Host::Ptr& host); - virtual void on_down(const SharedRefPtr& host); + virtual void on_down(const Host::Ptr& host); virtual LoadBalancingPolicy* new_instance() { return new DCAwarePolicy(local_dc_, @@ -78,8 +78,8 @@ class DCAwarePolicy : public LoadBalancingPolicy { PerDCHostMap() { uv_rwlock_init(&rwlock_); } ~PerDCHostMap() { uv_rwlock_destroy(&rwlock_); } - void add_host_to_dc(const std::string& dc, const SharedRefPtr& host); - void remove_host_from_dc(const std::string& dc, const SharedRefPtr& host); + void add_host_to_dc(const std::string& dc, const Host::Ptr& host); + void remove_host_from_dc(const std::string& dc, const Host::Ptr& host); const CopyOnWriteHostVec& get_hosts(const std::string& dc) const; void copy_dcs(KeySet* dcs) const; @@ -100,7 +100,7 @@ class DCAwarePolicy : public LoadBalancingPolicy { CassConsistency cl, size_t start_index); - virtual SharedRefPtr compute_next(); + virtual Host::Ptr compute_next(); private: const DCAwarePolicy* policy_; diff --git a/src/execute_request.hpp b/src/execute_request.hpp index f35f4b2a2..094e70fb4 100644 --- a/src/execute_request.hpp +++ b/src/execute_request.hpp @@ -43,7 +43,7 @@ class ExecuteRequest : public Statement { } } - const SharedRefPtr& prepared() const { return prepared_; } + const Prepared::ConstPtr& prepared() const { return prepared_; } private: virtual size_t get_indices(StringRef name, IndexVec* indices) { @@ -62,8 +62,8 @@ class ExecuteRequest : public Statement { int internal_encode(int version, RequestCallback* callback, BufferVec* bufs) const; private: - SharedRefPtr prepared_; - SharedRefPtr metadata_; + Prepared::ConstPtr prepared_; + ResultMetadata::Ptr metadata_; }; } // namespace cass diff --git a/src/future.cpp b/src/future.cpp index 1db89ac26..c80587738 100644 --- a/src/future.cpp +++ b/src/future.cpp @@ -53,10 +53,9 @@ const CassResult* cass_future_get_result(CassFuture* future) { if (future->type() != cass::CASS_FUTURE_TYPE_RESPONSE) { return NULL; } - cass::ResponseFuture* response_future = - static_cast(future->from()); - cass::SharedRefPtr result(response_future->response()); + cass::SharedRefPtr result( + static_cast(future->from())->response()); if (!result) return NULL; result->inc_ref(); @@ -78,7 +77,7 @@ const CassPrepared* cass_future_get_prepared(CassFuture* future) { cass::Prepared* prepared = new cass::Prepared(result, response_future->statement, response_future->schema_metadata); - if (prepared) prepared->inc_ref(); + prepared->inc_ref(); return CassPrepared::to(prepared); } @@ -86,10 +85,9 @@ const CassErrorResult* cass_future_get_error_result(CassFuture* future) { if (future->type() != cass::CASS_FUTURE_TYPE_RESPONSE) { return NULL; } - cass::ResponseFuture* response_future = - static_cast(future->from()); - cass::SharedRefPtr response(response_future->response()); + cass::Response::Ptr response( + static_cast(future->from())->response()); if (!response || response->opcode() != CQL_OPCODE_ERROR) { return NULL; } @@ -126,7 +124,7 @@ size_t cass_future_custom_payload_item_count(CassFuture* future) { if (future->type() != cass::CASS_FUTURE_TYPE_RESPONSE) { return 0; } - cass::SharedRefPtr response( + cass::Response::Ptr response( static_cast(future->from())->response()); if (!response) return 0; return response->custom_payload().size(); @@ -141,7 +139,7 @@ CassError cass_future_custom_payload_item(CassFuture* future, if (future->type() != cass::CASS_FUTURE_TYPE_RESPONSE) { return CASS_ERROR_LIB_INVALID_FUTURE_TYPE; } - cass::SharedRefPtr response( + cass::Response::Ptr response( static_cast(future->from())->response()); if (!response) return CASS_ERROR_LIB_NO_CUSTOM_PAYLOAD; diff --git a/src/future.hpp b/src/future.hpp index c190e4099..742a3fcee 100644 --- a/src/future.hpp +++ b/src/future.hpp @@ -40,6 +40,7 @@ enum FutureType { class Future : public RefCounted { public: + typedef SharedRefPtr Ptr; typedef void (*Callback)(CassFuture*, void*); struct Error { diff --git a/src/host.cpp b/src/host.cpp index 217625efb..cd6169fed 100644 --- a/src/host.cpp +++ b/src/host.cpp @@ -18,7 +18,7 @@ namespace cass { -void add_host(CopyOnWriteHostVec& hosts, const SharedRefPtr& host) { +void add_host(CopyOnWriteHostVec& hosts, const Host::Ptr& host) { HostVec::iterator i; for (i = hosts->begin(); i != hosts->end(); ++i) { if ((*i)->address() == host->address()) { @@ -31,7 +31,7 @@ void add_host(CopyOnWriteHostVec& hosts, const SharedRefPtr& host) { } } -void remove_host(CopyOnWriteHostVec& hosts, const SharedRefPtr& host) { +void remove_host(CopyOnWriteHostVec& hosts, const Host::Ptr& host) { HostVec::iterator i; for (i = hosts->begin(); i != hosts->end(); ++i) { if ((*i)->address() == host->address()) { diff --git a/src/host.hpp b/src/host.hpp index c0a0a9a51..55e625d3e 100644 --- a/src/host.hpp +++ b/src/host.hpp @@ -100,10 +100,10 @@ class Host : public RefCounted { class StateListener { public: virtual ~StateListener() { } - virtual void on_add(const SharedRefPtr& host) = 0; - virtual void on_remove(const SharedRefPtr& host) = 0; - virtual void on_up(const SharedRefPtr& host) = 0; - virtual void on_down(const SharedRefPtr& host) = 0; + virtual void on_add(const Ptr& host) = 0; + virtual void on_remove(const Ptr& host) = 0; + virtual void on_up(const Ptr& host) = 0; + virtual void on_down(const Ptr& host) = 0; }; enum HostState { @@ -246,19 +246,19 @@ class Host : public RefCounted { DISALLOW_COPY_AND_ASSIGN(Host); }; -typedef std::map > HostMap; +typedef std::map HostMap; struct GetHost { typedef std::pair Pair; Host::Ptr operator()(const Pair& pair) const { return pair.second; } }; -typedef std::pair > HostPair; -typedef std::vector > HostVec; +typedef std::pair HostPair; +typedef std::vector HostVec; typedef CopyOnWritePtr CopyOnWriteHostVec; -void add_host(CopyOnWriteHostVec& hosts, const SharedRefPtr& host); -void remove_host(CopyOnWriteHostVec& hosts, const SharedRefPtr& host); +void add_host(CopyOnWriteHostVec& hosts, const Host::Ptr& host); +void remove_host(CopyOnWriteHostVec& hosts, const Host::Ptr& host); } // namespace cass diff --git a/src/io_worker.cpp b/src/io_worker.cpp index 0f16275cb..7e0d45e94 100644 --- a/src/io_worker.cpp +++ b/src/io_worker.cpp @@ -122,7 +122,7 @@ void IOWorker::add_pool(const Host::ConstPtr& host, bool is_initial_connection) set_host_is_available(address, false); - SharedRefPtr pool(new Pool(this, host, is_initial_connection)); + Pool::Ptr pool(new Pool(this, host, is_initial_connection)); pools_[address] = pool; pool->connect(); } else { @@ -150,7 +150,7 @@ void IOWorker::retry(RequestHandler* request_handler) { PoolMap::const_iterator it = pools_.find(address); if (it != pools_.end() && it->second->is_ready()) { - const SharedRefPtr& pool = it->second; + const Pool::Ptr& pool = it->second; Connection* connection = pool->borrow_connection(); if (connection != NULL) { if (!pool->write(connection, request_handler)) { @@ -209,7 +209,7 @@ void IOWorker::notify_pool_closed(Pool* pool) { } void IOWorker::add_pending_flush(Pool* pool) { - pools_pending_flush_.push_back(SharedRefPtr(pool)); + pools_pending_flush_.push_back(Pool::Ptr(pool)); } void IOWorker::maybe_close() { @@ -317,7 +317,7 @@ void IOWorker::schedule_reconnect(const Host::ConstPtr& host) { host->address_string().c_str(), config_.reconnect_wait_time_ms(), static_cast(this)); - SharedRefPtr pool(new Pool(this, host, false)); + Pool::Ptr pool(new Pool(this, host, false)); pools_[host->address()] = pool; pool->delayed_connect(); } diff --git a/src/io_worker.hpp b/src/io_worker.hpp index 64fb5d909..29aaf5a95 100644 --- a/src/io_worker.hpp +++ b/src/io_worker.hpp @@ -26,6 +26,7 @@ #include "host.hpp" #include "logger.hpp" #include "metrics.hpp" +#include "pool.hpp" #include "spsc_queue.hpp" #include "timer.hpp" @@ -37,7 +38,6 @@ namespace cass { class Config; -class Pool; class RequestHandler; class Session; class SSLContext; @@ -65,6 +65,8 @@ class IOWorker : public EventThread , public RefCounted { public: + typedef SharedRefPtr Ptr; + enum State { IO_WORKER_STATE_READY, IO_WORKER_STATE_CLOSING, @@ -131,8 +133,8 @@ class IOWorker #endif private: - typedef sparsehash::dense_hash_map, AddressHash> PoolMap; - typedef std::vector > PoolVec; + typedef sparsehash::dense_hash_map PoolMap; + typedef std::vector PoolVec; void schedule_reconnect(const Host::ConstPtr& host); diff --git a/src/latency_aware_policy.cpp b/src/latency_aware_policy.cpp index ce0061f3b..7044006e3 100644 --- a/src/latency_aware_policy.cpp +++ b/src/latency_aware_policy.cpp @@ -24,7 +24,7 @@ namespace cass { -void LatencyAwarePolicy::init(const SharedRefPtr& connected_host, +void LatencyAwarePolicy::init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random) { hosts_->reserve(hosts.size()); @@ -45,7 +45,7 @@ void LatencyAwarePolicy::register_handles(uv_loop_t* loop) { } void LatencyAwarePolicy::close_handles() { - if (calculate_min_average_task_ != NULL) { + if (calculate_min_average_task_) { PeriodicTask::stop(calculate_min_average_task_); } } @@ -59,33 +59,33 @@ QueryPlan* LatencyAwarePolicy::new_query_plan(const std::string& connected_keysp token_map, cache)); } -void LatencyAwarePolicy::on_add(const SharedRefPtr& host) { +void LatencyAwarePolicy::on_add(const Host::Ptr& host) { host->enable_latency_tracking(settings_.scale_ns, settings_.min_measured); add_host(hosts_, host); ChainedLoadBalancingPolicy::on_add(host); } -void LatencyAwarePolicy::on_remove(const SharedRefPtr& host) { +void LatencyAwarePolicy::on_remove(const Host::Ptr& host) { remove_host(hosts_, host); ChainedLoadBalancingPolicy::on_remove(host); } -void LatencyAwarePolicy::on_up(const SharedRefPtr& host) { +void LatencyAwarePolicy::on_up(const Host::Ptr& host) { add_host(hosts_, host); ChainedLoadBalancingPolicy::on_up(host); } -void LatencyAwarePolicy::on_down(const SharedRefPtr& host) { +void LatencyAwarePolicy::on_down(const Host::Ptr& host) { remove_host(hosts_, host); ChainedLoadBalancingPolicy::on_down(host); } -SharedRefPtr LatencyAwarePolicy::LatencyAwareQueryPlan::compute_next() { +Host::Ptr LatencyAwarePolicy::LatencyAwareQueryPlan::compute_next() { int64_t min = policy_->min_average_.load(); const Settings& settings = policy_->settings_; uint64_t now = uv_hrtime(); - SharedRefPtr host; + Host::Ptr host; while ((host = child_plan_->compute_next())) { TimestampedAverage latency = host->get_current_average(); @@ -107,7 +107,7 @@ SharedRefPtr LatencyAwarePolicy::LatencyAwareQueryPlan::compute_next() { return skipped_[skipped_index_++]; } - return SharedRefPtr(); + return Host::Ptr(); } void LatencyAwarePolicy::on_work(PeriodicTask* task) { diff --git a/src/latency_aware_policy.hpp b/src/latency_aware_policy.hpp index 824608e70..eb439b258 100644 --- a/src/latency_aware_policy.hpp +++ b/src/latency_aware_policy.hpp @@ -51,7 +51,7 @@ class LatencyAwarePolicy : public ChainedLoadBalancingPolicy { virtual ~LatencyAwarePolicy() {} - virtual void init(const SharedRefPtr& connected_host, const HostMap& hosts, Random* random); + virtual void init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random); virtual void register_handles(uv_loop_t* loop); virtual void close_handles(); @@ -65,10 +65,10 @@ class LatencyAwarePolicy : public ChainedLoadBalancingPolicy { return new LatencyAwarePolicy(child_policy_->new_instance(), settings_); } - virtual void on_add(const SharedRefPtr& host); - virtual void on_remove(const SharedRefPtr& host); - virtual void on_up(const SharedRefPtr& host); - virtual void on_down(const SharedRefPtr& host); + virtual void on_add(const Host::Ptr& host); + virtual void on_remove(const Host::Ptr& host); + virtual void on_up(const Host::Ptr& host); + virtual void on_down(const Host::Ptr& host); public: // Testing only @@ -84,7 +84,7 @@ class LatencyAwarePolicy : public ChainedLoadBalancingPolicy { , child_plan_(child_plan) , skipped_index_(0) {} - SharedRefPtr compute_next(); + Host::Ptr compute_next(); private: LatencyAwarePolicy* policy_; @@ -98,7 +98,7 @@ class LatencyAwarePolicy : public ChainedLoadBalancingPolicy { static void on_after_work(PeriodicTask* task); Atomic min_average_; - PeriodicTask* calculate_min_average_task_; + PeriodicTask::Ptr calculate_min_average_task_; Settings settings_; CopyOnWriteHostVec hosts_; diff --git a/src/list_policy.cpp b/src/list_policy.cpp index aea9b9ddb..0c9ae2732 100644 --- a/src/list_policy.cpp +++ b/src/list_policy.cpp @@ -20,13 +20,13 @@ namespace cass { -void ListPolicy::init(const SharedRefPtr& connected_host, +void ListPolicy::init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random) { HostMap valid_hosts; for (HostMap::const_iterator i = hosts.begin(), end = hosts.end(); i != end; ++i) { - const SharedRefPtr& host = i->second; + const Host::Ptr& host = i->second; if (is_valid_host(host)) { valid_hosts.insert(HostPair(i->first, host)); } @@ -39,7 +39,7 @@ void ListPolicy::init(const SharedRefPtr& connected_host, ChainedLoadBalancingPolicy::init(connected_host, valid_hosts, random); } -CassHostDistance ListPolicy::distance(const SharedRefPtr& host) const { +CassHostDistance ListPolicy::distance(const Host::Ptr& host) const { if (is_valid_host(host)) { return child_policy_->distance(host); } @@ -56,25 +56,25 @@ QueryPlan* ListPolicy::new_query_plan(const std::string& connected_keyspace, cache); } -void ListPolicy::on_add(const SharedRefPtr& host) { +void ListPolicy::on_add(const Host::Ptr& host) { if (is_valid_host(host)) { child_policy_->on_add(host); } } -void ListPolicy::on_remove(const SharedRefPtr& host) { +void ListPolicy::on_remove(const Host::Ptr& host) { if (is_valid_host(host)) { child_policy_->on_remove(host); } } -void ListPolicy::on_up(const SharedRefPtr& host) { +void ListPolicy::on_up(const Host::Ptr& host) { if (is_valid_host(host)) { child_policy_->on_up(host); } } -void ListPolicy::on_down(const SharedRefPtr& host) { +void ListPolicy::on_down(const Host::Ptr& host) { if (is_valid_host(host)) { child_policy_->on_down(host); } diff --git a/src/list_policy.hpp b/src/list_policy.hpp index 32cff537a..0956dbc17 100644 --- a/src/list_policy.hpp +++ b/src/list_policy.hpp @@ -30,24 +30,24 @@ class ListPolicy : public ChainedLoadBalancingPolicy { virtual ~ListPolicy() {} - virtual void init(const SharedRefPtr& connected_host, const HostMap& hosts, Random* random); + virtual void init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random); - virtual CassHostDistance distance(const SharedRefPtr& host) const; + virtual CassHostDistance distance(const Host::Ptr& host) const; virtual QueryPlan* new_query_plan(const std::string& connected_keyspace, const Request* request, const TokenMap* token_map, Request::EncodingCache* cache); - virtual void on_add(const SharedRefPtr& host); - virtual void on_remove(const SharedRefPtr& host); - virtual void on_up(const SharedRefPtr& host); - virtual void on_down(const SharedRefPtr& host); + virtual void on_add(const Host::Ptr& host); + virtual void on_remove(const Host::Ptr& host); + virtual void on_up(const Host::Ptr& host); + virtual void on_down(const Host::Ptr& host); virtual ListPolicy* new_instance() = 0; private: - virtual bool is_valid_host(const SharedRefPtr& host) const = 0; + virtual bool is_valid_host(const Host::Ptr& host) const = 0; }; diff --git a/src/load_balancing.hpp b/src/load_balancing.hpp index 44e89e39d..4d52833b1 100644 --- a/src/load_balancing.hpp +++ b/src/load_balancing.hpp @@ -62,10 +62,10 @@ inline bool is_dc_local(CassConsistency cl) { class QueryPlan { public: virtual ~QueryPlan() {} - virtual SharedRefPtr compute_next() = 0; + virtual Host::Ptr compute_next() = 0; bool compute_next(Address* address) { - SharedRefPtr host = compute_next(); + Host::Ptr host = compute_next(); if (host) { *address = host->address(); return true; @@ -76,17 +76,19 @@ class QueryPlan { class LoadBalancingPolicy : public Host::StateListener, public RefCounted { public: + typedef SharedRefPtr Ptr; + LoadBalancingPolicy() : RefCounted() {} virtual ~LoadBalancingPolicy() {} - virtual void init(const SharedRefPtr& connected_host, const HostMap& hosts, Random* random) = 0; + virtual void init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random) = 0; virtual void register_handles(uv_loop_t* loop) {} virtual void close_handles() {} - virtual CassHostDistance distance(const SharedRefPtr& host) const = 0; + virtual CassHostDistance distance(const Host::Ptr& host) const = 0; virtual QueryPlan* new_query_plan(const std::string& connected_keyspace, const Request* request, @@ -104,22 +106,22 @@ class ChainedLoadBalancingPolicy : public LoadBalancingPolicy { virtual ~ChainedLoadBalancingPolicy() {} - virtual void init(const SharedRefPtr& connected_host, const HostMap& hosts, Random* random) { + virtual void init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random) { return child_policy_->init(connected_host, hosts, random); } - virtual CassHostDistance distance(const SharedRefPtr& host) const { return child_policy_->distance(host); } + virtual CassHostDistance distance(const Host::Ptr& host) const { return child_policy_->distance(host); } - virtual void on_add(const SharedRefPtr& host) { child_policy_->on_add(host); } + virtual void on_add(const Host::Ptr& host) { child_policy_->on_add(host); } - virtual void on_remove(const SharedRefPtr& host) { child_policy_->on_remove(host); } + virtual void on_remove(const Host::Ptr& host) { child_policy_->on_remove(host); } - virtual void on_up(const SharedRefPtr& host) { child_policy_->on_up(host); } + virtual void on_up(const Host::Ptr& host) { child_policy_->on_up(host); } - virtual void on_down(const SharedRefPtr& host) { child_policy_->on_down(host); } + virtual void on_down(const Host::Ptr& host) { child_policy_->on_down(host); } protected: - ScopedRefPtr child_policy_; + LoadBalancingPolicy::Ptr child_policy_; }; } // namespace cass diff --git a/src/metadata.cpp b/src/metadata.cpp index 9e223e439..fb9b27a43 100644 --- a/src/metadata.cpp +++ b/src/metadata.cpp @@ -983,7 +983,7 @@ std::string MetadataBase::get_string_field(const std::string& name) const { return value->to_string(); } -const Value* MetadataBase::add_field(const SharedRefPtr& buffer, const Row* row, const std::string& name) { +const Value* MetadataBase::add_field(const RefBuffer::Ptr& buffer, const Row* row, const std::string& name) { const Value* value = row->get_by_name(name); if (value == NULL) return NULL; if (value->size() <= 0) { @@ -995,7 +995,7 @@ const Value* MetadataBase::add_field(const SharedRefPtr& buffer, cons } } -void MetadataBase::add_field(const SharedRefPtr& buffer, const Value& value, const std::string& name) { +void MetadataBase::add_field(const RefBuffer::Ptr& buffer, const Value& value, const std::string& name) { fields_[name] = MetadataField(name, value, buffer); } @@ -1026,14 +1026,14 @@ void MetadataBase::add_json_list_field(int protocol_version, const Row* row, con return; } - Collection collection(CollectionType::list(SharedRefPtr(new DataType(CASS_VALUE_TYPE_TEXT)), false), + Collection collection(CollectionType::list(DataType::Ptr(new DataType(CASS_VALUE_TYPE_TEXT)), false), d.Size()); for (rapidjson::Value::ConstValueIterator i = d.Begin(); i != d.End(); ++i) { collection.append(cass::CassString(i->GetString(), i->GetStringLength())); } size_t encoded_size = collection.get_items_size(protocol_version); - SharedRefPtr encoded(RefBuffer::create(encoded_size)); + RefBuffer::Ptr encoded(RefBuffer::create(encoded_size)); collection.encode_items(protocol_version, encoded->data()); @@ -1071,8 +1071,8 @@ const Value* MetadataBase::add_json_map_field(int protocol_version, const Row* r return (fields_[name] = MetadataField(name)).value(); } - Collection collection(CollectionType::map(SharedRefPtr(new DataType(CASS_VALUE_TYPE_TEXT)), - SharedRefPtr(new DataType(CASS_VALUE_TYPE_TEXT)), + Collection collection(CollectionType::map(DataType::Ptr(new DataType(CASS_VALUE_TYPE_TEXT)), + DataType::Ptr(new DataType(CASS_VALUE_TYPE_TEXT)), false), 2 * d.MemberCount()); for (rapidjson::Value::ConstMemberIterator i = d.MemberBegin(); i != d.MemberEnd(); ++i) { @@ -1081,7 +1081,7 @@ const Value* MetadataBase::add_json_map_field(int protocol_version, const Row* r } size_t encoded_size = collection.get_items_size(protocol_version); - SharedRefPtr encoded(RefBuffer::create(encoded_size)); + RefBuffer::Ptr encoded(RefBuffer::create(encoded_size)); collection.encode_items(protocol_version, encoded->data()); @@ -1165,7 +1165,7 @@ const UserType* KeyspaceMetadata::get_user_type(const std::string& name) const { return i->second.get(); } -void KeyspaceMetadata::update(int protocol_version, const VersionNumber& cassandra_version, const SharedRefPtr& buffer, const Row* row) { +void KeyspaceMetadata::update(int protocol_version, const VersionNumber& cassandra_version, const RefBuffer::Ptr& buffer, const Row* row) { add_field(buffer, row, "keyspace_name"); add_field(buffer, row, "durable_writes"); if (cassandra_version >= VersionNumber(3, 0, 0)) { @@ -1230,7 +1230,7 @@ void KeyspaceMetadata::drop_aggregate(const std::string& full_aggregate_name) { } TableMetadataBase::TableMetadataBase(int protocol_version, const VersionNumber& cassandra_version, - const std::string& name, const SharedRefPtr& buffer, const Row* row) + const std::string& name, const RefBuffer::Ptr& buffer, const Row* row) : MetadataBase(name) { add_field(buffer, row, "keyspace_name"); add_field(buffer, row, "bloom_filter_fp_chance"); @@ -1349,7 +1349,7 @@ void TableMetadataBase::build_keys_and_sort(int protocol_version, const VersionN } } - SharedRefPtr key_validator + ParseResult::Ptr key_validator = DataTypeClassNameParser::parse_with_composite(get_string_field("key_validator"), native_types); size_t size = key_validator->types().size(); partition_key_.reserve(size); @@ -1383,7 +1383,7 @@ void TableMetadataBase::build_keys_and_sort(int protocol_version, const VersionN } // TODO: Figure out how to test these special cases and properly document them here - SharedRefPtr comparator + ParseResult::Ptr comparator = DataTypeClassNameParser::parse_with_composite(get_string_field("comparator"), native_types); size_t size = comparator->types().size(); if (comparator->is_composite()) { @@ -1431,7 +1431,7 @@ void TableMetadataBase::build_keys_and_sort(int protocol_version, const VersionN const TableMetadata::Ptr TableMetadata::NIL; TableMetadata::TableMetadata(int protocol_version, const VersionNumber& cassandra_version, - const std::string& name, const SharedRefPtr& buffer, const Row* row) + const std::string& name, const RefBuffer::Ptr& buffer, const Row* row) : TableMetadataBase(protocol_version, cassandra_version, name, buffer, row) { add_field(buffer, row, table_column_name(cassandra_version)); if (cassandra_version >= VersionNumber(3, 0, 0)) { @@ -1470,7 +1470,7 @@ void TableMetadata::key_aliases(const NativeDataTypes& native_types, KeyAliases* } } if (output->empty()) {// C* 1.2 tables created via CQL2 or thrift don't have col meta or key aliases - SharedRefPtr key_validator_type + ParseResult::Ptr key_validator_type = DataTypeClassNameParser::parse_with_composite(get_string_field("key_validator"), native_types); const size_t count = key_validator_type->types().size(); std::ostringstream ss("key"); @@ -1488,7 +1488,7 @@ const ViewMetadata::Ptr ViewMetadata::NIL; ViewMetadata::ViewMetadata(int protocol_version, const VersionNumber& cassandra_version, TableMetadata* table, - const std::string& name, const SharedRefPtr& buffer, const Row* row) + const std::string& name, const RefBuffer::Ptr& buffer, const Row* row) : TableMetadataBase(protocol_version, cassandra_version, name, buffer, row) , base_table_(table) { add_field(buffer, row, "keyspace_name"); @@ -1519,7 +1519,7 @@ void TableMetadata::clear_indexes() { FunctionMetadata::FunctionMetadata(int protocol_version, const VersionNumber& cassandra_version, const NativeDataTypes native_types, const std::string& name, const Value* signature, KeyspaceMetadata* keyspace, - const SharedRefPtr& buffer, const Row* row) + const RefBuffer::Ptr& buffer, const Row* row) : MetadataBase(Metadata::full_function_name(name, signature->as_stringlist())) , simple_name_(name) { const Value* value1; @@ -1591,7 +1591,7 @@ const DataType* FunctionMetadata::get_arg_type(StringRef name) const { AggregateMetadata::AggregateMetadata(int protocol_version, const VersionNumber& cassandra_version, const NativeDataTypes native_types, const std::string& name, const Value* signature, KeyspaceMetadata* keyspace, - const SharedRefPtr& buffer, const Row* row) + const RefBuffer::Ptr& buffer, const Row* row) : MetadataBase(Metadata::full_function_name(name, signature->as_stringlist())) , simple_name_(name) { const Value* value; @@ -1674,7 +1674,7 @@ AggregateMetadata::AggregateMetadata(int protocol_version, const VersionNumber& } IndexMetadata::Ptr IndexMetadata::from_row(const std::string& index_name, - const SharedRefPtr& buffer, const Row* row) { + const RefBuffer::Ptr& buffer, const Row* row) { IndexMetadata::Ptr index(new IndexMetadata(index_name)); StringRef kind; @@ -1708,7 +1708,7 @@ void IndexMetadata::update(StringRef kind, const Value* options) { IndexMetadata::Ptr IndexMetadata::from_legacy(int protocol_version, const std::string& index_name, const ColumnMetadata* column, - const SharedRefPtr& buffer, const Row* row) { + const RefBuffer::Ptr& buffer, const Row* row) { IndexMetadata::Ptr index(new IndexMetadata(index_name)); index->add_field(buffer, row, "index_name"); @@ -1771,7 +1771,7 @@ CassIndexType IndexMetadata::index_type_from_string(StringRef index_type) { ColumnMetadata::ColumnMetadata(int protocol_version, const VersionNumber& cassandra_version, const NativeDataTypes native_types, const std::string& name, KeyspaceMetadata* keyspace, - const SharedRefPtr& buffer, const Row* row) + const RefBuffer::Ptr& buffer, const Row* row) : MetadataBase(name) , type_(CASS_COLUMN_TYPE_REGULAR) , position_(0) @@ -1862,7 +1862,7 @@ ColumnMetadata::ColumnMetadata(int protocol_version, const VersionNumber& cassan void Metadata::InternalData::update_keyspaces(int protocol_version, const VersionNumber& cassandra_version, ResultResponse* result) { - SharedRefPtr buffer = result->buffer(); + RefBuffer::Ptr buffer = result->buffer(); ResultIterator rows(result); while (rows.next()) { @@ -1881,7 +1881,7 @@ void Metadata::InternalData::update_keyspaces(int protocol_version, const Versio void Metadata::InternalData::update_tables(int protocol_version, const VersionNumber& cassandra_version, ResultResponse* result) { - SharedRefPtr buffer = result->buffer(); + RefBuffer::Ptr buffer = result->buffer(); ResultIterator rows(result); @@ -1910,7 +1910,7 @@ void Metadata::InternalData::update_tables(int protocol_version, const VersionNu void Metadata::InternalData::update_views(int protocol_version, const VersionNumber& cassandra_version, ResultResponse* result) { - SharedRefPtr buffer = result->buffer(); + RefBuffer::Ptr buffer = result->buffer(); ResultIterator rows(result); @@ -2043,7 +2043,7 @@ void Metadata::InternalData::update_user_types(int protocol_version, const Versi void Metadata::InternalData::update_functions(int protocol_version, const VersionNumber& cassandra_version, const NativeDataTypes native_types, ResultResponse* result) { - SharedRefPtr buffer = result->buffer(); + RefBuffer::Ptr buffer = result->buffer(); ResultIterator rows(result); @@ -2077,7 +2077,7 @@ void Metadata::InternalData::update_functions(int protocol_version, const Versio } void Metadata::InternalData::update_aggregates(int protocol_version, const VersionNumber& cassandra_version, const NativeDataTypes native_types, ResultResponse* result) { - SharedRefPtr buffer = result->buffer(); + RefBuffer::Ptr buffer = result->buffer(); ResultIterator rows(result); @@ -2139,7 +2139,7 @@ void Metadata::InternalData::drop_aggregate(const std::string& keyspace_name, co } void Metadata::InternalData::update_columns(int protocol_version, const VersionNumber& cassandra_version, const NativeDataTypes native_types, ResultResponse* result) { - SharedRefPtr buffer = result->buffer(); + RefBuffer::Ptr buffer = result->buffer(); ResultIterator rows(result); @@ -2196,7 +2196,7 @@ void Metadata::InternalData::update_columns(int protocol_version, const VersionN } void Metadata::InternalData::update_legacy_indexes(int protocol_version, const VersionNumber& cassandra_version, ResultResponse* result) { - SharedRefPtr buffer = result->buffer(); + RefBuffer::Ptr buffer = result->buffer(); ResultIterator rows(result); @@ -2248,7 +2248,7 @@ void Metadata::InternalData::update_legacy_indexes(int protocol_version, const V } void Metadata::InternalData::update_indexes(int protocol_version, const VersionNumber& cassandra_version, ResultResponse* result) { - SharedRefPtr buffer = result->buffer(); + RefBuffer::Ptr buffer = result->buffer(); ResultIterator rows(result); diff --git a/src/metadata.hpp b/src/metadata.hpp index 5b36339c5..730118a8f 100644 --- a/src/metadata.hpp +++ b/src/metadata.hpp @@ -108,7 +108,7 @@ class MetadataField { MetadataField(const std::string& name, const Value& value, - const SharedRefPtr& buffer) + const RefBuffer::Ptr& buffer) : name_(name) , value_(value) , buffer_(buffer) { } @@ -124,7 +124,7 @@ class MetadataField { private: std::string name_; Value value_; - SharedRefPtr buffer_; + RefBuffer::Ptr buffer_; }; class MetadataFieldIterator : public Iterator { @@ -158,8 +158,8 @@ class MetadataBase { } protected: - const Value* add_field(const SharedRefPtr& buffer, const Row* row, const std::string& name); - void add_field(const SharedRefPtr& buffer, const Value& value, const std::string& name); + const Value* add_field(const RefBuffer::Ptr& buffer, const Row* row, const std::string& name); + void add_field(const RefBuffer::Ptr& buffer, const Value& value, const std::string& name); void add_json_list_field(int version, const Row* row, const std::string& name); const Value* add_json_map_field(int version, const Row* row, const std::string& name); @@ -203,7 +203,7 @@ class FunctionMetadata : public MetadataBase, public RefCounted& buffer, const Row* row); + const RefBuffer::Ptr& buffer, const Row* row); const std::string& simple_name() const { return simple_name_; } const Argument::Vec& args() const { return args_; } @@ -236,7 +236,7 @@ class AggregateMetadata : public MetadataBase, public RefCounted& buffer, const Row* row); + const RefBuffer::Ptr& buffer, const Row* row); const std::string& simple_name() const { return simple_name_; } const DataType::Vec arg_types() const { return arg_types_; } @@ -271,12 +271,12 @@ class IndexMetadata : public MetadataBase, public RefCounted { , type_(CASS_INDEX_TYPE_UNKNOWN) { } static IndexMetadata::Ptr from_row(const std::string& index_name, - const SharedRefPtr& buffer, const Row* row); + const RefBuffer::Ptr& buffer, const Row* row); void update(StringRef index_type, const Value* options); static IndexMetadata::Ptr from_legacy(int protocol_version, const std::string& index_name, const ColumnMetadata* column, - const SharedRefPtr& buffer, const Row* row); + const RefBuffer::Ptr& buffer, const Row* row); void update_legacy(StringRef index_type, const ColumnMetadata* column, const Value* options); @@ -319,7 +319,7 @@ class ColumnMetadata : public MetadataBase, public RefCounted { ColumnMetadata(int protocol_version, const VersionNumber& cassandra_version, const NativeDataTypes native_types, const std::string& name, KeyspaceMetadata* keyspace, - const SharedRefPtr& buffer, const Row* row); + const RefBuffer::Ptr& buffer, const Row* row); CassColumnType type() const { return type_; } int32_t position() const { return position_; } @@ -353,7 +353,7 @@ class TableMetadataBase : public MetadataBase, public RefCounted& buffer, const Row* row); + const std::string& name, const RefBuffer::Ptr& buffer, const Row* row); virtual ~TableMetadataBase() { } @@ -390,7 +390,7 @@ class ViewMetadata : public TableMetadataBase { ViewMetadata(int protocol_version, const VersionNumber& cassandra_version, TableMetadata* table, const std::string& name, - const SharedRefPtr& buffer, const Row* row); + const RefBuffer::Ptr& buffer, const Row* row); const TableMetadata* base_table() const { return base_table_; } TableMetadata* base_table() { return base_table_; } @@ -466,7 +466,7 @@ class TableMetadata : public TableMetadataBase { }; TableMetadata(int protocol_version, const VersionNumber& cassandra_version, const std::string& name, - const SharedRefPtr& buffer, const Row* row); + const RefBuffer::Ptr& buffer, const Row* row); const ViewMetadata::Vec& views() const { return views_; } const IndexMetadata::Vec& indexes() const { return indexes_; } @@ -532,7 +532,7 @@ class KeyspaceMetadata : public MetadataBase { , aggregates_(new AggregateMetadata::Map) { } void update(int protocol_version, const VersionNumber& cassandra_version, - const SharedRefPtr& buffer, const Row* row); + const RefBuffer::Ptr& buffer, const Row* row); const FunctionMetadata::Map& functions() const { return *functions_; } const UserType::Map& user_types() const { return *user_types_; } diff --git a/src/periodic_task.hpp b/src/periodic_task.hpp index 9a91287b0..be1891cdc 100644 --- a/src/periodic_task.hpp +++ b/src/periodic_task.hpp @@ -26,26 +26,29 @@ namespace cass { class PeriodicTask : public RefCounted { public: + typedef SharedRefPtr Ptr; + typedef void (*Callback)(PeriodicTask*); void* data() { return data_; } - static PeriodicTask* start(uv_loop_t* loop, uint64_t repeat, void* data, + static Ptr start(uv_loop_t* loop, uint64_t repeat, void* data, Callback work_cb, Callback after_work_cb) { - PeriodicTask* task = new PeriodicTask(data, work_cb, after_work_cb); - task->inc_ref(); + Ptr task(new PeriodicTask(data, work_cb, after_work_cb)); + + task->inc_ref(); // Timer reference uv_timer_init(loop, &task->timer_handle_); uv_timer_start(&task->timer_handle_, on_timeout, repeat, repeat); return task; } - static void stop(PeriodicTask* task) { + static void stop(const Ptr& task) { uv_timer_stop(&task->timer_handle_); close(task); } private: - static void close(PeriodicTask* task) { + static void close(const Ptr& task) { uv_close(copy_cast(&task->timer_handle_), on_close); } @@ -58,14 +61,14 @@ class PeriodicTask : public RefCounted { if (task->is_running_) return; - task->inc_ref(); + task->inc_ref(); // Work reference task->is_running_ = true; uv_queue_work(handle->loop, &task->work_request_, on_work, on_after_work); } static void on_close(uv_handle_t* handle) { PeriodicTask* task = static_cast(handle->data); - task->dec_ref(); + task->dec_ref(); // Remove timer reference } static void on_work(uv_work_t* request) { @@ -77,7 +80,7 @@ class PeriodicTask : public RefCounted { PeriodicTask* task = static_cast(request->data); task->after_work_cb_(task); task->is_running_ = false; - task->dec_ref(); + task->dec_ref(); // Remove work reference } private: diff --git a/src/pool.cpp b/src/pool.cpp index 95976a446..99ed92634 100644 --- a/src/pool.cpp +++ b/src/pool.cpp @@ -48,13 +48,13 @@ class SetKeyspaceCallback : public RequestCallback { void on_result_response(ResponseMessage* response); private: - ScopedRefPtr request_handler_; + RequestHandler::Ptr request_handler_; }; SetKeyspaceCallback::SetKeyspaceCallback(Connection* connection, - const std::string& keyspace, - RequestHandler* request_handler) - : RequestCallback(new QueryRequest("USE \"" + keyspace + "\"")) + const std::string& keyspace, + RequestHandler* request_handler) + : RequestCallback(Request::ConstPtr(new QueryRequest("USE \"" + keyspace + "\""))) , request_handler_(request_handler) { set_connection(connection); } diff --git a/src/pool.hpp b/src/pool.hpp index 1331c7fed..25f9a42af 100644 --- a/src/pool.hpp +++ b/src/pool.hpp @@ -36,6 +36,8 @@ class Config; class Pool : public RefCounted , public Connection::Listener { public: + typedef SharedRefPtr Ptr; + enum PoolState { POOL_STATE_NEW, POOL_STATE_CONNECTING, diff --git a/src/prepared.cpp b/src/prepared.cpp index a0bbe4786..518dd22f2 100644 --- a/src/prepared.cpp +++ b/src/prepared.cpp @@ -27,8 +27,7 @@ void cass_prepared_free(const CassPrepared* prepared) { } CassStatement* cass_prepared_bind(const CassPrepared* prepared) { - cass::ExecuteRequest* execute - = new cass::ExecuteRequest(prepared); + cass::ExecuteRequest* execute = new cass::ExecuteRequest(prepared); execute->inc_ref(); return CassStatement::to(execute); } @@ -79,7 +78,7 @@ const CassDataType* cass_prepared_parameter_data_type_by_name_n(const CassPrepar namespace cass { -Prepared::Prepared(const SharedRefPtr& result, +Prepared::Prepared(const ResultResponse::Ptr& result, const std::string& statement, const Metadata::SchemaSnapshot& schema_metadata) : result_(result) diff --git a/src/prepared.hpp b/src/prepared.hpp index e4a39b1cc..8117b0769 100644 --- a/src/prepared.hpp +++ b/src/prepared.hpp @@ -28,17 +28,19 @@ namespace cass { class Prepared : public RefCounted { public: - Prepared(const SharedRefPtr& result, + typedef SharedRefPtr ConstPtr; + + Prepared(const ResultResponse::Ptr& result, const std::string& statement, const Metadata::SchemaSnapshot& schema_metadata); - const SharedRefPtr& result() const { return result_; } + const ResultResponse::ConstPtr& result() const { return result_; } const std::string& id() const { return id_; } const std::string& statement() const { return statement_; } const ResultResponse::PKIndexVec& key_indices() const { return key_indices_; } private: - SharedRefPtr result_; + ResultResponse::ConstPtr result_; std::string id_; std::string statement_; ResultResponse::PKIndexVec key_indices_; diff --git a/src/ref_counted.hpp b/src/ref_counted.hpp index 08330bbbf..9102c0dd0 100644 --- a/src/ref_counted.hpp +++ b/src/ref_counted.hpp @@ -26,13 +26,13 @@ namespace cass { -struct RefCountedBase {}; +struct RefCountedBase { }; template class RefCounted : public RefCountedBase { public: RefCounted() - : ref_count_(0) {} + : ref_count_(0) { } int ref_count() const { return ref_count_.load(MEMORY_ORDER_ACQUIRE); @@ -56,37 +56,6 @@ class RefCounted : public RefCountedBase { DISALLOW_COPY_AND_ASSIGN(RefCounted); }; -class RefBuffer : public RefCounted { -public: - static RefBuffer* create(size_t size) { -#if defined(_WIN32) -#pragma warning(push) -#pragma warning(disable: 4291) //Invalid warning thrown RefBuffer has a delete function -#endif - return new (size) RefBuffer(); -#if defined(_WIN32) -#pragma warning(pop) -#endif - } - - char* data() { - return reinterpret_cast(this) + sizeof(RefBuffer); - } - - void operator delete(void* ptr) { - ::operator delete(ptr); - } - -private: - RefBuffer() {} - - void* operator new(size_t size, size_t extra) { - return ::operator new(size + extra); - } - - DISALLOW_COPY_AND_ASSIGN(RefBuffer); -}; - template class SharedRefPtr { public: @@ -161,46 +130,37 @@ class SharedRefPtr { T* ptr_; }; -template -class ScopedRefPtr { +class RefBuffer : public RefCounted { public: - typedef T type; + typedef SharedRefPtr Ptr; - explicit ScopedRefPtr(type* ptr = NULL) - : ptr_(ptr) { - if (ptr_ != NULL) { - ptr_->inc_ref(); - } + static RefBuffer* create(size_t size) { +#if defined(_WIN32) +#pragma warning(push) +#pragma warning(disable: 4291) //Invalid warning thrown RefBuffer has a delete function +#endif + return new (size) RefBuffer(); +#if defined(_WIN32) +#pragma warning(pop) +#endif } - ~ScopedRefPtr() { - if (ptr_ != NULL) { - ptr_->dec_ref(); - } + char* data() { + return reinterpret_cast(this) + sizeof(RefBuffer); } - void reset(type* ptr = NULL) { - if (ptr == ptr_) return; - if (ptr != NULL) { - ptr->inc_ref(); - } - type* temp = ptr_; - ptr_ = ptr; - if (temp != NULL) { - temp->dec_ref(); - } + void operator delete(void* ptr) { + ::operator delete(ptr); } - type* get() const { return ptr_; } - type& operator*() const { return *ptr_; } - type* operator->() const { return ptr_; } - operator bool() const { return ptr_ != NULL; } - private: - type* ptr_; + RefBuffer() { } -private: - DISALLOW_COPY_AND_ASSIGN(ScopedRefPtr); + void* operator new(size_t size, size_t extra) { + return ::operator new(size + extra); + } + + DISALLOW_COPY_AND_ASSIGN(RefBuffer); }; } // namespace cass diff --git a/src/request.hpp b/src/request.hpp index cd4b98ea7..efb574059 100644 --- a/src/request.hpp +++ b/src/request.hpp @@ -36,6 +36,8 @@ class RequestMessage; class CustomPayload : public RefCounted { public: + typedef SharedRefPtr ConstPtr; + virtual ~CustomPayload() { } void set(const char* name, size_t name_length, @@ -54,6 +56,8 @@ class CustomPayload : public RefCounted { class Request : public RefCounted { public: + typedef SharedRefPtr ConstPtr; + enum { ENCODE_ERROR_UNSUPPORTED_PROTOCOL = -1, ENCODE_ERROR_BATCH_WITH_NAMED_VALUES = -2, @@ -103,7 +107,7 @@ class Request : public RefCounted { retry_policy_.reset(retry_policy); } - const SharedRefPtr& custom_payload() const { + const CustomPayload::ConstPtr& custom_payload() const { return custom_payload_; } @@ -119,8 +123,8 @@ class Request : public RefCounted { CassConsistency serial_consistency_; int64_t timestamp_; uint64_t request_timeout_ms_; - SharedRefPtr retry_policy_; - SharedRefPtr custom_payload_; + RetryPolicy::Ptr retry_policy_; + CustomPayload::ConstPtr custom_payload_; private: DISALLOW_COPY_AND_ASSIGN(Request); diff --git a/src/request_callback.cpp b/src/request_callback.cpp index de4714057..b21fdafd3 100644 --- a/src/request_callback.cpp +++ b/src/request_callback.cpp @@ -165,8 +165,10 @@ bool MultipleRequestCallback::get_result_response(const ResponseMap& responses, void MultipleRequestCallback::execute_query(const std::string& index, const std::string& query) { if (has_errors_or_timeouts_) return; - responses_[index] = SharedRefPtr(); - SharedRefPtr callback(new InternalCallback(this, new QueryRequest(query), index)); + responses_[index] = Response::Ptr(); + SharedRefPtr callback( + new InternalCallback(Ptr(this), + Request::ConstPtr(new QueryRequest(query)), index)); remaining_++; if (!connection_->write(callback.get())) { on_error(CASS_ERROR_LIB_NO_STREAMS, "No more streams available"); diff --git a/src/request_callback.hpp b/src/request_callback.hpp index 3aecfd050..82dd7b2ae 100644 --- a/src/request_callback.hpp +++ b/src/request_callback.hpp @@ -23,6 +23,7 @@ #include "utils.hpp" #include "list.hpp" #include "request.hpp" +#include "response.hpp" #include "scoped_ptr.hpp" #include "timer.hpp" @@ -33,7 +34,6 @@ namespace cass { class Config; class Connection; -class Response; class ResponseMessage; class ResultResponse; @@ -52,7 +52,7 @@ class RequestCallback : public RefCounted, public List, public List request_; + SharedRefPtr request_; Connection* connection_; private: @@ -137,7 +137,8 @@ class RequestCallback : public RefCounted, public List { public: - typedef std::map > ResponseMap; + typedef SharedRefPtr Ptr; + typedef std::map ResponseMap; MultipleRequestCallback(Connection* connection) : connection_(connection) @@ -163,7 +164,7 @@ class MultipleRequestCallback : public RefCounted { private: class InternalCallback : public RequestCallback { public: - InternalCallback(MultipleRequestCallback* parent, const Request* request, const std::string& index) + InternalCallback(const Ptr& parent, const Request::ConstPtr& request, const std::string& index) : RequestCallback(request) , parent_(parent) , index_(index) { } @@ -173,7 +174,7 @@ class MultipleRequestCallback : public RefCounted { virtual void on_timeout(); private: - ScopedRefPtr parent_; + Ptr parent_; std::string index_; }; diff --git a/src/request_handler.cpp b/src/request_handler.cpp index 8cb6b971c..1345bbad8 100644 --- a/src/request_handler.cpp +++ b/src/request_handler.cpp @@ -36,8 +36,8 @@ namespace cass { class PrepareCallback : public RequestCallback { public: - PrepareCallback(RequestHandler* request_handler) - : RequestCallback(NULL) + PrepareCallback(const RequestHandler::Ptr& request_handler) + : RequestCallback(Request::ConstPtr()) , request_handler_(request_handler) {} bool init(const std::string& prepared_id); @@ -47,7 +47,7 @@ class PrepareCallback : public RequestCallback { virtual void on_timeout(); private: - ScopedRefPtr request_handler_; + RequestHandler::Ptr request_handler_; }; bool PrepareCallback::init(const std::string& prepared_id) { @@ -136,7 +136,6 @@ void RequestHandler::on_timeout() { } void RequestHandler::set_io_worker(IOWorker* io_worker) { - future_->set_loop(io_worker->loop()); io_worker_ = io_worker; } @@ -164,7 +163,7 @@ bool RequestHandler::is_host_up(const Address& address) const { return io_worker_->is_host_up(address); } -void RequestHandler::set_response(const SharedRefPtr& response) { +void RequestHandler::set_response(const Response::Ptr& response) { uint64_t elapsed = uv_hrtime() - start_time_ns(); current_host_->update_latency(elapsed); connection_->metrics()->record_request(elapsed); @@ -181,7 +180,7 @@ void RequestHandler::set_error(CassError code, const std::string& message) { return_connection_and_finish(); } -void RequestHandler::set_error_with_error_response(const SharedRefPtr& error, +void RequestHandler::set_error_with_error_response(const Response::Ptr& error, CassError code, const std::string& message) { future_->set_error_with_response(current_host_->address(), error, code, message); return_connection_and_finish(); @@ -221,9 +220,9 @@ void RequestHandler::on_result_response(ResponseMessage* response) { break; case CASS_RESULT_KIND_SCHEMA_CHANGE: { - SharedRefPtr schema_change_handler( + SchemaChangeCallback::Ptr schema_change_handler( new SchemaChangeCallback(connection_, - this, + Ptr(this), response->response_body())); schema_change_handler->execute(); break; @@ -285,7 +284,7 @@ void RequestHandler::on_error_response(ResponseMessage* response) { } void RequestHandler::on_error_unprepared(ErrorResponse* error) { - ScopedRefPtr prepare_handler(new PrepareCallback(this)); + SharedRefPtr prepare_handler(new PrepareCallback(RequestHandler::Ptr(this))); if (prepare_handler->init(error->prepared_id().to_string())) { if (!connection_->write(prepare_handler.get())) { // Try to prepare on the same host but on a different connection @@ -325,7 +324,7 @@ void RequestHandler::handle_retry_decision(ResponseMessage* response, break; case RetryPolicy::RetryDecision::IGNORE: - set_response(SharedRefPtr(new ResultResponse())); + set_response(Response::Ptr(new ResultResponse())); break; } num_retries_++; diff --git a/src/request_handler.hpp b/src/request_handler.hpp index 2f110ba9a..c6b12841c 100644 --- a/src/request_handler.hpp +++ b/src/request_handler.hpp @@ -41,18 +41,20 @@ class Timer; class ResponseFuture : public Future { public: + typedef SharedRefPtr Ptr; + ResponseFuture(int protocol_version, const VersionNumber& cassandra_version, const Metadata& metadata) : Future(CASS_FUTURE_TYPE_RESPONSE) , schema_metadata(metadata.schema_snapshot(protocol_version, cassandra_version)) { } - void set_response(Address address, const SharedRefPtr& response) { + void set_response(Address address, const Response::Ptr& response) { ScopedMutex lock(&mutex_); address_ = address; response_ = response; internal_set(lock); } - const SharedRefPtr& response() { + const Response::Ptr& response() { ScopedMutex lock(&mutex_); internal_wait(lock); return response_; @@ -64,7 +66,7 @@ class ResponseFuture : public Future { internal_set_error(code, message, lock); } - void set_error_with_response(Address address, const SharedRefPtr& response, + void set_error_with_response(Address address, const Response::Ptr& response, CassError code, const std::string& message) { ScopedMutex lock(&mutex_); address_ = address; @@ -83,14 +85,16 @@ class ResponseFuture : public Future { private: Address address_; - SharedRefPtr response_; + Response::Ptr response_; }; class RequestHandler : public RequestCallback { public: - RequestHandler(const Request* request, - ResponseFuture* future, + typedef SharedRefPtr Ptr; + + RequestHandler(const Request::ConstPtr& request, + const ResponseFuture::Ptr& future, RetryPolicy* retry_policy) : RequestCallback(request) , future_(future) @@ -125,11 +129,11 @@ class RequestHandler : public RequestCallback { bool is_host_up(const Address& address) const; - void set_response(const SharedRefPtr& response); + void set_response(const Response::Ptr& response); private: void set_error(CassError code, const std::string& message); - void set_error_with_error_response(const SharedRefPtr& error, + void set_error_with_error_response(const Response::Ptr& error, CassError code, const std::string& message); void return_connection(); void return_connection_and_finish(); @@ -141,11 +145,11 @@ class RequestHandler : public RequestCallback { void handle_retry_decision(ResponseMessage* response, const RetryPolicy::RetryDecision& decision); - ScopedRefPtr future_; + SharedRefPtr future_; RetryPolicy* retry_policy_; int num_retries_; bool is_query_plan_exhausted_; - SharedRefPtr current_host_; + Host::Ptr current_host_; ScopedPtr query_plan_; IOWorker* io_worker_; Pool* pool_; diff --git a/src/response.hpp b/src/response.hpp index c3deea8e3..a6fa9a35d 100644 --- a/src/response.hpp +++ b/src/response.hpp @@ -30,6 +30,8 @@ namespace cass { class Response : public RefCounted { public: + typedef SharedRefPtr Ptr; + struct CustomPayloadItem { CustomPayloadItem(StringRef name, StringRef value) : name(name) @@ -49,10 +51,10 @@ class Response : public RefCounted { char* data() const { return buffer_->data(); } - const SharedRefPtr& buffer() const { return buffer_; } + const RefBuffer::Ptr& buffer() const { return buffer_; } void set_buffer(size_t size) { - buffer_ = SharedRefPtr(RefBuffer::create(size)); + buffer_ = RefBuffer::Ptr(RefBuffer::create(size)); } const CustomPayloadVec& custom_payload() const { return custom_payload_; } @@ -65,7 +67,7 @@ class Response : public RefCounted { private: uint8_t opcode_; - SharedRefPtr buffer_; + RefBuffer::Ptr buffer_; CustomPayloadVec custom_payload_; private: @@ -94,7 +96,7 @@ class ResponseMessage { int16_t stream() const { return stream_; } - const SharedRefPtr& response_body() { return response_body_; } + const Response::Ptr& response_body() { return response_body_; } bool is_body_ready() const { return is_body_ready_; } @@ -118,7 +120,7 @@ class ResponseMessage { bool is_body_ready_; bool is_body_error_; - SharedRefPtr response_body_; + Response::Ptr response_body_; char* body_buffer_pos_; private: diff --git a/src/result_metadata.hpp b/src/result_metadata.hpp index 6488d5c9a..06eb18811 100644 --- a/src/result_metadata.hpp +++ b/src/result_metadata.hpp @@ -43,6 +43,8 @@ struct ColumnDefinition : public HashTableEntry { class ResultMetadata : public RefCounted { public: + typedef SharedRefPtr Ptr; + ResultMetadata(size_t column_count); const ColumnDefinition& get_column_definition(size_t index) const { return defs_[index]; } diff --git a/src/result_response.cpp b/src/result_response.cpp index 0670ff44a..2d4bd27a6 100644 --- a/src/result_response.cpp +++ b/src/result_response.cpp @@ -108,7 +108,7 @@ class DataTypeDecoder { char* buffer() const { return buffer_; } - SharedRefPtr decode() { + DataType::Ptr decode() { uint16_t value_type; buffer_ = decode_uint16(buffer_, value_type); @@ -132,7 +132,7 @@ class DataTypeDecoder { if (data_type_cache_[value_type]) { return data_type_cache_[value_type]; } else { - SharedRefPtr data_type( + DataType::Ptr data_type( new DataType(static_cast(value_type))); data_type_cache_[value_type] = data_type; return data_type; @@ -141,26 +141,26 @@ class DataTypeDecoder { break; } - return SharedRefPtr(); + return DataType::Ptr(); } private: - SharedRefPtr decode_custom() { + DataType::Ptr decode_custom() { StringRef class_name; buffer_ = decode_string(buffer_, &class_name); - return SharedRefPtr(new CustomType(class_name.to_string())); + return DataType::Ptr(new CustomType(class_name.to_string())); } - SharedRefPtr decode_collection(CassValueType collection_type) { + DataType::Ptr decode_collection(CassValueType collection_type) { DataType::Vec types; types.push_back(decode()); if (collection_type == CASS_VALUE_TYPE_MAP) { types.push_back(decode()); } - return SharedRefPtr(new CollectionType(collection_type, types, false)); + return DataType::Ptr(new CollectionType(collection_type, types, false)); } - SharedRefPtr decode_user_type() { + DataType::Ptr decode_user_type() { StringRef keyspace; buffer_ = decode_string(buffer_, &keyspace); @@ -176,13 +176,13 @@ class DataTypeDecoder { buffer_ = decode_string(buffer_, &field_name); fields.push_back(UserType::Field(field_name.to_string(), decode())); } - return SharedRefPtr(new UserType(keyspace.to_string(), + return DataType::Ptr(new UserType(keyspace.to_string(), type_name.to_string(), fields, false)); } - SharedRefPtr decode_tuple() { + DataType::Ptr decode_tuple() { uint16_t n; buffer_ = decode_uint16(buffer_, n); @@ -190,12 +190,12 @@ class DataTypeDecoder { for (uint16_t i = 0; i < n; ++i) { types.push_back(decode()); } - return SharedRefPtr(new TupleType(types, false)); + return DataType::Ptr(new TupleType(types, false)); } private: char* buffer_; - SharedRefPtr data_type_cache_[CASS_VALUE_TYPE_LAST_ENTRY]; + DataType::Ptr data_type_cache_[CASS_VALUE_TYPE_LAST_ENTRY]; }; bool ResultResponse::decode(int version, char* input, size_t size) { @@ -230,7 +230,7 @@ bool ResultResponse::decode(int version, char* input, size_t size) { return false; } -char* ResultResponse::decode_metadata(char* input, SharedRefPtr* metadata, +char* ResultResponse::decode_metadata(char* input, ResultMetadata::Ptr* metadata, bool has_pk_indices) { int32_t flags = 0; char* buffer = decode_int32(input, flags); diff --git a/src/result_response.hpp b/src/result_response.hpp index 0a174da82..74ccf9972 100644 --- a/src/result_response.hpp +++ b/src/result_response.hpp @@ -35,6 +35,8 @@ class ResultIterator; class ResultResponse : public Response { public: + typedef SharedRefPtr Ptr; + typedef SharedRefPtr ConstPtr; typedef std::vector PKIndexVec; ResultResponse() @@ -57,14 +59,14 @@ class ResultResponse : public Response { bool no_metadata() const { return !metadata_; } - const SharedRefPtr& metadata() const { return metadata_; } + const ResultMetadata::Ptr& metadata() const { return metadata_; } void set_metadata(ResultMetadata* metadata) { metadata_.reset(metadata); decode_first_row(); } - const SharedRefPtr& result_metadata() const { return result_metadata_; } + const ResultMetadata::Ptr& result_metadata() const { return result_metadata_; } StringRef paging_state() const { return paging_state_; } StringRef prepared() const { return prepared_; } @@ -82,7 +84,7 @@ class ResultResponse : public Response { bool decode(int version, char* input, size_t size); private: - char* decode_metadata(char* input, SharedRefPtr* metadata, + char* decode_metadata(char* input, ResultMetadata::Ptr* metadata, bool has_pk_indices = false); void decode_first_row(); @@ -99,8 +101,8 @@ class ResultResponse : public Response { int protocol_version_; int32_t kind_; bool has_more_pages_; // row data - SharedRefPtr metadata_; - SharedRefPtr result_metadata_; + ResultMetadata::Ptr metadata_; + ResultMetadata::Ptr result_metadata_; StringRef paging_state_; // row paging StringRef prepared_; // prepared result StringRef change_; // schema change diff --git a/src/retry_policy.hpp b/src/retry_policy.hpp index fd4c83b66..152251949 100644 --- a/src/retry_policy.hpp +++ b/src/retry_policy.hpp @@ -30,6 +30,8 @@ namespace cass { class RetryPolicy : public RefCounted { public: + typedef SharedRefPtr Ptr; + enum Type { DEFAULT, DOWNGRADING, @@ -123,7 +125,7 @@ class FallthroughRetryPolicy : public RetryPolicy { class LoggingRetryPolicy : public RetryPolicy { public: - LoggingRetryPolicy(const SharedRefPtr& retry_policy) + LoggingRetryPolicy(const RetryPolicy::Ptr& retry_policy) : RetryPolicy(LOGGING) , retry_policy_(retry_policy) { } @@ -132,7 +134,7 @@ class LoggingRetryPolicy : public RetryPolicy { virtual RetryDecision on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const; private: - SharedRefPtr retry_policy_; + RetryPolicy::Ptr retry_policy_; }; } // namespace cass diff --git a/src/round_robin_policy.cpp b/src/round_robin_policy.cpp index f4405f65f..fd98d95bb 100644 --- a/src/round_robin_policy.cpp +++ b/src/round_robin_policy.cpp @@ -21,7 +21,7 @@ namespace cass { -void RoundRobinPolicy::init(const SharedRefPtr& connected_host, +void RoundRobinPolicy::init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random) { hosts_->reserve(hosts.size()); @@ -31,7 +31,7 @@ void RoundRobinPolicy::init(const SharedRefPtr& connected_host, } } -CassHostDistance RoundRobinPolicy::distance(const SharedRefPtr& host) const { +CassHostDistance RoundRobinPolicy::distance(const Host::Ptr& host) const { return CASS_HOST_DISTANCE_LOCAL; } @@ -42,31 +42,31 @@ QueryPlan* RoundRobinPolicy::new_query_plan(const std::string& connected_keyspac return new RoundRobinQueryPlan(hosts_, index_++); } -void RoundRobinPolicy::on_add(const SharedRefPtr& host) { +void RoundRobinPolicy::on_add(const Host::Ptr& host) { add_host(hosts_, host); } -void RoundRobinPolicy::on_remove(const SharedRefPtr& host) { +void RoundRobinPolicy::on_remove(const Host::Ptr& host) { remove_host(hosts_, host); } -void RoundRobinPolicy::on_up(const SharedRefPtr& host) { +void RoundRobinPolicy::on_up(const Host::Ptr& host) { on_add(host); } -void RoundRobinPolicy::on_down(const SharedRefPtr& host) { +void RoundRobinPolicy::on_down(const Host::Ptr& host) { on_remove(host); } -SharedRefPtr RoundRobinPolicy::RoundRobinQueryPlan::compute_next() { +Host::Ptr RoundRobinPolicy::RoundRobinQueryPlan::compute_next() { while (remaining_ > 0) { --remaining_; - const SharedRefPtr& host((*hosts_)[index_++ % hosts_->size()]); + const Host::Ptr& host((*hosts_)[index_++ % hosts_->size()]); if (host->is_up()) { return host; } } - return SharedRefPtr(); + return Host::Ptr(); } } // namespace cass diff --git a/src/round_robin_policy.hpp b/src/round_robin_policy.hpp index 14fa780a8..32d38a54e 100644 --- a/src/round_robin_policy.hpp +++ b/src/round_robin_policy.hpp @@ -31,19 +31,19 @@ class RoundRobinPolicy : public LoadBalancingPolicy { : hosts_(new HostVec) , index_(0) { } - virtual void init(const SharedRefPtr& connected_host, const HostMap& hosts, Random* random); + virtual void init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random); - virtual CassHostDistance distance(const SharedRefPtr& host) const; + virtual CassHostDistance distance(const Host::Ptr& host) const; virtual QueryPlan* new_query_plan(const std::string& connected_keyspace, const Request* request, const TokenMap* token_map, Request::EncodingCache* cache); - virtual void on_add(const SharedRefPtr& host); - virtual void on_remove(const SharedRefPtr& host); - virtual void on_up(const SharedRefPtr& host); - virtual void on_down(const SharedRefPtr& host); + virtual void on_add(const Host::Ptr& host); + virtual void on_remove(const Host::Ptr& host); + virtual void on_up(const Host::Ptr& host); + virtual void on_down(const Host::Ptr& host); virtual LoadBalancingPolicy* new_instance() { return new RoundRobinPolicy(); } @@ -55,7 +55,7 @@ class RoundRobinPolicy : public LoadBalancingPolicy { , index_(start_index) , remaining_(hosts->size()) { } - virtual SharedRefPtr compute_next(); + virtual Host::Ptr compute_next(); private: const CopyOnWriteHostVec hosts_; diff --git a/src/schema_change_callback.cpp b/src/schema_change_callback.cpp index a0cdc1e6a..64b27d2b8 100644 --- a/src/schema_change_callback.cpp +++ b/src/schema_change_callback.cpp @@ -35,9 +35,9 @@ namespace cass { SchemaChangeCallback::SchemaChangeCallback(Connection* connection, - RequestHandler* request_handler, - const SharedRefPtr& response, - uint64_t elapsed) + const RequestHandler::Ptr& request_handler, + const Response::Ptr& response, + uint64_t elapsed) : MultipleRequestCallback(connection) , request_handler_(request_handler) , request_response_(response) @@ -123,11 +123,11 @@ void SchemaChangeCallback::on_set(const ResponseMap& responses) { "Trying again in %d ms", RETRY_SCHEMA_AGREEMENT_WAIT_MS); // Try again - SharedRefPtr callback( + Ptr callback( new SchemaChangeCallback(connection(), - request_handler_.get(), - request_response_, - elapsed_ms_)); + request_handler_, + request_response_, + elapsed_ms_)); connection()->schedule_schema_agreement(callback, RETRY_SCHEMA_AGREEMENT_WAIT_MS); } diff --git a/src/schema_change_callback.hpp b/src/schema_change_callback.hpp index 9566c881d..a74322a4c 100644 --- a/src/schema_change_callback.hpp +++ b/src/schema_change_callback.hpp @@ -31,9 +31,11 @@ class Response; class SchemaChangeCallback : public MultipleRequestCallback { public: + typedef SharedRefPtr Ptr; + SchemaChangeCallback(Connection* connection, - RequestHandler* request_handler, - const SharedRefPtr& response, + const RequestHandler::Ptr& request_handler, + const Response::Ptr& response, uint64_t elapsed = 0); void execute(); @@ -46,8 +48,8 @@ class SchemaChangeCallback : public MultipleRequestCallback { private: bool has_schema_agreement(const ResponseMap& responses); - ScopedRefPtr request_handler_; - SharedRefPtr request_response_; + RequestHandler::Ptr request_handler_; + Response::Ptr request_response_; uint64_t start_ms_; uint64_t elapsed_ms_; }; diff --git a/src/session.cpp b/src/session.cpp index b44f2fe18..8c2211b20 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -36,7 +36,7 @@ void cass_session_free(CassSession* session) { // hang indefinitely otherwise. This causes minimal delay // if the session is already closed. cass::SharedRefPtr future(new cass::SessionFuture()); - session->close_async(future.get(), true); + session->close_async(future, true); future->wait(); delete session->from(); @@ -59,17 +59,17 @@ CassFuture* cass_session_connect_keyspace_n(CassSession* session, const CassCluster* cluster, const char* keyspace, size_t keyspace_length) { - cass::SessionFuture* connect_future = new cass::SessionFuture(); - connect_future->inc_ref(); + cass::SessionFuture::Ptr connect_future(new cass::SessionFuture()); session->connect_async(cluster->config(), std::string(keyspace, keyspace_length), connect_future); - return CassFuture::to(connect_future); + connect_future->inc_ref(); + return CassFuture::to(connect_future.get()); } CassFuture* cass_session_close(CassSession* session) { - cass::SessionFuture* close_future = new cass::SessionFuture(); - close_future->inc_ref(); + cass::SessionFuture::Ptr close_future(new cass::SessionFuture()); session->close_async(close_future); - return CassFuture::to(close_future); + close_future->inc_ref(); + return CassFuture::to(close_future.get()); } CassFuture* cass_session_prepare(CassSession* session, const char* query) { @@ -79,16 +79,22 @@ CassFuture* cass_session_prepare(CassSession* session, const char* query) { CassFuture* cass_session_prepare_n(CassSession* session, const char* query, size_t query_length) { - return CassFuture::to(session->prepare(query, query_length)); + cass::Future::Ptr future(session->prepare(query, query_length)); + future->inc_ref(); + return CassFuture::to(future.get()); } CassFuture* cass_session_execute(CassSession* session, const CassStatement* statement) { - return CassFuture::to(session->execute(statement->from())); + cass::Future::Ptr future(session->execute(cass::Request::ConstPtr(statement->from()))); + future->inc_ref(); + return CassFuture::to(future.get()); } CassFuture* cass_session_execute_batch(CassSession* session, const CassBatch* batch) { - return CassFuture::to(session->execute(batch->from())); + cass::Future::Ptr future(session->execute(cass::Request::ConstPtr(batch->from()))); + future->inc_ref(); + return CassFuture::to(future.get()); } const CassSchemaMeta* cass_session_get_schema_meta(const CassSession* session) { @@ -181,7 +187,7 @@ int Session::init() { if (rc != 0) return rc; for (unsigned int i = 0; i < config_.thread_count_io(); ++i) { - SharedRefPtr io_worker(new IOWorker(this)); + IOWorker::Ptr io_worker(new IOWorker(this)); int rc = io_worker->init(); if (rc != 0) return rc; io_workers_.push_back(io_worker); @@ -204,19 +210,19 @@ void Session::broadcast_keyspace_change(const std::string& keyspace, keyspace_ = CopyOnWritePtr(new std::string(keyspace)); } -SharedRefPtr Session::get_host(const Address& address) { +Host::Ptr Session::get_host(const Address& address) { // Lock hosts. This can be called on a non-session thread. ScopedMutex l(&hosts_mutex_); HostMap::iterator it = hosts_.find(address); if (it == hosts_.end()) { - return SharedRefPtr(); + return Host::Ptr(); } return it->second; } -SharedRefPtr Session::add_host(const Address& address) { +Host::Ptr Session::add_host(const Address& address) { LOG_DEBUG("Adding new host: %s", address.to_string().c_str()); - SharedRefPtr host(new Host(address, !current_host_mark_)); + Host::Ptr host(new Host(address, !current_host_mark_)); { // Lock hosts ScopedMutex l(&hosts_mutex_); hosts_[address] = host; @@ -281,7 +287,7 @@ bool Session::notify_down_async(const Address& address) { return send_event_async(event); } -void Session::connect_async(const Config& config, const std::string& keyspace, Future* future) { +void Session::connect_async(const Config& config, const std::string& keyspace, const Future::Ptr& future) { ScopedMutex l(&state_mutex_); if (state_.load(MEMORY_ORDER_RELAXED) != SESSION_STATE_CLOSED) { @@ -310,7 +316,7 @@ void Session::connect_async(const Config& config, const std::string& keyspace, F LOG_DEBUG("Issued connect event"); state_.store(SESSION_STATE_CONNECTING, MEMORY_ORDER_RELAXED); - connect_future_.reset(future); + connect_future_ = future; if (!keyspace.empty()) { broadcast_keyspace_change(keyspace, NULL); @@ -323,7 +329,7 @@ void Session::connect_async(const Config& config, const std::string& keyspace, F run(); } -void Session::close_async(Future* future, bool force) { +void Session::close_async(const Future::Ptr& future, bool force) { ScopedMutex l(&state_mutex_); State state = state_.load(MEMORY_ORDER_RELAXED); @@ -335,7 +341,7 @@ void Session::close_async(Future* future, bool force) { } state_.store(SESSION_STATE_CLOSING, MEMORY_ORDER_RELAXED); - close_future_.reset(future); + close_future_ = future; if (!wait_for_connect_to_finish) { internal_close(); @@ -503,7 +509,7 @@ void Session::on_resolve(MultiResolver::Resolver* resolver) { if (resolver->is_success()) { AddressVec addresses = resolver->addresses(); for (AddressVec::iterator it = addresses.begin(); it != addresses.end(); ++it) { - SharedRefPtr host = session->add_host(*it); + Host::Ptr host = session->add_host(*it); host->set_hostname(resolver->hostname()); } } else if (resolver->is_timed_out()) { @@ -519,11 +525,13 @@ void Session::on_resolve_done(MultiResolver* resolver) { resolver->data()->internal_connect(); } -void Session::execute(RequestHandler* request_handler) { +void Session::execute(const RequestHandler::Ptr& request_handler) { + request_handler->inc_ref(); if (state_.load(MEMORY_ORDER_ACQUIRE) != SESSION_STATE_CONNECTED) { request_handler->on_error(CASS_ERROR_LIB_NO_HOSTS_AVAILABLE, "Session is not connected"); - } else if (!request_queue_->enqueue(request_handler)) { + return; + } else if (!request_queue_->enqueue(request_handler.get())) { request_handler->on_error(CASS_ERROR_LIB_REQUEST_QUEUE_FULL, "The request queue has reached capacity"); } @@ -533,7 +541,7 @@ void Session::execute(RequestHandler* request_handler) { void Session::on_resolve_name(MultiResolver::NameResolver* resolver) { Session* session = resolver->data()->data(); if (resolver->is_success()) { - SharedRefPtr host = session->add_host(resolver->address()); + Host::Ptr host = session->add_host(resolver->address()); host->set_hostname(resolver->hostname()); } else if (resolver->is_timed_out()) { LOG_ERROR("Timed out attempting to resolve hostname for host %s\n", @@ -575,23 +583,24 @@ void Session::on_control_connection_error(CassError code, const std::string& mes notify_connect_error(code, message); } -Future* Session::prepare(const char* statement, size_t length) { - PrepareRequest* prepare = new PrepareRequest(); +Future::Ptr Session::prepare(const char* statement, size_t length) { + SharedRefPtr prepare(new PrepareRequest()); prepare->set_query(statement, length); - ResponseFuture* future = new ResponseFuture(protocol_version(), cassandra_version(), metadata_); - future->inc_ref(); // External reference + ResponseFuture::Ptr future( + new ResponseFuture(protocol_version(), + cassandra_version(), + metadata_)); future->statement.assign(statement, length); - RequestHandler* request_handler = new RequestHandler(prepare, future, NULL); - request_handler->inc_ref(); // IOWorker reference + RequestHandler::Ptr request_handler(new RequestHandler(prepare, future, NULL)); execute(request_handler); return future; } -void Session::on_add(SharedRefPtr host, bool is_initial_connection) { +void Session::on_add(Host::Ptr host, bool is_initial_connection) { #if UV_VERSION_MAJOR >= 1 if (config_.use_hostname_resolution() && host->hostname().empty()) { NameResolver::resolve(loop(), @@ -606,7 +615,7 @@ void Session::on_add(SharedRefPtr host, bool is_initial_connection) { #endif } -void Session::internal_on_add(SharedRefPtr host, bool is_initial_connection) { +void Session::internal_on_add(Host::Ptr host, bool is_initial_connection) { host->set_up(); if (load_balancing_policy_->distance(host) == CASS_HOST_DISTANCE_IGNORE) { @@ -625,7 +634,7 @@ void Session::internal_on_add(SharedRefPtr host, bool is_initial_connectio } } -void Session::on_remove(SharedRefPtr host) { +void Session::on_remove(Host::Ptr host) { load_balancing_policy_->on_remove(host); { // Lock hosts ScopedMutex l(&hosts_mutex_); @@ -637,7 +646,7 @@ void Session::on_remove(SharedRefPtr host) { } } -void Session::on_up(SharedRefPtr host) { +void Session::on_up(Host::Ptr host) { host->set_up(); if (load_balancing_policy_->distance(host) == CASS_HOST_DISTANCE_IGNORE) { @@ -652,7 +661,7 @@ void Session::on_up(SharedRefPtr host) { } } -void Session::on_down(SharedRefPtr host) { +void Session::on_down(Host::Ptr host) { host->set_down(); load_balancing_policy_->on_down(host); @@ -668,18 +677,19 @@ void Session::on_down(SharedRefPtr host) { } } -Future* Session::execute(const RoutableRequest* request) { - ResponseFuture* future = new ResponseFuture(protocol_version(), cassandra_version(), metadata_); - future->inc_ref(); // External reference +Future::Ptr Session::execute(const Request::ConstPtr& request) { + ResponseFuture::Ptr future( + new ResponseFuture(protocol_version(), + cassandra_version(), + metadata_)); RetryPolicy* retry_policy = request->retry_policy() != NULL ? request->retry_policy() : config().retry_policy(); - RequestHandler* request_handler = new RequestHandler(request, - future, - retry_policy); - request_handler->inc_ref(); // IOWorker reference + RequestHandler::Ptr request_handler(new RequestHandler(request, + future, + retry_policy)); execute(request_handler); @@ -718,7 +728,7 @@ void Session::on_execute(uv_async_t* data) { size_t start = session->current_io_worker_; for (size_t i = 0, size = session->io_workers_.size(); i < size; ++i) { - const SharedRefPtr& io_worker = session->io_workers_[start % size]; + const IOWorker::Ptr& io_worker = session->io_workers_[start % size]; if (io_worker->is_host_available(address) && io_worker->execute(request_handler)) { session->current_io_worker_ = (start + 1) % size; diff --git a/src/session.hpp b/src/session.hpp index 56e6c357b..7ff5665db 100644 --- a/src/session.hpp +++ b/src/session.hpp @@ -89,7 +89,7 @@ class Session : public EventThread { void broadcast_keyspace_change(const std::string& keyspace, const IOWorker* calling_io_worker); - SharedRefPtr get_host(const Address& address); + Host::Ptr get_host(const Address& address); bool notify_ready_async(); bool notify_keyspace_error_async(); @@ -97,11 +97,11 @@ class Session : public EventThread { bool notify_up_async(const Address& address); bool notify_down_async(const Address& address); - void connect_async(const Config& config, const std::string& keyspace, Future* future); - void close_async(Future* future, bool force = false); + void connect_async(const Config& config, const std::string& keyspace, const Future::Ptr& future); + void close_async(const Future::Ptr& future, bool force = false); - Future* prepare(const char* statement, size_t length); - Future* execute(const RoutableRequest* statement); + Future::Ptr prepare(const char* statement, size_t length); + Future::Ptr execute(const Request::ConstPtr& request); const Metadata& metadata() const { return metadata_; } @@ -126,7 +126,7 @@ class Session : public EventThread { void notify_connect_error(CassError code, const std::string& message); void notify_closed(); - void execute(RequestHandler* request_handler); + void execute(const RequestHandler::Ptr& request_handler); virtual void on_run(); virtual void on_after_run(); @@ -138,13 +138,13 @@ class Session : public EventThread { #if UV_VERSION_MAJOR >= 1 struct ResolveNameData { ResolveNameData(Session* session, - const SharedRefPtr& host, + const Host::Ptr& host, bool is_initial_connection) : session(session) , host(host) , is_initial_connection(is_initial_connection) { } Session* session; - SharedRefPtr host; + Host::Ptr host; bool is_initial_connection; }; typedef cass::NameResolver NameResolver; @@ -167,7 +167,7 @@ class Session : public EventThread { // TODO(mpenick): Consider removing friend access to session friend class ControlConnection; - SharedRefPtr add_host(const Address& address); + Host::Ptr add_host(const Address& address); void purge_hosts(bool is_initial_connection); Metadata& metadata() { return metadata_; } @@ -175,26 +175,26 @@ class Session : public EventThread { void on_control_connection_ready(); void on_control_connection_error(CassError code, const std::string& message); - void on_add(SharedRefPtr host, bool is_initial_connection); - void internal_on_add(SharedRefPtr host, bool is_initial_connection); + void on_add(Host::Ptr host, bool is_initial_connection); + void internal_on_add(Host::Ptr host, bool is_initial_connection); - void on_remove(SharedRefPtr host); - void on_up(SharedRefPtr host); - void on_down(SharedRefPtr host); + void on_remove(Host::Ptr host); + void on_up(Host::Ptr host); + void on_down(Host::Ptr host); private: - typedef std::vector > IOWorkerVec; + typedef std::vector IOWorkerVec; Atomic state_; uv_mutex_t state_mutex_; Config config_; ScopedPtr metrics_; - ScopedRefPtr load_balancing_policy_; + LoadBalancingPolicy::Ptr load_balancing_policy_; CassError connect_error_code_; std::string connect_error_message_; - ScopedRefPtr connect_future_; - ScopedRefPtr close_future_; + Future::Ptr connect_future_; + Future::Ptr close_future_; HostMap hosts_; uv_mutex_t hosts_mutex_; @@ -216,6 +216,8 @@ class Session : public EventThread { class SessionFuture : public Future { public: + typedef SharedRefPtr Ptr; + SessionFuture() : Future(CASS_FUTURE_TYPE_SESSION) {} }; diff --git a/src/ssl.cpp b/src/ssl.cpp index 2d5b91347..c620930cf 100644 --- a/src/ssl.cpp +++ b/src/ssl.cpp @@ -24,9 +24,9 @@ extern "C" { CassSsl* cass_ssl_new() { - cass::SslContext* ssl_context = cass::SslContextFactory::create(); + cass::SslContext::Ptr ssl_context(cass::SslContextFactory::create()); ssl_context->inc_ref(); - return CassSsl::to(ssl_context); + return CassSsl::to(ssl_context.get()); } void cass_ssl_free(CassSsl* ssl) { @@ -81,7 +81,7 @@ namespace cass { static uv_once_t ssl_init_guard = UV_ONCE_INIT; template -SslContext* SslContextFactoryBase::create() { +SslContext::Ptr SslContextFactoryBase::create() { init(); return T::create(); } diff --git a/src/ssl.hpp b/src/ssl.hpp index 0b69cc4e1..cdf4b2b02 100644 --- a/src/ssl.hpp +++ b/src/ssl.hpp @@ -71,6 +71,8 @@ class SslSession { class SslContext : public RefCounted { public: + typedef SharedRefPtr Ptr; + SslContext() : verify_flags_(CASS_SSL_VERIFY_PEER_CERT) {} @@ -95,7 +97,7 @@ class SslContext : public RefCounted { template class SslContextFactoryBase { public: - static SslContext* create(); + static SslContext::Ptr create(); static void init(); }; diff --git a/src/ssl/ssl_no_impl.cpp b/src/ssl/ssl_no_impl.cpp index 9a0aeb1fb..6edda906d 100644 --- a/src/ssl/ssl_no_impl.cpp +++ b/src/ssl/ssl_no_impl.cpp @@ -45,8 +45,8 @@ CassError cass::NoSslContext::set_private_key(const char* key, return CASS_ERROR_LIB_NOT_IMPLEMENTED; } -SslContext* NoSslContextFactory::create() { - return new NoSslContext(); +SslContext::Ptr NoSslContextFactory::create() { + return SslContext::Ptr(new NoSslContext()); } diff --git a/src/ssl/ssl_no_impl.hpp b/src/ssl/ssl_no_impl.hpp index cdc579f21..bc90ab3ba 100644 --- a/src/ssl/ssl_no_impl.hpp +++ b/src/ssl/ssl_no_impl.hpp @@ -46,7 +46,7 @@ class NoSslContext : public SslContext { class NoSslContextFactory : SslContextFactoryBase { public: - static SslContext* create(); + static SslContext::Ptr create(); static void init() {} }; diff --git a/src/ssl/ssl_openssl_impl.cpp b/src/ssl/ssl_openssl_impl.cpp index 9471760c8..49dd03bfb 100644 --- a/src/ssl/ssl_openssl_impl.cpp +++ b/src/ssl/ssl_openssl_impl.cpp @@ -574,8 +574,8 @@ CassError OpenSslContext::set_private_key(const char* key, return CASS_OK; } -SslContext* OpenSslContextFactory::create() { - return new OpenSslContext(); +SslContext::Ptr OpenSslContextFactory::create() { + return SslContext::Ptr(new OpenSslContext()); } void OpenSslContextFactory::init() { diff --git a/src/ssl/ssl_openssl_impl.hpp b/src/ssl/ssl_openssl_impl.hpp index 5ae1eac3c..1a5ad5e73 100644 --- a/src/ssl/ssl_openssl_impl.hpp +++ b/src/ssl/ssl_openssl_impl.hpp @@ -70,7 +70,7 @@ class OpenSslContext : public SslContext { class OpenSslContextFactory : SslContextFactoryBase { public: - static SslContext* create(); + static SslContext::Ptr create(); static void init(); }; diff --git a/src/statement.hpp b/src/statement.hpp index 6ee936efc..5be30f495 100644 --- a/src/statement.hpp +++ b/src/statement.hpp @@ -34,6 +34,8 @@ class RequestCallback; class Statement : public RoutableRequest, public AbstractData { public: + typedef SharedRefPtr Ptr; + Statement(uint8_t opcode, uint8_t kind, size_t values_count = 0) : RoutableRequest(opcode) , AbstractData(values_count) diff --git a/src/token_aware_policy.cpp b/src/token_aware_policy.cpp index 47110db02..e4d0645bc 100644 --- a/src/token_aware_policy.cpp +++ b/src/token_aware_policy.cpp @@ -34,7 +34,7 @@ static inline bool contains(const CopyOnWriteHostVec& replicas, const Address& a return false; } -void TokenAwarePolicy::init(const SharedRefPtr& connected_host, +void TokenAwarePolicy::init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random) { if (random != NULL) { @@ -79,23 +79,23 @@ QueryPlan* TokenAwarePolicy::new_query_plan(const std::string& connected_keyspac return child_policy_->new_query_plan(connected_keyspace, request, token_map, cache); } -SharedRefPtr TokenAwarePolicy::TokenAwareQueryPlan::compute_next() { +Host::Ptr TokenAwarePolicy::TokenAwareQueryPlan::compute_next() { while (remaining_ > 0) { --remaining_; - const SharedRefPtr& host((*replicas_)[index_++ % replicas_->size()]); + const Host::Ptr& host((*replicas_)[index_++ % replicas_->size()]); if (host->is_up() && child_policy_->distance(host) == CASS_HOST_DISTANCE_LOCAL) { return host; } } - SharedRefPtr host; + Host::Ptr host; while ((host = child_plan_->compute_next())) { if (!contains(replicas_, host->address()) || child_policy_->distance(host) != CASS_HOST_DISTANCE_LOCAL) { return host; } } - return SharedRefPtr(); + return Host::Ptr(); } } // namespace cass diff --git a/src/token_aware_policy.hpp b/src/token_aware_policy.hpp index 2b86f66fb..eb1d0aa35 100644 --- a/src/token_aware_policy.hpp +++ b/src/token_aware_policy.hpp @@ -32,7 +32,7 @@ class TokenAwarePolicy : public ChainedLoadBalancingPolicy { virtual ~TokenAwarePolicy() {} - virtual void init(const SharedRefPtr& connected_host, const HostMap& hosts, Random* random); + virtual void init(const Host::Ptr& connected_host, const HostMap& hosts, Random* random); virtual QueryPlan* new_query_plan(const std::string& connected_keyspace, const Request* request, @@ -51,7 +51,7 @@ class TokenAwarePolicy : public ChainedLoadBalancingPolicy { , index_(start_index) , remaining_(replicas->size()) {} - SharedRefPtr compute_next(); + Host::Ptr compute_next(); private: LoadBalancingPolicy* child_policy_; diff --git a/src/tuple.cpp b/src/tuple.cpp index 5f8579292..6f0f3da30 100644 --- a/src/tuple.cpp +++ b/src/tuple.cpp @@ -36,7 +36,7 @@ CassTuple* cass_tuple_new_from_data_type(const CassDataType* data_type) { } return CassTuple::to( new cass::Tuple( - cass::SharedRefPtr(data_type))); + cass::DataType::ConstPtr(data_type))); } void cass_tuple_free(CassTuple* tuple) { diff --git a/src/tuple.hpp b/src/tuple.hpp index ea3ba145c..2acbfe156 100644 --- a/src/tuple.hpp +++ b/src/tuple.hpp @@ -44,7 +44,7 @@ class Tuple { : data_type_(data_type) , items_(data_type_->types().size()) { } - const SharedRefPtr& data_type() const { return data_type_; } + const TupleType::ConstPtr& data_type() const { return data_type_; } const BufferVec& items() const { return items_; } #define SET_TYPE(Type) \ @@ -99,7 +99,7 @@ class Tuple { void encode_buffers(size_t pos, Buffer* buf) const; private: - SharedRefPtr data_type_; + TupleType::ConstPtr data_type_; BufferVec items_; private: diff --git a/src/user_type_value.cpp b/src/user_type_value.cpp index ddfad02b5..9c2f53f5c 100644 --- a/src/user_type_value.cpp +++ b/src/user_type_value.cpp @@ -31,7 +31,7 @@ CassUserType* cass_user_type_new_from_data_type(const CassDataType* data_type) { } return CassUserType::to( new cass::UserTypeValue( - cass::SharedRefPtr(data_type))); + cass::DataType::ConstPtr(data_type))); } void cass_user_type_free(CassUserType* user_type) { diff --git a/src/whitelist_dc_policy.cpp b/src/whitelist_dc_policy.cpp index 42c58f8db..f154cb384 100644 --- a/src/whitelist_dc_policy.cpp +++ b/src/whitelist_dc_policy.cpp @@ -18,7 +18,7 @@ namespace cass { -bool WhitelistDCPolicy::is_valid_host(const SharedRefPtr& host) const { +bool WhitelistDCPolicy::is_valid_host(const Host::Ptr& host) const { const std::string& host_dc = host->dc(); for (DcList::const_iterator it = dcs_.begin(), end = dcs_.end(); it != end; ++it) { diff --git a/src/whitelist_dc_policy.hpp b/src/whitelist_dc_policy.hpp index 16d244e97..863385a25 100644 --- a/src/whitelist_dc_policy.hpp +++ b/src/whitelist_dc_policy.hpp @@ -38,7 +38,7 @@ class WhitelistDCPolicy : public ListPolicy { } private: - bool is_valid_host(const SharedRefPtr& host) const; + bool is_valid_host(const Host::Ptr& host) const; DcList dcs_; diff --git a/src/whitelist_policy.cpp b/src/whitelist_policy.cpp index 998e426f1..f2ba5cf34 100644 --- a/src/whitelist_policy.cpp +++ b/src/whitelist_policy.cpp @@ -18,7 +18,7 @@ namespace cass { -bool WhitelistPolicy::is_valid_host(const SharedRefPtr& host) const { +bool WhitelistPolicy::is_valid_host(const Host::Ptr& host) const { const std::string& host_address = host->address().to_string(false); for (ContactPointList::const_iterator it = hosts_.begin(), end = hosts_.end(); diff --git a/src/whitelist_policy.hpp b/src/whitelist_policy.hpp index b95d70931..07156562c 100644 --- a/src/whitelist_policy.hpp +++ b/src/whitelist_policy.hpp @@ -38,7 +38,7 @@ class WhitelistPolicy : public ListPolicy { } private: - bool is_valid_host(const SharedRefPtr& host) const; + bool is_valid_host(const Host::Ptr& host) const; ContactPointList hosts_; From 1ce771916bcdc2947eac9be97753c77897fa1c57 Mon Sep 17 00:00:00 2001 From: Michael Penick Date: Wed, 28 Sep 2016 16:20:06 -0700 Subject: [PATCH 3/7] Improvement: Remove executing future callbacks on thread pool --- src/future.cpp | 34 ++++------------------------------ src/future.hpp | 22 ++++++++-------------- 2 files changed, 12 insertions(+), 44 deletions(-) diff --git a/src/future.cpp b/src/future.cpp index c80587738..42e934cd1 100644 --- a/src/future.cpp +++ b/src/future.cpp @@ -180,39 +180,13 @@ void Future::internal_set(ScopedMutex& lock) { is_set_ = true; uv_cond_broadcast(&cond_); if (callback_) { - if (loop_.load() == NULL) { - Callback callback = callback_; - void* data = data_; - lock.unlock(); - callback(CassFuture::to(this), data); - } else { - run_callback_on_work_thread(); - } + Callback callback = callback_; + void* data = data_; + lock.unlock(); + callback(CassFuture::to(this), data); } } -void Future::run_callback_on_work_thread() { - inc_ref(); // Keep the future alive for the callback - work_.data = this; - uv_queue_work(loop_.load(), &work_, on_work, on_after_work); -} - -void Future::on_work(uv_work_t* work) { - Future* future = static_cast(work->data); - - ScopedMutex lock(&future->mutex_); - Callback callback = future->callback_; - void* data = future->data_; - lock.unlock(); - - callback(CassFuture::to(future), data); -} - -void Future::on_after_work(uv_work_t* work, int status) { - Future* future = static_cast(work->data); - future->dec_ref(); -} - } // namespace cass diff --git a/src/future.hpp b/src/future.hpp index 742a3fcee..52e8342ff 100644 --- a/src/future.hpp +++ b/src/future.hpp @@ -55,7 +55,6 @@ class Future : public RefCounted { Future(FutureType type) : is_set_(false) , type_(type) - , loop_(NULL) , callback_(NULL) { uv_mutex_init(&mutex_); uv_cond_init(&cond_); @@ -94,18 +93,20 @@ class Future : public RefCounted { internal_set(lock); } - void set_error(CassError code, const std::string& message) { + bool set_error(CassError code, const std::string& message) { ScopedMutex lock(&mutex_); - internal_set_error(code, message, lock); - } - - void set_loop(uv_loop_t* loop) { - loop_.store(loop); + if (!is_set_) { + internal_set_error(code, message, lock); + return true; + } + return false; } bool set_callback(Callback callback, void* data); protected: + bool is_set() const { return is_set_; } + void internal_wait(ScopedMutex& lock) { while (!is_set_) { uv_cond_wait(&cond_, lock.get()); @@ -130,18 +131,11 @@ class Future : public RefCounted { uv_mutex_t mutex_; -private: - void run_callback_on_work_thread(); - static void on_work(uv_work_t* work); - static void on_after_work(uv_work_t* work, int status); - private: bool is_set_; uv_cond_t cond_; FutureType type_; ScopedPtr error_; - Atomic loop_; - uv_work_t work_; Callback callback_; void* data_; From 845cefbf6552b43c03e95e983725a7d101e05924 Mon Sep 17 00:00:00 2001 From: Michael Penick Date: Wed, 28 Sep 2016 16:20:21 -0700 Subject: [PATCH 4/7] Speculative execution --- examples/perf/perf.c | 22 +- include/cassandra.h | 63 ++++- src/batch_request.cpp | 10 +- src/cluster.cpp | 21 +- src/config.hpp | 12 + src/connection.cpp | 278 +++++++++++---------- src/connection.hpp | 14 +- src/control_connection.cpp | 23 +- src/control_connection.hpp | 15 +- src/io_worker.cpp | 65 ++--- src/io_worker.hpp | 8 +- src/pool.cpp | 150 ++++++------ src/pool.hpp | 11 +- src/prepare_request.hpp | 5 +- src/query_request.cpp | 4 +- src/request.cpp | 9 + src/request.hpp | 16 +- src/request_callback.cpp | 127 ++++++---- src/request_callback.hpp | 140 ++++++----- src/request_handler.cpp | 424 +++++++++++++++++++++------------ src/request_handler.hpp | 174 ++++++++++---- src/retry_policy.cpp | 97 ++++++-- src/retry_policy.hpp | 74 ++++-- src/schema_change_callback.cpp | 18 +- src/schema_change_callback.hpp | 4 +- src/session.cpp | 77 +++--- src/session.hpp | 5 +- src/speculative_execution.hpp | 97 ++++++++ src/statement.cpp | 13 +- src/testing.cpp | 2 +- 30 files changed, 1316 insertions(+), 662 deletions(-) create mode 100644 src/speculative_execution.hpp diff --git a/examples/perf/perf.c b/examples/perf/perf.c index b41074170..8eb8b7d3e 100644 --- a/examples/perf/perf.c +++ b/examples/perf/perf.c @@ -120,7 +120,7 @@ CassCluster* create_cluster(const char* hosts) { CassError connect_session(CassSession* session, const CassCluster* cluster) { CassError rc = CASS_OK; - CassFuture* future = cass_session_connect_keyspace(session, cluster, "examples"); + CassFuture* future = cass_session_connect(session, cluster); cass_future_wait(future); rc = cass_future_error_code(future); @@ -227,7 +227,7 @@ void run_insert_queries(void* data) { CassSession* session = (CassSession*)data; const CassPrepared* insert_prepared = NULL; - const char* insert_query = "INSERT INTO songs (id, title, album, artist, tags) VALUES (?, ?, ?, ?, ?);"; + const char* insert_query = "INSERT INTO stress.songs (id, title, album, artist, tags) VALUES (?, ?, ?, ?, ?);"; #if USE_PREPARED if (prepare_query(session, insert_query, &insert_prepared) == CASS_OK) { @@ -256,6 +256,8 @@ void select_from_perf(CassSession* session, const char* query, const CassPrepare statement = cass_statement_new(query, 0); } + cass_statement_set_is_idempotent(statement, cass_true); + futures[i] = cass_session_execute(session, statement); cass_statement_free(statement); @@ -279,7 +281,7 @@ void run_select_queries(void* data) { int i; CassSession* session = (CassSession*)data; const CassPrepared* select_prepared = NULL; - const char* select_query = "SELECT * FROM songs WHERE id = a98d21b2-1900-11e4-b97b-e5e358e71e0d"; + const char* select_query = "SELECT * FROM stress.songs WHERE id = a98d21b2-1900-11e4-b97b-e5e358e71e0d"; #if USE_PREPARED if (prepare_query(session, select_query, &select_prepared) == CASS_OK) { @@ -301,7 +303,6 @@ int main(int argc, char* argv[]) { uv_thread_t threads[NUM_THREADS]; CassCluster* cluster = NULL; CassSession* session = NULL; - CassFuture* close_future = NULL; char* hosts = "127.0.0.1"; if (argc > 1) { hosts = argv[1]; @@ -321,8 +322,15 @@ int main(int argc, char* argv[]) { return -1; } + execute_query(session, "CREATE KEYSPACE IF NOT EXISTS stress WITH " + "replication = { 'class': 'SimpleStrategy', 'replication_factor': '1'}"); + + execute_query(session, "CREATE TABLE IF NOT EXISTS stress.songs (id uuid PRIMARY KEY, " + "title text, album text, artist text, " + "tags set, data blob)"); + execute_query(session, - "INSERT INTO songs (id, title, album, artist, tags) VALUES " + "INSERT INTO stress.songs (id, title, album, artist, tags) VALUES " "(a98d21b2-1900-11e4-b97b-e5e358e71e0d, " "'La Petite Tonkinoise', 'Bye Bye Blackbird', 'Joséphine Baker', { 'jazz', '2013' });"); @@ -355,10 +363,8 @@ int main(int argc, char* argv[]) { uv_thread_join(&threads[i]); } - close_future = cass_session_close(session); - cass_future_wait(close_future); - cass_future_free(close_future); cass_cluster_free(cluster); + cass_session_free(session); cass_uuid_gen_free(uuid_gen); status_destroy(&status); diff --git a/include/cassandra.h b/include/cassandra.h index 753e3eda0..e6faee40f 100644 --- a/include/cassandra.h +++ b/include/cassandra.h @@ -1714,13 +1714,39 @@ cass_cluster_set_use_hostname_resolution(CassCluster* cluster, * @param[in] cluster * @param[in] enabled * @return CASS_OK if successful, otherwise an error occurred - * - * @see cass_cluster_set_resolve_timeout() */ CASS_EXPORT CassError cass_cluster_set_use_randomized_contact_points(CassCluster* cluster, cass_bool_t enabled); +/** + * Enable constant speculative executions with the supplied settings. + * + * @public @memberof CassCluster + * + * @param[in] cluster + * @param[in] constant_delay_ms + * @param[in] max_speculative_executions + * @return CASS_OK if successful, otherwise an error occurred + */ +CASS_EXPORT CassError +cass_cluster_set_constant_speculative_execution_policy(CassCluster* cluster, + cass_int64_t constant_delay_ms, + int max_speculative_executions); + +/** + * Disable speculative executions + * + * Default: This is the default speculative execution policy. + * + * @public @memberof CassCluster + * + * @param[in] cluster + * @return CASS_OK if successful, otherwise an error occurred + */ +CASS_EXPORT CassError +cass_cluster_set_no_speculative_execution_policy(CassCluster* cluster); + /*********************************************************************************** * * Session @@ -3938,6 +3964,22 @@ CASS_EXPORT CassError cass_statement_set_request_timeout(CassStatement* statement, cass_uint64_t timeout_ms); +/** + * Sets whether the statement is idempotent. Idempotent statements are able to be + * automatically retried after timeouts/errors and can be speculatively executed. + * + * @public @memberof CassStatement + * + * @param[in] statement + * @param[in] is_idempotent + * @return CASS_OK if successful, otherwise an error occurred. + * + * @see cass_cluster_set_constant_speculative_execution_policy() + */ +CASS_EXPORT CassError +cass_statement_set_is_idempotent(CassStatement* statement, + cass_bool_t is_idempotent); + /** * Sets the statement's retry policy. * @@ -5214,6 +5256,23 @@ CASS_EXPORT CassError cass_batch_set_request_timeout(CassBatch* batch, cass_uint64_t timeout_ms); +/** + * Sets whether the statements in a batch are idempotent. Idempotent batches + * are able to be automatically retried after timeouts/errors and can be + * speculatively executed. + * + * @public @memberof CassBatch + * + * @param[in] batch + * @param[in] is_idempotent + * @return CASS_OK if successful, otherwise an error occurred. + * + * @see cass_cluster_set_constant_speculative_execution_policy() + */ +CASS_EXPORT CassError +cass_batch_set_is_idempotent(CassBatch* batch, + cass_bool_t is_idempotent); + /** * Sets the batch's retry policy. * diff --git a/src/batch_request.cpp b/src/batch_request.cpp index 9717dd40f..750f4633c 100644 --- a/src/batch_request.cpp +++ b/src/batch_request.cpp @@ -58,6 +58,12 @@ CassError cass_batch_set_request_timeout(CassBatch *batch, return CASS_OK; } +CassError cass_batch_set_is_idempotent(CassBatch* batch, + cass_bool_t is_idempotent) { + batch->set_is_idempotent(is_idempotent == cass_true); + return CASS_OK; +} + CassError cass_batch_set_retry_policy(CassBatch* batch, CassRetryPolicy* retry_policy) { batch->set_retry_policy(retry_policy); @@ -84,7 +90,7 @@ int BatchRequest::encode(int version, RequestCallback* callback, BufferVec* bufs uint8_t flags = 0; if (version == 1) { - return ENCODE_ERROR_UNSUPPORTED_PROTOCOL; + return REQUEST_ERROR_UNSUPPORTED_PROTOCOL; } { @@ -106,7 +112,7 @@ int BatchRequest::encode(int version, RequestCallback* callback, BufferVec* bufs if (statement->has_names_for_values()) { callback->on_error(CASS_ERROR_LIB_BAD_PARAMS, "Batches cannot contain queries with named values"); - return ENCODE_ERROR_BATCH_WITH_NAMED_VALUES; + return REQUEST_ERROR_BATCH_WITH_NAMED_VALUES; } int32_t result = (*i)->encode_batch(version, bufs, callback); if (result < 0) { diff --git a/src/cluster.cpp b/src/cluster.cpp index 031407687..4b34b8fa0 100644 --- a/src/cluster.cpp +++ b/src/cluster.cpp @@ -17,9 +17,10 @@ #include "cluster.hpp" #include "dc_aware_policy.hpp" +#include "external_types.hpp" #include "logger.hpp" #include "round_robin_policy.hpp" -#include "external_types.hpp" +#include "speculative_execution.hpp" #include "utils.hpp" #include @@ -423,6 +424,24 @@ CassError cass_cluster_set_use_randomized_contact_points(CassCluster* cluster, return CASS_OK; } +CassError cass_cluster_set_constant_speculative_execution_policy(CassCluster* cluster, + cass_int64_t constant_delay_ms, + int max_speculative_executions) { + if (constant_delay_ms < 0 || max_speculative_executions < 0) { + return CASS_ERROR_LIB_BAD_PARAMS; + } + cluster->config().set_speculative_execution_policy( + new cass::ConstantSpeculativeExecutionPolicy(constant_delay_ms, + max_speculative_executions)); + return CASS_OK; +} + +CassError cass_cluster_set_no_speculative_execution_policy(CassCluster* cluster) { + cluster->config().set_speculative_execution_policy( + new cass::NoSpeculativeExecutionPolicy()); + return CASS_OK; +} + void cass_cluster_free(CassCluster* cluster) { delete cluster->from(); } diff --git a/src/config.hpp b/src/config.hpp index 5cf20c5d4..6cc016d61 100644 --- a/src/config.hpp +++ b/src/config.hpp @@ -29,6 +29,7 @@ #include "blacklist_policy.hpp" #include "whitelist_dc_policy.hpp" #include "blacklist_dc_policy.hpp" +#include "speculative_execution.hpp" #include #include @@ -64,6 +65,7 @@ class Config { , log_data_(NULL) , auth_provider_(new AuthProvider()) , load_balancing_policy_(new DCAwarePolicy()) + , speculative_execution_policy_(new NoSpeculativeExecutionPolicy()) , token_aware_routing_(true) , latency_aware_routing_(false) , tcp_nodelay_enable_(true) @@ -268,6 +270,15 @@ class Config { load_balancing_policy_.reset(lbp); } + SpeculativeExecutionPolicy* speculative_execution_policy() const { + return speculative_execution_policy_->new_instance(); + } + + void set_speculative_execution_policy(SpeculativeExecutionPolicy* sep) { + if (sep == NULL) return; + speculative_execution_policy_.reset(sep); + } + SslContext* ssl_context() const { return ssl_context_.get(); } void set_ssl_context(SslContext* ssl_context) { @@ -391,6 +402,7 @@ class Config { void* log_data_; AuthProvider::Ptr auth_provider_; LoadBalancingPolicy::Ptr load_balancing_policy_; + SharedRefPtr speculative_execution_policy_; SslContext::Ptr ssl_context_; bool token_aware_routing_; bool latency_aware_routing_; diff --git a/src/connection.cpp b/src/connection.cpp index 89954754c..c9ab928a8 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -61,21 +61,57 @@ namespace cass { static void cleanup_pending_callbacks(List* pending) { while (!pending->is_empty()) { - RequestCallback* callback = pending->front(); - pending->remove(callback); - if (callback->state() == RequestCallback::REQUEST_STATE_WRITING || - callback->state() == RequestCallback::REQUEST_STATE_READING) { - callback->on_timeout(); - callback->stop_timer(); + RequestCallback::Ptr callback(pending->front()); + + pending->remove(callback.get()); + + switch (callback->state()) { + case RequestCallback::REQUEST_STATE_NEW: + case RequestCallback::REQUEST_STATE_FINISHED: + assert(false && "Request state is invalid in cleanup"); + break; + + case RequestCallback::REQUEST_STATE_CANCELLED: + callback->finish(); + break; + + case RequestCallback::REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE: + case RequestCallback::REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING: + callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); + callback->finish(); + break; + + case RequestCallback::REQUEST_STATE_WRITING: + case RequestCallback::REQUEST_STATE_READING: + case RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE: + if (callback->request()->is_idempotent()) { + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); + callback->finish_with_retry(true); + } else { + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); + callback->on_error(CASS_ERROR_LIB_REQUEST_TIMED_OUT, + "Request timed out"); + callback->finish(); + } + break; + + case RequestCallback::REQUEST_STATE_RETRY_WRITE_OUTSTANDING: + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); + callback->finish_with_retry(false); + break; } - callback->dec_ref(); } } +Connection::StartupCallback::StartupCallback(Connection* connection, const Request::ConstPtr& request) + : SimpleRequestCallback(connection->loop(), + connection->config().request_timeout_ms(), + request) { } + void Connection::StartupCallback::on_set(ResponseMessage* response) { switch (response->opcode()) { case CQL_OPCODE_SUPPORTED: - connection_->on_supported(response); + connection()->on_supported(response); break; case CQL_OPCODE_ERROR: { @@ -92,31 +128,31 @@ void Connection::StartupCallback::on_set(ResponseMessage* response) { error->message().find("does not exist") != StringRef::npos) { error_code = CONNECTION_ERROR_KEYSPACE; } - connection_->notify_error("Received error response " + error->error_message(), error_code); + connection()->notify_error("Received error response " + error->error_message(), error_code); break; } case CQL_OPCODE_AUTHENTICATE: { AuthenticateResponse* auth = static_cast(response->response_body().get()); - connection_->on_authenticate(auth->class_name()); + connection()->on_authenticate(auth->class_name()); break; } case CQL_OPCODE_AUTH_CHALLENGE: - connection_->on_auth_challenge( - static_cast(request_.get()), + connection()->on_auth_challenge( + static_cast(request()), static_cast(response->response_body().get())->token()); break; case CQL_OPCODE_AUTH_SUCCESS: - connection_->on_auth_success( - static_cast(request_.get()), + connection()->on_auth_success( + static_cast(request()), static_cast(response->response_body().get())->token()); break; case CQL_OPCODE_READY: - connection_->on_ready(); + connection()->on_ready(); break; case CQL_OPCODE_RESULT: @@ -124,7 +160,7 @@ void Connection::StartupCallback::on_set(ResponseMessage* response) { break; default: - connection_->notify_error("Invalid opcode"); + connection()->notify_error("Invalid opcode"); break; } } @@ -134,12 +170,12 @@ void Connection::StartupCallback::on_error(CassError code, std::ostringstream ss; ss << "Error: '" << message << "' (0x" << std::hex << std::uppercase << std::setw(8) << std::setfill('0') << code << ")"; - connection_->notify_error(ss.str()); + connection()->notify_error(ss.str()); } void Connection::StartupCallback::on_timeout() { - if (!connection_->is_closing()) { - connection_->notify_error("Timed out", CONNECTION_ERROR_TIMEOUT); + if (!connection()->is_closing()) { + connection()->notify_error("Timed out", CONNECTION_ERROR_TIMEOUT); } } @@ -148,36 +184,36 @@ void Connection::StartupCallback::on_result_response(ResponseMessage* response) static_cast(response->response_body().get()); switch (result->kind()) { case CASS_RESULT_KIND_SET_KEYSPACE: - connection_->on_set_keyspace(); + connection()->on_set_keyspace(); break; default: - connection_->notify_error("Invalid result response. Expected set keyspace."); + connection()->notify_error("Invalid result response. Expected set keyspace."); break; } } Connection::HeartbeatCallback::HeartbeatCallback(Connection* connection) - : RequestCallback(Request::ConstPtr(new OptionsRequest())) { - set_connection(connection); -} + : SimpleRequestCallback(connection->loop(), + connection->config().request_timeout_ms(), + Request::ConstPtr(new OptionsRequest())) { } void Connection::HeartbeatCallback::on_set(ResponseMessage* response) { LOG_TRACE("Heartbeat completed on host %s", - connection_->address_string().c_str()); - connection_->heartbeat_outstanding_ = false; + connection()->address_string().c_str()); + connection()->heartbeat_outstanding_ = false; } void Connection::HeartbeatCallback::on_error(CassError code, const std::string& message) { LOG_WARN("An error occurred on host %s during a heartbeat request: %s", - connection_->address_string().c_str(), + connection()->address_string().c_str(), message.c_str()); - connection_->heartbeat_outstanding_ = false; + connection()->heartbeat_outstanding_ = false; } void Connection::HeartbeatCallback::on_timeout() { LOG_WARN("Heartbeat request timed out on host %s", - connection_->address_string().c_str()); - connection_->heartbeat_outstanding_ = false; + connection()->address_string().c_str()); + connection()->heartbeat_outstanding_ = false; } Connection::Connection(uv_loop_t* loop, @@ -240,23 +276,26 @@ void Connection::connect() { } } -bool Connection::write(RequestCallback* callback, bool flush_immediately) { - bool result = internal_write(callback, flush_immediately); - if (result) { +bool Connection::write(const RequestCallback::Ptr& callback, bool flush_immediately) { + int32_t result = internal_write(callback, flush_immediately); + if (result > 0) { restart_heartbeat_timer(); } - return result; + return result != Request::REQUEST_ERROR_NO_AVAILABLE_STREAM_IDS; } -bool Connection::internal_write(RequestCallback* callback, bool flush_immediately) { - int stream = stream_manager_.acquire(callback); +int32_t Connection::internal_write(const RequestCallback::Ptr& callback, bool flush_immediately) { + if (callback->state() == RequestCallback::REQUEST_STATE_CANCELLED) { + return Request::REQUEST_ERROR_CANCELLED; + } + + int stream = stream_manager_.acquire(callback.get()); if (stream < 0) { - return false; + return Request::REQUEST_ERROR_NO_AVAILABLE_STREAM_IDS; } callback->inc_ref(); // Connection reference - callback->set_connection(this); - callback->set_stream(stream); + callback->start(this, stream); if (pending_writes_.is_empty() || pending_writes_.back()->is_flushed()) { if (ssl_session_) { @@ -268,12 +307,12 @@ bool Connection::internal_write(RequestCallback* callback, bool flush_immediatel PendingWriteBase *pending_write = pending_writes_.back(); - int32_t request_size = pending_write->write(callback); + int32_t request_size = pending_write->write(callback.get()); if (request_size < 0) { stream_manager_.release(stream); switch (request_size) { - case Request::ENCODE_ERROR_BATCH_WITH_NAMED_VALUES: - case Request::ENCODE_ERROR_PARAMETER_UNSET: + case Request::REQUEST_ERROR_BATCH_WITH_NAMED_VALUES: + case Request::REQUEST_ERROR_PARAMETER_UNSET: // Already handled break; @@ -282,8 +321,8 @@ bool Connection::internal_write(RequestCallback* callback, bool flush_immediatel "Operation unsupported by this protocol version"); break; } - callback->dec_ref(); - return true; // Don't retry + callback->finish(); + return request_size; } pending_writes_size_ += request_size; @@ -300,14 +339,6 @@ bool Connection::internal_write(RequestCallback* callback, bool flush_immediatel opcode_to_string(callback->request()->opcode()).c_str(), stream); callback->set_state(RequestCallback::REQUEST_STATE_WRITING); - uint64_t request_timeout_ms = callback->request_timeout_ms(config_); - if (request_timeout_ms > 0) { // 0 means no timeout - callback->start_timer(loop_, - request_timeout_ms, - callback, - Connection::on_timeout); - } - if (flush_immediately) { pending_write->flush(); } @@ -462,16 +493,18 @@ void Connection::consume(char* input, size_t size) { opcode_to_string(response->opcode())); } } else { - RequestCallback* callback = NULL; - if (stream_manager_.get_pending_and_release(response->stream(), callback)) { + RequestCallback* temp = NULL; + + if (stream_manager_.get_pending_and_release(response->stream(), temp)) { + RequestCallback::Ptr callback(temp); + switch (callback->state()) { case RequestCallback::REQUEST_STATE_READING: maybe_set_keyspace(response.get()); - pending_reads_.remove(callback); - callback->stop_timer(); - callback->set_state(RequestCallback::REQUEST_STATE_DONE); + pending_reads_.remove(callback.get()); + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); callback->on_set(response.get()); - callback->dec_ref(); + callback->finish(); break; case RequestCallback::REQUEST_STATE_WRITING: @@ -483,15 +516,14 @@ void Connection::consume(char* input, size_t size) { callback->on_set(response.get()); break; - case RequestCallback::REQUEST_STATE_TIMEOUT: - pending_reads_.remove(callback); - callback->set_state(RequestCallback::REQUEST_STATE_DONE); - callback->dec_ref(); + case RequestCallback::REQUEST_STATE_CANCELLED: + pending_reads_.remove(callback.get()); + callback->finish(); break; - case RequestCallback::REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING: + case RequestCallback::REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING: // We must wait for the write callback before we can do the cleanup - callback->set_state(RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE); + callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE); break; default: @@ -725,23 +757,11 @@ void Connection::on_read_ssl(uv_stream_t* client, ssize_t nread, const uv_buf_t* } } -void Connection::on_timeout(Timer* timer) { - RequestCallback* callback = static_cast(timer->data()); - Connection* connection = callback->connection(); - LOG_INFO("Request timed out to host %s on connection(%p)", - connection->host_->address_string().c_str(), - static_cast(connection)); - // TODO (mpenick): We need to handle the case where we have too many - // timeout requests and we run out of stream ids. The java-driver - // uses a threshold to defunct the connection. - callback->set_state(RequestCallback::REQUEST_STATE_TIMEOUT); - callback->on_timeout(); - - connection->metrics_->request_timeouts.inc(); -} - void Connection::on_connected() { - internal_write(new StartupCallback(this, Request::ConstPtr(new OptionsRequest()))); + internal_write(RequestCallback::Ptr( + new StartupCallback(this, + Request::ConstPtr( + new OptionsRequest())))); } void Connection::on_authenticate(const std::string& class_name) { @@ -759,10 +779,10 @@ void Connection::on_auth_challenge(const AuthResponseRequest* request, notify_error("Failed evaluating challenge token: " + request->auth()->error(), CONNECTION_ERROR_AUTH); return; } - internal_write(new StartupCallback(this, - Request::ConstPtr( - new AuthResponseRequest(response, - request->auth())))); + internal_write(RequestCallback::Ptr( + new StartupCallback(this, + Request::ConstPtr( + new AuthResponseRequest(response, request->auth()))))); } void Connection::on_auth_success(const AuthResponseRequest* request, @@ -777,18 +797,20 @@ void Connection::on_auth_success(const AuthResponseRequest* request, void Connection::on_ready() { if (state_ == CONNECTION_STATE_CONNECTED && listener_->event_types() != 0) { set_state(CONNECTION_STATE_REGISTERING_EVENTS); - internal_write(new StartupCallback(this, - Request::ConstPtr( - new RegisterRequest(listener_->event_types())))); + internal_write(RequestCallback::Ptr( + new StartupCallback(this, + Request::ConstPtr( + new RegisterRequest(listener_->event_types()))))); return; } if (keyspace_.empty()) { notify_ready(); } else { - internal_write(new StartupCallback(this, - Request::ConstPtr( - new QueryRequest("USE \"" + keyspace_ + "\"")))); + internal_write(RequestCallback::Ptr( + new StartupCallback(this, + Request::ConstPtr( + new QueryRequest("USE \"" + keyspace_ + "\""))))); } } @@ -803,8 +825,10 @@ void Connection::on_supported(ResponseMessage* response) { // TODO(mstump) do something with the supported info (void)supported; - internal_write(new StartupCallback(this, - Request::ConstPtr(new StartupRequest()))); + internal_write(RequestCallback::Ptr( + new StartupCallback(this, + Request::ConstPtr( + new StartupRequest())))); } void Connection::on_pending_schema_agreement(Timer* timer) { @@ -871,9 +895,10 @@ void Connection::send_credentials(const std::string& class_name) { if (v1_auth) { V1Authenticator::Credentials credentials; v1_auth->get_credentials(&credentials); - internal_write(new StartupCallback(this, - Request::ConstPtr( - new CredentialsRequest(credentials)))); + internal_write(RequestCallback::Ptr( + new StartupCallback(this, + Request::ConstPtr( + new CredentialsRequest(credentials))))); } else { send_initial_auth_response(class_name); } @@ -889,9 +914,10 @@ void Connection::send_initial_auth_response(const std::string& class_name) { notify_error("Failed creating initial response token: " + auth->error(), CONNECTION_ERROR_AUTH); return; } - internal_write(new StartupCallback(this, - Request::ConstPtr( - new AuthResponseRequest(response, auth)))); + internal_write(RequestCallback::Ptr( + new StartupCallback(this, + Request::ConstPtr( + new AuthResponseRequest(response, auth))))); } } @@ -907,7 +933,7 @@ void Connection::on_heartbeat(Timer* timer) { Connection* connection = static_cast(timer->data()); if (!connection->heartbeat_outstanding_) { - if (!connection->internal_write(new HeartbeatCallback(connection))) { + if (!connection->internal_write(RequestCallback::Ptr(new HeartbeatCallback(connection)))) { // Recycling only this connection with a timeout error. This is unlikely and // it means the connection ran out of stream IDs as a result of requests // that never returned and as a result timed out. @@ -967,16 +993,23 @@ void Connection::PendingWriteBase::on_write(uv_write_t* req, int status) { Connection* connection = static_cast(pending_write->connection_); + connection->pending_writes_size_ -= pending_write->size(); + if (connection->pending_writes_size_ < + connection->config_.write_bytes_low_water_mark() && + connection->state_ == CONNECTION_STATE_OVERWHELMED) { + connection->set_state(CONNECTION_STATE_READY); + } + while (!pending_write->callbacks_.is_empty()) { - RequestCallback* callback = pending_write->callbacks_.front(); + RequestCallback::Ptr callback(pending_write->callbacks_.front()); - pending_write->callbacks_.remove(callback); + pending_write->callbacks_.remove(callback.get()); switch (callback->state()) { case RequestCallback::REQUEST_STATE_WRITING: if (status == 0) { callback->set_state(RequestCallback::REQUEST_STATE_READING); - connection->pending_reads_.add_to_back(callback); + connection->pending_reads_.add_to_back(callback.get()); } else { if (!connection->is_closing()) { connection->notify_error("Write error '" + @@ -986,32 +1019,36 @@ void Connection::PendingWriteBase::on_write(uv_write_t* req, int status) { } connection->stream_manager_.release(callback->stream()); - callback->stop_timer(); - callback->set_state(RequestCallback::REQUEST_STATE_DONE); + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); callback->on_error(CASS_ERROR_LIB_WRITE_ERROR, "Unable to write to socket"); - callback->dec_ref(); + callback->finish(); } break; - case RequestCallback::REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING: - // The read may still come back, handle cleanup there - callback->set_state(RequestCallback::REQUEST_STATE_TIMEOUT); - connection->pending_reads_.add_to_back(callback); - break; - case RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE: // The read callback happened before the write callback // returned. This is now responsible for cleanup. - callback->stop_timer(); - callback->set_state(RequestCallback::REQUEST_STATE_DONE); - callback->dec_ref(); + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); + callback->finish(); break; case RequestCallback::REQUEST_STATE_RETRY_WRITE_OUTSTANDING: - callback->stop_timer(); - callback->retry(); - callback->dec_ref(); + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); + callback->finish_with_retry(false); + break; + + case RequestCallback::REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE: + // The read callback happened before the write callback + // returned. This is now responsible for cleanup. + callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); + callback->finish(); + break; + + case RequestCallback::REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING: + // The read may still come back, handle cleanup there + callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); + connection->pending_reads_.add_to_back(callback.get()); break; default: @@ -1020,13 +1057,6 @@ void Connection::PendingWriteBase::on_write(uv_write_t* req, int status) { } } - connection->pending_writes_size_ -= pending_write->size(); - if (connection->pending_writes_size_ < - connection->config_.write_bytes_low_water_mark() && - connection->state_ == CONNECTION_STATE_OVERWHELMED) { - connection->set_state(CONNECTION_STATE_READY); - } - connection->pending_writes_.remove(pending_write); delete pending_write; diff --git a/src/connection.hpp b/src/connection.hpp index 017340153..955d882ff 100644 --- a/src/connection.hpp +++ b/src/connection.hpp @@ -104,11 +104,12 @@ class Connection { void connect(); - bool write(RequestCallback* request, bool flush_immediately = true); + bool write(const RequestCallback::Ptr& request, bool flush_immediately = true); void flush(); void schedule_schema_agreement(const SchemaChangeCallback::Ptr& callback, uint64_t wait); + uv_loop_t* loop() { return loop_; } const Config& config() const { return config_; } Metrics* metrics() { return metrics_; } const Address& address() const { return host_->address(); } @@ -163,12 +164,9 @@ class Connection { char buf_[MAX_BUFFER_SIZE]; }; - class StartupCallback : public RequestCallback { + class StartupCallback : public SimpleRequestCallback { public: - StartupCallback(Connection* connection, const Request::ConstPtr& request) - : RequestCallback(request) { - set_connection(connection); - } + StartupCallback(Connection* connection, const Request::ConstPtr& request); virtual void on_set(ResponseMessage* response); virtual void on_error(CassError code, const std::string& message); @@ -178,7 +176,7 @@ class Connection { void on_result_response(ResponseMessage* response); }; - class HeartbeatCallback : public RequestCallback { + class HeartbeatCallback : public SimpleRequestCallback { public: HeartbeatCallback(Connection* connection); @@ -254,7 +252,7 @@ class Connection { Timer timer; }; - bool internal_write(RequestCallback* request, bool flush_immediately = true); + int32_t internal_write(const RequestCallback::Ptr& request, bool flush_immediately = true); void internal_close(ConnectionState close_state); void set_state(ConnectionState state); void consume(char* input, size_t size); diff --git a/src/control_connection.cpp b/src/control_connection.cpp index 828c6905e..6f4cadff1 100644 --- a/src/control_connection.cpp +++ b/src/control_connection.cpp @@ -657,7 +657,7 @@ void ControlConnection::refresh_node_info(Host::Ptr host, this, response_callback, data)); - if (!connection_->write(callback.get())) { + if (!connection_->write(callback)) { LOG_ERROR("No more stream available while attempting to refresh node info"); } } @@ -806,11 +806,11 @@ void ControlConnection::refresh_keyspace(const StringRef& keyspace_name) { LOG_DEBUG("Refreshing keyspace %s", query.c_str()); - connection_->write( + connection_->write(RequestCallback::Ptr( new ControlCallback(Request::ConstPtr(new QueryRequest(query)), this, ControlConnection::on_refresh_keyspace, - keyspace_name.to_string())); + keyspace_name.to_string()))); } void ControlConnection::on_refresh_keyspace(ControlConnection* control_connection, @@ -936,11 +936,12 @@ void ControlConnection::refresh_type(const StringRef& keyspace_name, LOG_DEBUG("Refreshing type %s", query.c_str()); - connection_->write( - new ControlCallback >(Request::ConstPtr(new QueryRequest(query)), - this, - ControlConnection::on_refresh_type, - std::make_pair(keyspace_name.to_string(), type_name.to_string()))); + connection_->write(RequestCallback::Ptr( + new ControlCallback >( + Request::ConstPtr(new QueryRequest(query)), + this, + ControlConnection::on_refresh_type, + std::make_pair(keyspace_name.to_string(), type_name.to_string())))); } void ControlConnection::on_refresh_type(ControlConnection* control_connection, @@ -1002,11 +1003,11 @@ void ControlConnection::refresh_function(const StringRef& keyspace_name, request->set(1, CassString(function_name.data(), function_name.size())); request->set(2, signature); - connection_->write( + connection_->write(RequestCallback::Ptr( new ControlCallback(request, this, ControlConnection::on_refresh_function, - RefreshFunctionData(keyspace_name, function_name, arg_types, is_aggregate))); + RefreshFunctionData(keyspace_name, function_name, arg_types, is_aggregate)))); } void ControlConnection::on_refresh_function(ControlConnection* control_connection, @@ -1118,4 +1119,6 @@ void ControlConnection::ControlMultipleRequestCallback::on_set( Address ControlConnection::bind_any_ipv4_("0.0.0.0", 0); Address ControlConnection::bind_any_ipv6_("::", 0); + + } // namespace cass diff --git a/src/control_connection.hpp b/src/control_connection.hpp index c41b7a301..fc23a82df 100644 --- a/src/control_connection.hpp +++ b/src/control_connection.hpp @@ -18,6 +18,7 @@ #define __CASS_CONTROL_CONNECTION_HPP_INCLUDED__ #include "address.hpp" +#include "config.hpp" #include "connection.hpp" #include "request_callback.hpp" #include "host.hpp" @@ -116,18 +117,20 @@ class ControlConnection : public Connection::Listener { struct UnusedData {}; template - class ControlCallback : public RequestCallback { + class ControlCallback : public SimpleRequestCallback { public: typedef void (*ResponseCallback)(ControlConnection*, const T&, Response*); ControlCallback(const Request::ConstPtr& request, - ControlConnection* control_connection, - ResponseCallback response_callback, - const T& data) - : RequestCallback(request) + ControlConnection* control_connection, + ResponseCallback response_callback, + const T& data) + : SimpleRequestCallback(control_connection->connection_->loop(), + control_connection->connection_->config().request_timeout_ms(), + request) , control_connection_(control_connection) , response_callback_(response_callback) - , data_(data) {} + , data_(data) { } virtual void on_set(ResponseMessage* response) { Response* response_body = response->response_body().get(); diff --git a/src/io_worker.cpp b/src/io_worker.cpp index 7e0d45e94..9a297cc8a 100644 --- a/src/io_worker.cpp +++ b/src/io_worker.cpp @@ -135,38 +135,39 @@ void IOWorker::add_pool(const Host::ConstPtr& host, bool is_initial_connection) } } -bool IOWorker::execute(RequestHandler* request_handler) { - return request_queue_.enqueue(request_handler); -} - -void IOWorker::retry(RequestHandler* request_handler) { - Address address; - if (!request_handler->get_current_host_address(&address)) { - request_handler->on_error(CASS_ERROR_LIB_NO_HOSTS_AVAILABLE, - "All hosts in current policy attempted " - "and were either unavailable or failed"); - return; +bool IOWorker::execute(const RequestHandler::Ptr& request_handler) { + request_handler->inc_ref(); // Queue reference + if (!request_queue_.enqueue(request_handler.get())) { + request_handler->dec_ref(); + return false; } + return true; +} - PoolMap::const_iterator it = pools_.find(address); - if (it != pools_.end() && it->second->is_ready()) { - const Pool::Ptr& pool = it->second; - Connection* connection = pool->borrow_connection(); - if (connection != NULL) { - if (!pool->write(connection, request_handler)) { - request_handler->next_host(); - retry(request_handler); +void IOWorker::retry(const SpeculativeExecution::Ptr& speculative_execution) { + while (speculative_execution->current_host()) { + PoolMap::const_iterator it = pools_.find(speculative_execution->current_host()->address()); + if (it != pools_.end() && it->second->is_ready()) { + const Pool::Ptr& pool = it->second; + Connection* connection = pool->borrow_connection(); + if (connection != NULL) { + if (pool->write(connection, speculative_execution)) { + return; // Success + } + } else { // Too busy, or no connections + pool->wait_for_connection(speculative_execution); + return; // Waiting for connection } - } else { // Too busy, or no connections - pool->wait_for_connection(request_handler); } - } else { - request_handler->next_host(); - retry(request_handler); + speculative_execution->next_host(); } + + speculative_execution->on_error(CASS_ERROR_LIB_NO_HOSTS_AVAILABLE, + "All hosts in current policy attempted " + "and were either unavailable or failed"); } -void IOWorker::request_finished(RequestHandler* request_handler) { +void IOWorker::request_finished() { pending_request_count_--; maybe_close(); request_queue_.send(); @@ -281,13 +282,17 @@ void IOWorker::on_execute(uv_async_t* async) { #endif IOWorker* io_worker = static_cast(async->data); - RequestHandler* request_handler = NULL; + RequestHandler* temp = NULL; size_t remaining = io_worker->config().max_requests_per_flush(); - while (remaining != 0 && io_worker->request_queue_.dequeue(request_handler)) { - if (request_handler != NULL) { + while (remaining != 0 && io_worker->request_queue_.dequeue(temp)) { + RequestHandler::Ptr request_handler(temp); + if (request_handler) { + request_handler->dec_ref(); // Queue reference io_worker->pending_request_count_++; - request_handler->set_io_worker(io_worker); - request_handler->retry(); + request_handler->start_request(io_worker); + SpeculativeExecution::Ptr speculative_execution(new SpeculativeExecution(request_handler, + request_handler->first_host())); + speculative_execution->execute(); } else { io_worker->state_ = IO_WORKER_STATE_CLOSING; } diff --git a/src/io_worker.hpp b/src/io_worker.hpp index 29aaf5a95..3d803d165 100644 --- a/src/io_worker.hpp +++ b/src/io_worker.hpp @@ -27,6 +27,7 @@ #include "logger.hpp" #include "metrics.hpp" #include "pool.hpp" +#include "request_handler.hpp" #include "spsc_queue.hpp" #include "timer.hpp" @@ -38,6 +39,7 @@ namespace cass { class Config; +class Pool; class RequestHandler; class Session; class SSLContext; @@ -104,10 +106,10 @@ class IOWorker bool remove_pool_async(const Host::ConstPtr& host, bool cancel_reconnect); void close_async(); - bool execute(RequestHandler* request_handler); + bool execute(const RequestHandler::Ptr& request_handler); - void retry(RequestHandler* request_handler); - void request_finished(RequestHandler* request_handler); + void retry(const SpeculativeExecution::Ptr& speculative_execution); + void request_finished(); void notify_pool_ready(Pool* pool); void notify_pool_closed(Pool* pool); diff --git a/src/pool.cpp b/src/pool.cpp index 99ed92634..a9bad9301 100644 --- a/src/pool.cpp +++ b/src/pool.cpp @@ -34,11 +34,11 @@ static bool least_busy_comp(Connection* a, Connection* b) { return a->pending_request_count() < b->pending_request_count(); } -class SetKeyspaceCallback : public RequestCallback { +class SetKeyspaceCallback : public SimpleRequestCallback { public: SetKeyspaceCallback(Connection* connection, const std::string& keyspace, - RequestHandler* request_handler); + const SpeculativeExecution::Ptr& speculative_execution); virtual void on_set(ResponseMessage* response); virtual void on_error(CassError code, const std::string& message); @@ -48,16 +48,16 @@ class SetKeyspaceCallback : public RequestCallback { void on_result_response(ResponseMessage* response); private: - RequestHandler::Ptr request_handler_; + SpeculativeExecution::Ptr speculative_execution_; }; SetKeyspaceCallback::SetKeyspaceCallback(Connection* connection, - const std::string& keyspace, - RequestHandler* request_handler) - : RequestCallback(Request::ConstPtr(new QueryRequest("USE \"" + keyspace + "\""))) - , request_handler_(request_handler) { - set_connection(connection); -} + const std::string& keyspace, + const SpeculativeExecution::Ptr& speculative_execution) + : SimpleRequestCallback(connection->loop(), + connection->config().request_timeout_ms(), + Request::ConstPtr(new QueryRequest("USE \"" + keyspace + "\""))) + , speculative_execution_(speculative_execution) { } void SetKeyspaceCallback::on_set(ResponseMessage* response) { switch (response->opcode()) { @@ -65,9 +65,9 @@ void SetKeyspaceCallback::on_set(ResponseMessage* response) { on_result_response(response); break; case CQL_OPCODE_ERROR: - connection_->defunct(); - request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, - "Unable to set keyspace"); + connection()->defunct(); + speculative_execution_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, + "Unable to set keyspace"); break; default: break; @@ -75,28 +75,27 @@ void SetKeyspaceCallback::on_set(ResponseMessage* response) { } void SetKeyspaceCallback::on_error(CassError code, const std::string& message) { - connection_->defunct(); - request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, - "Unable to set keyspace"); + connection()->defunct(); + speculative_execution_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, + "Unable to set keyspace"); } void SetKeyspaceCallback::on_timeout() { - // TODO(mpenick): What to do here? - request_handler_->on_timeout(); + speculative_execution_->retry_next_host(); } void SetKeyspaceCallback::on_result_response(ResponseMessage* response) { ResultResponse* result = static_cast(response->response_body().get()); if (result->kind() == CASS_RESULT_KIND_SET_KEYSPACE) { - if (!connection_->write(request_handler_.get())) { + if (!connection()->write(speculative_execution_)) { // Try on the same host but a different connection - request_handler_->retry(); + speculative_execution_->retry_current_host(); } } else { - connection_->defunct(); - request_handler_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, - "Unable to set keyspace"); + connection()->defunct(); + speculative_execution_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, + "Unable to set keyspace"); } } @@ -121,12 +120,12 @@ Pool::~Pool() { static_cast(this), static_cast(pending_requests_.size())); while (!pending_requests_.is_empty()) { - RequestHandler* request_handler - = static_cast(pending_requests_.front()); - pending_requests_.remove(request_handler); - request_handler->stop_timer(); - request_handler->next_host(); - request_handler->retry(); + SpeculativeExecution::Ptr speculative_execution( + static_cast(pending_requests_.front())); + pending_requests_.remove(speculative_execution.get()); + speculative_execution->dec_ref(); + speculative_execution->stop_pending_request(); + speculative_execution->retry_next_host(); } } @@ -211,40 +210,22 @@ Connection* Pool::borrow_connection() { } void Pool::return_connection(Connection* connection) { - if (!connection->is_ready() || pending_requests_.is_empty()) return; - RequestHandler* request_handler - = static_cast(pending_requests_.front()); - remove_pending_request(request_handler); - request_handler->stop_timer(); - if (!write(connection, request_handler)) { - request_handler->next_host(); - request_handler->retry(); - } -} + while (connection->is_ready() && !pending_requests_.is_empty()) { + SpeculativeExecution::Ptr speculative_execution( + static_cast(pending_requests_.front())); -void Pool::add_pending_request(RequestHandler* request_handler) { - pending_requests_.add_to_back(request_handler); - - if (pending_requests_.size() % 10 == 0) { - LOG_DEBUG("%u request%s pending on %s pool(%p)", - static_cast(pending_requests_.size() + 1), - pending_requests_.size() > 0 ? "s":"", - host_->address_string().c_str(), - static_cast(this)); - } + remove_pending_request(speculative_execution.get()); + speculative_execution->stop_pending_request(); - if (pending_requests_.size() > config_.pending_requests_high_water_mark()) { - LOG_WARN("Exceeded pending requests water mark (current: %u water mark: %u) for host %s", - static_cast(pending_requests_.size()), - config_.pending_requests_high_water_mark(), - host_->address_string().c_str()); - set_is_available(false); - metrics_->exceeded_pending_requests_water_mark.inc(); + if (!write(connection, speculative_execution)) { + speculative_execution->retry_next_host(); + } } } -void Pool::remove_pending_request(RequestHandler* request_handler) { - pending_requests_.remove(request_handler); +void Pool::remove_pending_request(SpeculativeExecution* speculative_execution) { + pending_requests_.remove(speculative_execution); + speculative_execution->dec_ref(); set_is_available(true); } @@ -264,10 +245,10 @@ void Pool::set_is_available(bool is_available) { } } -bool Pool::write(Connection* connection, RequestHandler* request_handler) { - request_handler->set_pool(this); +bool Pool::write(Connection* connection, const SpeculativeExecution::Ptr& speculative_execution) { + speculative_execution->set_pool(this); if (*io_worker_->keyspace() == connection->keyspace()) { - if (!connection->write(request_handler, false)) { + if (!connection->write(speculative_execution, false)) { return false; } } else { @@ -275,8 +256,9 @@ bool Pool::write(Connection* connection, RequestHandler* request_handler) { io_worker_->keyspace()->c_str(), static_cast(connection), static_cast(this)); - if (!connection->write(new SetKeyspaceCallback(connection, *io_worker_->keyspace(), - request_handler), false)) { + if (!connection->write(RequestCallback::Ptr( + new SetKeyspaceCallback(connection, *io_worker_->keyspace(), + speculative_execution)), false)) { return false; } } @@ -446,25 +428,45 @@ void Pool::on_availability_change(Connection* connection) { } void Pool::on_pending_request_timeout(Timer* timer) { - RequestHandler* request_handler = static_cast(timer->data()); - Pool* pool = request_handler->pool(); + SpeculativeExecution::Ptr speculative_execution( + static_cast(timer->data())); + Pool* pool = speculative_execution->pool(); pool->metrics_->pending_request_timeouts.inc(); - pool->remove_pending_request(request_handler); - request_handler->next_host(); - request_handler->retry(); + pool->remove_pending_request(speculative_execution.get()); + speculative_execution->retry_next_host(); LOG_DEBUG("Timeout waiting for connection to %s pool(%p)", pool->host_->address_string().c_str(), static_cast(pool)); pool->maybe_close(); } -void Pool::wait_for_connection(RequestHandler* request_handler) { - request_handler->set_pool(this); - request_handler->start_timer(loop_, - config_.connect_timeout_ms(), - request_handler, - Pool::on_pending_request_timeout); - add_pending_request(request_handler); +void Pool::wait_for_connection(const SpeculativeExecution::Ptr& speculative_execution) { + if (speculative_execution->state() == RequestCallback::REQUEST_STATE_CANCELLED) { + return; + } + + speculative_execution->inc_ref(); + pending_requests_.add_to_back(speculative_execution.get()); + + speculative_execution->start_pending_request(this, + Pool::on_pending_request_timeout); + + if (pending_requests_.size() % 10 == 0) { + LOG_DEBUG("%u request%s pending on %s pool(%p)", + static_cast(pending_requests_.size() + 1), + pending_requests_.size() > 0 ? "s":"", + host_->address_string().c_str(), + static_cast(this)); + } + + if (pending_requests_.size() > config_.pending_requests_high_water_mark()) { + LOG_WARN("Exceeded pending requests water mark (current: %u water mark: %u) for host %s", + static_cast(pending_requests_.size()), + config_.pending_requests_high_water_mark(), + host_->address_string().c_str()); + set_is_available(false); + metrics_->exceeded_pending_requests_water_mark.inc(); + } } void Pool::on_partial_reconnect(Timer* timer) { diff --git a/src/pool.hpp b/src/pool.hpp index 25f9a42af..48ec90c11 100644 --- a/src/pool.hpp +++ b/src/pool.hpp @@ -24,13 +24,13 @@ #include "ref_counted.hpp" #include "request.hpp" #include "request_callback.hpp" +#include "request_handler.hpp" #include "scoped_ptr.hpp" #include "timer.hpp" namespace cass { class IOWorker; -class RequestHandler; class Config; class Pool : public RefCounted @@ -56,13 +56,15 @@ class Pool : public RefCounted void delayed_connect(); void close(bool cancel_reconnect = false); - bool write(Connection* connection, RequestHandler* request_handler); + bool write(Connection* connection, const SpeculativeExecution::Ptr& speculative_execution); void flush(); - void wait_for_connection(RequestHandler* request_handler); + void wait_for_connection(const SpeculativeExecution::Ptr& speculative_execution); Connection* borrow_connection(); const Host::ConstPtr& host() const { return host_; } + uv_loop_t* loop() { return loop_; } + const Config& config() const { return config_; } bool is_initial_connection() const { return is_initial_connection_; } bool is_ready() const { return state_ == POOL_STATE_READY; } @@ -81,8 +83,7 @@ class Pool : public RefCounted void return_connection(Connection* connection); private: - void add_pending_request(RequestHandler* request_handler); - void remove_pending_request(RequestHandler* request_handler); + void remove_pending_request(SpeculativeExecution* speculative_execution); void set_is_available(bool is_available); void maybe_notify_ready(); diff --git a/src/prepare_request.hpp b/src/prepare_request.hpp index 3141ad7ed..695ed9310 100644 --- a/src/prepare_request.hpp +++ b/src/prepare_request.hpp @@ -26,8 +26,9 @@ namespace cass { class PrepareRequest : public Request { public: - PrepareRequest() - : Request(CQL_OPCODE_PREPARE) {} + PrepareRequest(const std::string& query) + : Request(CQL_OPCODE_PREPARE) + , query_(query) { } const std::string& query() const { return query_; } diff --git a/src/query_request.cpp b/src/query_request.cpp index 1514eb780..76c4ff5de 100644 --- a/src/query_request.cpp +++ b/src/query_request.cpp @@ -40,7 +40,7 @@ int32_t QueryRequest::encode_batch(int version, BufferVec* bufs, RequestCallback if (has_names_for_values()) { if (version < 3) { LOG_ERROR("Protocol version %d does not support named values", version); - return ENCODE_ERROR_UNSUPPORTED_PROTOCOL; + return REQUEST_ERROR_UNSUPPORTED_PROTOCOL; } buf.encode_uint16(pos, value_names_.size()); length += copy_buffers_with_names(version, bufs, callback->encoding_cache()); @@ -155,7 +155,7 @@ int QueryRequest::internal_encode(int version, RequestCallback* callback, Buffer if (has_names_for_values()) { if (version < 3) { LOG_ERROR("Protocol version %d does not support named values", version); - return ENCODE_ERROR_UNSUPPORTED_PROTOCOL; + return REQUEST_ERROR_UNSUPPORTED_PROTOCOL; } buf.encode_uint16(pos, value_names_.size()); length += copy_buffers_with_names(version, bufs, callback->encoding_cache()); diff --git a/src/request.cpp b/src/request.cpp index 051015be4..78beed2c6 100644 --- a/src/request.cpp +++ b/src/request.cpp @@ -16,6 +16,7 @@ #include "request.hpp" +#include "config.hpp" #include "external_types.hpp" extern "C" { @@ -83,4 +84,12 @@ int32_t CustomPayload::encode(BufferVec* bufs) const { return length; } +uint64_t Request::request_timeout_ms(uint64_t default_timeout_ms) const { + uint64_t request_timeout_ms = request_timeout_ms_; + if (request_timeout_ms == CASS_UINT64_MAX) { + return default_timeout_ms; + } + return request_timeout_ms; +} + } // namespace cass diff --git a/src/request.hpp b/src/request.hpp index efb574059..4ae0fc1d3 100644 --- a/src/request.hpp +++ b/src/request.hpp @@ -59,9 +59,11 @@ class Request : public RefCounted { typedef SharedRefPtr ConstPtr; enum { - ENCODE_ERROR_UNSUPPORTED_PROTOCOL = -1, - ENCODE_ERROR_BATCH_WITH_NAMED_VALUES = -2, - ENCODE_ERROR_PARAMETER_UNSET = -3 + REQUEST_ERROR_UNSUPPORTED_PROTOCOL = -1, + REQUEST_ERROR_BATCH_WITH_NAMED_VALUES = -2, + REQUEST_ERROR_PARAMETER_UNSET = -3, + REQUEST_ERROR_NO_AVAILABLE_STREAM_IDS = -4, + REQUEST_ERROR_CANCELLED = -5 }; static const CassConsistency DEFAULT_CONSISTENCY = CASS_CONSISTENCY_LOCAL_ONE; @@ -73,6 +75,7 @@ class Request : public RefCounted { , consistency_(DEFAULT_CONSISTENCY) , serial_consistency_(CASS_CONSISTENCY_ANY) , timestamp_(CASS_INT64_MIN) + , is_idempotent_(false) , request_timeout_ms_(CASS_UINT64_MAX) { } // Disabled (use the cluster-level timeout) virtual ~Request() { } @@ -93,7 +96,11 @@ class Request : public RefCounted { void set_timestamp(int64_t timestamp) { timestamp_ = timestamp; } - uint64_t request_timeout_ms() const { return request_timeout_ms_; } + bool is_idempotent() const { return is_idempotent_; } + + void set_is_idempotent(bool is_idempotent) { is_idempotent_ = is_idempotent; } + + uint64_t request_timeout_ms(uint64_t default_timeout_ms) const; void set_request_timeout_ms(uint64_t request_timeout_ms) { request_timeout_ms_ = request_timeout_ms; @@ -122,6 +129,7 @@ class Request : public RefCounted { CassConsistency consistency_; CassConsistency serial_consistency_; int64_t timestamp_; + bool is_idempotent_; uint64_t request_timeout_ms_; RetryPolicy::Ptr retry_policy_; CustomPayload::ConstPtr custom_payload_; diff --git a/src/request_callback.cpp b/src/request_callback.cpp index b21fdafd3..f73b6419a 100644 --- a/src/request_callback.cpp +++ b/src/request_callback.cpp @@ -29,7 +29,7 @@ namespace cass { int32_t RequestCallback::encode(int version, int flags, BufferVec* bufs) { if (version < 1 || version > 4) { - return Request::ENCODE_ERROR_UNSUPPORTED_PROTOCOL; + return Request::REQUEST_ERROR_UNSUPPORTED_PROTOCOL; } size_t index = bufs->size(); @@ -71,11 +71,9 @@ int32_t RequestCallback::encode(int version, int flags, BufferVec* bufs) { void RequestCallback::set_state(RequestCallback::State next_state) { switch (state_) { case REQUEST_STATE_NEW: - if (next_state == REQUEST_STATE_NEW) { - state_ = next_state; - stream_ = -1; - } else if (next_state == REQUEST_STATE_WRITING) { - start_time_ns_ = uv_hrtime(); + if (next_state == REQUEST_STATE_NEW || + next_state == REQUEST_STATE_CANCELLED || + next_state == REQUEST_STATE_WRITING) { state_ = next_state; } else { assert(false && "Invalid request state after new"); @@ -83,78 +81,89 @@ void RequestCallback::set_state(RequestCallback::State next_state) { break; case REQUEST_STATE_WRITING: - if (next_state == REQUEST_STATE_READING) { // Success - state_ = next_state; - } else if (next_state == REQUEST_STATE_READ_BEFORE_WRITE || - next_state == REQUEST_STATE_DONE) { - stop_timer(); + if(next_state == REQUEST_STATE_READING || + next_state == REQUEST_STATE_READ_BEFORE_WRITE || + next_state == REQUEST_STATE_FINISHED) { state_ = next_state; - } else if (next_state == REQUEST_STATE_TIMEOUT) { - state_ = REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING; + } else if (next_state == REQUEST_STATE_CANCELLED) { + state_ = REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING; } else { assert(false && "Invalid request state after writing"); } break; case REQUEST_STATE_READING: - if (next_state == REQUEST_STATE_DONE) { // Success - stop_timer(); - state_ = next_state; - } else if (next_state == REQUEST_STATE_TIMEOUT) { + if(next_state == REQUEST_STATE_FINISHED || + next_state == REQUEST_STATE_CANCELLED) { state_ = next_state; } else { assert(false && "Invalid request state after reading"); } break; - case REQUEST_STATE_TIMEOUT: - assert(next_state == REQUEST_STATE_DONE && - "Invalid request state after timeout"); - state_ = next_state; + case REQUEST_STATE_READ_BEFORE_WRITE: + if (next_state == REQUEST_STATE_RETRY_WRITE_OUTSTANDING || + next_state == REQUEST_STATE_FINISHED) { + state_ = next_state; + } else if (next_state == REQUEST_STATE_CANCELLED) { + state_ = REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE; + } else { + assert(false && "Invalid request state after read before write"); + } break; - case REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING: - assert((next_state == REQUEST_STATE_TIMEOUT || - next_state == REQUEST_STATE_READ_BEFORE_WRITE) && - "Invalid request state after timeout (write outstanding)"); - state_ = next_state; + case REQUEST_STATE_RETRY_WRITE_OUTSTANDING: + if (next_state == REQUEST_STATE_FINISHED) { + state_ = next_state; + } else if (next_state == REQUEST_STATE_CANCELLED) { + state_ = REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING; + } else { + assert(false && "Invalid request state after retry"); + } break; - case REQUEST_STATE_READ_BEFORE_WRITE: - assert((next_state == REQUEST_STATE_DONE || - next_state == REQUEST_STATE_RETRY_WRITE_OUTSTANDING) && - "Invalid request state after read before write"); - state_ = next_state; + case REQUEST_STATE_FINISHED: + if (next_state == REQUEST_STATE_NEW || + next_state == REQUEST_STATE_CANCELLED) { + state_ = next_state; + } else { + assert(false && "Invalid request state after finished"); + } break; - case REQUEST_STATE_RETRY_WRITE_OUTSTANDING: - assert(next_state == REQUEST_STATE_NEW && "Invalid request state after retry"); - state_ = next_state; + case REQUEST_STATE_CANCELLED: + assert((next_state == REQUEST_STATE_FINISHED && + next_state == REQUEST_STATE_CANCELLED) || + "Invalid request state after cancelled"); + // Ignore. Leave the request in the cancelled state. break; - case REQUEST_STATE_DONE: - assert(next_state == REQUEST_STATE_NEW && "Invalid request state after done"); - state_ = next_state; + + case REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE: + if (next_state == REQUEST_STATE_CANCELLED) { + state_ = next_state; + } else { + assert(false && "Invalid request state after cancelled (read before write)"); + } + + case REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING: + if (next_state == REQUEST_STATE_CANCELLED || + next_state == REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE) { + state_ = next_state; + } else { + assert(false && "Invalid request state after cancelled (write outstanding)"); + } break; default: assert(false && "Invalid request state"); break; } - -} - -uint64_t RequestCallback::request_timeout_ms(const Config& config) const { - uint64_t request_timeout_ms = request_->request_timeout_ms(); - if (request_timeout_ms == CASS_UINT64_MAX) { - return config.request_timeout_ms(); - } - return request_timeout_ms; } bool MultipleRequestCallback::get_result_response(const ResponseMap& responses, - const std::string& index, - ResultResponse** response) { + const std::string& index, + ResultResponse** response) { ResponseMap::const_iterator it = responses.find(index); if (it == responses.end() || it->second->opcode() != CQL_OPCODE_RESULT) { return false; @@ -170,11 +179,20 @@ void MultipleRequestCallback::execute_query(const std::string& index, const std: new InternalCallback(Ptr(this), Request::ConstPtr(new QueryRequest(query)), index)); remaining_++; - if (!connection_->write(callback.get())) { + if (!connection_->write(callback)) { on_error(CASS_ERROR_LIB_NO_STREAMS, "No more streams available"); } } +MultipleRequestCallback::InternalCallback::InternalCallback(const MultipleRequestCallback::Ptr& parent, + const Request::ConstPtr& request, + const std::string& index) + : SimpleRequestCallback(parent->connection()->loop(), + parent->connection()->config().request_timeout_ms(), + request) + , parent_(parent) + , index_(index) { } + void MultipleRequestCallback::InternalCallback::on_set(ResponseMessage* response) { parent_->responses_[index_] = response->response_body(); if (--parent_->remaining_ == 0 && !parent_->has_errors_or_timeouts_) { @@ -196,4 +214,15 @@ void MultipleRequestCallback::InternalCallback::on_timeout() { parent_->has_errors_or_timeouts_ = true; } +SimpleRequestCallback::SimpleRequestCallback(uv_loop_t* loop, + uint64_t request_timeout_ms, + const Request::ConstPtr& request) + : RequestCallback() + , request_(request) { + timer_.start(loop, + request->request_timeout_ms(request_timeout_ms), + this, + on_timeout); +} + } // namespace cass diff --git a/src/request_callback.hpp b/src/request_callback.hpp index 82dd7b2ae..f78a6b497 100644 --- a/src/request_callback.hpp +++ b/src/request_callback.hpp @@ -41,98 +41,125 @@ typedef std::vector UvBufVec; class RequestCallback : public RefCounted, public List::Node { public: + typedef SharedRefPtr Ptr; + enum State { REQUEST_STATE_NEW, REQUEST_STATE_WRITING, REQUEST_STATE_READING, - REQUEST_STATE_TIMEOUT, - REQUEST_STATE_TIMEOUT_WRITE_OUTSTANDING, REQUEST_STATE_READ_BEFORE_WRITE, REQUEST_STATE_RETRY_WRITE_OUTSTANDING, - REQUEST_STATE_DONE + REQUEST_STATE_FINISHED, + REQUEST_STATE_CANCELLED, + REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE, + REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING }; - RequestCallback(const Request::ConstPtr& request) - : request_(request) - , connection_(NULL) + RequestCallback() + : connection_(NULL) , stream_(-1) , state_(REQUEST_STATE_NEW) - , cl_(CASS_CONSISTENCY_UNKNOWN) - , timestamp_(CASS_INT64_MIN) - , start_time_ns_(0) { } + , cl_(CASS_CONSISTENCY_UNKNOWN) { } + + virtual ~RequestCallback() { } - virtual ~RequestCallback() {} + void start(Connection* connection, int stream) { + connection_ = connection; + stream_ = stream; + on_start(); + } + + void finish() { + on_finish(); + connection_ = NULL; + stream_ = -1; + dec_ref(); + } + + void finish_with_retry(bool use_next_host) { + on_finish_with_retry(use_next_host); + connection_ = NULL; + stream_ = -1; + dec_ref(); + } int32_t encode(int version, int flags, BufferVec* bufs); virtual void on_set(ResponseMessage* response) = 0; virtual void on_error(CassError code, const std::string& message) = 0; - virtual void on_timeout() = 0; - virtual void retry() { } + virtual const Request* request() const = 0; + virtual Request::EncodingCache* encoding_cache() = 0; - const Request* request() const { return request_.get(); } + virtual int64_t timestamp() const { return request()->timestamp(); } Connection* connection() const { return connection_; } - void set_connection(Connection* connection) { - connection_ = connection; - } - int stream() const { return stream_; } - void set_stream(int stream) { - stream_ = stream; - } - State state() const { return state_; } - void set_state(State next_state); - void start_timer(uv_loop_t* loop, uint64_t timeout, void* data, - Timer::Callback cb) { - timer_.start(loop, timeout, data, cb); - } - - void stop_timer() { - timer_.stop(); - } - CassConsistency consistency() const { return cl_ != CASS_CONSISTENCY_UNKNOWN ? cl_ : request()->consistency(); } void set_consistency(CassConsistency cl) { cl_ = cl; } - int64_t timestamp() const { - return timestamp_; +protected: + // Called right before a request is written to a host. + virtual void on_start() = 0; + + // One of theses methods will always be called when a connection is finished + // with a request regardless of the outcome. + virtual void on_finish() = 0; + virtual void on_finish_with_retry(bool use_next_host) = 0; + +private: + Connection* connection_; + int stream_; + State state_; + CassConsistency cl_; + +private: + DISALLOW_COPY_AND_ASSIGN(RequestCallback); +}; + +class SimpleRequestCallback : public RequestCallback { +public: + SimpleRequestCallback(uv_loop_t* loop, + uint64_t request_timeout_ms, + const Request::ConstPtr& request); + + virtual void on_start() { + // Ignore } - void set_timestamp(int64_t timestamp) { - timestamp_ = timestamp; + virtual void on_finish() { + timer_.stop(); } - uint64_t request_timeout_ms(const Config& config) const; + virtual void on_finish_with_retry(bool use_next_host) { + timer_.stop(); + on_timeout(); + } - uint64_t start_time_ns() const { return start_time_ns_; } + virtual const Request* request() const { return request_.get(); } + virtual Request::EncodingCache* encoding_cache() { return &encoding_cache_; } - Request::EncodingCache* encoding_cache() { return &encoding_cache_; } + virtual void on_timeout() = 0; -protected: - SharedRefPtr request_; - Connection* connection_; +private: + static void on_timeout(Timer* timer) { + SimpleRequestCallback* callback = static_cast(timer->data()); + callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); + callback->on_timeout(); + } private: Timer timer_; - int stream_; - State state_; - CassConsistency cl_; - int64_t timestamp_; - uint64_t start_time_ns_; + const Request::ConstPtr request_; Request::EncodingCache encoding_cache_; - -private: - DISALLOW_COPY_AND_ASSIGN(RequestCallback); }; class MultipleRequestCallback : public RefCounted { @@ -157,24 +184,21 @@ class MultipleRequestCallback : public RefCounted { virtual void on_error(CassError code, const std::string& message) = 0; virtual void on_timeout() = 0; - Connection* connection() { - return connection_; - } + Connection* connection() { return connection_; } private: - class InternalCallback : public RequestCallback { + class InternalCallback : public SimpleRequestCallback { public: - InternalCallback(const Ptr& parent, const Request::ConstPtr& request, const std::string& index) - : RequestCallback(request) - , parent_(parent) - , index_(index) { } + InternalCallback(const MultipleRequestCallback::Ptr& parent, + const Request::ConstPtr& request, + const std::string& index); virtual void on_set(ResponseMessage* response); virtual void on_error(CassError code, const std::string& message); virtual void on_timeout(); private: - Ptr parent_; + MultipleRequestCallback::Ptr parent_; std::string index_; }; diff --git a/src/request_handler.cpp b/src/request_handler.cpp index 1345bbad8..b431c780e 100644 --- a/src/request_handler.cpp +++ b/src/request_handler.cpp @@ -34,58 +34,35 @@ namespace cass { -class PrepareCallback : public RequestCallback { +class PrepareCallback : public SimpleRequestCallback { public: - PrepareCallback(const RequestHandler::Ptr& request_handler) - : RequestCallback(Request::ConstPtr()) - , request_handler_(request_handler) {} - - bool init(const std::string& prepared_id); + PrepareCallback(const std::string& query, SpeculativeExecution* speculative_execution) + : SimpleRequestCallback(speculative_execution->connection()->loop(), + speculative_execution->connection()->config().request_timeout_ms(), + Request::ConstPtr(new PrepareRequest(query))) + , speculative_execution_(speculative_execution) { } virtual void on_set(ResponseMessage* response); virtual void on_error(CassError code, const std::string& message); virtual void on_timeout(); private: - RequestHandler::Ptr request_handler_; + SpeculativeExecution::Ptr speculative_execution_; }; -bool PrepareCallback::init(const std::string& prepared_id) { - PrepareRequest* prepare = - static_cast(new PrepareRequest()); - request_.reset(prepare); - if (request_handler_->request()->opcode() == CQL_OPCODE_EXECUTE) { - const ExecuteRequest* execute = static_cast( - request_handler_->request()); - prepare->set_query(execute->prepared()->statement()); - return true; - } else if (request_handler_->request()->opcode() == CQL_OPCODE_BATCH) { - const BatchRequest* batch = static_cast( - request_handler_->request()); - std::string prepared_statement; - if (batch->prepared_statement(prepared_id, &prepared_statement)) { - prepare->set_query(prepared_statement); - return true; - } - } - return false; // Invalid request type -} - void PrepareCallback::on_set(ResponseMessage* response) { switch (response->opcode()) { case CQL_OPCODE_RESULT: { ResultResponse* result = static_cast(response->response_body().get()); if (result->kind() == CASS_RESULT_KIND_PREPARED) { - request_handler_->retry(); + speculative_execution_->retry_current_host(); } else { - request_handler_->next_host(); - request_handler_->retry(); + speculative_execution_->retry_next_host(); } } break; case CQL_OPCODE_ERROR: - request_handler_->next_host(); - request_handler_->retry(); + speculative_execution_->retry_next_host(); break; default: break; @@ -93,18 +70,113 @@ void PrepareCallback::on_set(ResponseMessage* response) { } void PrepareCallback::on_error(CassError code, const std::string& message) { - request_handler_->next_host(); - request_handler_->retry(); + speculative_execution_->retry_next_host(); } void PrepareCallback::on_timeout() { - request_handler_->next_host(); - request_handler_->retry(); + speculative_execution_->retry_next_host(); +} + +void RequestHandler::add_execution(SpeculativeExecution* speculative_execution) { + running_executions_++; + speculative_execution->inc_ref(); + speculative_executions_.push_back(speculative_execution); } -void RequestHandler::on_set(ResponseMessage* response) { - assert(connection_ != NULL); - assert(!is_query_plan_exhausted_ && "Tried to set on a non-existent host"); +void RequestHandler::schedule_next_execution(const Host::Ptr& current_host) { + int64_t timeout = execution_plan_->next_execution(current_host); + if (timeout >= 0) { + SpeculativeExecution::Ptr speculative_execution( + new SpeculativeExecution(RequestHandler::Ptr(this))); + speculative_execution->schedule_next(timeout); + } +} + +void RequestHandler::start_request(IOWorker* io_worker) { + io_worker_ = io_worker; + timer_.start(io_worker->loop(), + request_->request_timeout_ms(io_worker->config().request_timeout_ms()), + this, + on_timeout); +} + +void RequestHandler::set_response(const Host::Ptr& host, + const Response::Ptr& response) { + if (future_->set_response(host->address(), response)) { + stop_request(); + } +} + +void RequestHandler::set_error(CassError code, + const std::string& message) { + if (future_->set_error(code, message)) { + stop_request(); + } +} + +void RequestHandler::set_error(const Host::Ptr& host, + CassError code, const std::string& message) { + bool skip = (code == CASS_ERROR_LIB_NO_HOSTS_AVAILABLE && --running_executions_ > 0); + if (!skip) { + if (host) { + if (future_->set_error_with_address(host->address(), code, message)) { + stop_request(); + } + } else { + set_error(code, message); + } + } +} + +void RequestHandler::set_error_with_error_response(const Host::Ptr& host, + const Response::Ptr& error, + CassError code, const std::string& message) { + if (future_->set_error_with_response(host->address(), error, code, message)) { + stop_request(); + } +} + +void RequestHandler::on_timeout(Timer* timer) { + RequestHandler* request_handler = + static_cast(timer->data()); + request_handler->set_error(CASS_ERROR_LIB_REQUEST_TIMED_OUT, + "Request timed out"); +} + +void RequestHandler::stop_request() { + timer_.stop(); + for (SpeculativeExecutionVec::const_iterator i = speculative_executions_.begin(), + end = speculative_executions_.end(); i != end; ++i) { + SpeculativeExecution* speculative_execution = *i; + speculative_execution->cancel(); + speculative_execution->dec_ref(); + } + if (io_worker_ != NULL) { + io_worker_->request_finished(); + io_worker_ = NULL; + } +} + +SpeculativeExecution::SpeculativeExecution(const RequestHandler::Ptr& request_handler, + const Host::Ptr& current_host) + : RequestCallback() + , request_handler_(request_handler) + , current_host_(current_host) + , pool_(NULL) + , num_retries_(0) + , start_time_ns_(0) { + request_handler_->add_execution(this); +} + +void SpeculativeExecution::on_execute(Timer* timer) { + SpeculativeExecution* speculative_execution = static_cast(timer->data()); + speculative_execution->next_host(); + speculative_execution->execute(); +} + +void SpeculativeExecution::on_set(ResponseMessage* response) { + assert(connection() != NULL); + assert(current_host_ && "Tried to set on a non-existent host"); switch (response->opcode()) { case CQL_OPCODE_RESULT: on_result_response(response); @@ -113,102 +185,96 @@ void RequestHandler::on_set(ResponseMessage* response) { on_error_response(response); break; default: - connection_->defunct(); + connection()->defunct(); set_error(CASS_ERROR_LIB_UNEXPECTED_RESPONSE, "Unexpected response"); break; } } -void RequestHandler::on_error(CassError code, const std::string& message) { +void SpeculativeExecution::on_error(CassError code, const std::string& message) { + // Handle recoverable errors by retrying with the next host if (code == CASS_ERROR_LIB_WRITE_ERROR || code == CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE) { - next_host(); - retry(); - return_connection(); + retry_next_host(); } else { set_error(code, message); } } -void RequestHandler::on_timeout() { - assert(!is_query_plan_exhausted_ && "Tried to timeout on a non-existent host"); - set_error(CASS_ERROR_LIB_REQUEST_TIMED_OUT, "Request timed out"); +void SpeculativeExecution::on_start() { + start_time_ns_ = uv_hrtime(); } -void RequestHandler::set_io_worker(IOWorker* io_worker) { - io_worker_ = io_worker; +void SpeculativeExecution::on_finish() { + return_connection(); } -void RequestHandler::retry() { - // Reset the request so it can be executed again - set_state(REQUEST_STATE_NEW); - pool_ = NULL; - io_worker_->retry(this); -} +void SpeculativeExecution::on_finish_with_retry(bool use_next_host) { + return_connection(); -bool RequestHandler::get_current_host_address(Address* address) { - if (is_query_plan_exhausted_) { - return false; + if (use_next_host) { + retry_next_host(); + } else { + retry_current_host(); } - *address = current_host_->address(); - return true; -} - -void RequestHandler::next_host() { - current_host_ = query_plan_->compute_next(); - is_query_plan_exhausted_ = !current_host_; } -bool RequestHandler::is_host_up(const Address& address) const { - return io_worker_->is_host_up(address); +void SpeculativeExecution::start_pending_request(Pool* pool, Timer::Callback cb) { + pool_ = pool; + pending_request_timer_.start(pool->loop(), pool->config().connect_timeout_ms(), this, cb); } -void RequestHandler::set_response(const Response::Ptr& response) { - uint64_t elapsed = uv_hrtime() - start_time_ns(); - current_host_->update_latency(elapsed); - connection_->metrics()->record_request(elapsed); - future_->set_response(current_host_->address(), response); - return_connection_and_finish(); +void SpeculativeExecution::stop_pending_request() { + pending_request_timer_.stop(); } -void RequestHandler::set_error(CassError code, const std::string& message) { - if (is_query_plan_exhausted_) { - future_->set_error(code, message); - } else { - future_->set_error_with_host_address(current_host_->address(), code, message); +void SpeculativeExecution::retry_current_host() { + if (state() == REQUEST_STATE_CANCELLED) { + return; } - return_connection_and_finish(); + + // Reset the request so it can be executed again + set_state(REQUEST_STATE_NEW); + pool_ = NULL; + request_handler_->io_worker()->retry(RequestCallback::Ptr(this)); } -void RequestHandler::set_error_with_error_response(const Response::Ptr& error, - CassError code, const std::string& message) { - future_->set_error_with_response(current_host_->address(), error, code, message); - return_connection_and_finish(); +void SpeculativeExecution::retry_next_host() { + next_host(); + retry_current_host(); } -void RequestHandler::return_connection() { - if (pool_ != NULL && connection_ != NULL) { - pool_->return_connection(connection_); +void SpeculativeExecution::execute() { + if (request()->is_idempotent()) { + request_handler_->schedule_next_execution(current_host_); } + request_handler_->io_worker()->retry(RequestCallback::Ptr(this)); + } -void RequestHandler::return_connection_and_finish() { - return_connection(); - if (io_worker_ != NULL) { - io_worker_->request_finished(this); +void SpeculativeExecution::schedule_next(int64_t timeout) { + if (timeout > 0) { + schedule_timer_.start(request_handler_->io_worker()->loop(), timeout, this, on_execute); + } else { + next_host(); + execute(); } - dec_ref(); } -void RequestHandler::on_result_response(ResponseMessage* response) { +void SpeculativeExecution::cancel() { + schedule_timer_.stop(); + set_state(REQUEST_STATE_CANCELLED); +} + +void SpeculativeExecution::on_result_response(ResponseMessage* response) { ResultResponse* result = static_cast(response->response_body().get()); switch (result->kind()) { case CASS_RESULT_KIND_ROWS: // Execute statements with no metadata get their metadata from // result_metadata() returned when the statement was prepared. - if (request_->opcode() == CQL_OPCODE_EXECUTE && result->no_metadata()) { - const ExecuteRequest* execute = static_cast(request_.get()); + if (request()->opcode() == CQL_OPCODE_EXECUTE && result->no_metadata()) { + const ExecuteRequest* execute = static_cast(request()); if (!execute->skip_metadata()) { // Caused by a race condition in C* 2.1.0 on_error(CASS_ERROR_LIB_UNEXPECTED_RESPONSE, "Expected metadata but no metadata in response (see CASSANDRA-8054)"); @@ -221,7 +287,7 @@ void RequestHandler::on_result_response(ResponseMessage* response) { case CASS_RESULT_KIND_SCHEMA_CHANGE: { SchemaChangeCallback::Ptr schema_change_handler( - new SchemaChangeCallback(connection_, + new SchemaChangeCallback(connection(), Ptr(this), response->response_body())); schema_change_handler->execute(); @@ -229,7 +295,7 @@ void RequestHandler::on_result_response(ResponseMessage* response) { } case CASS_RESULT_KIND_SET_KEYSPACE: - io_worker_->broadcast_keyspace_change(result->keyspace().to_string()); + request_handler_->io_worker()->broadcast_keyspace_change(result->keyspace().to_string()); set_response(response->response_body()); break; @@ -239,70 +305,82 @@ void RequestHandler::on_result_response(ResponseMessage* response) { } } -void RequestHandler::on_error_response(ResponseMessage* response) { +void SpeculativeExecution::on_error_response(ResponseMessage* response) { ErrorResponse* error = static_cast(response->response_body().get()); + RetryPolicy::RetryDecision decision = RetryPolicy::RetryDecision::return_error(); - switch(error->code()) { - case CQL_ERROR_UNPREPARED: - on_error_unprepared(error); - break; + switch(error->code()) { case CQL_ERROR_READ_TIMEOUT: - handle_retry_decision(response, - retry_policy_->on_read_timeout(error->consistency(), - error->received(), - error->required(), - error->data_present() > 0, - num_retries_)); + decision = request_handler_->retry_policy()->on_read_timeout(request(), + error->consistency(), + error->received(), + error->required(), + error->data_present() > 0, + num_retries_); break; case CQL_ERROR_WRITE_TIMEOUT: - handle_retry_decision(response, - retry_policy_->on_write_timeout(error->consistency(), - error->received(), - error->required(), - error->write_type(), - num_retries_)); + if (request()->is_idempotent()) { + decision = request_handler_->retry_policy()->on_write_timeout(request(), + error->consistency(), + error->received(), + error->required(), + error->write_type(), + num_retries_); + } break; case CQL_ERROR_UNAVAILABLE: - handle_retry_decision(response, - retry_policy_->on_unavailable(error->consistency(), - error->required(), - error->received(), - num_retries_)); + decision = request_handler_->retry_policy()->on_unavailable(request(), + error->consistency(), + error->required(), + error->received(), + num_retries_); break; - default: - set_error(static_cast(CASS_ERROR( - CASS_ERROR_SOURCE_SERVER, error->code())), - error->message().to_string()); + case CQL_ERROR_OVERLOADED: + LOG_WARN("Host %s is overloaded.", + connection()->address_string().c_str()); + if (request()->is_idempotent()) { + decision = request_handler_->retry_policy()->on_request_error(request(), + request()->consistency(), + error, + num_retries_); + } break; - } -} -void RequestHandler::on_error_unprepared(ErrorResponse* error) { - SharedRefPtr prepare_handler(new PrepareCallback(RequestHandler::Ptr(this))); - if (prepare_handler->init(error->prepared_id().to_string())) { - if (!connection_->write(prepare_handler.get())) { - // Try to prepare on the same host but on a different connection - retry(); - } - } else { - connection_->defunct(); - set_error(CASS_ERROR_LIB_UNEXPECTED_RESPONSE, - "Received unprepared error for invalid " - "request type or invalid prepared id"); - } -} + case CQL_ERROR_SERVER_ERROR: + LOG_WARN("Received server error '%s' from host %s. Defuncting the connection...", + error->message().to_string().c_str(), + connection()->address_string().c_str()); + connection()->defunct(); + if (request()->is_idempotent()) { + decision = request_handler_->retry_policy()->on_request_error(request(), + request()->consistency(), + error, + num_retries_); + } + break; -void RequestHandler::handle_retry_decision(ResponseMessage* response, - const RetryPolicy::RetryDecision& decision) { - ErrorResponse* error = - static_cast(response->response_body().get()); + case CQL_ERROR_IS_BOOTSTRAPPING: + LOG_ERROR("Query sent to bootstrapping host %s. Retrying on the next host...", + connection()->address_string().c_str()); + retry_next_host(); + return; // Done + case CQL_ERROR_UNPREPARED: + on_error_unprepared(error); + return; // Done + + default: + // Return the error response + break; + } + + // Process retry decision switch(decision.type()) { case RetryPolicy::RetryDecision::RETURN_ERROR: set_error_with_error_response(response->response_body(), @@ -316,18 +394,72 @@ void RequestHandler::handle_retry_decision(ResponseMessage* response, if (!decision.retry_current_host()) { next_host(); } - if (state() == REQUEST_STATE_DONE) { - retry(); + if (state() == REQUEST_STATE_FINISHED) { + retry_current_host(); } else { set_state(REQUEST_STATE_RETRY_WRITE_OUTSTANDING); } + num_retries_++; break; case RetryPolicy::RetryDecision::IGNORE: set_response(Response::Ptr(new ResultResponse())); break; } - num_retries_++; +} + +void SpeculativeExecution::on_error_unprepared(ErrorResponse* error) { + std::string prepared_statement; + + if (request()->opcode() == CQL_OPCODE_EXECUTE) { + const ExecuteRequest* execute = static_cast(request()); + prepared_statement = execute->prepared()->statement(); + } else if (request()->opcode() == CQL_OPCODE_BATCH) { + const BatchRequest* batch = static_cast(request()); + if (!batch->prepared_statement(error->prepared_id().to_string(), &prepared_statement)) { + set_error(CASS_ERROR_LIB_UNEXPECTED_RESPONSE, + "Unable to find prepared statement in batch statement"); + return; + } + } else { + connection()->defunct(); + set_error(CASS_ERROR_LIB_UNEXPECTED_RESPONSE, + "Received unprepared error for invalid " + "request type or invalid prepared id"); + return; + } + + if (!connection()->write(RequestCallback::Ptr( + new PrepareCallback(prepared_statement, this)))) { + // Try to prepare on the same host but on a different connection + retry_current_host(); + } +} + +bool SpeculativeExecution::is_host_up(const Address& address) const { + return request_handler_->io_worker()->is_host_up(address); +} + +void SpeculativeExecution::set_response(const Response::Ptr& response) { + uint64_t elapsed = uv_hrtime() - start_time_ns_; + current_host_->update_latency(elapsed); + connection()->metrics()->record_request(elapsed); + request_handler_->set_response(current_host_, response); +} + +void SpeculativeExecution::set_error(CassError code, const std::string& message) { + request_handler_->set_error(current_host_, code, message); +} + +void SpeculativeExecution::set_error_with_error_response(const Response::Ptr& error, + CassError code, const std::string& message) { + request_handler_->set_error_with_error_response(current_host_, error, code, message); +} + +void SpeculativeExecution::return_connection() { + if (pool_ != NULL && connection() != NULL) { + pool_->return_connection(connection()); + } } } // namespace cass diff --git a/src/request_handler.hpp b/src/request_handler.hpp index c6b12841c..07aeeb431 100644 --- a/src/request_handler.hpp +++ b/src/request_handler.hpp @@ -28,6 +28,7 @@ #include "response.hpp" #include "retry_policy.hpp" #include "scoped_ptr.hpp" +#include "speculative_execution.hpp" #include #include @@ -47,11 +48,15 @@ class ResponseFuture : public Future { : Future(CASS_FUTURE_TYPE_RESPONSE) , schema_metadata(metadata.schema_snapshot(protocol_version, cassandra_version)) { } - void set_response(Address address, const Response::Ptr& response) { + bool set_response(Address address, const Response::Ptr& response) { ScopedMutex lock(&mutex_); - address_ = address; - response_ = response; - internal_set(lock); + if (!is_set()) { + address_ = address; + response_ = response; + internal_set(lock); + return true; + } + return false; } const Response::Ptr& response() { @@ -60,21 +65,29 @@ class ResponseFuture : public Future { return response_; } - void set_error_with_host_address(Address address, CassError code, const std::string& message) { + bool set_error_with_address(Address address, CassError code, const std::string& message) { ScopedMutex lock(&mutex_); - address_ = address; - internal_set_error(code, message, lock); + if (!is_set()) { + address_ = address; + internal_set_error(code, message, lock); + return true; + } + return false; } - void set_error_with_response(Address address, const Response::Ptr& response, + bool set_error_with_response(Address address, const Response::Ptr& response, CassError code, const std::string& message) { ScopedMutex lock(&mutex_); - address_ = address; - response_ = response; - internal_set_error(code, message, lock); + if (!is_set()) { + address_ = address; + response_ = response; + internal_set_error(code, message, lock); + return true; + } + return false; } - Address get_host_address() { + Address address() { ScopedMutex lock(&mutex_); internal_wait(lock); return address_; @@ -88,71 +101,144 @@ class ResponseFuture : public Future { Response::Ptr response_; }; +class SpeculativeExecution; -class RequestHandler : public RequestCallback { +class RequestHandler : public RefCounted { public: typedef SharedRefPtr Ptr; RequestHandler(const Request::ConstPtr& request, const ResponseFuture::Ptr& future, RetryPolicy* retry_policy) - : RequestCallback(request) - , future_(future) - , retry_policy_(retry_policy) - , num_retries_(0) - , is_query_plan_exhausted_(true) - , io_worker_(NULL) - , pool_(NULL) { - set_timestamp(request->timestamp()); - } + : request_(request) + , timestamp_(request->timestamp()) + , future_(future) + , retry_policy_(retry_policy) + , io_worker_(NULL) + , running_executions_(0) { } - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); + const Request* request() const { return request_.get(); } + + int64_t timestamp() const { return timestamp_; } + void set_timestamp(int64_t timestamp) { timestamp_ = timestamp; } - virtual void retry(); + Request::EncodingCache* encoding_cache() { return &encoding_cache_; } + + RetryPolicy* retry_policy() { return retry_policy_; } void set_query_plan(QueryPlan* query_plan) { query_plan_.reset(query_plan); + first_host_ = next_host(); + } + + void set_execution_plan(SpeculativeExecutionPlan* execution_plan) { + execution_plan_.reset(execution_plan); } - void set_io_worker(IOWorker* io_worker); + const Host::Ptr& first_host() const { return first_host_; } + const Host::Ptr next_host() { return query_plan_->compute_next(); } + + IOWorker* io_worker() { return io_worker_; } + + void start_request(IOWorker* io_worker); + + void set_response(const Host::Ptr& host, + const Response::Ptr& response); + void set_error(CassError code, const std::string& message); + void set_error(const Host::Ptr& host, + CassError code, const std::string& message); + void set_error_with_error_response(const Host::Ptr& host, + const Response::Ptr& error, + CassError code, const std::string& message); + +private: + static void on_timeout(Timer* timer); + +private: + friend class SpeculativeExecution; + + void add_execution(SpeculativeExecution* speculative_execution); + void schedule_next_execution(const Host::Ptr& first_host); + void stop_request(); + +private: + typedef std::vector SpeculativeExecutionVec; + + const Request::ConstPtr request_; + int64_t timestamp_; + SharedRefPtr future_; + RetryPolicy* retry_policy_; + ScopedPtr query_plan_; + ScopedPtr execution_plan_; + Host::Ptr first_host_; + IOWorker* io_worker_; + Timer timer_; + int running_executions_; + SpeculativeExecutionVec speculative_executions_; + Request::EncodingCache encoding_cache_; +}; + +class SpeculativeExecution : public RequestCallback { +public: + typedef SharedRefPtr Ptr; + + SpeculativeExecution(const RequestHandler::Ptr& request_handler, + const Host::Ptr& current_host = Host::Ptr()); + + virtual void on_set(ResponseMessage* response); + virtual void on_error(CassError code, const std::string& message); + + virtual const Request* request() const { return request_handler_->request(); } + virtual int64_t timestamp() const { return request_handler_->timestamp(); } + virtual Request::EncodingCache* encoding_cache() { return request_handler_->encoding_cache(); } Pool* pool() const { return pool_; } + void set_pool(Pool* pool) { pool_ = pool; } - void set_pool(Pool* pool) { - pool_ = pool; - } + const Host::Ptr& current_host() const { return current_host_; } + void next_host() { current_host_ = request_handler_->next_host(); } - bool get_current_host_address(Address* address); - void next_host(); + void start_pending_request(Pool* pool, Timer::Callback cb); + void stop_pending_request(); - bool is_host_up(const Address& address) const; + void retry_current_host(); + void retry_next_host(); - void set_response(const Response::Ptr& response); + void execute(); + void schedule_next(int64_t timeout = 0); + void cancel(); + +private: + static void on_execute(Timer* timer); + + virtual void on_start(); + virtual void on_finish(); + virtual void on_finish_with_retry(bool use_next_host); private: + friend class SchemaChangeCallback; + + bool is_host_up(const Address& address) const; + + void set_response(const Response::Ptr& response); void set_error(CassError code, const std::string& message); void set_error_with_error_response(const Response::Ptr& error, CassError code, const std::string& message); + void return_connection(); - void return_connection_and_finish(); void on_result_response(ResponseMessage* response); void on_error_response(ResponseMessage* response); void on_error_unprepared(ErrorResponse* error); - void handle_retry_decision(ResponseMessage* response, - const RetryPolicy::RetryDecision& decision); - - SharedRefPtr future_; - RetryPolicy* retry_policy_; - int num_retries_; - bool is_query_plan_exhausted_; +private: + RequestHandler::Ptr request_handler_; Host::Ptr current_host_; - ScopedPtr query_plan_; - IOWorker* io_worker_; Pool* pool_; + Timer schedule_timer_; + Timer pending_request_timer_; + int num_retries_; + uint64_t start_time_ns_; }; } // namespace cass diff --git a/src/retry_policy.cpp b/src/retry_policy.cpp index 3c643adc3..0a6ded0e6 100644 --- a/src/retry_policy.cpp +++ b/src/retry_policy.cpp @@ -18,6 +18,7 @@ #include "external_types.hpp" #include "logger.hpp" +#include "request.hpp" extern "C" { @@ -72,7 +73,9 @@ inline RetryPolicy::RetryDecision max_likely_to_work(int received) { // Default retry policy -RetryPolicy::RetryDecision DefaultRetryPolicy::on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const { +RetryPolicy::RetryDecision DefaultRetryPolicy::on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const { if (num_retries != 0) { return RetryDecision::return_error(); } @@ -84,7 +87,9 @@ RetryPolicy::RetryDecision DefaultRetryPolicy::on_read_timeout(CassConsistency c } } -RetryPolicy::RetryDecision DefaultRetryPolicy::on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const { +RetryPolicy::RetryDecision DefaultRetryPolicy::on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const { if (num_retries != 0) { return RetryDecision::return_error(); } @@ -96,7 +101,9 @@ RetryPolicy::RetryDecision DefaultRetryPolicy::on_write_timeout(CassConsistency } } -RetryPolicy::RetryDecision DefaultRetryPolicy::on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const { +RetryPolicy::RetryDecision DefaultRetryPolicy::on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const { if (num_retries == 0) { return RetryDecision::retry_next_host(cl); } else { @@ -104,9 +111,16 @@ RetryPolicy::RetryDecision DefaultRetryPolicy::on_unavailable(CassConsistency cl } } +RetryPolicy::RetryDecision DefaultRetryPolicy::on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* error, int num_retries) const { + return RetryDecision::retry_next_host(cl); +} + // Downgrading retry policy -RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const { +RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const { if (num_retries != 0) { return RetryDecision::return_error(); } @@ -127,7 +141,9 @@ RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_read_timeout(Ca } } -RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const { +RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const { if (num_retries != 0) { return RetryDecision::return_error(); } @@ -152,31 +168,51 @@ RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_write_timeout(C } } -RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const { +RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const { if (num_retries != 0) { return RetryDecision::return_error(); } return max_likely_to_work(alive); } +RetryPolicy::RetryDecision DowngradingConsistencyRetryPolicy::on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* error, int num_retries) const { + return RetryDecision::retry_next_host(cl); +} + // Fallthrough retry policy -RetryPolicy::RetryDecision FallthroughRetryPolicy::on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const { +RetryPolicy::RetryDecision FallthroughRetryPolicy::on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const { + return RetryDecision::return_error(); +} + +RetryPolicy::RetryDecision FallthroughRetryPolicy::on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const { return RetryDecision::return_error(); } -RetryPolicy::RetryDecision FallthroughRetryPolicy::on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const { +RetryPolicy::RetryDecision FallthroughRetryPolicy::on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const { return RetryDecision::return_error(); } -RetryPolicy::RetryDecision FallthroughRetryPolicy::on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const { +RetryPolicy::RetryDecision FallthroughRetryPolicy::on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* error, int num_retries) const { return RetryDecision::return_error(); } // Logging retry policy -RetryPolicy::RetryDecision LoggingRetryPolicy::on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const { - RetryDecision decision = retry_policy_->on_read_timeout(cl, received, required, data_recevied, num_retries); +RetryPolicy::RetryDecision LoggingRetryPolicy::on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const { + RetryDecision decision = retry_policy_->on_read_timeout(request, cl, received, required, data_recevied, num_retries); switch (decision.type()) { case RetryDecision::IGNORE: @@ -201,8 +237,10 @@ RetryPolicy::RetryDecision LoggingRetryPolicy::on_read_timeout(CassConsistency c return decision; } -RetryPolicy::RetryDecision LoggingRetryPolicy::on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const { - RetryDecision decision = retry_policy_->on_write_timeout(cl, received, required, write_type, num_retries); +RetryPolicy::RetryDecision LoggingRetryPolicy::on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const { + RetryDecision decision = retry_policy_->on_write_timeout(request, cl, received, required, write_type, num_retries); switch (decision.type()) { case RetryDecision::IGNORE: @@ -227,8 +265,10 @@ RetryPolicy::RetryDecision LoggingRetryPolicy::on_write_timeout(CassConsistency return decision; } -RetryPolicy::RetryDecision LoggingRetryPolicy::on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const { - RetryDecision decision = retry_policy_->on_unavailable(cl, required, alive, num_retries); +RetryPolicy::RetryDecision LoggingRetryPolicy::on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const { + RetryDecision decision = retry_policy_->on_unavailable(request, cl, required, alive, num_retries); switch (decision.type()) { case RetryDecision::IGNORE: @@ -252,4 +292,31 @@ RetryPolicy::RetryDecision LoggingRetryPolicy::on_unavailable(CassConsistency cl return decision; } +RetryPolicy::RetryDecision LoggingRetryPolicy::on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* error, int num_retries) const { + RetryDecision decision = retry_policy_->on_request_error(request, cl, error, num_retries); + + switch (decision.type()) { + case RetryDecision::IGNORE: + LOG_INFO("Ignoring request error (initial consistency: %s, error: %s, retries: %d)", + cass_consistency_string(cl), + error->message().to_string().c_str(), + num_retries); + break; + + case RetryDecision::RETRY: + LOG_INFO("Retrying on request error at consistency %s (initial consistency: %s, error: %s, retries: %d)", + cass_consistency_string(decision.retry_consistency()), + cass_consistency_string(cl), + error->message().to_string().c_str(), + num_retries); + break; + + default: + break; + } + + return decision; +} + } // namespace cass diff --git a/src/retry_policy.hpp b/src/retry_policy.hpp index 152251949..c61a86bb7 100644 --- a/src/retry_policy.hpp +++ b/src/retry_policy.hpp @@ -18,6 +18,7 @@ #define __CASS_RETRY_POLICY_HPP_INCLUDED__ #include "cassandra.h" +#include "error_response.hpp" #include "ref_counted.hpp" #ifdef _WIN32 @@ -28,6 +29,9 @@ namespace cass { +class ErrorResponse; +class Request; + class RetryPolicy : public RefCounted { public: typedef SharedRefPtr Ptr; @@ -85,9 +89,17 @@ class RetryPolicy : public RefCounted { Type type() const { return type_; } - virtual RetryDecision on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const = 0; - virtual RetryDecision on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const = 0; - virtual RetryDecision on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const = 0; + virtual RetryDecision on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const = 0; + virtual RetryDecision on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const = 0; + virtual RetryDecision on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const = 0; + virtual RetryDecision on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* error, int num_retries) const = 0; private: Type type_; @@ -98,9 +110,17 @@ class DefaultRetryPolicy : public RetryPolicy { DefaultRetryPolicy() : RetryPolicy(DEFAULT) { } - virtual RetryDecision on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const; - virtual RetryDecision on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const; - virtual RetryDecision on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const; + virtual RetryDecision on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const; + virtual RetryDecision on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const; + virtual RetryDecision on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const; + virtual RetryDecision on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* error, int num_retries) const; }; class DowngradingConsistencyRetryPolicy : public RetryPolicy { @@ -108,9 +128,17 @@ class DowngradingConsistencyRetryPolicy : public RetryPolicy { DowngradingConsistencyRetryPolicy() : RetryPolicy(DOWNGRADING) { } - virtual RetryDecision on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const; - virtual RetryDecision on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const; - virtual RetryDecision on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const; + virtual RetryDecision on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const; + virtual RetryDecision on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const; + virtual RetryDecision on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const; + virtual RetryDecision on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* error, int num_retries) const; }; class FallthroughRetryPolicy : public RetryPolicy { @@ -118,9 +146,17 @@ class FallthroughRetryPolicy : public RetryPolicy { FallthroughRetryPolicy() : RetryPolicy(FALLTHROUGH) { } - virtual RetryDecision on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const; - virtual RetryDecision on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const; - virtual RetryDecision on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const; + virtual RetryDecision on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const; + virtual RetryDecision on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const; + virtual RetryDecision on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const; + virtual RetryDecision on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* error, int num_retries) const; }; class LoggingRetryPolicy : public RetryPolicy { @@ -129,9 +165,17 @@ class LoggingRetryPolicy : public RetryPolicy { : RetryPolicy(LOGGING) , retry_policy_(retry_policy) { } - virtual RetryDecision on_read_timeout(CassConsistency cl, int received, int required, bool data_recevied, int num_retries) const; - virtual RetryDecision on_write_timeout(CassConsistency cl, int received, int required, CassWriteType write_type, int num_retries) const; - virtual RetryDecision on_unavailable(CassConsistency cl, int required, int alive, int num_retries) const; + virtual RetryDecision on_read_timeout(const Request* request, CassConsistency cl, + int received, int required, + bool data_recevied, int num_retries) const; + virtual RetryDecision on_write_timeout(const Request* request, CassConsistency cl, + int received, int required, + CassWriteType write_type, int num_retries) const; + virtual RetryDecision on_unavailable(const Request* request, CassConsistency cl, + int required, int alive, + int num_retries) const; + virtual RetryDecision on_request_error(const Request* request, CassConsistency cl, + const ErrorResponse* response, int num_retries) const; private: RetryPolicy::Ptr retry_policy_; diff --git a/src/schema_change_callback.cpp b/src/schema_change_callback.cpp index 64b27d2b8..940abcb3f 100644 --- a/src/schema_change_callback.cpp +++ b/src/schema_change_callback.cpp @@ -35,11 +35,11 @@ namespace cass { SchemaChangeCallback::SchemaChangeCallback(Connection* connection, - const RequestHandler::Ptr& request_handler, + const SpeculativeExecution::Ptr& speculative_execution, const Response::Ptr& response, uint64_t elapsed) : MultipleRequestCallback(connection) - , request_handler_(request_handler) + , speculative_execution_(speculative_execution) , request_response_(response) , start_ms_(get_time_since_epoch_ms()) , elapsed_ms_(elapsed) {} @@ -80,7 +80,7 @@ bool SchemaChangeCallback::has_schema_agreement(const ResponseMap& responses) { row->get_by_name("rpc_address"), &address); - if (is_valid_address && request_handler_->is_host_up(address)) { + if (is_valid_address && speculative_execution_->is_host_up(address)) { const Value* v = row->get_by_name("schema_version"); if (!row->get_by_name("rpc_address")->is_null() && !v->is_null()) { StringRef version(v->to_string_ref()); @@ -109,13 +109,13 @@ void SchemaChangeCallback::on_set(const ResponseMap& responses) { if (!has_error && has_schema_agreement(responses)) { LOG_DEBUG("Found schema agreement in %llu ms", static_cast(elapsed_ms_)); - request_handler_->set_response(request_response_); + speculative_execution_->set_response(request_response_); return; } else if (elapsed_ms_ >= MAX_SCHEMA_AGREEMENT_WAIT_MS) { LOG_WARN("No schema agreement on live nodes after %llu ms. " "Schema may not be up-to-date on some nodes.", static_cast(elapsed_ms_)); - request_handler_->set_response(request_response_); + speculative_execution_->set_response(request_response_); return; } @@ -125,7 +125,7 @@ void SchemaChangeCallback::on_set(const ResponseMap& responses) { // Try again Ptr callback( new SchemaChangeCallback(connection(), - request_handler_, + speculative_execution_, request_response_, elapsed_ms_)); connection()->schedule_schema_agreement(callback, @@ -137,17 +137,17 @@ void SchemaChangeCallback::on_error(CassError code, const std::string& message) ss << "An error occurred waiting for schema agreement: '" << message << "' (0x" << std::hex << std::uppercase << std::setw(8) << std::setfill('0') << code << ")"; LOG_ERROR("%s", ss.str().c_str()); - request_handler_->set_response(request_response_); + speculative_execution_->set_response(request_response_); } void SchemaChangeCallback::on_timeout() { LOG_ERROR("A timeout occurred waiting for schema agreement"); - request_handler_->set_response(request_response_); + speculative_execution_->set_response(request_response_); } void SchemaChangeCallback::on_closing() { LOG_WARN("Connection closed while waiting for schema agreement"); - request_handler_->set_response(request_response_); + speculative_execution_->set_response(request_response_); } } // namespace cass diff --git a/src/schema_change_callback.hpp b/src/schema_change_callback.hpp index a74322a4c..78b309c4a 100644 --- a/src/schema_change_callback.hpp +++ b/src/schema_change_callback.hpp @@ -34,7 +34,7 @@ class SchemaChangeCallback : public MultipleRequestCallback { typedef SharedRefPtr Ptr; SchemaChangeCallback(Connection* connection, - const RequestHandler::Ptr& request_handler, + const SpeculativeExecution::Ptr& speculative_execution, const Response::Ptr& response, uint64_t elapsed = 0); @@ -48,7 +48,7 @@ class SchemaChangeCallback : public MultipleRequestCallback { private: bool has_schema_agreement(const ResponseMap& responses); - RequestHandler::Ptr request_handler_; + SpeculativeExecution::Ptr speculative_execution_; Response::Ptr request_response_; uint64_t start_ms_; uint64_t elapsed_ms_; diff --git a/src/session.cpp b/src/session.cpp index 8c2211b20..f3e3a7df3 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -20,7 +20,6 @@ #include "constants.hpp" #include "logger.hpp" #include "prepare_request.hpp" -#include "request_handler.hpp" #include "scoped_lock.hpp" #include "timer.hpp" #include "external_types.hpp" @@ -162,6 +161,7 @@ void Session::clear(const Config& config) { random_.reset(); metrics_.reset(new Metrics(config_.thread_count_io() + 1)); load_balancing_policy_.reset(config.load_balancing_policy()); + speculative_execution_policy_.reset(config.speculative_execution_policy()); connect_future_.reset(); close_future_.reset(); { // Lock hosts @@ -526,14 +526,17 @@ void Session::on_resolve_done(MultiResolver* resolver) { } void Session::execute(const RequestHandler::Ptr& request_handler) { - request_handler->inc_ref(); if (state_.load(MEMORY_ORDER_ACQUIRE) != SESSION_STATE_CONNECTED) { - request_handler->on_error(CASS_ERROR_LIB_NO_HOSTS_AVAILABLE, - "Session is not connected"); + request_handler->set_error(CASS_ERROR_LIB_NO_HOSTS_AVAILABLE, + "Session is not connected"); return; - } else if (!request_queue_->enqueue(request_handler.get())) { - request_handler->on_error(CASS_ERROR_LIB_REQUEST_QUEUE_FULL, - "The request queue has reached capacity"); + } + + request_handler->inc_ref(); // Queue reference + if (!request_queue_->enqueue(request_handler.get())) { + request_handler->dec_ref(); + request_handler->set_error(CASS_ERROR_LIB_REQUEST_QUEUE_FULL, + "The request queue has reached capacity"); } } @@ -584,18 +587,14 @@ void Session::on_control_connection_error(CassError code, const std::string& mes } Future::Ptr Session::prepare(const char* statement, size_t length) { - SharedRefPtr prepare(new PrepareRequest()); - prepare->set_query(statement, length); + SharedRefPtr prepare(new PrepareRequest(std::string(statement, length))); - ResponseFuture::Ptr future( - new ResponseFuture(protocol_version(), - cassandra_version(), - metadata_)); + ResponseFuture::Ptr future(new ResponseFuture(protocol_version(), + cassandra_version(), + metadata_)); future->statement.assign(statement, length); - RequestHandler::Ptr request_handler(new RequestHandler(prepare, future, NULL)); - - execute(request_handler); + execute(RequestHandler::Ptr(new RequestHandler(prepare, future, NULL))); return future; } @@ -678,20 +677,17 @@ void Session::on_down(Host::Ptr host) { } Future::Ptr Session::execute(const Request::ConstPtr& request) { - ResponseFuture::Ptr future( - new ResponseFuture(protocol_version(), - cassandra_version(), - metadata_)); + ResponseFuture::Ptr future(new ResponseFuture(protocol_version(), + cassandra_version(), + metadata_)); RetryPolicy* retry_policy = request->retry_policy() != NULL ? request->retry_policy() : config().retry_policy(); - RequestHandler::Ptr request_handler(new RequestHandler(request, - future, - retry_policy)); - - execute(request_handler); + execute(RequestHandler::Ptr(new RequestHandler(request, + future, + retry_policy))); return future; } @@ -705,11 +701,16 @@ void Session::on_execute(uv_async_t* data) { bool is_closing = false; - RequestHandler* request_handler = NULL; - while (session->request_queue_->dequeue(request_handler)) { - if (request_handler != NULL) { - request_handler->set_query_plan(session->new_query_plan(request_handler->request(), - request_handler->encoding_cache())); + RequestHandler* temp = NULL; + while (session->request_queue_->dequeue(temp)) { + RequestHandler::Ptr request_handler(temp); + if (request_handler) { + request_handler->dec_ref(); // Queue reference + + request_handler->set_query_plan( + session->new_query_plan(request_handler->request(), + request_handler->encoding_cache())); + request_handler->set_execution_plan(session->new_execution_plan(request_handler->request())); if (request_handler->timestamp() == CASS_INT64_MIN) { request_handler->set_timestamp(session->config_.timestamp_gen()->next()); @@ -717,19 +718,16 @@ void Session::on_execute(uv_async_t* data) { bool is_done = false; while (!is_done) { - request_handler->next_host(); - - Address address; - if (!request_handler->get_current_host_address(&address)) { - request_handler->on_error(CASS_ERROR_LIB_NO_HOSTS_AVAILABLE, - "All connections on all I/O threads are busy"); + if (!request_handler->first_host()) { + request_handler->set_error(CASS_ERROR_LIB_NO_HOSTS_AVAILABLE, + "All connections on all I/O threads are busy"); break; } size_t start = session->current_io_worker_; for (size_t i = 0, size = session->io_workers_.size(); i < size; ++i) { const IOWorker::Ptr& io_worker = session->io_workers_[start % size]; - if (io_worker->is_host_available(address) && + if (io_worker->is_host_available(request_handler->first_host()->address()) && io_worker->execute(request_handler)) { session->current_io_worker_ = (start + 1) % size; is_done = true; @@ -758,4 +756,9 @@ QueryPlan* Session::new_query_plan(const Request* request, Request::EncodingCach return load_balancing_policy_->new_query_plan(*keyspace, request, token_map_.get(), cache); } +SpeculativeExecutionPlan* Session::new_execution_plan(const Request* request) { + const CopyOnWritePtr keyspace(keyspace_); + return speculative_execution_policy_->new_plan(*keyspace, request); +} + } // namespace cass diff --git a/src/session.hpp b/src/session.hpp index 7ff5665db..4d4bbaed5 100644 --- a/src/session.hpp +++ b/src/session.hpp @@ -29,10 +29,12 @@ #include "mpmc_queue.hpp" #include "random.hpp" #include "ref_counted.hpp" +#include "request_handler.hpp" #include "resolver.hpp" #include "row.hpp" #include "scoped_lock.hpp" #include "scoped_ptr.hpp" +#include "speculative_execution.hpp" #include "token_map.hpp" #include @@ -44,7 +46,6 @@ namespace cass { -class RequestHandler; class Future; class IOWorker; class Request; @@ -160,6 +161,7 @@ class Session : public EventThread { #endif QueryPlan* new_query_plan(const Request* request = NULL, Request::EncodingCache* cache = NULL); + SpeculativeExecutionPlan* new_execution_plan(const Request* request); void on_reconnect(Timer* timer); @@ -191,6 +193,7 @@ class Session : public EventThread { Config config_; ScopedPtr metrics_; LoadBalancingPolicy::Ptr load_balancing_policy_; + SharedRefPtr speculative_execution_policy_; CassError connect_error_code_; std::string connect_error_message_; Future::Ptr connect_future_; diff --git a/src/speculative_execution.hpp b/src/speculative_execution.hpp new file mode 100644 index 000000000..d3328df74 --- /dev/null +++ b/src/speculative_execution.hpp @@ -0,0 +1,97 @@ +/* + Copyright (c) 2014-2016 DataStax + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#ifndef __CASS_SPECULATIVE_EXECUTION_HPP_INCLUDED__ +#define __CASS_SPECULATIVE_EXECUTION_HPP_INCLUDED__ + +#include "host.hpp" +#include "ref_counted.hpp" + +#include +#include + +namespace cass { + +class Request; + +class SpeculativeExecutionPlan { +public: + virtual ~SpeculativeExecutionPlan() { } + + virtual int64_t next_execution(const Host::Ptr& current_host) = 0; +}; + +class SpeculativeExecutionPolicy : public RefCounted { +public: + virtual ~SpeculativeExecutionPolicy() { } + + virtual SpeculativeExecutionPlan* new_plan(const std::string& keyspace, + const Request* request) = 0; + + virtual SpeculativeExecutionPolicy* new_instance() = 0; +}; + +class NoSpeculativeExecutionPlan : public SpeculativeExecutionPlan { +public: + virtual int64_t next_execution(const Host::Ptr& current_host) { return -1; } +}; + +class NoSpeculativeExecutionPolicy : public SpeculativeExecutionPolicy { +public: + virtual SpeculativeExecutionPlan* new_plan(const std::string& keyspace, + const Request* request) { + return new NoSpeculativeExecutionPlan(); + } + + virtual SpeculativeExecutionPolicy* new_instance() { return this; } +}; + +class ConstantSpeculativeExecutionPlan : public SpeculativeExecutionPlan { +public: + ConstantSpeculativeExecutionPlan(int64_t constant_delay_ms, int count) + : constant_delay_ms_(constant_delay_ms) + , count_(count) { } + + virtual int64_t next_execution(const Host::Ptr& current_host) { + return --count_ >= 0 ? constant_delay_ms_ : -1; + } + +private: + const int64_t constant_delay_ms_; + int count_; +}; + +class ConstantSpeculativeExecutionPolicy : public SpeculativeExecutionPolicy { +public: + ConstantSpeculativeExecutionPolicy(int64_t constant_delay_ms, int max_speculative_executions) + : constant_delay_ms_(constant_delay_ms) + , max_speculative_executions_(max_speculative_executions) { } + + virtual SpeculativeExecutionPlan* new_plan(const std::string& keyspace, + const Request* request) { + return new ConstantSpeculativeExecutionPlan(constant_delay_ms_, + max_speculative_executions_); + } + + virtual SpeculativeExecutionPolicy* new_instance() { return this; } + + const int64_t constant_delay_ms_; + const int max_speculative_executions_; +}; + +} // namespace cass + +#endif diff --git a/src/statement.cpp b/src/statement.cpp index 9020057b1..3be8dfafa 100644 --- a/src/statement.cpp +++ b/src/statement.cpp @@ -115,13 +115,18 @@ CassError cass_statement_set_timestamp(CassStatement* statement, return CASS_OK; } -CassError -cass_statement_set_request_timeout(CassStatement* statement, - cass_uint64_t timeout_ms) { +CassError cass_statement_set_request_timeout(CassStatement* statement, + cass_uint64_t timeout_ms) { statement->set_request_timeout_ms(timeout_ms); return CASS_OK; } +CassError cass_statement_set_is_idempotent(CassStatement* statement, + cass_bool_t is_idempotent) { + statement->set_is_idempotent(is_idempotent == cass_true); + return CASS_OK; +} + CassError cass_statement_set_custom_payload(CassStatement* statement, const CassCustomPayload* payload) { statement->set_custom_payload(payload); @@ -253,7 +258,7 @@ int32_t Statement::copy_buffers(int version, BufferVec* bufs, RequestCallback* c std::stringstream ss; ss << "Query parameter at index " << i << " was not set"; callback->on_error(CASS_ERROR_LIB_PARAMETER_UNSET, ss.str()); - return Request::ENCODE_ERROR_PARAMETER_UNSET; + return Request::REQUEST_ERROR_PARAMETER_UNSET; } } size += bufs->back().size(); diff --git a/src/testing.cpp b/src/testing.cpp index b1153499b..2af611873 100644 --- a/src/testing.cpp +++ b/src/testing.cpp @@ -32,7 +32,7 @@ std::string get_host_from_future(CassFuture* future) { } cass::ResponseFuture* response_future = static_cast(future->from()); - return response_future->get_host_address().to_string(); + return response_future->address().to_string(); } unsigned get_connect_timeout_from_cluster(CassCluster* cluster) { From 8edc79a14e61fb04eca34496c1f702642c5a6dc3 Mon Sep 17 00:00:00 2001 From: Michael Penick Date: Thu, 29 Sep 2016 08:27:20 -0700 Subject: [PATCH 5/7] Fixed unit tests broken by specultive execution changes --- test/unit_tests/src/test_retry_policies.cpp | 48 ++++++++++----------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/test/unit_tests/src/test_retry_policies.cpp b/test/unit_tests/src/test_retry_policies.cpp index a1d8640ff..45561f593 100644 --- a/test/unit_tests/src/test_retry_policies.cpp +++ b/test/unit_tests/src/test_retry_policies.cpp @@ -39,19 +39,19 @@ void check_default(cass::RetryPolicy& policy) { // Read timeout { // Retry because data wasn't present - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, false, 0), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, false, 0), RetryDecision::RETRY, CASS_CONSISTENCY_QUORUM, true); // Return error because recieved < required - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 2, 3, false, 0), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 2, 3, false, 0), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); // Return error because a retry has already happened - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, false, 1), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, false, 1), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); @@ -60,13 +60,13 @@ void check_default(cass::RetryPolicy& policy) { // Write timeout { // Retry because of batch log failed to write - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_BATCH_LOG, 0), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_BATCH_LOG, 0), RetryDecision::RETRY, CASS_CONSISTENCY_QUORUM, true); // Return error because a retry has already happened - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_BATCH_LOG, 1), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_BATCH_LOG, 1), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); @@ -75,13 +75,13 @@ void check_default(cass::RetryPolicy& policy) { // Unavailable { // Retry with next host - check_decision(policy.on_unavailable(CASS_CONSISTENCY_QUORUM, 3, 3, 0), + check_decision(policy.on_unavailable(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, 0), RetryDecision::RETRY, CASS_CONSISTENCY_QUORUM, false); // Return error because a retry has already happened - check_decision(policy.on_unavailable(CASS_CONSISTENCY_QUORUM, 3, 3, 1), + check_decision(policy.on_unavailable(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, 1), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); @@ -103,37 +103,37 @@ BOOST_AUTO_TEST_CASE(downgrading) // Read timeout { // Retry because data wasn't present - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, false, 0), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, false, 0), RetryDecision::RETRY, CASS_CONSISTENCY_QUORUM, true); // Downgrade consistency to three - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 3, 4, false, 0), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 4, false, 0), RetryDecision::RETRY, CASS_CONSISTENCY_THREE, true); // Downgrade consistency to two - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 2, 4, false, 0), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 2, 4, false, 0), RetryDecision::RETRY, CASS_CONSISTENCY_TWO, true); // Downgrade consistency to one - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 1, 4, false, 0), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 1, 4, false, 0), RetryDecision::RETRY, CASS_CONSISTENCY_ONE, true); // Return error because no copies - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 0, 4, false, 0), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 0, 4, false, 0), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); // Return error because a retry has already happened - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, false, 1), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, false, 1), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); @@ -142,37 +142,37 @@ BOOST_AUTO_TEST_CASE(downgrading) // Write timeout { // Ignore if at least one copy - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 1, 3, CASS_WRITE_TYPE_SIMPLE, 0), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 1, 3, CASS_WRITE_TYPE_SIMPLE, 0), RetryDecision::IGNORE, CASS_CONSISTENCY_UNKNOWN, false); // Ignore if at least one copy - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 1, 3, CASS_WRITE_TYPE_BATCH, 0), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 1, 3, CASS_WRITE_TYPE_BATCH, 0), RetryDecision::IGNORE, CASS_CONSISTENCY_UNKNOWN, false); // Return error if no copies - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 0, 3, CASS_WRITE_TYPE_SIMPLE, 0), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 0, 3, CASS_WRITE_TYPE_SIMPLE, 0), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); // Downgrade consistency to two - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 2, 3, CASS_WRITE_TYPE_UNLOGGED_BATCH, 0), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 2, 3, CASS_WRITE_TYPE_UNLOGGED_BATCH, 0), RetryDecision::RETRY, CASS_CONSISTENCY_TWO, true); // Retry because of batch log failed to write - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_BATCH_LOG, 0), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_BATCH_LOG, 0), RetryDecision::RETRY, CASS_CONSISTENCY_QUORUM, true); // Return error because a retry has already happened - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_BATCH_LOG, 1), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_BATCH_LOG, 1), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); @@ -181,13 +181,13 @@ BOOST_AUTO_TEST_CASE(downgrading) // Unavailable { // Retry with next host - check_decision(policy.on_unavailable(CASS_CONSISTENCY_QUORUM, 3, 2, 0), + check_decision(policy.on_unavailable(NULL, CASS_CONSISTENCY_QUORUM, 3, 2, 0), RetryDecision::RETRY, CASS_CONSISTENCY_TWO, true); // Return error because a retry has already happened - check_decision(policy.on_unavailable(CASS_CONSISTENCY_QUORUM, 3, 3, 1), + check_decision(policy.on_unavailable(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, 1), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); @@ -200,17 +200,17 @@ BOOST_AUTO_TEST_CASE(fallthrough) // Always fail - check_decision(policy.on_read_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, false, 0), + check_decision(policy.on_read_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, false, 0), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); - check_decision(policy.on_write_timeout(CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_SIMPLE, 0), + check_decision(policy.on_write_timeout(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, CASS_WRITE_TYPE_SIMPLE, 0), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); - check_decision(policy.on_unavailable(CASS_CONSISTENCY_QUORUM, 3, 3, 0), + check_decision(policy.on_unavailable(NULL, CASS_CONSISTENCY_QUORUM, 3, 3, 0), RetryDecision::RETURN_ERROR, CASS_CONSISTENCY_UNKNOWN, false); From ecdd469938de481a30905f286110b31f72fa09f0 Mon Sep 17 00:00:00 2001 From: Michael Penick Date: Fri, 30 Sep 2016 16:09:36 -0700 Subject: [PATCH 6/7] Speculative execution fixes --- examples/perf/perf.c | 6 +- src/connection.cpp | 139 +++++++++--------- src/connection.hpp | 20 +-- src/control_connection.hpp | 11 +- src/future.cpp | 11 +- src/io_worker.cpp | 2 +- src/pool.cpp | 33 ++--- src/request_callback.cpp | 79 ++++++---- src/request_callback.hpp | 95 ++++++------ src/request_handler.cpp | 89 +++++------ src/request_handler.hpp | 51 ++++--- src/session.cpp | 6 +- .../src/test_consistency.cpp | 7 +- test/integration_tests/src/test_future.cpp | 4 +- test/integration_tests/src/test_ssl.cpp | 5 +- 15 files changed, 288 insertions(+), 270 deletions(-) diff --git a/examples/perf/perf.c b/examples/perf/perf.c index 8eb8b7d3e..539b1129f 100644 --- a/examples/perf/perf.c +++ b/examples/perf/perf.c @@ -198,6 +198,8 @@ void insert_into_perf(CassSession* session, const char* query, const CassPrepare statement = cass_statement_new(query, 5); } + cass_statement_set_is_idempotent(statement, cass_true); + cass_uuid_gen_time(uuid_gen, &id); cass_statement_bind_uuid(statement, 0, id); cass_statement_bind_string(statement, 1, big_string); @@ -322,8 +324,10 @@ int main(int argc, char* argv[]) { return -1; } + execute_query(session, "DROP KEYSPACE stress"); + execute_query(session, "CREATE KEYSPACE IF NOT EXISTS stress WITH " - "replication = { 'class': 'SimpleStrategy', 'replication_factor': '1'}"); + "replication = { 'class': 'SimpleStrategy', 'replication_factor': '3'}"); execute_query(session, "CREATE TABLE IF NOT EXISTS stress.songs (id uuid PRIMARY KEY, " "title text, album text, artist text, " diff --git a/src/connection.cpp b/src/connection.cpp index c9ab928a8..fe3869490 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -68,47 +68,43 @@ static void cleanup_pending_callbacks(List* pending) { switch (callback->state()) { case RequestCallback::REQUEST_STATE_NEW: case RequestCallback::REQUEST_STATE_FINISHED: - assert(false && "Request state is invalid in cleanup"); - break; - case RequestCallback::REQUEST_STATE_CANCELLED: - callback->finish(); + assert(false && "Request state is invalid in cleanup"); break; - case RequestCallback::REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE: - case RequestCallback::REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING: - callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); - callback->finish(); + case RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE: + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); + // Use the response saved in the read callback + callback->on_set(callback->read_before_write_response()); break; case RequestCallback::REQUEST_STATE_WRITING: case RequestCallback::REQUEST_STATE_READING: - case RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE: + callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); if (callback->request()->is_idempotent()) { - callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); - callback->finish_with_retry(true); + callback->on_retry(true); } else { - callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); callback->on_error(CASS_ERROR_LIB_REQUEST_TIMED_OUT, "Request timed out"); - callback->finish(); } break; - case RequestCallback::REQUEST_STATE_RETRY_WRITE_OUTSTANDING: - callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); - callback->finish_with_retry(false); + case RequestCallback::REQUEST_STATE_CANCELLED_WRITING: + case RequestCallback::REQUEST_STATE_CANCELLED_READING: + case RequestCallback::REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE: + callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); + callback->on_cancel(); break; } + + callback->dec_ref(); } } -Connection::StartupCallback::StartupCallback(Connection* connection, const Request::ConstPtr& request) - : SimpleRequestCallback(connection->loop(), - connection->config().request_timeout_ms(), - request) { } +Connection::StartupCallback::StartupCallback(const Request::ConstPtr& request) + : SimpleRequestCallback(request) { } -void Connection::StartupCallback::on_set(ResponseMessage* response) { +void Connection::StartupCallback::on_internal_set(ResponseMessage* response) { switch (response->opcode()) { case CQL_OPCODE_SUPPORTED: connection()->on_supported(response); @@ -165,15 +161,15 @@ void Connection::StartupCallback::on_set(ResponseMessage* response) { } } -void Connection::StartupCallback::on_error(CassError code, - const std::string& message) { +void Connection::StartupCallback::on_internal_error(CassError code, + const std::string& message) { std::ostringstream ss; ss << "Error: '" << message << "' (0x" << std::hex << std::uppercase << std::setw(8) << std::setfill('0') << code << ")"; connection()->notify_error(ss.str()); } -void Connection::StartupCallback::on_timeout() { +void Connection::StartupCallback::on_internal_timeout() { if (!connection()->is_closing()) { connection()->notify_error("Timed out", CONNECTION_ERROR_TIMEOUT); } @@ -192,25 +188,23 @@ void Connection::StartupCallback::on_result_response(ResponseMessage* response) } } -Connection::HeartbeatCallback::HeartbeatCallback(Connection* connection) - : SimpleRequestCallback(connection->loop(), - connection->config().request_timeout_ms(), - Request::ConstPtr(new OptionsRequest())) { } +Connection::HeartbeatCallback::HeartbeatCallback() + : SimpleRequestCallback(Request::ConstPtr(new OptionsRequest())) { } -void Connection::HeartbeatCallback::on_set(ResponseMessage* response) { +void Connection::HeartbeatCallback::on_internal_set(ResponseMessage* response) { LOG_TRACE("Heartbeat completed on host %s", connection()->address_string().c_str()); connection()->heartbeat_outstanding_ = false; } -void Connection::HeartbeatCallback::on_error(CassError code, const std::string& message) { +void Connection::HeartbeatCallback::on_internal_error(CassError code, const std::string& message) { LOG_WARN("An error occurred on host %s during a heartbeat request: %s", connection()->address_string().c_str(), message.c_str()); connection()->heartbeat_outstanding_ = false; } -void Connection::HeartbeatCallback::on_timeout() { +void Connection::HeartbeatCallback::on_internal_timeout() { LOG_WARN("Heartbeat request timed out on host %s", connection()->address_string().c_str()); connection()->heartbeat_outstanding_ = false; @@ -310,6 +304,7 @@ int32_t Connection::internal_write(const RequestCallback::Ptr& callback, bool fl int32_t request_size = pending_write->write(callback.get()); if (request_size < 0) { stream_manager_.release(stream); + switch (request_size) { case Request::REQUEST_ERROR_BATCH_WITH_NAMED_VALUES: case Request::REQUEST_ERROR_PARAMETER_UNSET: @@ -318,10 +313,11 @@ int32_t Connection::internal_write(const RequestCallback::Ptr& callback, bool fl default: callback->on_error(CASS_ERROR_LIB_MESSAGE_ENCODE, - "Operation unsupported by this protocol version"); + "Operation unsupported by this protocol version"); break; } - callback->finish(); + + callback->dec_ref(); return request_size; } @@ -335,8 +331,10 @@ int32_t Connection::internal_write(const RequestCallback::Ptr& callback, bool fl set_state(CONNECTION_STATE_OVERWHELMED); } - LOG_TRACE("Sending message type %s with stream %d", - opcode_to_string(callback->request()->opcode()).c_str(), stream); + LOG_TRACE("Sending message type %s with stream %d on host %s", + opcode_to_string(callback->request()->opcode()).c_str(), + stream, + address_string().c_str()); callback->set_state(RequestCallback::REQUEST_STATE_WRITING); if (flush_immediately) { @@ -500,29 +498,33 @@ void Connection::consume(char* input, size_t size) { switch (callback->state()) { case RequestCallback::REQUEST_STATE_READING: - maybe_set_keyspace(response.get()); pending_reads_.remove(callback.get()); callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); + maybe_set_keyspace(response.get()); callback->on_set(response.get()); - callback->finish(); + callback->dec_ref(); break; case RequestCallback::REQUEST_STATE_WRITING: // There are cases when the read callback will happen // before the write callback. If this happens we have - // to allow the write callback to cleanup. - maybe_set_keyspace(response.get()); + // to allow the write callback to finish the request. callback->set_state(RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE); - callback->on_set(response.get()); + // Save the response for the write callback + callback->set_read_before_write_response(response.release()); // Transfer ownership break; - case RequestCallback::REQUEST_STATE_CANCELLED: + case RequestCallback::REQUEST_STATE_CANCELLED_READING: pending_reads_.remove(callback.get()); - callback->finish(); + callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); + callback->on_cancel(); + callback->dec_ref(); break; - case RequestCallback::REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING: - // We must wait for the write callback before we can do the cleanup + case RequestCallback::REQUEST_STATE_CANCELLED_WRITING: + // There are cases when the read callback will happen + // before the write callback. If this happens we have + // to allow the write callback to finish the request. callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE); break; @@ -759,8 +761,7 @@ void Connection::on_read_ssl(uv_stream_t* client, ssize_t nread, const uv_buf_t* void Connection::on_connected() { internal_write(RequestCallback::Ptr( - new StartupCallback(this, - Request::ConstPtr( + new StartupCallback(Request::ConstPtr( new OptionsRequest())))); } @@ -780,8 +781,7 @@ void Connection::on_auth_challenge(const AuthResponseRequest* request, return; } internal_write(RequestCallback::Ptr( - new StartupCallback(this, - Request::ConstPtr( + new StartupCallback(Request::ConstPtr( new AuthResponseRequest(response, request->auth()))))); } @@ -798,8 +798,7 @@ void Connection::on_ready() { if (state_ == CONNECTION_STATE_CONNECTED && listener_->event_types() != 0) { set_state(CONNECTION_STATE_REGISTERING_EVENTS); internal_write(RequestCallback::Ptr( - new StartupCallback(this, - Request::ConstPtr( + new StartupCallback(Request::ConstPtr( new RegisterRequest(listener_->event_types()))))); return; } @@ -808,8 +807,7 @@ void Connection::on_ready() { notify_ready(); } else { internal_write(RequestCallback::Ptr( - new StartupCallback(this, - Request::ConstPtr( + new StartupCallback(Request::ConstPtr( new QueryRequest("USE \"" + keyspace_ + "\""))))); } } @@ -826,8 +824,7 @@ void Connection::on_supported(ResponseMessage* response) { (void)supported; internal_write(RequestCallback::Ptr( - new StartupCallback(this, - Request::ConstPtr( + new StartupCallback(Request::ConstPtr( new StartupRequest())))); } @@ -896,8 +893,7 @@ void Connection::send_credentials(const std::string& class_name) { V1Authenticator::Credentials credentials; v1_auth->get_credentials(&credentials); internal_write(RequestCallback::Ptr( - new StartupCallback(this, - Request::ConstPtr( + new StartupCallback(Request::ConstPtr( new CredentialsRequest(credentials))))); } else { send_initial_auth_response(class_name); @@ -915,8 +911,7 @@ void Connection::send_initial_auth_response(const std::string& class_name) { return; } internal_write(RequestCallback::Ptr( - new StartupCallback(this, - Request::ConstPtr( + new StartupCallback(Request::ConstPtr( new AuthResponseRequest(response, auth))))); } } @@ -933,7 +928,7 @@ void Connection::on_heartbeat(Timer* timer) { Connection* connection = static_cast(timer->data()); if (!connection->heartbeat_outstanding_) { - if (!connection->internal_write(RequestCallback::Ptr(new HeartbeatCallback(connection)))) { + if (!connection->internal_write(RequestCallback::Ptr(new HeartbeatCallback()))) { // Recycling only this connection with a timeout error. This is unlikely and // it means the connection ran out of stream IDs as a result of requests // that never returned and as a result timed out. @@ -1021,34 +1016,32 @@ void Connection::PendingWriteBase::on_write(uv_write_t* req, int status) { connection->stream_manager_.release(callback->stream()); callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); callback->on_error(CASS_ERROR_LIB_WRITE_ERROR, - "Unable to write to socket"); - callback->finish(); + "Unable to write to socket"); + callback->dec_ref(); } break; case RequestCallback::REQUEST_STATE_READ_BEFORE_WRITE: // The read callback happened before the write callback - // returned. This is now responsible for cleanup. + // returned. This is now responsible for finishing the request. callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); - callback->finish(); + // Use the response saved in the read callback + connection->maybe_set_keyspace(callback->read_before_write_response()); + callback->on_set(callback->read_before_write_response()); + callback->dec_ref(); break; - case RequestCallback::REQUEST_STATE_RETRY_WRITE_OUTSTANDING: - callback->set_state(RequestCallback::REQUEST_STATE_FINISHED); - callback->finish_with_retry(false); + case RequestCallback::REQUEST_STATE_CANCELLED_WRITING: + callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED_READING); + connection->pending_reads_.add_to_back(callback.get()); break; case RequestCallback::REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE: // The read callback happened before the write callback // returned. This is now responsible for cleanup. callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); - callback->finish(); - break; - - case RequestCallback::REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING: - // The read may still come back, handle cleanup there - callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); - connection->pending_reads_.add_to_back(callback.get()); + callback->on_cancel(); + callback->dec_ref(); break; default: diff --git a/src/connection.hpp b/src/connection.hpp index 955d882ff..6e28bb1d4 100644 --- a/src/connection.hpp +++ b/src/connection.hpp @@ -111,7 +111,6 @@ class Connection { uv_loop_t* loop() { return loop_; } const Config& config() const { return config_; } - Metrics* metrics() { return metrics_; } const Address& address() const { return host_->address(); } const std::string& address_string() const { return host_->address_string(); } const std::string& keyspace() const { return keyspace_; } @@ -166,23 +165,24 @@ class Connection { class StartupCallback : public SimpleRequestCallback { public: - StartupCallback(Connection* connection, const Request::ConstPtr& request); - - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); + StartupCallback(const Request::ConstPtr& request); private: + virtual void on_internal_set(ResponseMessage* response); + virtual void on_internal_error(CassError code, const std::string& message); + virtual void on_internal_timeout(); + void on_result_response(ResponseMessage* response); }; class HeartbeatCallback : public SimpleRequestCallback { public: - HeartbeatCallback(Connection* connection); + HeartbeatCallback(); - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); + private: + virtual void on_internal_set(ResponseMessage* response); + virtual void on_internal_error(CassError code, const std::string& message); + virtual void on_internal_timeout(); }; class PendingWriteBase : public List::Node { diff --git a/src/control_connection.hpp b/src/control_connection.hpp index fc23a82df..1ba4542b4 100644 --- a/src/control_connection.hpp +++ b/src/control_connection.hpp @@ -125,14 +125,13 @@ class ControlConnection : public Connection::Listener { ControlConnection* control_connection, ResponseCallback response_callback, const T& data) - : SimpleRequestCallback(control_connection->connection_->loop(), - control_connection->connection_->config().request_timeout_ms(), - request) + : SimpleRequestCallback(request) , control_connection_(control_connection) , response_callback_(response_callback) , data_(data) { } - virtual void on_set(ResponseMessage* response) { + private: + virtual void on_internal_set(ResponseMessage* response) { Response* response_body = response->response_body().get(); if (control_connection_->handle_query_invalid_response(response_body)) { return; @@ -140,11 +139,11 @@ class ControlConnection : public Connection::Listener { response_callback_(control_connection_, data_, response_body); } - virtual void on_error(CassError code, const std::string& message) { + virtual void on_internal_error(CassError code, const std::string& message) { control_connection_->handle_query_failure(code, message); } - virtual void on_timeout() { + virtual void on_internal_timeout() { control_connection_->handle_query_timeout(); } diff --git a/src/future.cpp b/src/future.cpp index 42e934cd1..6e17e9e71 100644 --- a/src/future.cpp +++ b/src/future.cpp @@ -54,12 +54,15 @@ const CassResult* cass_future_get_result(CassFuture* future) { return NULL; } - cass::SharedRefPtr result( + cass::Response::Ptr response( static_cast(future->from())->response()); - if (!result) return NULL; + if (!response || response->opcode() == CQL_OPCODE_ERROR) { + return NULL; + } - result->inc_ref(); - return CassResult::to(result.get()); + response->inc_ref(); + return CassResult::to( + static_cast(response.get())); } const CassPrepared* cass_future_get_prepared(CassFuture* future) { diff --git a/src/io_worker.cpp b/src/io_worker.cpp index 9a297cc8a..f2ab630c9 100644 --- a/src/io_worker.cpp +++ b/src/io_worker.cpp @@ -291,7 +291,7 @@ void IOWorker::on_execute(uv_async_t* async) { io_worker->pending_request_count_++; request_handler->start_request(io_worker); SpeculativeExecution::Ptr speculative_execution(new SpeculativeExecution(request_handler, - request_handler->first_host())); + request_handler->current_host())); speculative_execution->execute(); } else { io_worker->state_ = IO_WORKER_STATE_CLOSING; diff --git a/src/pool.cpp b/src/pool.cpp index a9bad9301..1482ce7e6 100644 --- a/src/pool.cpp +++ b/src/pool.cpp @@ -36,30 +36,26 @@ static bool least_busy_comp(Connection* a, Connection* b) { class SetKeyspaceCallback : public SimpleRequestCallback { public: - SetKeyspaceCallback(Connection* connection, - const std::string& keyspace, - const SpeculativeExecution::Ptr& speculative_execution); - - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); + SetKeyspaceCallback(const std::string& keyspace, + const SpeculativeExecution::Ptr& speculative_execution); private: + virtual void on_internal_set(ResponseMessage* response); + virtual void on_internal_error(CassError code, const std::string& message); + virtual void on_internal_timeout(); + void on_result_response(ResponseMessage* response); private: SpeculativeExecution::Ptr speculative_execution_; }; -SetKeyspaceCallback::SetKeyspaceCallback(Connection* connection, - const std::string& keyspace, - const SpeculativeExecution::Ptr& speculative_execution) - : SimpleRequestCallback(connection->loop(), - connection->config().request_timeout_ms(), - Request::ConstPtr(new QueryRequest("USE \"" + keyspace + "\""))) +SetKeyspaceCallback::SetKeyspaceCallback(const std::string& keyspace, + const SpeculativeExecution::Ptr& speculative_execution) + : SimpleRequestCallback(Request::ConstPtr(new QueryRequest("USE \"" + keyspace + "\""))) , speculative_execution_(speculative_execution) { } -void SetKeyspaceCallback::on_set(ResponseMessage* response) { +void SetKeyspaceCallback::on_internal_set(ResponseMessage* response) { switch (response->opcode()) { case CQL_OPCODE_RESULT: on_result_response(response); @@ -74,13 +70,13 @@ void SetKeyspaceCallback::on_set(ResponseMessage* response) { } } -void SetKeyspaceCallback::on_error(CassError code, const std::string& message) { +void SetKeyspaceCallback::on_internal_error(CassError code, const std::string& message) { connection()->defunct(); speculative_execution_->on_error(CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE, "Unable to set keyspace"); } -void SetKeyspaceCallback::on_timeout() { +void SetKeyspaceCallback::on_internal_timeout() { speculative_execution_->retry_next_host(); } @@ -257,8 +253,9 @@ bool Pool::write(Connection* connection, const SpeculativeExecution::Ptr& specul static_cast(connection), static_cast(this)); if (!connection->write(RequestCallback::Ptr( - new SetKeyspaceCallback(connection, *io_worker_->keyspace(), - speculative_execution)), false)) { + new SetKeyspaceCallback(*io_worker_->keyspace(), + speculative_execution)), + false)) { return false; } } diff --git a/src/request_callback.cpp b/src/request_callback.cpp index f73b6419a..126db66ce 100644 --- a/src/request_callback.cpp +++ b/src/request_callback.cpp @@ -27,6 +27,12 @@ namespace cass { +void RequestCallback::start(Connection* connection, int stream) { + connection_ = connection; + stream_ = stream; + on_start(); +} + int32_t RequestCallback::encode(int version, int flags, BufferVec* bufs) { if (version < 1 || version > 4) { return Request::REQUEST_ERROR_UNSUPPORTED_PROTOCOL; @@ -69,6 +75,8 @@ int32_t RequestCallback::encode(int version, int flags, BufferVec* bufs) { } void RequestCallback::set_state(RequestCallback::State next_state) { + state_history_.push_back(state_); + switch (state_) { case REQUEST_STATE_NEW: if (next_state == REQUEST_STATE_NEW || @@ -86,24 +94,24 @@ void RequestCallback::set_state(RequestCallback::State next_state) { next_state == REQUEST_STATE_FINISHED) { state_ = next_state; } else if (next_state == REQUEST_STATE_CANCELLED) { - state_ = REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING; + state_ = REQUEST_STATE_CANCELLED_WRITING; } else { assert(false && "Invalid request state after writing"); } break; case REQUEST_STATE_READING: - if(next_state == REQUEST_STATE_FINISHED || - next_state == REQUEST_STATE_CANCELLED) { + if(next_state == REQUEST_STATE_FINISHED) { state_ = next_state; + } else if (next_state == REQUEST_STATE_CANCELLED) { + state_ = REQUEST_STATE_CANCELLED_READING; } else { assert(false && "Invalid request state after reading"); } break; case REQUEST_STATE_READ_BEFORE_WRITE: - if (next_state == REQUEST_STATE_RETRY_WRITE_OUTSTANDING || - next_state == REQUEST_STATE_FINISHED) { + if (next_state == REQUEST_STATE_FINISHED) { state_ = next_state; } else if (next_state == REQUEST_STATE_CANCELLED) { state_ = REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE; @@ -112,16 +120,6 @@ void RequestCallback::set_state(RequestCallback::State next_state) { } break; - case REQUEST_STATE_RETRY_WRITE_OUTSTANDING: - if (next_state == REQUEST_STATE_FINISHED) { - state_ = next_state; - } else if (next_state == REQUEST_STATE_CANCELLED) { - state_ = REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING; - } else { - assert(false && "Invalid request state after retry"); - } - break; - case REQUEST_STATE_FINISHED: if (next_state == REQUEST_STATE_NEW || next_state == REQUEST_STATE_CANCELLED) { @@ -146,8 +144,17 @@ void RequestCallback::set_state(RequestCallback::State next_state) { assert(false && "Invalid request state after cancelled (read before write)"); } - case REQUEST_STATE_CANCELLED_WRITE_OUTSTANDING: + case REQUEST_STATE_CANCELLED_READING: + if (next_state == REQUEST_STATE_CANCELLED) { + state_ = next_state; + } else { + assert(false && "Invalid request state after cancelled (read outstanding)"); + } + break; + + case REQUEST_STATE_CANCELLED_WRITING: if (next_state == REQUEST_STATE_CANCELLED || + next_state == REQUEST_STATE_CANCELLED_READING || next_state == REQUEST_STATE_CANCELLED_READ_BEFORE_WRITE) { state_ = next_state; } else { @@ -187,42 +194,56 @@ void MultipleRequestCallback::execute_query(const std::string& index, const std: MultipleRequestCallback::InternalCallback::InternalCallback(const MultipleRequestCallback::Ptr& parent, const Request::ConstPtr& request, const std::string& index) - : SimpleRequestCallback(parent->connection()->loop(), - parent->connection()->config().request_timeout_ms(), - request) + : SimpleRequestCallback(request) , parent_(parent) , index_(index) { } -void MultipleRequestCallback::InternalCallback::on_set(ResponseMessage* response) { +void MultipleRequestCallback::InternalCallback::on_internal_set(ResponseMessage* response) { parent_->responses_[index_] = response->response_body(); if (--parent_->remaining_ == 0 && !parent_->has_errors_or_timeouts_) { parent_->on_set(parent_->responses_); } } -void MultipleRequestCallback::InternalCallback::on_error(CassError code, const std::string& message) { +void MultipleRequestCallback::InternalCallback::on_internal_error(CassError code, + const std::string& message) { if (!parent_->has_errors_or_timeouts_) { parent_->on_error(code, message); } parent_->has_errors_or_timeouts_ = true; } -void MultipleRequestCallback::InternalCallback::on_timeout() { +void MultipleRequestCallback::InternalCallback::on_internal_timeout() { if (!parent_->has_errors_or_timeouts_) { parent_->on_timeout(); } parent_->has_errors_or_timeouts_ = true; } -SimpleRequestCallback::SimpleRequestCallback(uv_loop_t* loop, - uint64_t request_timeout_ms, - const Request::ConstPtr& request) - : RequestCallback() - , request_(request) { - timer_.start(loop, - request->request_timeout_ms(request_timeout_ms), +void SimpleRequestCallback::on_start() { + timer_.start(connection()->loop(), + request()->request_timeout_ms(connection()->config().request_timeout_ms()), this, on_timeout); } +void SimpleRequestCallback::on_set(ResponseMessage* response) { + timer_.stop(); + on_internal_set(response); +} + +void SimpleRequestCallback::on_error(CassError code, const std::string& message) { + timer_.stop(); + on_internal_error(code, message); +} + +void SimpleRequestCallback::on_retry(bool use_next_host) { + timer_.stop(); + on_internal_timeout(); // Retries are unhandled so timeout +} + +void SimpleRequestCallback::on_cancel() { + timer_.stop(); +} + } // namespace cass diff --git a/src/request_callback.hpp b/src/request_callback.hpp index f78a6b497..9fba6e0d2 100644 --- a/src/request_callback.hpp +++ b/src/request_callback.hpp @@ -48,11 +48,11 @@ class RequestCallback : public RefCounted, public List, public Listtimestamp(); } @@ -106,20 +89,31 @@ class RequestCallback : public RefCounted, public List read_before_write_response_; + + typedef std::vector StateVec; + StateVec state_history_; private: DISALLOW_COPY_AND_ASSIGN(RequestCallback); @@ -127,33 +121,31 @@ class RequestCallback : public RefCounted, public List(timer->data()); callback->set_state(RequestCallback::REQUEST_STATE_CANCELLED); - callback->on_timeout(); + callback->on_internal_timeout(); } private: @@ -193,9 +185,10 @@ class MultipleRequestCallback : public RefCounted { const Request::ConstPtr& request, const std::string& index); - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); + private: + virtual void on_internal_set(ResponseMessage* response); + virtual void on_internal_error(CassError code, const std::string& message); + virtual void on_internal_timeout(); private: MultipleRequestCallback::Ptr parent_; diff --git a/src/request_handler.cpp b/src/request_handler.cpp index b431c780e..44c62aa67 100644 --- a/src/request_handler.cpp +++ b/src/request_handler.cpp @@ -37,20 +37,19 @@ namespace cass { class PrepareCallback : public SimpleRequestCallback { public: PrepareCallback(const std::string& query, SpeculativeExecution* speculative_execution) - : SimpleRequestCallback(speculative_execution->connection()->loop(), - speculative_execution->connection()->config().request_timeout_ms(), - Request::ConstPtr(new PrepareRequest(query))) + : SimpleRequestCallback(Request::ConstPtr(new PrepareRequest(query))) , speculative_execution_(speculative_execution) { } - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual void on_timeout(); +private: + virtual void on_internal_set(ResponseMessage* response); + virtual void on_internal_error(CassError code, const std::string& message); + virtual void on_internal_timeout(); private: SpeculativeExecution::Ptr speculative_execution_; }; -void PrepareCallback::on_set(ResponseMessage* response) { +void PrepareCallback::on_internal_set(ResponseMessage* response) { switch (response->opcode()) { case CQL_OPCODE_RESULT: { ResultResponse* result = @@ -69,11 +68,11 @@ void PrepareCallback::on_set(ResponseMessage* response) { } } -void PrepareCallback::on_error(CassError code, const std::string& message) { +void PrepareCallback::on_internal_error(CassError code, const std::string& message) { speculative_execution_->retry_next_host(); } -void PrepareCallback::on_timeout() { +void PrepareCallback::on_internal_timeout() { speculative_execution_->retry_next_host(); } @@ -103,6 +102,7 @@ void RequestHandler::start_request(IOWorker* io_worker) { void RequestHandler::set_response(const Host::Ptr& host, const Response::Ptr& response) { if (future_->set_response(host->address(), response)) { + io_worker_->metrics()->record_request(uv_hrtime() - start_time_ns_); stop_request(); } } @@ -139,6 +139,8 @@ void RequestHandler::set_error_with_error_response(const Host::Ptr& host, void RequestHandler::on_timeout(Timer* timer) { RequestHandler* request_handler = static_cast(timer->data()); + LOG_DEBUG("Request timed out on host %s", + request_handler->current_host_->address_string().c_str()); request_handler->set_error(CASS_ERROR_LIB_REQUEST_TIMED_OUT, "Request timed out"); } @@ -174,9 +176,16 @@ void SpeculativeExecution::on_execute(Timer* timer) { speculative_execution->execute(); } +void SpeculativeExecution::on_start() { + start_time_ns_ = uv_hrtime(); +} + void SpeculativeExecution::on_set(ResponseMessage* response) { assert(connection() != NULL); assert(current_host_ && "Tried to set on a non-existent host"); + + return_connection(); + switch (response->opcode()) { case CQL_OPCODE_RESULT: on_result_response(response); @@ -192,6 +201,8 @@ void SpeculativeExecution::on_set(ResponseMessage* response) { } void SpeculativeExecution::on_error(CassError code, const std::string& message) { + return_connection(); + // Handle recoverable errors by retrying with the next host if (code == CASS_ERROR_LIB_WRITE_ERROR || code == CASS_ERROR_LIB_UNABLE_TO_SET_KEYSPACE) { @@ -201,15 +212,7 @@ void SpeculativeExecution::on_error(CassError code, const std::string& message) } } -void SpeculativeExecution::on_start() { - start_time_ns_ = uv_hrtime(); -} - -void SpeculativeExecution::on_finish() { - return_connection(); -} - -void SpeculativeExecution::on_finish_with_retry(bool use_next_host) { +void SpeculativeExecution::on_retry(bool use_next_host) { return_connection(); if (use_next_host) { @@ -219,13 +222,8 @@ void SpeculativeExecution::on_finish_with_retry(bool use_next_host) { } } -void SpeculativeExecution::start_pending_request(Pool* pool, Timer::Callback cb) { - pool_ = pool; - pending_request_timer_.start(pool->loop(), pool->config().connect_timeout_ms(), this, cb); -} - -void SpeculativeExecution::stop_pending_request() { - pending_request_timer_.stop(); +void SpeculativeExecution::on_cancel() { + return_connection(); } void SpeculativeExecution::retry_current_host() { @@ -244,6 +242,15 @@ void SpeculativeExecution::retry_next_host() { retry_current_host(); } +void SpeculativeExecution::start_pending_request(Pool* pool, Timer::Callback cb) { + pool_ = pool; + pending_request_timer_.start(pool->loop(), pool->config().connect_timeout_ms(), this, cb); +} + +void SpeculativeExecution::stop_pending_request() { + pending_request_timer_.stop(); +} + void SpeculativeExecution::execute() { if (request()->is_idempotent()) { request_handler_->schedule_next_execution(current_host_); @@ -269,15 +276,19 @@ void SpeculativeExecution::cancel() { void SpeculativeExecution::on_result_response(ResponseMessage* response) { ResultResponse* result = static_cast(response->response_body().get()); + switch (result->kind()) { case CASS_RESULT_KIND_ROWS: + current_host_->update_latency(uv_hrtime() - start_time_ns_); + // Execute statements with no metadata get their metadata from // result_metadata() returned when the statement was prepared. if (request()->opcode() == CQL_OPCODE_EXECUTE && result->no_metadata()) { const ExecuteRequest* execute = static_cast(request()); if (!execute->skip_metadata()) { // Caused by a race condition in C* 2.1.0 - on_error(CASS_ERROR_LIB_UNEXPECTED_RESPONSE, "Expected metadata but no metadata in response (see CASSANDRA-8054)"); + on_error(CASS_ERROR_LIB_UNEXPECTED_RESPONSE, + "Expected metadata but no metadata in response (see CASSANDRA-8054)"); return; } result->set_metadata(execute->prepared()->result()->result_metadata().get()); @@ -288,8 +299,8 @@ void SpeculativeExecution::on_result_response(ResponseMessage* response) { case CASS_RESULT_KIND_SCHEMA_CHANGE: { SchemaChangeCallback::Ptr schema_change_handler( new SchemaChangeCallback(connection(), - Ptr(this), - response->response_body())); + Ptr(this), + response->response_body())); schema_change_handler->execute(); break; } @@ -391,13 +402,10 @@ void SpeculativeExecution::on_error_response(ResponseMessage* response) { case RetryPolicy::RetryDecision::RETRY: set_consistency(decision.retry_consistency()); - if (!decision.retry_current_host()) { - next_host(); - } - if (state() == REQUEST_STATE_FINISHED) { + if (decision.retry_current_host()) { retry_current_host(); } else { - set_state(REQUEST_STATE_RETRY_WRITE_OUTSTANDING); + retry_next_host(); } num_retries_++; break; @@ -436,14 +444,17 @@ void SpeculativeExecution::on_error_unprepared(ErrorResponse* error) { } } +void SpeculativeExecution::return_connection() { + if (pool_ != NULL && connection() != NULL) { + pool_->return_connection(connection()); + } +} + bool SpeculativeExecution::is_host_up(const Address& address) const { return request_handler_->io_worker()->is_host_up(address); } void SpeculativeExecution::set_response(const Response::Ptr& response) { - uint64_t elapsed = uv_hrtime() - start_time_ns_; - current_host_->update_latency(elapsed); - connection()->metrics()->record_request(elapsed); request_handler_->set_response(current_host_, response); } @@ -456,10 +467,4 @@ void SpeculativeExecution::set_error_with_error_response(const Response::Ptr& er request_handler_->set_error_with_error_response(current_host_, error, code, message); } -void SpeculativeExecution::return_connection() { - if (pool_ != NULL && connection() != NULL) { - pool_->return_connection(connection()); - } -} - } // namespace cass diff --git a/src/request_handler.hpp b/src/request_handler.hpp index 07aeeb431..8ab30abec 100644 --- a/src/request_handler.hpp +++ b/src/request_handler.hpp @@ -115,7 +115,8 @@ class RequestHandler : public RefCounted { , future_(future) , retry_policy_(retry_policy) , io_worker_(NULL) - , running_executions_(0) { } + , running_executions_(0) + , start_time_ns_(uv_hrtime()) { } const Request* request() const { return request_.get(); } @@ -126,17 +127,17 @@ class RequestHandler : public RefCounted { RetryPolicy* retry_policy() { return retry_policy_; } - void set_query_plan(QueryPlan* query_plan) { - query_plan_.reset(query_plan); - first_host_ = next_host(); - } + void set_query_plan(QueryPlan* query_plan) { query_plan_.reset(query_plan); } void set_execution_plan(SpeculativeExecutionPlan* execution_plan) { execution_plan_.reset(execution_plan); } - const Host::Ptr& first_host() const { return first_host_; } - const Host::Ptr next_host() { return query_plan_->compute_next(); } + const Host::Ptr& current_host() const { return current_host_; } + const Host::Ptr& next_host() { + current_host_ = query_plan_->compute_next(); + return current_host_; + } IOWorker* io_worker() { return io_worker_; } @@ -158,7 +159,7 @@ class RequestHandler : public RefCounted { friend class SpeculativeExecution; void add_execution(SpeculativeExecution* speculative_execution); - void schedule_next_execution(const Host::Ptr& first_host); + void schedule_next_execution(const Host::Ptr& current_host); void stop_request(); private: @@ -170,12 +171,13 @@ class RequestHandler : public RefCounted { RetryPolicy* retry_policy_; ScopedPtr query_plan_; ScopedPtr execution_plan_; - Host::Ptr first_host_; + Host::Ptr current_host_; IOWorker* io_worker_; Timer timer_; int running_executions_; SpeculativeExecutionVec speculative_executions_; Request::EncodingCache encoding_cache_; + uint64_t start_time_ns_; }; class SpeculativeExecution : public RequestCallback { @@ -185,9 +187,6 @@ class SpeculativeExecution : public RequestCallback { SpeculativeExecution(const RequestHandler::Ptr& request_handler, const Host::Ptr& current_host = Host::Ptr()); - virtual void on_set(ResponseMessage* response); - virtual void on_error(CassError code, const std::string& message); - virtual const Request* request() const { return request_handler_->request(); } virtual int64_t timestamp() const { return request_handler_->timestamp(); } virtual Request::EncodingCache* encoding_cache() { return request_handler_->encoding_cache(); } @@ -198,22 +197,33 @@ class SpeculativeExecution : public RequestCallback { const Host::Ptr& current_host() const { return current_host_; } void next_host() { current_host_ = request_handler_->next_host(); } - void start_pending_request(Pool* pool, Timer::Callback cb); - void stop_pending_request(); - void retry_current_host(); void retry_next_host(); + void start_pending_request(Pool* pool, Timer::Callback cb); + void stop_pending_request(); + void execute(); void schedule_next(int64_t timeout = 0); void cancel(); + virtual void on_error(CassError code, const std::string& message); + private: static void on_execute(Timer* timer); virtual void on_start(); - virtual void on_finish(); - virtual void on_finish_with_retry(bool use_next_host); + + virtual void on_retry(bool use_next_host); + + virtual void on_set(ResponseMessage* response); + virtual void on_cancel(); + + void on_result_response(ResponseMessage* response); + void on_error_response(ResponseMessage* response); + void on_error_unprepared(ErrorResponse* error); + + void return_connection(); private: friend class SchemaChangeCallback; @@ -225,16 +235,11 @@ class SpeculativeExecution : public RequestCallback { void set_error_with_error_response(const Response::Ptr& error, CassError code, const std::string& message); - void return_connection(); - - void on_result_response(ResponseMessage* response); - void on_error_response(ResponseMessage* response); - void on_error_unprepared(ErrorResponse* error); - private: RequestHandler::Ptr request_handler_; Host::Ptr current_host_; Pool* pool_; + Connection* connection_; Timer schedule_timer_; Timer pending_request_timer_; int num_retries_; diff --git a/src/session.cpp b/src/session.cpp index f3e3a7df3..d93b059fa 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -718,7 +718,9 @@ void Session::on_execute(uv_async_t* data) { bool is_done = false; while (!is_done) { - if (!request_handler->first_host()) { + request_handler->next_host(); + + if (!request_handler->current_host()) { request_handler->set_error(CASS_ERROR_LIB_NO_HOSTS_AVAILABLE, "All connections on all I/O threads are busy"); break; @@ -727,7 +729,7 @@ void Session::on_execute(uv_async_t* data) { size_t start = session->current_io_worker_; for (size_t i = 0, size = session->io_workers_.size(); i < size; ++i) { const IOWorker::Ptr& io_worker = session->io_workers_[start % size]; - if (io_worker->is_host_available(request_handler->first_host()->address()) && + if (io_worker->is_host_available(request_handler->current_host()->address()) && io_worker->execute(request_handler)) { session->current_io_worker_ = (start + 1) % size; is_done = true; diff --git a/test/integration_tests/src/test_consistency.cpp b/test/integration_tests/src/test_consistency.cpp index 350e4dda0..f512aab23 100644 --- a/test/integration_tests/src/test_consistency.cpp +++ b/test/integration_tests/src/test_consistency.cpp @@ -221,7 +221,7 @@ BOOST_AUTO_TEST_CASE(retry_policy_downgrading) BOOST_CHECK_EQUAL(query_result, CASS_OK); } - ccm->pause_node(2); + ccm->stop_node(2); { CassError init_result; @@ -236,7 +236,7 @@ BOOST_AUTO_TEST_CASE(retry_policy_downgrading) BOOST_CHECK_EQUAL(query_result, CASS_OK); } - ccm->pause_node(3); + ccm->stop_node(3); { CassError init_result; @@ -264,9 +264,6 @@ BOOST_AUTO_TEST_CASE(retry_policy_downgrading) BOOST_CHECK_EQUAL(query_result, CASS_OK); } - ccm->resume_node(2); - ccm->resume_node(3); - cass_retry_policy_free(downgrading_policy); // Ensure the keyspace is dropped diff --git a/test/integration_tests/src/test_future.cpp b/test/integration_tests/src/test_future.cpp index 31c0179ef..76d1954f6 100644 --- a/test/integration_tests/src/test_future.cpp +++ b/test/integration_tests/src/test_future.cpp @@ -49,8 +49,8 @@ BOOST_AUTO_TEST_CASE(error) // Should not be set BOOST_CHECK(cass_future_get_result(future.get()) == NULL); - BOOST_CHECK(cass_future_get_error_result(future.get()) == NULL); BOOST_CHECK(cass_future_get_prepared(future.get()) == NULL); + BOOST_CHECK(cass_future_get_error_result(future.get()) != NULL); BOOST_CHECK_EQUAL(cass_future_custom_payload_item_count(future.get()), 0); { @@ -60,7 +60,7 @@ BOOST_AUTO_TEST_CASE(error) BOOST_REQUIRE_EQUAL(cass_future_custom_payload_item(future.get(), 0, &name, &name_length, &value, &value_size), - CASS_ERROR_LIB_NO_CUSTOM_PAYLOAD); + CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS); } } diff --git a/test/integration_tests/src/test_ssl.cpp b/test/integration_tests/src/test_ssl.cpp index dd8961e5d..c67dd7b33 100644 --- a/test/integration_tests/src/test_ssl.cpp +++ b/test/integration_tests/src/test_ssl.cpp @@ -131,11 +131,11 @@ struct TestSSL { * @param is_failure True if test is supposed to fail; false otherwise * (default: false) * @param nodes Number of nodes for the cluster (default: 1) - * @param protocol_version Protocol version to use for connection (default: 2) */ - void setup(bool is_ssl = true, bool is_client_authentication = false, bool is_failure = false, unsigned int nodes = 1, unsigned int protocol_version = 2) { + void setup(bool is_ssl = true, bool is_client_authentication = false, bool is_failure = false, unsigned int nodes = 1) { //Create a n-node cluster ccm_->create_cluster(nodes, 0, false, is_ssl, is_client_authentication); + ccm_->start_cluster(); //Initialize the cpp-driver cluster_ = cass_cluster_new(); @@ -145,7 +145,6 @@ struct TestSSL { cass_cluster_set_num_threads_io(cluster_, 1); cass_cluster_set_core_connections_per_host(cluster_, 2); cass_cluster_set_max_connections_per_host(cluster_, 4); - cass_cluster_set_protocol_version(cluster_, protocol_version); cass_cluster_set_ssl(cluster_, ssl_); //Establish the connection (if ssl) From b77ccce0cba3046c3e0100c4ef74ac812d1674ef Mon Sep 17 00:00:00 2001 From: Michael Penick Date: Wed, 5 Oct 2016 12:01:39 -0700 Subject: [PATCH 7/7] Fix: Remove debug code and use small opt. vector for specultive executions --- src/request_callback.cpp | 2 -- src/request_callback.hpp | 1 - src/request_handler.hpp | 3 ++- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/request_callback.cpp b/src/request_callback.cpp index 126db66ce..9e5dcc565 100644 --- a/src/request_callback.cpp +++ b/src/request_callback.cpp @@ -75,8 +75,6 @@ int32_t RequestCallback::encode(int version, int flags, BufferVec* bufs) { } void RequestCallback::set_state(RequestCallback::State next_state) { - state_history_.push_back(state_); - switch (state_) { case REQUEST_STATE_NEW: if (next_state == REQUEST_STATE_NEW || diff --git a/src/request_callback.hpp b/src/request_callback.hpp index 9fba6e0d2..d01ce68b1 100644 --- a/src/request_callback.hpp +++ b/src/request_callback.hpp @@ -113,7 +113,6 @@ class RequestCallback : public RefCounted, public List read_before_write_response_; typedef std::vector StateVec; - StateVec state_history_; private: DISALLOW_COPY_AND_ASSIGN(RequestCallback); diff --git a/src/request_handler.hpp b/src/request_handler.hpp index 8ab30abec..d29c7aa06 100644 --- a/src/request_handler.hpp +++ b/src/request_handler.hpp @@ -19,6 +19,7 @@ #include "constants.hpp" #include "error_response.hpp" +#include "fixed_vector.hpp" #include "future.hpp" #include "request_callback.hpp" #include "host.hpp" @@ -163,7 +164,7 @@ class RequestHandler : public RefCounted { void stop_request(); private: - typedef std::vector SpeculativeExecutionVec; + typedef FixedVector SpeculativeExecutionVec; const Request::ConstPtr request_; int64_t timestamp_;