Skip to content

Commit

Permalink
apacheGH-37257: [Ruby][FlightSQL] Use the same options for auto prepa…
Browse files Browse the repository at this point in the history
…red statement close request
  • Loading branch information
kou committed Aug 18, 2023
1 parent 9fea4ee commit b375207
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 9 deletions.
2 changes: 1 addition & 1 deletion c_glib/arrow-flight-glib/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ gaflight_call_options_clear_headers(GAFlightCallOptions *options)
* @func: (scope call): The user's callback function.
* @user_data: (closure): Data for @func.
*
* Iterates over all header in the options.
* Iterates over all headers in the options.
*
* Since: 9.0.0
*/
Expand Down
46 changes: 44 additions & 2 deletions c_glib/arrow-flight-glib/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,11 @@ gaflight_message_reader_get_descriptor(GAFlightMessageReader *reader)
}


typedef struct GAFlightServerCallContextPrivate_ {
struct GAFlightServerCallContextPrivate {
arrow::flight::ServerCallContext *call_context;
} GAFlightServerCallContextPrivate;
std::string current_incoming_header_key;
std::string current_incoming_header_value;
};

enum {
PROP_CALL_CONTEXT = 1,
Expand All @@ -310,6 +312,15 @@ G_DEFINE_TYPE_WITH_PRIVATE(GAFlightServerCallContext,
gaflight_server_call_context_get_instance_private( \
GAFLIGHT_SERVER_CALL_CONTEXT(obj)))

static void
gaflight_server_call_context_finalize(GObject *object)
{
auto priv = GAFLIGHT_SERVER_CALL_CONTEXT_GET_PRIVATE(object);
priv->current_incoming_header_key.~basic_string();
priv->current_incoming_header_value.~basic_string();
G_OBJECT_CLASS(gaflight_server_call_context_parent_class)->finalize(object);
}

static void
gaflight_server_call_context_set_property(GObject *object,
guint prop_id,
Expand All @@ -333,13 +344,17 @@ gaflight_server_call_context_set_property(GObject *object,
static void
gaflight_server_call_context_init(GAFlightServerCallContext *object)
{
auto priv = GAFLIGHT_SERVER_CALL_CONTEXT_GET_PRIVATE(object);
new(&(priv->current_incoming_header_key)) std::string;
new(&(priv->current_incoming_header_value)) std::string;
}

static void
gaflight_server_call_context_class_init(GAFlightServerCallContextClass *klass)
{
auto gobject_class = G_OBJECT_CLASS(klass);

gobject_class->finalize = gaflight_server_call_context_finalize;
gobject_class->set_property = gaflight_server_call_context_set_property;

GParamSpec *spec;
Expand All @@ -351,6 +366,33 @@ gaflight_server_call_context_class_init(GAFlightServerCallContextClass *klass)
g_object_class_install_property(gobject_class, PROP_CALL_CONTEXT, spec);
}

/**
* gaflight_server_call_context_foreach_incoming_header:
* @context: A #GAFlightServerCallContext.
* @func: (scope call): The user's callback function.
* @user_data: (closure): Data for @func.
*
* Iterates over all incoming headers.
*
* Since: 14.0.0
*/
void
gaflight_server_call_context_foreach_incoming_header(
GAFlightServerCallContext *context,
GAFlightHeaderFunc func,
gpointer user_data)
{
auto priv = GAFLIGHT_SERVER_CALL_CONTEXT_GET_PRIVATE(context);
auto flight_context = gaflight_server_call_context_get_raw(context);
for (const auto &header : flight_context->incoming_headers()) {
priv->current_incoming_header_key = std::string(header.first);
priv->current_incoming_header_value = std::string(header.second);
func(priv->current_incoming_header_key.c_str(),
priv->current_incoming_header_value.c_str(),
user_data);
}
}


struct GAFlightServerAuthSenderPrivate {
arrow::flight::ServerAuthSender *sender;
Expand Down
7 changes: 7 additions & 0 deletions c_glib/arrow-flight-glib/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ struct _GAFlightServerCallContextClass
GObjectClass parent_class;
};

GARROW_AVAILABLE_IN_14_0
void
gaflight_server_call_context_foreach_incoming_header(
GAFlightServerCallContext *context,
GAFlightHeaderFunc func,
gpointer user_data);


#define GAFLIGHT_TYPE_SERVER_AUTH_SENDER \
(gaflight_server_auth_sender_get_type())
Expand Down
4 changes: 4 additions & 0 deletions c_glib/doc/arrow-flight-glib/arrow-flight-glib-docs.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
<title>Index of deprecated API</title>
<xi:include href="xml/api-index-deprecated.xml"><xi:fallback /></xi:include>
</index>
<index id="api-index-14-0-0" role="14.0.0">
<title>Index of new symbols in 14.0.0</title>
<xi:include href="xml/api-index-14.0.0.xml"><xi:fallback /></xi:include>
</index>
<index id="api-index-12-0-0" role="12.0.0">
<title>Index of new symbols in 12.0.0</title>
<xi:include href="xml/api-index-12.0.0.xml"><xi:fallback /></xi:include>
Expand Down
6 changes: 3 additions & 3 deletions ruby/red-arrow-flight-sql/lib/arrow-flight-sql/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
module ArrowFlightSQL
class Client
alias_method :prepare_raw, :prepare
def prepare(*args)
statement = prepare_raw(*args)
def prepare(query, options=nil)
statement = prepare_raw(query, options)
if block_given?
begin
yield(statement)
ensure
statement.close unless statement.closed?
statement.close(options) unless statement.closed?
end
else
statement
Expand Down
7 changes: 6 additions & 1 deletion ruby/red-arrow-flight-sql/test/helper/server.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def virtual_do_do_get_statement(context, command)
end

def virtual_do_create_prepared_statement(context, request)
unless request.query == "INSERT INTO page_view_table VALUES (?, true)"
unless request.query == "INSERT INTO page_view_table VALUES ($1, true)"
raise Arrow::Error::Invalid.new("invalid SQL")
end
result = ArrowFlightSQL::CreatePreparedStatementResult.new
Expand All @@ -62,6 +62,11 @@ def virtual_do_close_prepared_statement(context, request)
unless request.handle.to_s == "valid-handle"
raise Arrow::Error::Invalid.new("invalid handle")
end
access_key = context.incoming_headers.assoc("x-access-key")
unless access_key == ["x-access-key", "secret"]
message = "invalid access key: #{access_key.inspect}"
raise Arrow::Error::Invalid.new(message)
end
end
end
end
6 changes: 4 additions & 2 deletions ruby/red-arrow-flight-sql/test/test-client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def test_execute
end

def test_prepare
insert_sql = "INSERT INTO page_view_table VALUES (?, true)"
insert_sql = "INSERT INTO page_view_table VALUES ($1, true)"
block_called = false
@sql_client.prepare(insert_sql) do |statement|
options = ArrowFlight::CallOptions.new
options.add_header("x-access-key", "secret")
@sql_client.prepare(insert_sql, options) do |statement|
block_called = true
assert_equal([
Arrow::Schema.new(count: :uint64, private: :boolean),
Expand Down
1 change: 1 addition & 0 deletions ruby/red-arrow-flight/lib/arrow-flight/loader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def require_libraries
require "arrow-flight/client-options"
require "arrow-flight/location"
require "arrow-flight/record-batch-reader"
require "arrow-flight/server-call-context"
require "arrow-flight/server-options"
require "arrow-flight/ticket"
end
Expand Down
31 changes: 31 additions & 0 deletions ruby/red-arrow-flight/lib/arrow-flight/server-call-context.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

module ArrowFlight
class ServerCallContext
def each_incoming_header
return to_enum(__method__) unless block_given?
foreach_incoming_header do |key, value|
yield(key, value)
end
end

def incoming_headers
each_incoming_header.to_a
end
end
end

0 comments on commit b375207

Please sign in to comment.