-
Notifications
You must be signed in to change notification settings - Fork 48
/
copy.py
369 lines (351 loc) · 16.2 KB
/
copy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
from mysqlsh.plugin_manager import plugin, plugin_function
from mysqlsh_plugins_common import get_major_version
from mysqlsh_plugins_common import get_distro
from user import mds
import re
global session_destination
global server_info_destination
global server_info_origin
global session_origin
session_destination = None
session_origin = None
server_info_destination = None
server_info_origin = None
def _get_server_info(session):
return session.run_sql(
"""SELECT CONCAT(@@hostname," (",@@version,")")"""
).fetch_one()[0]
def _connect_to_destination(shell, uri):
try:
session2 = shell.open_session(uri)
except:
print("ERROR: unable to connect {}, aborting...".format(uri))
return None
return session2
@plugin_function("user.copy")
def copy_users_grants(dryrun=False, ocimds=False, force=False, session=None):
"""
Copy a user to another server
Args:
dryrun (bool): Don't run the statements but only shows them.
ocimds (bool): Use OCI MDS compatibility mode. Default is False.
force (bool): Reply "yes" to all questions when the plan is to copy multiple users. Default is False.
session (object): The optional session object used to query the
database. If omitted the MySQL Shell's current session will be used.
"""
global session_destination
global session_origin
global server_info_destination
global server_info_origin
# Get hold of the global shell object
import mysqlsh
shell = mysqlsh.globals.shell
old_format = None
if session is None:
session = shell.get_session()
if session is None:
print(
"No session specified. Either pass a session object to this "
"function or connect the shell to a database"
)
return
if session_origin != session:
server_info_origin = _get_server_info(session)
session_origin = session
if session_destination is None:
answer = shell.prompt(
"You need to specify a destination server (<user@>server<:port>): "
)
if not answer:
print("Aborting, no destination server defined.")
return
else:
session_destination = _connect_to_destination(shell, answer)
if not session_destination:
return
server_info_destination = _get_server_info(session_destination)
else:
answer = shell.prompt(
"The current destination server is [{}] - {}, is that OK ? (y/N) ".format(
session_destination.get_uri(), server_info_destination
),
{"defaultValue": "n"},
)
if answer.lower() == "n":
answer = shell.prompt(
"You wanted to change destination server, please specify one (<user@>server<:port>): "
)
if not answer:
print("Aborting, no destination server defined.")
return
else:
session_destination = _connect_to_destination(shell, answer)
if not session_destination:
return
server_info_destination = _get_server_info(session_destination)
search_string = shell.prompt(
"Enter the user to search (you can use wildcards '%', leave blank for all): "
)
if len(search_string.strip()) > 0:
search_string = 'AND user LIKE "{}"'.format(search_string)
print("Info: locked users and users having expired password are not listed.")
mysql_version = get_major_version(session)
mysql_major_int = int(mysql_version.split(".")[0])
mysql_distro = get_distro(session)
if mysql_distro.lower().startswith("maria"):
print("-- MariaDB Detected")
mysql_version = "5.6"
mysql_major_int = 5
if mysql_major_int >= 8 or mysql_version == "5.7":
# Get the list of users
stmt = """SELECT DISTINCT User, Host,
IF(authentication_string = "","NO", "YES") HAS_PWD
FROM mysql.user
WHERE NOT( `account_locked`="Y" AND `password_expired`="Y" AND `authentication_string`="" ) {}
ORDER BY User, Host;
""".format(
search_string
)
else:
stmt = """SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA='mysql' AND TABLE_NAME='user' AND COLUMN_NAME='password';"""
old_format = session.run_sql(stmt).fetch_all()
if len(old_format) > 0:
stmt = """SELECT DISTINCT User, Host,
IF(password = "","NO", "YES") HAS_PWD
FROM mysql.user
WHERE NOT(`password_expired`="Y" AND `authentication_string`="" ) {}
ORDER BY User, Host;
""".format(
search_string
)
else:
stmt = """SELECT DISTINCT User, Host,
IF(authentication_string = "","NO", "YES") HAS_PWD
FROM mysql.user
WHERE NOT(`password_expired`="Y" AND `authentication_string`="" ) {}
ORDER BY User, Host;
""".format(
search_string
)
users = session.run_sql(stmt).fetch_all()
final_s = ""
if len(users) > 1:
final_s = "s"
print("{} user{} found!".format(len(users), final_s))
for user in users:
if ocimds and user[2] == "NO":
print(
"[`{}`@`{}`] is not compatible with OCI MDS as it has not password, ignoring it...".format(
user[0], user[1]
)
)
continue
if not force:
answer = shell.prompt(
"Do you want to copy [`{}`@`{}`] ? (y/N) ".format(user[0], user[1]),
{"defaultValue": "n"},
)
else:
answer = "y"
if answer.lower() == "y":
if mysql_major_int >= 8:
stmt = """SELECT from_user, from_host FROM mysql.role_edges WHERE to_user = ? and to_host = ?"""
roles = session.run_sql(stmt, [user[0], user[1]]).fetch_all()
if len(roles) > 0:
final_s = " is"
question = "Do you want to copy that role ? (y/N) "
if len(roles) > 1:
final_s = "s are"
question = "Do you want to copy those roles ? (y/N) "
print("The following role{} assigned to the user:".format(final_s))
for role in roles:
print("- `{}`@`{}`".format(role[0], role[1]))
if not force:
answer = shell.prompt(question, {"defaultValue": "n"})
if answer.lower() == "y":
for role in roles:
stmt = """SHOW CREATE USER `{}`@`{}`""".format(
role[0], role[1]
)
create_user = session.run_sql(stmt).fetch_one()[0] + ";"
if dryrun:
print("-- Role `{}`@`{}`".format(role[0], role[1]))
print(
"CREATE ROLE IF NOT EXISTS `{}`@`{}`;".format(
role[0], role[1]
)
)
# print(create_user)
else:
print(
"Copying ROLE `{}`@`{}`: {} - {}--> {} - {}".format(
role[0],
role[1],
session.get_uri(),
server_info_origin,
session_destination.get_uri(),
server_info_destination,
)
)
try:
session_destination.run_sql(
"CREATE ROLE IF NOT EXISTS ?@?;",
[role[0], role[1]],
)
# session_destination.run_sql(create_user)
except mysqlsh.DBError as err:
print("Aborting: {}".format(err))
return
stmt = """SHOW GRANTS FOR `{}`@`{}`""".format(
role[0], role[1]
)
grants = session.run_sql(stmt).fetch_all()
for grant in grants:
if ocimds:
grant_stmt = grant[0]
on_stmt = re.sub(
r"^.*( ON .*\..* TO .*$)", r"\1", grant_stmt
)
grant_stmt_tmp = re.sub("^GRANT ", "", grant_stmt)
grant_stmt_tmp = re.sub(
" ON .*\..* TO .*$", "", grant_stmt_tmp
)
tab_grants = grant_stmt_tmp.split(", ")
tab_list = []
for priv in tab_grants:
for allowed_priv in mds.mds_allowed_privileges:
if allowed_priv == priv:
tab_list.append(allowed_priv)
break
if len(tab_list) > 0:
grant_stmt = (
"GRANT " + ", ".join(tab_list) + on_stmt
)
else:
grant_stmt = None
else:
grant_stmt = grant[0]
if grant_stmt:
if dryrun:
print("{};".format(grant_stmt))
else:
try:
session_destination.run_sql(grant_stmt)
except mysqlsh.DBError as err:
print("Aborting: {}".format(err))
return
else:
print(
"Warning: some grants may fail if the role is not created on the destination server."
)
back_tick = False
if mysql_major_int < 8 and mysql_version != "5.7":
stmt = """SHOW GRANTS FOR `{}`@`{}`""".format(user[0], user[1])
create_user = session.run_sql(stmt).fetch_one()[0] + ";"
if "`{}`@".format(user[0]) in create_user:
create_user = create_user.replace(
" TO `{}`@`".format(user[0]),
"CREATE USER IF NOT EXISTS `{}`@`".format(user[0]),
)
back_tick = True
else:
create_user = create_user.replace(
" TO '{}'@'".format(user[0]),
"CREATE USER IF NOT EXISTS '{}'@'".format(user[0]),
)
create_user = re.sub(
r".*CREATE USER IF NOT", "CREATE USER IF NOT", create_user
)
else:
stmt = """SHOW CREATE USER `{}`@`{}`""".format(user[0], user[1])
create_user = session.run_sql(stmt).fetch_one()[0] + ";"
create_user = create_user.replace(
"CREATE USER '{}'@'".format(user[0]),
"CREATE USER IF NOT EXISTS '{}'@'".format(user[0]),
)
# we need to find the password in binary format
stmt = """SELECT authentication_string, convert(authentication_string using binary) authbin
FROM mysql.user where user='{}' and host='{}'""".format(
user[0], user[1]
)
auth_user = session.run_sql(stmt).fetch_one()
auth_string = auth_user[0]
auth_string_bin = auth_user[1]
hex_string = auth_string_bin.hex()
create_user = re.sub(
r" AS '(.*)' ", r" AS 0x{} ".format(hex_string), create_user
)
if mysql_major_int < 8 and mysql_version != "5.7":
if len(old_format) > 0 and not back_tick:
# we need to find the password
stmt = "SELECT password FROM mysql.user WHERE user='{}' AND host='{}'".format(
user[0], user[1]
)
pwd = session.run_sql(stmt).fetch_one()[0]
create_user = create_user.replace(
"BY PASSWORD",
"WITH 'mysql_native_password' AS '{}'".format(pwd),
)
else:
create_user = create_user.replace(
"BY PASSWORD", "WITH 'mysql_native_password' AS"
)
if dryrun:
print("-- User `{}`@`{}`".format(user[0], user[1]))
else:
print(
"Copying USER `{}`@`{}`: {} - {} --> {} - {} ".format(
user[0],
user[1],
session.get_uri(),
server_info_origin,
session_destination.get_uri(),
server_info_destination,
)
)
try:
session_destination.run_sql(create_user)
except mysqlsh.DBError as err:
print("Aborting: {}".format(err))
return
stmt = """SHOW GRANTS FOR `{}`@`{}`""".format(user[0], user[1])
grants = session.run_sql(stmt).fetch_all()
if not dryrun and len(grants) > 0:
print("Copying GRANTS.", end="")
for grant in grants:
if "IDENTIFIED BY PASSWORD" in grant[0]:
grant_stmt = re.sub(r" IDENTIFIED BY PASSWORD.*$", "", grant[0])
else:
grant_stmt = grant[0]
if ocimds:
on_stmt = re.sub(r"^.*( ON .*\..* TO .*$)", r"\1", grant_stmt)
grant_stmt_tmp = re.sub("^GRANT ", "", grant_stmt)
grant_stmt_tmp = re.sub(" ON .*\..* TO .*$", "", grant_stmt_tmp)
tab_grants = grant_stmt_tmp.split(", ")
tab_list = []
for priv in tab_grants:
for allowed_priv in mds.mds_allowed_privileges:
if allowed_priv == priv:
tab_list.append(allowed_priv)
break
if len(tab_list) > 0:
grant_stmt = "GRANT " + ", ".join(tab_list) + on_stmt
else:
grant_stmt = None
if grant_stmt:
if dryrun:
print("{};".format(grant_stmt))
else:
try:
session_destination.run_sql(grant_stmt)
print(".", end="")
except mysqlsh.DBError as err:
print("\nAborting: {}".format(err))
print(
"You may need to install mysql-client to save the password."
)
return
if not dryrun and len(grants) > 0:
print("\nUser(s) copied successfully!")
return