/
many_to_many_set.cr
143 lines (125 loc) · 5 KB
/
many_to_many_set.cr
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
require "./set"
module Marten
module DB
module Query
# Represents a query set resulting from a many-to-many relation.
class ManyToManySet(M) < Set(M)
@m2m_field : Field::Base? = nil
@m2m_through_from_field : Field::Base? = nil
@m2m_through_to_field : Field::Base? = nil
def initialize(
@instance : Marten::DB::Model,
@field_id : String,
@through_related_name : String,
@through_model_from_field_id : String,
@through_model_to_field_id : String,
query : SQL::Query(M)? = nil
)
@query = if query.nil?
q = SQL::Query(M).new
q.add_query_node(
Node.new({"#{@through_related_name}__#{@through_model_from_field_id}" => @instance})
)
q
else
query.not_nil!
end
end
# Adds the given objects to the many-to-many relationship.
#
# If the objects specified in `objs` are already in the relationship, they will be skipped and not added again.
def add(*objs : M)
add(objs.to_a)
end
# :ditto:
def add(objs : Enumerable(M) | Iterable(M))
query.connection.transaction do
# Identify which objects are already added to the many to many relationship and skip them.
existing_object_ids = m2m_field.as(Field::ManyToMany).through._base_queryset
.using(query.using)
.filter(
Query::Node.new(
{
m2m_through_from_field.id => @instance.pk.as(Field::Any),
"#{m2m_through_to_field.id}__in" => objs.map(&.pk!.as(Field::Any)).to_a,
}
)
)
.pluck([m2m_through_to_field.id]).flatten
# Add each object that was not already in the relationship.
through_objs_to_add = objs.compact_map do |obj|
next if existing_object_ids.includes?(obj.id)
through_obj = m2m_field.as(Field::ManyToMany).through.new
through_obj.set_field_value(m2m_through_from_field.id, @instance.pk)
through_obj.set_field_value(m2m_through_to_field.id, obj.pk)
through_obj
end
if !through_objs_to_add.empty?
m2m_field.as(Field::ManyToMany).through.all.using(query.using).unsafe_bulk_create(through_objs_to_add)
end
reset_result_cache
end
end
# Clears the many-to-many relationship.
def clear : Nil
query.connection.transaction do
deletion_qs = m2m_field.as(Field::ManyToMany).through._base_queryset
.using(query.using)
.filter(Query::Node.new({m2m_through_from_field.id => @instance.pk.as(Field::Any)}))
if (query.predicate_node.try(&.children.size) || 1) > 1
# If the m2m queryset was filtered we need to target the right objects for deletion.
deletion_qs = deletion_qs
.filter(Query::Node.new({"#{m2m_through_to_field.id}__in" => pluck(:pk).flatten}))
end
deletion_qs.delete
reset_result_cache
end
end
# Removes the given objects from the many-to-many relationship.
def remove(*objs : M) : Nil
remove(objs.to_a)
end
# :ditto:
def remove(objs : Enumerable(M) | Iterable(M)) : Nil
query.connection.transaction do
m2m_field.as(Field::ManyToMany).through._base_queryset
.using(query.using)
.filter(
Query::Node.new(
{
m2m_through_from_field.id => @instance.pk.as(Field::Any),
"#{m2m_through_to_field.id}__in" => objs.map(&.pk!.as(Field::Any)).to_a,
}
)
)
.delete
reset_result_cache
end
end
protected def clone(other_query = nil)
ManyToManySet(M).new(
instance: @instance,
field_id: @field_id,
through_related_name: @through_related_name,
through_model_from_field_id: @through_model_from_field_id,
through_model_to_field_id: @through_model_to_field_id,
query: other_query.nil? ? @query.clone : other_query.not_nil!
)
end
private def m2m_field
@m2m_field ||= @instance.class.get_field(@field_id)
end
private def m2m_through_from_field
@m2m_through_from_field ||= m2m_field.as(Field::ManyToMany)
.through
.get_local_relation_field(@through_model_from_field_id)
end
private def m2m_through_to_field
@m2m_through_to_field ||= m2m_field.as(Field::ManyToMany)
.through
.get_local_relation_field(@through_model_to_field_id)
end
end
end
end
end