Skip to content

Commit

Permalink
prepared statements can be disabled
Browse files Browse the repository at this point in the history
  • Loading branch information
tenderlove committed Feb 22, 2012
1 parent 349d5a6 commit fd39847
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 30 deletions.
Expand Up @@ -2,9 +2,11 @@ module ActiveRecord
module ConnectionAdapters # :nodoc:
module DatabaseStatements
# Converts an arel AST to SQL
def to_sql(arel)
def to_sql(arel, binds = [])
if arel.respond_to?(:ast)
visitor.accept(arel.ast)
visitor.accept(arel.ast) do
quote(*binds.shift.reverse)
end
else
arel
end
Expand All @@ -13,7 +15,7 @@ def to_sql(arel)
# Returns an array of record hashes with the column names as keys and
# column values as values.
def select_all(arel, name = nil, binds = [])
select(to_sql(arel), name, binds)
select(to_sql(arel, binds), name, binds)
end

# Returns a record hash with the column names as keys and column values
Expand All @@ -33,7 +35,7 @@ def select_value(arel, name = nil)
# Returns an array of the values of the first column in a select:
# select_values("SELECT id FROM companies LIMIT 3") => [1,2,3]
def select_values(arel, name = nil)
result = select_rows(to_sql(arel), name)
result = select_rows(to_sql(arel, []), name)
result.map { |v| v[0] }
end

Expand Down Expand Up @@ -84,19 +86,19 @@ def exec_update(sql, name, binds)
# If the next id was calculated in advance (as in Oracle), it should be
# passed in as +id_value+.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [])
sql, binds = sql_for_insert(to_sql(arel), pk, id_value, sequence_name, binds)
sql, binds = sql_for_insert(to_sql(arel, binds), pk, id_value, sequence_name, binds)
value = exec_insert(sql, name, binds)
id_value || last_inserted_id(value)
end

# Executes the update statement and returns the number of rows affected.
def update(arel, name = nil, binds = [])
exec_update(to_sql(arel), name, binds)
exec_update(to_sql(arel, binds), name, binds)
end

# Executes the delete statement and returns the number of rows affected.
def delete(arel, name = nil, binds = [])
exec_delete(to_sql(arel), name, binds)
exec_delete(to_sql(arel, binds), name, binds)
end

# Checks whether there is currently no transaction active. This is done
Expand Down
Expand Up @@ -57,7 +57,7 @@ def clear_query_cache

def select_all(arel, name = nil, binds = [])
if @query_cache_enabled
sql = to_sql(arel)
sql = to_sql(arel, binds)
cache_sql(sql, binds) { super(sql, name, binds) }
else
super
Expand Down
@@ -1,4 +1,5 @@
require 'active_support/core_ext/object/blank'
require 'arel/visitors/bind_visitor'

module ActiveRecord
module ConnectionAdapters
Expand Down Expand Up @@ -122,12 +123,21 @@ def missing_default_forged_as_empty_string?(default)
:boolean => { :name => "tinyint", :limit => 1 }
}

class BindSubstitution < Arel::Visitors::MySQL # :nodoc:
include Arel::Visitors::BindVisitor
end

# FIXME: Make the first parameter more similar for the two adapters
def initialize(connection, logger, connection_options, config)
super(connection, logger)
@connection_options, @config = connection_options, config
@quoted_column_names, @quoted_table_names = {}, {}
@visitor = Arel::Visitors::MySQL.new self

if config.fetch(:prepared_statements) { true }
@visitor = Arel::Visitors::MySQL.new self
else
@visitor = BindSubstitution.new self
end
end

def adapter_name #:nodoc:
Expand Down
Expand Up @@ -32,6 +32,7 @@ def adapter

def initialize(connection, logger, connection_options, config)
super
@visitor = BindSubstitution.new self
configure_connection
end

Expand Down Expand Up @@ -65,10 +66,6 @@ def quote_string(string)
@connection.escape(string)
end

def substitute_at(column, index)
Arel::Nodes::BindParam.new "\0"
end

# CONNECTION MANAGEMENT ====================================

def active?
Expand All @@ -94,7 +91,7 @@ def disconnect!
# DATABASE STATEMENTS ======================================

def explain(arel, binds = [])
sql = "EXPLAIN #{to_sql(arel)}"
sql = "EXPLAIN #{to_sql(arel, binds.dup)}"
start = Time.now
result = exec_query(sql, 'EXPLAIN', binds)
elapsed = Time.now - start
Expand Down Expand Up @@ -220,8 +217,7 @@ def exec_query(sql, name = 'SQL', binds = [])
# Returns an array of record hashes with the column names as keys and
# column values as values.
def select(sql, name = nil, binds = [])
binds = binds.dup
exec_query(sql.gsub("\0") { quote(*binds.shift.reverse) }, name)
exec_query(sql, name)
end

