Skip to content

Commit

Permalink
Merge pull request #82 from belousovAV/fix_dao_finder
Browse files Browse the repository at this point in the history
Fix dao finder when table name does not match dao name. Support dao s…
  • Loading branch information
rous-gg committed Aug 25, 2023
2 parents 7277f98 + 2ccb114 commit aadeb07
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 111 deletions.
57 changes: 26 additions & 31 deletions ree_lib/lib/ree_lib/packages/ree_dao/package/ree_dao/association.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ class Association
link :index_by, from: :ree_array
link :underscore, from: :ree_string

attr_reader :parent, :parent_dao_name, :list, :global_opts
attr_reader :parent, :parent_dao, :list, :global_opts

contract(ReeDao::Associations, Symbol, Array, Ksplat[RestKeys => Any] => Any)
def initialize(parent, parent_dao_name, list, **global_opts)
contract(ReeDao::Associations, Sequel::Dataset, Array, Ksplat[RestKeys => Any] => Any)
def initialize(parent, parent_dao, list, **global_opts)
@parent = parent
@parent_dao_name = parent_dao_name
@parent_dao = parent_dao
@list = list
@global_opts = global_opts
end
Expand All @@ -36,7 +36,7 @@ def handle_field(field_proc)
Or[:belongs_to, :has_one, :has_many],
Symbol,
Ksplat[RestKeys => Any],
Optblock => Nilor[Array]
Optblock => Array
)
def load_association(assoc_type, assoc_name, **__opts, &block)
__opts[:autoload_children] ||= false
Expand All @@ -47,21 +47,17 @@ def load_association(assoc_type, assoc_name, **__opts, &block)
**__opts
)

dao = find_dao(assoc_name, parent, __opts[:scope])

dao_name = if dao
dao.first_source_table
elsif __opts[:scope].is_a?(Array)
name = underscore(demodulize(__opts[:scope].first.class.name))
scope = __opts[:scope]

if name.end_with?("s")
"#{name}es"
else
"#{name}s"
end
dao = if scope.is_a?(Array)
return [] if scope.empty?
name = underscore(demodulize(scope.first.class.name))
find_dao(name, parent, nil)
else
find_dao(assoc_name, parent, scope)
end

process_block(assoc_index, __opts[:autoload_children], __opts[:to_dto], dao_name, &block) if block_given?
process_block(assoc_index, __opts[:autoload_children], __opts[:to_dto], dao, &block) if block_given?

list
end
Expand All @@ -75,7 +71,7 @@ def load_association_by_type(type, assoc_name, **__opts)
case type
when :belongs_to
one_to_one(
parent_dao_name,
parent_dao,
assoc_name,
list,
scope: __opts[:scope],
Expand All @@ -86,7 +82,7 @@ def load_association_by_type(type, assoc_name, **__opts)
)
when :has_one
one_to_one(
parent_dao_name,
parent_dao,
assoc_name,
list,
scope: __opts[:scope],
Expand All @@ -98,7 +94,7 @@ def load_association_by_type(type, assoc_name, **__opts)
)
when :has_many
one_to_many(
parent_dao_name,
parent_dao,
assoc_name,
list,
scope: __opts[:scope],
Expand All @@ -110,8 +106,8 @@ def load_association_by_type(type, assoc_name, **__opts)
end
end

contract(Or[Hash, Array], Bool, Nilor[Proc], Symbol, Block => Any)
def process_block(assoc, autoload_children, to_dto, parent_dao_name, &block)
contract(Or[Hash, Array], Bool, Nilor[Proc], Sequel::Dataset, Block => Any)
def process_block(assoc, autoload_children, to_dto, parent_dao, &block)
assoc_list = assoc.is_a?(Array) ? assoc : assoc.values.flatten

if to_dto
Expand All @@ -124,13 +120,12 @@ def process_block(assoc, autoload_children, to_dto, parent_dao_name, &block)
parent.agg_caller,
assoc_list,
parent.local_vars,
parent_dao_name,
parent_dao,
autoload_children,
**global_opts
)
parent_dao = find_dao(parent_dao_name, parent.agg_caller)

if dao_in_transaction?(parent_dao) || ReeDao::Associations.sync_mode?
if parent_dao.db.in_transaction? || ReeDao::Associations.sync_mode?
associations.instance_exec(assoc_list, &block)
else
threads = associations.instance_exec(assoc_list, &block)
Expand All @@ -150,7 +145,7 @@ def process_block(assoc, autoload_children, to_dto, parent_dao_name, &block)
end

