forked from delitescere/squealer
/
target.rb
187 lines (150 loc) · 4.74 KB
/
target.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
require 'delegate'
require 'singleton'
#TODO: Use logger and log throughout
#TODO: Counters and timers
module Squealer
class Target
def self.current
Queue.instance.current
end
def initialize(database_connection, table_name, row_id=nil, &block)
raise BlockRequired, "Block must be given to target (otherwise, there's no work to do)" unless block_given?
raise ArgumentError, "Table name must be supplied" if table_name.to_s.strip.empty?
@table_name = table_name.to_s
@binding = block.binding
verify_table_name_in_scope
@row_id = infer_row_id
@column_names = []
@column_values = []
@sql = ''
target(&block)
end
def sql
@sql
end
def assign(column_name, &block)
@column_names << column_name
if block_given?
@column_values << yield
else
@column_values << infer_value(column_name, @binding)
end
end
private
def obtain_row_id(row_id)
#TODO: Remove in version 1.3 - just call infer_row_id in initialize
if row_id != nil
puts "\033[33mWARNING - squealer:\033[0m the 'target' row_id parameter is deprecated and will be invalid in version 1.3 and above. Remove it, and ensure the table_name matches a variable containing a hashmap with an _id key"
row_id
else
infer_row_id
end
end
def infer_row_id
(
(eval "#{@table_name}[:_id]", @binding, __FILE__, __LINE__) ||
(eval "#{@table_name}['_id']", @binding, __FILE__, __LINE__)
).to_s
end
3
def verify_table_name_in_scope
table = eval "#{@table_name}", @binding, __FILE__, __LINE__
raise ArgumentError, "The variable '#{@table_name}' is not a hashmap" unless table.is_a? Hash
raise ArgumentError, "The hashmap '#{@table_name}' must have an '_id' key" unless table.has_key?('_id') || table.has_key?(:_id)
rescue NameError
raise NameError, "A variable named '#{@table_name}' must be in scope, and reference a hashmap with at least an '_id' key."
end
def infer_value(column_name, binding)
value = eval "#{@table_name}.#{column_name}", binding, __FILE__, __LINE__
unless value
name = column_name.to_s
if name =~ /_id$/
related = name[0..-4] #strip "_id"
value = eval "#{related}._id", binding, __FILE__, __LINE__
end
end
value
end
def target
Queue.instance.push(self)
yield self
insert_statement = %{INSERT INTO "#{@table_name}"}
insert_statement << %{ (#{pk_name}#{column_names}) VALUES ('#{@row_id}'#{column_value_markers})}
if Database.instance.upsertable?
insert_statement << %{ ON DUPLICATE KEY UPDATE #{column_markers}}
@sql = insert_statement
else
update_statement = %{UPDATE "#{@table_name}" SET #{column_markers} WHERE #{pk_name}='#{@row_id}'}
process_sql(update_statement)
@sql = update_statement + "; " + insert_statement
end
process_sql(insert_statement)
Queue.instance.pop
end
def self.targets
@@targets
end
def targets
@@targets
end
def process_sql(sql)
values = Database.instance.upsertable? ? typecast_values * 2 : typecast_values
execute_sql(sql, values)
end
def execute_sql(sql, values)
Database.instance.export.create_command(sql).execute_non_query(*values)
rescue DataObjects::IntegrityError
raise "Failed to execute statement: #{sql} with #{values.inspect}.\nOriginal Exception was: #{$!.to_s}" if Database.instance.upsertable?
rescue
raise "Failed to execute statement: #{sql} with #{values.inspect}.\nOriginal Exception was: #{$!.to_s}"
end
def pk_name
'id'
end
def column_names
return if @column_names.size == 0
",#{@column_names.map { |name| quote_identifier(name) }.join(',')}"
end
def column_values
@column_values
end
def column_value_markers
return if @column_names.size == 0
result = ""
@column_names.size.times { result << ',?'}
result
end
def column_markers
return if @column_names.size == 0
result = ""
@column_names.each {|k| result << "#{quote_identifier(k)}=?," }
result.chop
end
def typecast_values
column_values.map do |value|
case value
when Array
value.join(",")
when BSON::ObjectID
value.to_s
else
value
end
end
end
def quote_identifier(name)
%{"#{name}"}
end
class Queue < DelegateClass(Array)
include Singleton
def current
last
end
protected
def initialize
super([])
end
end
class BlockRequired < ArgumentError; end
end
end