def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
Expand All @@ -231,17 +227,11 @@ def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
alias :create :insert_sql

def exec_insert(sql, name, binds)
binds = binds.dup

# Pretend to support bind parameters
execute sql.gsub("\0") { quote(*binds.shift.reverse) }, name
execute to_sql(sql, binds), name
end

def exec_delete(sql, name, binds)
binds = binds.dup

# Pretend to support bind parameters
execute sql.gsub("\0") { quote(*binds.shift.reverse) }, name
execute to_sql(sql, binds), name
@connection.affected_rows
end
alias :exec_update :exec_delete
Expand Down
Expand Up @@ -2,6 +2,7 @@
require 'active_support/core_ext/object/blank'
require 'active_record/connection_adapters/statement_pool'
require 'active_record/connection_adapters/postgresql/oid'
require 'arel/visitors/bind_visitor'

# Make sure we're using pg high enough for PGResult#values
gem 'pg', '~> 0.11'
Expand Down Expand Up @@ -373,11 +374,23 @@ def connection_active?
end
end

class BindSubstitution < Arel::Visitors::PostgreSQL # :nodoc:
include Arel::Visitors::BindVisitor
end

# Initializes and connects a PostgreSQL adapter.
def initialize(connection, logger, connection_parameters, config)
super(connection, logger)

if config.fetch(:prepared_statements) { true }
@visitor = Arel::Visitors::PostgreSQL.new self
else
@visitor = BindSubstitution.new self
end

connection_parameters.delete :prepared_statements

@connection_parameters, @config = connection_parameters, config
@visitor = Arel::Visitors::PostgreSQL.new self

# @local_tz is initialized as nil to avoid warnings when connect tries to use it
@local_tz = nil
Expand Down Expand Up @@ -599,7 +612,7 @@ def disable_referential_integrity #:nodoc:
# DATABASE STATEMENTS ======================================

def explain(arel, binds = [])
sql = "EXPLAIN #{to_sql(arel)}"
sql = "EXPLAIN #{to_sql(arel, binds)}"
ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', binds))
end

Expand Down
@@ -1,5 +1,6 @@
require 'active_record/connection_adapters/abstract_adapter'
require 'active_record/connection_adapters/statement_pool'
require 'arel/visitors/bind_visitor'

module ActiveRecord
module ConnectionAdapters #:nodoc:
Expand Down Expand Up @@ -68,12 +69,21 @@ def dealloc(stmt)
end
end

class BindSubstitution < Arel::Visitors::SQLite # :nodoc:
include Arel::Visitors::BindVisitor
end

def initialize(connection, logger, config)
super(connection, logger)
@statements = StatementPool.new(@connection,
config.fetch(:statement_limit) { 1000 })
@config = config
@visitor = Arel::Visitors::SQLite.new self

if config.fetch(:prepared_statements) { true }
@visitor = Arel::Visitors::SQLite.new self
else
@visitor = BindSubstitution.new self
end
end

def adapter_name #:nodoc:
Expand Down Expand Up @@ -201,7 +211,7 @@ def type_cast(value, column) # :nodoc:
# DATABASE STATEMENTS ======================================

def explain(arel, binds = [])
sql = "EXPLAIN QUERY PLAN #{to_sql(arel)}"
sql = "EXPLAIN QUERY PLAN #{to_sql(arel, binds)}"
ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', binds))
end

Expand Down
3 changes: 2 additions & 1 deletion activerecord/lib/active_record/relation.rb
Expand Up @@ -78,6 +78,7 @@ def new(*args, &block)
end

def initialize_copy(other)
@bind_values = @bind_values.dup
reset
end

Expand Down Expand Up @@ -454,7 +455,7 @@ def reset
end

def to_sql
@to_sql ||= klass.connection.to_sql(arel)
@to_sql ||= klass.connection.to_sql(arel, @bind_values.dup)
end

def where_values_hash
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/relation/finder_methods.rb
Expand Up @@ -208,7 +208,7 @@ def exists?(id = false)
def find_with_associations
join_dependency = construct_join_dependency_for_association_find
relation = construct_relation_for_association_find(join_dependency)
rows = connection.select_all(relation, 'SQL', relation.bind_values)
rows = connection.select_all(relation, 'SQL', relation.bind_values.dup)
join_dependency.instantiate(rows)
rescue ThrowResult
[]
Expand Down

0 comments on commit fd39847

Please sign in to comment.