contract(
Symbol,
Sequel::Dataset,
Symbol,
Array,
Kwargs[
Expand All @@ -162,7 +157,7 @@ def process_block(assoc, autoload_children, to_dto, parent_dao_name, &block)
reverse: Bool
] => Or[Hash, Array]
)
def one_to_one(parent_assoc_name, assoc_name, list, scope: nil, primary_key: :id, foreign_key: nil, setter: nil, to_dto: nil, reverse: true)
def one_to_one(parent_dao, assoc_name, list, scope: nil, primary_key: :id, foreign_key: nil, setter: nil, to_dto: nil, reverse: true)
return {} if list.empty?

primary_key ||= :id
Expand All @@ -175,7 +170,7 @@ def one_to_one(parent_assoc_name, assoc_name, list, scope: nil, primary_key: :id
if reverse
# has_one
if !foreign_key
foreign_key = "#{parent_assoc_name.to_s.gsub(/s$/,'')}_id".to_sym
foreign_key = "#{parent_dao.first_source_table.to_s.gsub(/s$/,'')}_id".to_sym
end

root_ids = list.map(&:id).uniq
Expand Down Expand Up @@ -211,7 +206,7 @@ def one_to_one(parent_assoc_name, assoc_name, list, scope: nil, primary_key: :id
end

contract(
Symbol,
Sequel::Dataset,
Symbol,
Array,
Kwargs[
Expand All @@ -222,7 +217,7 @@ def one_to_one(parent_assoc_name, assoc_name, list, scope: nil, primary_key: :id
to_dto: Nilor[Proc]
] => Or[Hash, Array]
)
def one_to_many(parent_assoc_name, assoc_name, list, primary_key: nil, foreign_key: nil, scope: nil, setter: nil, to_dto: nil)
def one_to_many(parent_dao, assoc_name, list, primary_key: nil, foreign_key: nil, scope: nil, setter: nil, to_dto: nil)
return {} if list.empty?

primary_key ||= :id
Expand All @@ -233,7 +228,7 @@ def one_to_many(parent_assoc_name, assoc_name, list, primary_key: nil, foreign_k
assoc_dao = nil
assoc_dao = find_dao(assoc_name, parent, scope)

foreign_key ||= "#{parent_assoc_name.to_s.gsub(/s$/, '')}_id".to_sym
foreign_key ||= "#{parent_dao.first_source_table.to_s.gsub(/s$/, '')}_id".to_sym

root_ids = list.map(&:"#{primary_key}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,20 @@ def self.extended(base)
end

module InstanceMethods
SUFFIXES = ["", "s", "es", "dao", "s_dao", "es_dao"].freeze

def find_dao(assoc_name, parent_caller, scope = nil)
dao_from_name = parent_caller.instance_variable_get("@#{assoc_name}") || parent_caller.instance_variable_get("@#{assoc_name}s")
return dao_from_name if dao_from_name

raise ArgumentError, "can't find DAO for :#{assoc_name}, provide correct scope or association name" if scope.nil?
return nil if scope.is_a?(Array)

table_name = scope.first_source_table
dao_from_scope = parent_caller.instance_variable_get("@#{table_name}")
return dao_from_scope if dao_from_scope

SUFFIXES.each do |suffix|
dao_from_name = parent_caller.instance_variable_get("@#{assoc_name}#{suffix}")
return dao_from_name if dao_from_name
end

if scope.is_a?(Sequel::Dataset)
return scope.unfiltered
end

raise ArgumentError, "can't find DAO for :#{assoc_name}, provide correct scope or association name"
end

def dao_in_transaction?(dao)
return false if dao.nil?

dao.db.in_transaction?
end
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ class Associations
include Ree::LinkDSL
include ReeDao::AssociationMethods

attr_reader :agg_caller, :list, :local_vars, :only, :except, :parent_dao_name, :autoload_children, :global_opts
attr_reader :agg_caller, :list, :local_vars, :only, :except, :parent_dao, :autoload_children, :global_opts

def initialize(agg_caller, list, local_vars, parent_dao_name, autoload_children = false, **opts)
def initialize(agg_caller, list, local_vars, parent_dao, autoload_children = false, **opts)
@agg_caller = agg_caller
@list = list
@local_vars = local_vars
@global_opts = opts || {}
@only = opts[:only] if opts[:only]
@except = opts[:except] if opts[:except]
@parent_dao_name = parent_dao_name
@parent_dao = parent_dao
@autoload_children = autoload_children

if @only && @except
shared_keys = @only.intersection(@except)

if shared_keys.size > 0
raise ArgumentError.new("you can't use both :only and :except for #{shared_keys.map { "\"#{_1}\"" }.join(", ")} keys")
raise ArgumentError.new("you can't use both :only and :except for #{shared_keys.map { "\"#{_1}\"" }.join(", ")} keys")
end
end

Expand Down Expand Up @@ -87,12 +87,10 @@ def field(assoc_name, proc)
Optblock => Any
)
def association(assoc_type, assoc_name, __opts, &block)
parent_dao = find_dao(parent_dao_name, agg_caller)

if dao_in_transaction?(parent_dao) || self.class.sync_mode?
if parent_dao.db.in_transaction? || self.class.sync_mode?
return if association_is_not_included?(assoc_name) || list.empty?

association = Association.new(self, parent_dao_name, list, **global_opts)
association = Association.new(self, parent_dao, list, **global_opts)

if assoc_type == :field
association.handle_field(assoc_name, __opts)
Expand All @@ -104,7 +102,7 @@ def association(assoc_type, assoc_name, __opts, &block)
return { association_threads: @assoc_threads, field_threads: @field_threads }
end

association = Association.new(self, parent_dao_name, list, **global_opts)
association = Association.new(self, parent_dao, list, **global_opts)

if assoc_type == :field
field_proc = __opts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@

class ReeDao::Agg
include Ree::FnDSL
include ReeDao::AssociationMethods

fn :agg do
link :demodulize, from: :ree_string
link :underscore, from: :ree_string
link "ree_dao/associations", -> { Associations }
link "ree_dao/contract/dao_dataset_contract", -> { DaoDatasetContract }
link "ree_dao/contract/entity_contract", -> { EntityContract }
end

contract(
Nilor[DaoDatasetContract],
DaoDatasetContract,
Or[Sequel::Dataset, ArrayOf[Integer], ArrayOf[EntityContract], Integer],
Ksplat[
only?: ArrayOf[Symbol],
Expand All @@ -23,26 +20,16 @@ class ReeDao::Agg
],
Optblock => ArrayOf[Any]
)
def call(dao = nil, ids_or_scope, **opts, &block)
def call(dao, ids_or_scope, **opts, &block)
scope = if ids_or_scope.is_a?(Array) && ids_or_scope.any? { _1.is_a?(Integer) }
raise ArgumentError.new("Dao should be provided") if dao.nil?
return [] if ids_or_scope.empty?

dao.where(id: ids_or_scope)
elsif ids_or_scope.is_a?(Integer)
raise ArgumentError.new("Dao should be provided") if dao.nil?

dao.where(id: ids_or_scope)
else
ids_or_scope
end

if dao
dao_name = dao.first_source_table
else
dao_name = underscore(demodulize(scope.first.class.name)).to_sym
end

list = scope.is_a?(Sequel::Dataset) ? scope.all : scope

if opts[:to_dto]
Expand All @@ -51,7 +38,7 @@ def call(dao = nil, ids_or_scope, **opts, &block)
end
end

load_associations(dao_name, list, **opts, &block) if block_given?
load_associations(dao.unfiltered, list, **opts, &block) if block_given?

if ids_or_scope.is_a?(Array)
list.sort_by { ids_or_scope.index(_1.id) }
Expand All @@ -62,7 +49,7 @@ def call(dao = nil, ids_or_scope, **opts, &block)

private

def load_associations(dao_name, list, **opts, &block)
def load_associations(dao, list, **opts, &block)
return if list.empty?

local_vars = block.binding.eval(<<-CODE, __FILE__, __LINE__ + 1)
Expand All @@ -72,10 +59,9 @@ def load_associations(dao_name, list, **opts, &block)

agg_caller = block.binding.eval("self")

associations = Associations.new(agg_caller, list, local_vars, dao_name, **opts).instance_exec(list, &block)
dao = find_dao(dao_name, agg_caller)
associations = Associations.new(agg_caller, list, local_vars, dao, **opts).instance_exec(list, &block)

if dao_in_transaction?(dao) || ReeDao.load_sync_associations_enabled?
if dao.db.in_transaction? || ReeDao.load_sync_associations_enabled?
associations
else
associations[:association_threads].map do |association, assoc_type, assoc_name, opts, block|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
"args": [
{
"arg": "dao",
"arg_type": "opt",
"type": "Nilor[PackageName::DaoName::Dao: \"SELECT * FROM `table`\"]"
"arg_type": "req",
"type": "PackageName::DaoName::Dao: \"SELECT * FROM `table`\""
},
{
"arg": "ids_or_scope",
Expand All @@ -38,21 +38,6 @@
}
],
"links": [
{
"target": "demodulize",
"package_name": "ree_string",
"as": "demodulize",
"imports": [

]
},
{
"target": "underscore",
"package_name": "ree_string",
"as": "underscore",
"imports": [

]
}
]
}
Loading

0 comments on commit aadeb07

Please sign in to comment.