Skip to content

Commit

Permalink
Add information_schema aware version of cast_if_necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
koichirok committed Nov 29, 2017
1 parent 65faa82 commit 45e04ee
Showing 1 changed file with 56 additions and 14 deletions.
70 changes: 56 additions & 14 deletions library/mysql_query
Original file line number Diff line number Diff line change
Expand Up @@ -209,24 +209,64 @@ def connect(connection, module):
module.fail_json(msg="Error connecting to mysql database: %s" % str(e))


def extract_column_value_maps(parameter):
def extract_column_value_maps(cursor, params, target):
"""
extract mysql-quoted tuple-lists for parameters given as ansible params
:param parameter:
:param cursor:
:param params:
:param target: one of 'identifiers', 'values' or 'defaults'
:return:
"""
parameter = params[target]

def cast_if_necessary(v):
if isinstance(v, int):
return long(v)
if isinstance(v, str) and v.isdigit():
return long(v)
if parameter:

return v
placeholder = ", ".join(["%s"] * len(parameter))
where = ' AND '.join(['table_schema = %s',
'table_name = %s',
'column_name in ({0})'.format(placeholder)])
query = "select column_name,data_type from information_schema.columns where {0}".format(where)

try:
# Get columns types from information_schema
res = cursor.execute(query, tuple([params['name'], params['table']] + parameter.keys()))
if res == 0:
raise Exception()

data_type=dict(cursor.fetchall())
if len(data_type) != len(parameter):
raise Exception()

# XXX: Not tested for all column types.

def cast_if_necessary(column, value):
if data_type[column] in {'tinyint'}: # 'smallint' may also be here
if isinstance(value, int):
return int(value)
if isinstance(value, str) and value.isdigit():
return int(value)
elif 'int' in data_type[column]:
if isinstance(value, int):
return long(value)
if isinstance(value, str) and value.isdigit():
return long(value)
elif data_type[column] in {'float','double'}:
return float(value)
return value

except Exception, e:
# could not obtain data type info, use original logic

def cast_if_necessary(c,v):
if isinstance(v, int):
return long(v)
if isinstance(v, str) and v.isdigit():
return long(v)

return v

if parameter:
for column, value in parameter.items():
yield (mysql_quote_identifier(column, 'column'), cast_if_necessary(value))
yield (mysql_quote_identifier(column, 'column'), cast_if_necessary(column, value))


def failed(action):
Expand Down Expand Up @@ -262,11 +302,8 @@ def main():
if module.params["state"] == "absent":
module.fail_json(msg="state=absent is not yet implemented")

# mysql_quote all identifiers and get the parameters into shape
# mysql_quote table name
table = mysql_quote_identifier(module.params['table'], 'table')
identifiers = dict(extract_column_value_maps(module.params['identifiers']))
values = dict(extract_column_value_maps(module.params['values']))
defaults = dict(extract_column_value_maps(module.params['defaults']))

exit_messages = {
INSERT_REQUIRED: dict(changed=True, msg='No such record, need to insert'),
Expand All @@ -276,6 +313,11 @@ def main():
}

with closing(connect(build_connection_parameter(module.params), module)) as db_connection:
# mysql_quote all identifiers and get the parameters into shape
identifiers = dict(extract_column_value_maps(db_connection.cursor(), module.params, 'identifiers'))
values = dict(extract_column_value_maps(db_connection.cursor(), module.params, 'values'))
defaults = dict(extract_column_value_maps(db_connection.cursor(), module.params, 'defaults'))

# find out what needs to be done (independently of check-mode)
required_action, diff = change_required(db_connection.cursor(), table, identifiers, values)

Expand Down

0 comments on commit 45e04ee

Please sign in to comment.