Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
lua-resty-mysql/lib/resty/mysql.lua
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1417 lines (1048 sloc)
34.2 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| -- Copyright (C) Yichun Zhang (agentzh) | |
| local bit = require "bit" | |
| local resty_sha256 = require "resty.sha256" | |
| local sub = string.sub | |
| local tcp = ngx.socket.tcp | |
| local strbyte = string.byte | |
| local strchar = string.char | |
| local strfind = string.find | |
| local format = string.format | |
| local strrep = string.rep | |
| local null = ngx.null | |
| local band = bit.band | |
| local bxor = bit.bxor | |
| local bor = bit.bor | |
| local lshift = bit.lshift | |
| local rshift = bit.rshift | |
| local tohex = bit.tohex | |
| local sha1 = ngx.sha1_bin | |
| local concat = table.concat | |
| local setmetatable = setmetatable | |
| local error = error | |
| local tonumber = tonumber | |
| local to_int = math.floor | |
| local has_rsa, resty_rsa = pcall(require, "resty.rsa") | |
| if not ngx.config then | |
| error("ngx_lua 0.9.11+ or ngx_stream_lua required") | |
| end | |
| if (not ngx.config.subsystem | |
| or ngx.config.subsystem == "http") -- subsystem is http | |
| and (not ngx.config.ngx_lua_version | |
| or ngx.config.ngx_lua_version < 9011) -- old version | |
| then | |
| error("ngx_lua 0.9.11+ required") | |
| end | |
| local ok, new_tab = pcall(require, "table.new") | |
| if not ok then | |
| new_tab = function (narr, nrec) return {} end | |
| end | |
| local _M = { _VERSION = '0.26' } | |
| -- constants | |
| local STATE_CONNECTED = 1 | |
| local STATE_COMMAND_SENT = 2 | |
| local COM_QUIT = 0x01 | |
| local COM_QUERY = 0x03 | |
| -- refer to https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags | |
| -- CLIENT_LONG_PASSWORD | CLIENT_FOUND_ROWS | CLIENT_LONG_FLAG | |
| -- | CLIENT_CONNECT_WITH_DB | CLIENT_ODBC | CLIENT_LOCAL_FILES | |
| -- | CLIENT_IGNORE_SPACE | CLIENT_PROTOCOL_41 | CLIENT_INTERACTIVE | |
| -- | CLIENT_IGNORE_SIGPIPE | CLIENT_TRANSACTIONS | CLIENT_RESERVED | |
| -- | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_STATEMENTS | CLIENT_MULTI_RESULTS | |
| local DEFAULT_CLIENT_FLAGS = 0x3f7cf | |
| local CLIENT_SSL = 0x00000800 | |
| local CLIENT_PLUGIN_AUTH = 0x00080000 | |
| local CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000 | |
| local DEFAULT_AUTH_PLUGIN = "mysql_native_password" | |
| local SERVER_MORE_RESULTS_EXISTS = 8 | |
| local RESP_OK = "OK" | |
| local RESP_AUTHMOREDATA = "AUTHMOREDATA" | |
| local RESP_LOCALINFILE = "LOCALINFILE" | |
| local RESP_EOF = "EOF" | |
| local RESP_ERR = "ERR" | |
| local RESP_DATA = "DATA" | |
| local MY_RND_MAX_VAL = 0x3FFFFFFF | |
| local MIN_PROTOCOL_VER = 10 | |
| local LEN_NATIVE_SCRAMBLE = 20 | |
| local LEN_OLD_SCRAMBLE = 8 | |
| -- 16MB - 1, the default max allowed packet size used by libmysqlclient | |
| local FULL_PACKET_SIZE = 16777215 | |
| -- the following charset map is generated from the following mysql query: | |
| -- SELECT CHARACTER_SET_NAME, ID | |
| -- FROM information_schema.collations | |
| -- WHERE IS_DEFAULT = 'Yes' ORDER BY id; | |
| local CHARSET_MAP = { | |
| _default = 0, | |
| big5 = 1, | |
| dec8 = 3, | |
| cp850 = 4, | |
| hp8 = 6, | |
| koi8r = 7, | |
| latin1 = 8, | |
| latin2 = 9, | |
| swe7 = 10, | |
| ascii = 11, | |
| ujis = 12, | |
| sjis = 13, | |
| hebrew = 16, | |
| tis620 = 18, | |
| euckr = 19, | |
| koi8u = 22, | |
| gb2312 = 24, | |
| greek = 25, | |
| cp1250 = 26, | |
| gbk = 28, | |
| latin5 = 30, | |
| armscii8 = 32, | |
| utf8 = 33, | |
| ucs2 = 35, | |
| cp866 = 36, | |
| keybcs2 = 37, | |
| macce = 38, | |
| macroman = 39, | |
| cp852 = 40, | |
| latin7 = 41, | |
| utf8mb4 = 45, | |
| cp1251 = 51, | |
| utf16 = 54, | |
| utf16le = 56, | |
| cp1256 = 57, | |
| cp1257 = 59, | |
| utf32 = 60, | |
| binary = 63, | |
| geostd8 = 92, | |
| cp932 = 95, | |
| eucjpms = 97, | |
| gb18030 = 248 | |
| } | |
| local mt = { __index = _M } | |
| -- mysql field value type converters | |
| local converters = new_tab(0, 9) | |
| for i = 0x01, 0x05 do | |
| -- tiny, short, long, float, double | |
| converters[i] = tonumber | |
| end | |
| converters[0x00] = tonumber -- decimal | |
| -- converters[0x08] = tonumber -- long long | |
| converters[0x09] = tonumber -- int24 | |
| converters[0x0d] = tonumber -- year | |
| converters[0xf6] = tonumber -- newdecimal | |
| local function _get_byte2(data, i) | |
| local a, b = strbyte(data, i, i + 1) | |
| return bor(a, lshift(b, 8)), i + 2 | |
| end | |
| local function _get_byte3(data, i) | |
| local a, b, c = strbyte(data, i, i + 2) | |
| return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 | |
| end | |
| local function _get_byte4(data, i) | |
| local a, b, c, d = strbyte(data, i, i + 3) | |
| return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 | |
| end | |
| local function _get_byte8(data, i) | |
| local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) | |
| -- XXX workaround for the lack of 64-bit support in bitop: | |
| -- XXX return results in the range of signed 32 bit numbers | |
| local lo = bor(a, lshift(b, 8), lshift(c, 16)) | |
| local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) | |
| return lo + 16777216 * d + hi * 4294967296, i + 8 | |
| -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), | |
| -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 | |
| end | |
| local function _set_byte2(n) | |
| return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) | |
| end | |
| local function _set_byte3(n) | |
| return strchar(band(n, 0xff), | |
| band(rshift(n, 8), 0xff), | |
| band(rshift(n, 16), 0xff)) | |
| end | |
| local function _set_byte4(n) | |
| return strchar(band(n, 0xff), | |
| band(rshift(n, 8), 0xff), | |
| band(rshift(n, 16), 0xff), | |
| band(rshift(n, 24), 0xff)) | |
| end | |
| local function _from_cstring(data, i) | |
| local last = strfind(data, "\0", i, true) | |
| if not last then | |
| return nil, nil | |
| end | |
| return sub(data, i, last - 1), last + 1 | |
| end | |
| local function _to_cstring(data) | |
| return data .. "\0" | |
| end | |
| local function _dump(data) | |
| local len = #data | |
| local bytes = new_tab(len, 0) | |
| for i = 1, len do | |
| bytes[i] = format("%x", strbyte(data, i)) | |
| end | |
| return concat(bytes, " ") | |
| end | |
| local function _dumphex(data) | |
| local len = #data | |
| local bytes = new_tab(len, 0) | |
| for i = 1, len do | |
| bytes[i] = tohex(strbyte(data, i), 2) | |
| end | |
| return concat(bytes, " ") | |
| end | |
| local function _pwd_hash(password) | |
| local add = 7 | |
| local hash1 = 1345345333 | |
| local hash2 = 0x12345671 | |
| local len = #password | |
| for i = 1, len do | |
| -- skip spaces and tabs in password | |
| local byte = strbyte(password, i) | |
| if byte ~= 32 and byte ~= 9 then -- not ' ' or '\t' | |
| hash1 = bxor(hash1, (band(hash1, 63) + add) * byte | |
| + lshift(hash1, 8)) | |
| hash2 = bxor(lshift(hash2, 8), hash1) + hash2 | |
| add = add + byte | |
| end | |
| end | |
| -- remove sign bit (1<<31)-1) | |
| return band(hash1, 0x7FFFFFFF), band(hash2, 0x7FFFFFFF) | |
| end | |
| local function _random_byte(seed1, seed2) | |
| seed1 = (seed1 * 3 + seed2) % MY_RND_MAX_VAL | |
| seed2 = (seed1 + seed2 + 33) % MY_RND_MAX_VAL | |
| return to_int(seed1 * 31 / MY_RND_MAX_VAL), seed1, seed2 | |
| end | |
| local function _compute_old_token(password, scramble) | |
| if password == "" then | |
| return "" | |
| end | |
| scramble = sub(scramble, 1, LEN_OLD_SCRAMBLE) | |
| local hash_pw1, hash_pw2 = _pwd_hash(password) | |
| local hash_sc1, hash_sc2 = _pwd_hash(scramble) | |
| local seed1 = bxor(hash_pw1, hash_sc1) % MY_RND_MAX_VAL | |
| local seed2 = bxor(hash_pw2, hash_sc2) % MY_RND_MAX_VAL | |
| local rand_byte | |
| local bytes = new_tab(LEN_OLD_SCRAMBLE, 0) | |
| for i = 1, LEN_OLD_SCRAMBLE do | |
| rand_byte, seed1, seed2 = _random_byte(seed1, seed2) | |
| bytes[i] = rand_byte + 64 | |
| end | |
| rand_byte = _random_byte(seed1, seed2) | |
| for i = 1, LEN_OLD_SCRAMBLE do | |
| bytes[i] = strchar(bxor(bytes[i], rand_byte)) | |
| end | |
| return _to_cstring(concat(bytes)) | |
| end | |
| local function _compute_sha256_token(password, scramble) | |
| if password == "" then | |
| return "" | |
| end | |
| local sha256 = resty_sha256:new() | |
| if not sha256 then | |
| return nil, "failed to create the sha256 object" | |
| end | |
| if not sha256:update(password) then | |
| return nil, "failed to update string to sha256" | |
| end | |
| local message1 = sha256:final() | |
| sha256:reset() | |
| if not sha256:update(message1) then | |
| return nil, "failed to update string to sha256" | |
| end | |
| local message1_hash = sha256:final() | |
| sha256:reset() | |
| if not sha256:update(message1_hash) then | |
| return nil, "failed to update string to sha256" | |
| end | |
| if not sha256:update(scramble) then | |
| return nil, "failed to update string to sha256" | |
| end | |
| local message2 = sha256:final() | |
| local n = #message2 | |
| local bytes = new_tab(n, 0) | |
| for i = 1, n do | |
| bytes[i] = strchar(bxor(strbyte(message1, i), strbyte(message2, i))) | |
| end | |
| return concat(bytes) | |
| end | |
| local function _compute_token(password, scramble) | |
| if password == "" then | |
| return "" | |
| end | |
| scramble = sub(scramble, 1, LEN_NATIVE_SCRAMBLE) | |
| local stage1 = sha1(password) | |
| local stage2 = sha1(stage1) | |
| local stage3 = sha1(scramble .. stage2) | |
| local n = #stage1 | |
| local bytes = new_tab(n, 0) | |
| for i = 1, n do | |
| bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) | |
| end | |
| return concat(bytes) | |
| end | |
| local function _send_packet(self, req, size) | |
| local sock = self.sock | |
| self.packet_no = self.packet_no + 1 | |
| -- print("packet no: ", self.packet_no) | |
| local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req | |
| -- print("sending packet: ", _dump(packet)) | |
| -- print("sending packet... of size " .. #packet) | |
| return sock:send(packet) | |
| end | |
| local function _recv_packet(self) | |
| local sock = self.sock | |
| local data, err = sock:receive(4) -- packet header | |
| if not data then | |
| return nil, nil, "failed to receive packet header: " .. err | |
| end | |
| --print("packet header: ", _dump(data)) | |
| local len, pos = _get_byte3(data, 1) | |
| --print("packet length: ", len) | |
| if len == 0 then | |
| return nil, nil, "empty packet" | |
| end | |
| if len > self._max_packet_size then | |
| return nil, nil, "packet size too big: " .. len | |
| end | |
| local num = strbyte(data, pos) | |
| --print("recv packet: packet no: ", num) | |
| self.packet_no = num | |
| data, err = sock:receive(len) | |
| --print("receive returned") | |
| if not data then | |
| return nil, nil, "failed to read packet content: " .. err | |
| end | |
| --print("packet content: ", _dump(data)) | |
| --print("packet content (ascii): ", data) | |
| local field_count = strbyte(data, 1) | |
| local typ | |
| if field_count == 0x00 then | |
| typ = RESP_OK | |
| elseif field_count == 0x01 then | |
| typ = RESP_AUTHMOREDATA | |
| elseif field_count == 0xfb then | |
| typ = RESP_LOCALINFILE | |
| elseif field_count == 0xfe then | |
| typ = RESP_EOF | |
| elseif field_count == 0xff then | |
| typ = RESP_ERR | |
| else | |
| typ = RESP_DATA | |
| end | |
| return data, typ | |
| end | |
| local function _from_length_coded_bin(data, pos) | |
| local first = strbyte(data, pos) | |
| --print("LCB: first: ", first) | |
| if not first then | |
| return nil, pos | |
| end | |
| if first >= 0 and first <= 250 then | |
| return first, pos + 1 | |
| end | |
| if first == 251 then | |
| return null, pos + 1 | |
| end | |
| if first == 252 then | |
| pos = pos + 1 | |
| return _get_byte2(data, pos) | |
| end | |
| if first == 253 then | |
| pos = pos + 1 | |
| return _get_byte3(data, pos) | |
| end | |
| if first == 254 then | |
| pos = pos + 1 | |
| return _get_byte8(data, pos) | |
| end | |
| return nil, pos + 1 | |
| end | |
| local function _from_length_coded_str(data, pos) | |
| local len | |
| len, pos = _from_length_coded_bin(data, pos) | |
| if not len or len == null then | |
| return null, pos | |
| end | |
| return sub(data, pos, pos + len - 1), pos + len | |
| end | |
| local function _parse_ok_packet(packet) | |
| local res = new_tab(0, 5) | |
| local pos | |
| res.affected_rows, pos = _from_length_coded_bin(packet, 2) | |
| --print("affected rows: ", res.affected_rows, ", pos:", pos) | |
| res.insert_id, pos = _from_length_coded_bin(packet, pos) | |
| --print("insert id: ", res.insert_id, ", pos:", pos) | |
| res.server_status, pos = _get_byte2(packet, pos) | |
| --print("server status: ", res.server_status, ", pos:", pos) | |
| res.warning_count, pos = _get_byte2(packet, pos) | |
| --print("warning count: ", res.warning_count, ", pos: ", pos) | |
| local message = _from_length_coded_str(packet, pos) | |
| if message and message ~= null then | |
| res.message = message | |
| end | |
| --print("message: ", res.message, ", pos:", pos) | |
| return res | |
| end | |
| local function _parse_eof_packet(packet) | |
| local pos = 2 | |
| local warning_count, pos = _get_byte2(packet, pos) | |
| local status_flags = _get_byte2(packet, pos) | |
| return warning_count, status_flags | |
| end | |
| local function _parse_err_packet(packet) | |
| local errno, pos = _get_byte2(packet, 2) | |
| local marker = sub(packet, pos, pos) | |
| local sqlstate | |
| if marker == '#' then | |
| -- with sqlstate | |
| pos = pos + 1 | |
| sqlstate = sub(packet, pos, pos + 5 - 1) | |
| pos = pos + 5 | |
| end | |
| local message = sub(packet, pos) | |
| return errno, message, sqlstate | |
| end | |
| local function _parse_result_set_header_packet(packet) | |
| local field_count, pos = _from_length_coded_bin(packet, 1) | |
| local extra | |
| extra = _from_length_coded_bin(packet, pos) | |
| return field_count, extra | |
| end | |
| local function _parse_field_packet(data) | |
| local col = new_tab(0, 2) | |
| local catalog, db, table, orig_table, orig_name, charsetnr, length | |
| local pos | |
| catalog, pos = _from_length_coded_str(data, 1) | |
| --print("catalog: ", col.catalog, ", pos:", pos) | |
| db, pos = _from_length_coded_str(data, pos) | |
| table, pos = _from_length_coded_str(data, pos) | |
| orig_table, pos = _from_length_coded_str(data, pos) | |
| col.name, pos = _from_length_coded_str(data, pos) | |
| orig_name, pos = _from_length_coded_str(data, pos) | |
| pos = pos + 1 -- ignore the filler | |
| charsetnr, pos = _get_byte2(data, pos) | |
| length, pos = _get_byte4(data, pos) | |
| col.type = strbyte(data, pos) | |
| --[[ | |
| pos = pos + 1 | |
| col.flags, pos = _get_byte2(data, pos) | |
| col.decimals = strbyte(data, pos) | |
| pos = pos + 1 | |
| local default = sub(data, pos + 2) | |
| if default and default ~= "" then | |
| col.default = default | |
| end | |
| --]] | |
| return col | |
| end | |
| local function _parse_row_data_packet(data, cols, compact) | |
| local pos = 1 | |
| local ncols = #cols | |
| local row | |
| if compact then | |
| row = new_tab(ncols, 0) | |
| else | |
| row = new_tab(0, ncols) | |
| end | |
| for i = 1, ncols do | |
| local value | |
| value, pos = _from_length_coded_str(data, pos) | |
| local col = cols[i] | |
| local typ = col.type | |
| local name = col.name | |
| --print("row field value: ", value, ", type: ", typ) | |
| if value ~= null then | |
| local conv = converters[typ] | |
| if conv then | |
| value = conv(value) | |
| end | |
| end | |
| if compact then | |
| row[i] = value | |
| else | |
| row[name] = value | |
| end | |
| end | |
| return row | |
| end | |
| local function _recv_field_packet(self) | |
| local packet, typ, err = _recv_packet(self) | |
| if not packet then | |
| return nil, err | |
| end | |
| if typ == RESP_ERR then | |
| local errno, msg, sqlstate = _parse_err_packet(packet) | |
| return nil, msg, errno, sqlstate | |
| end | |
| if typ ~= RESP_DATA then | |
| return nil, "bad field packet type: " .. typ | |
| end | |
| -- typ == RESP_DATA | |
| return _parse_field_packet(packet) | |
| end | |
| -- refer to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html | |
| local function _read_hand_shake_packet(self) | |
| local packet, typ, err = _recv_packet(self) | |
| if not packet then | |
| return nil, nil, err | |
| end | |
| if typ == RESP_ERR then | |
| local errno, msg, sqlstate = _parse_err_packet(packet) | |
| return nil, nil, msg, errno, sqlstate | |
| end | |
| local protocol_ver = tonumber(strbyte(packet)) | |
| if not protocol_ver then | |
| return nil, nil, | |
| "bad handshake initialization packet: bad protocol version" | |
| end | |
| if protocol_ver < MIN_PROTOCOL_VER then | |
| return nil, nil, "unsupported protocol version " .. protocol_ver | |
| .. ", version " .. MIN_PROTOCOL_VER | |
| .. " or higher is required" | |
| end | |
| self.protocol_ver = protocol_ver | |
| local server_ver, pos = _from_cstring(packet, 2) | |
| if not server_ver then | |
| return nil, nil, | |
| "bad handshake initialization packet: bad server version" | |
| end | |
| self._server_ver = server_ver | |
| local thread_id, pos = _get_byte4(packet, pos) | |
| local scramble = sub(packet, pos, pos + 8 - 1) | |
| if not scramble then | |
| return nil, nil, "1st part of scramble not found" | |
| end | |
| pos = pos + 9 -- skip filler(8 + 1) | |
| -- two lower bytes | |
| local capabilities -- server capabilities | |
| capabilities, pos = _get_byte2(packet, pos) | |
| self._server_lang = strbyte(packet, pos) | |
| pos = pos + 1 | |
| self._server_status, pos = _get_byte2(packet, pos) | |
| local more_capabilities | |
| more_capabilities, pos = _get_byte2(packet, pos) | |
| self.capabilities = bor(capabilities, lshift(more_capabilities, 16)) | |
| pos = pos + 11 -- skip length of auth-plugin-data(1) and reserved(10) | |
| -- follow official Python library uses the fixed length 12 | |
| -- and the 13th byte is "\0 byte | |
| local scramble_part2 = sub(packet, pos, pos + 12 - 1) | |
| if not scramble_part2 then | |
| return nil, nil, "2nd part of scramble not found" | |
| end | |
| pos = pos + 13 | |
| local plugin, _ | |
| if band(self.capabilities, CLIENT_PLUGIN_AUTH) > 0 then | |
| plugin, _ = _from_cstring(packet, pos) | |
| if not plugin then | |
| -- EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) | |
| -- \NUL otherwise | |
| plugin = sub(packet, pos) | |
| end | |
| else | |
| plugin = DEFAULT_AUTH_PLUGIN | |
| end | |
| return scramble .. scramble_part2, plugin | |
| end | |
| local function _append_auth_length(self, data) | |
| local n = #data | |
| if n <= 250 then | |
| data = strchar(n) .. data | |
| return data, 1 + n | |
| end | |
| self.DEFAULT_CLIENT_FLAGS = bor(self.DEFAULT_CLIENT_FLAGS, | |
| CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) | |
| if n <= 0xffff then | |
| data = strchar(0xfc, band(n, 0xff), band(rshift(n, 8), 0xff)) .. data | |
| return data, 3 + n | |
| end | |
| if n <= 0xffffff then | |
| data = strchar(0xfd, | |
| band(n, 0xff), | |
| band(rshift(n, 8), 0xff), | |
| band(rshift(n, 16), 0xff)) | |
| .. data | |
| return data, 4 + n | |
| end | |
| data = strchar(0xfe, | |
| band(n, 0xff), | |
| band(rshift(n, 8), 0xff), | |
| band(rshift(n, 16), 0xff), | |
| band(rshift(n, 24), 0xff), | |
| band(rshift(n, 32), 0xff), | |
| band(rshift(n, 40), 0xff), | |
| band(rshift(n, 48), 0xff), | |
| band(rshift(n, 56), 0xff)) | |
| .. data | |
| return data, 9 + n | |
| end | |
| local function _write_hand_shake_response(self, auth_resp, plugin) | |
| local append_auth, len = _append_auth_length(self, auth_resp) | |
| if self.use_ssl then | |
| if band(self.capabilities, CLIENT_SSL) == 0 then | |
| return "ssl disabled on server" | |
| end | |
| -- send a SSL Request Packet | |
| local req = _set_byte4(bor(self.DEFAULT_CLIENT_FLAGS, CLIENT_SSL)) | |
| .. _set_byte4(self._max_packet_size) | |
| .. strchar(self.charset) | |
| .. strrep("\0", 23) | |
| local packet_len = 4 + 4 + 1 + 23 | |
| local bytes, err = _send_packet(self, req, packet_len) | |
| if not bytes then | |
| return "failed to send client authentication packet: " .. err | |
| end | |
| local sock = self.sock | |
| local ok, err = sock:sslhandshake(false, nil, self.ssl_verify) | |
| if not ok then | |
| return "failed to do ssl handshake: " .. (err or "") | |
| end | |
| end | |
| local req = _set_byte4(self.DEFAULT_CLIENT_FLAGS) | |
| .. _set_byte4(self._max_packet_size) | |
| .. strchar(self.charset) | |
| .. strrep("\0", 23) | |
| .. _to_cstring(self.user) | |
| .. append_auth | |
| .. _to_cstring(self.database) | |
| .. _to_cstring(plugin) | |
| local packet_len = 4 + 4 + 1 + 23 + #self.user + 1 | |
| + len + #self.database + 1 + #plugin + 1 | |
| local bytes, err = _send_packet(self, req, packet_len) | |
| if not bytes then | |
| return "failed to send client authentication packet: " .. err | |
| end | |
| return nil | |
| end | |
| local function _read_auth_result(self, old_auth_data, plugin) | |
| local packet, typ, err = _recv_packet(self) | |
| if not packet then | |
| return nil, nil, "failed to receive the result packet: " .. err | |
| end | |
| if typ == RESP_OK then | |
| return RESP_OK, "" | |
| end | |
| if typ == RESP_AUTHMOREDATA then | |
| return sub(packet, 2), "" | |
| end | |
| if typ == RESP_EOF then | |
| if #packet == 1 then -- old pre-4.1 authentication protocol | |
| return nil, "mysql_old_password" | |
| end | |
| local pos | |
| plugin, pos = _from_cstring(packet, 2) | |
| if not plugin then | |
| return nil, nil, "malformed packet" | |
| end | |
| return sub(packet, pos), plugin | |
| end | |
| if typ == RESP_ERR then | |
| local errno, msg, sqlstate = _parse_err_packet(packet) | |
| return errno, sqlstate, msg | |
| end | |
| return nil, nil, "bad packet type: " .. typ | |
| end | |
| local function _read_ok_result(self) | |
| local packet, typ, err = _recv_packet(self) | |
| if not packet then | |
| return "failed to receive the result packet: " .. err | |
| end | |
| if typ == RESP_ERR then | |
| local errno, msg, sqlstate = _parse_err_packet(packet) | |
| return msg, errno, sqlstate | |
| end | |
| if typ ~= RESP_OK then | |
| return "bad packet type: " .. typ | |
| end | |
| end | |
| local function _encrypt_password(self, auth_data, public_key) | |
| if not has_rsa then | |
| error("auth plugin caching_sha2_password or sha256_password are not" .. | |
| " supported because resty.rsa is not installed", 2) | |
| end | |
| local password = _to_cstring(self.password) | |
| local n = #password | |
| local l = #auth_data | |
| local bytes = new_tab(n, 0) | |
| for i = 1, n do | |
| local j = i % l | |
| bytes[i] = strchar(bxor(strbyte(password, i), strbyte(auth_data, j))) | |
| end | |
| local pub, err = resty_rsa:new({ | |
| public_key = public_key, | |
| key_type = resty_rsa.KEY_TYPE.PKCS8, | |
| padding = resty_rsa.PADDING.RSA_PKCS1_OAEP_PADDING, | |
| algorithm = "sha1", | |
| }) | |
| if not pub then | |
| return nil, "new rsa err: " .. err | |
| end | |
| local enc, err = pub:encrypt(concat(bytes)) | |
| if not enc then | |
| return nil, "encode password packet: " .. err | |
| end | |
| return enc | |
| end | |
| local function _write_encode_password(self, auth_data, public_key) | |
| local enc, err = _encrypt_password(self, auth_data, public_key) | |
| local bytes, err = _send_packet(self, enc, #enc) | |
| if not bytes then | |
| return "failed to send encode password packet: " .. err | |
| end | |
| end | |
| local function _auth(self, auth_data, plugin) | |
| local password = self.password | |
| if plugin == "caching_sha2_password" then | |
| local auth_resp, err = _compute_sha256_token(password, auth_data) | |
| if err then | |
| return nil, "failed to compute sha256 token: " .. err | |
| end | |
| return auth_resp | |
| end | |
| if plugin == "mysql_old_password" then | |
| return _compute_old_token(password, auth_data) | |
| end | |
| if plugin == "mysql_clear_password" then | |
| return _to_cstring(password) | |
| end | |
| if plugin == "mysql_native_password" then | |
| return _compute_token(password, auth_data) | |
| end | |
| if plugin == "sha256_password" then | |
| if self.is_unix or self.use_ssl or #password == 0 then | |
| return _to_cstring(password) | |
| end | |
| local public_key = self.public_key | |
| if public_key then | |
| return _encrypt_password(self, auth_data, public_key) | |
| end | |
| return "\1" -- request public key from server | |
| end | |
| return nil, "unknown plugin: " .. plugin | |
| end | |
| local function _handle_auth_result(self, old_auth_data, plugin) | |
| local auth_data, new_plugin, err = _read_auth_result(self, old_auth_data, | |
| plugin) | |
| if err ~= nil then | |
| local errno, sqlstate = auth_data, new_plugin | |
| return err, errno, sqlstate | |
| end | |
| if auth_data == RESP_OK then | |
| return | |
| end | |
| if new_plugin ~= "" then | |
| if not auth_data then | |
| auth_data = old_auth_data | |
| else | |
| old_auth_data = auth_data | |
| end | |
| plugin = new_plugin | |
| local auth_resp, err = _auth(self, auth_data, plugin) | |
| if not auth_resp then | |
| return err | |
| end | |
| local bytes, err = _send_packet(self, auth_resp, #auth_resp) | |
| if not bytes then | |
| return "failed to send client authentication packet: " .. err | |
| end | |
| auth_data, new_plugin, err = _read_auth_result(self, old_auth_data, | |
| plugin) | |
| if err ~= nil then | |
| local errno, sqlstate = auth_data, new_plugin | |
| return err, errno, sqlstate | |
| end | |
| if auth_data == RESP_OK then | |
| return | |
| end | |
| if new_plugin ~= "" then | |
| return "malformed packet" | |
| end | |
| end | |
| if plugin == "caching_sha2_password" then | |
| local len = #auth_data | |
| if len == 0 then | |
| return | |
| end | |
| if len == 1 then | |
| local status = strbyte(auth_data) | |
| -- caching_sha2_password fast auth success | |
| if status == 3 then | |
| return _read_ok_result(self) | |
| end | |
| -- caching_sha2_password perform full authentication | |
| if status == 4 then | |
| if self.is_unix or self.use_ssl then | |
| local bytes, err = _send_packet(self, | |
| _to_cstring(self.password), | |
| #self.password + 1) | |
| if not bytes then | |
| return "failed to send cleartext auth packet: " | |
| .. err | |
| end | |
| else | |
| local public_key = self.public_key | |
| if not public_key then | |
| -- caching_sha2_password request public_key | |
| local bytes, err = _send_packet(self, "\2", 1) | |
| if not bytes then | |
| return "failed to send password request packet: " | |
| .. err | |
| end | |
| local packet, _, err = _recv_packet(self) | |
| if not packet then | |
| return "failed to receive the result packet: " | |
| .. err | |
| end | |
| public_key = sub(packet, 2) | |
| end | |
| err = _write_encode_password(self, old_auth_data, | |
| public_key) | |
| if err then | |
| return err | |
| end | |
| self.public_key = public_key | |
| end | |
| return _read_ok_result(self) | |
| end | |
| end | |
| return "malformed packet" | |
| end | |
| if plugin == "sha256_password" then | |
| if #auth_data ~= 0 then | |
| local enc, err = _write_encode_password(self, old_auth_data, | |
| auth_data) | |
| if err then | |
| return err | |
| end | |
| return _read_ok_result(self) | |
| end | |
| end | |
| end | |
| function _M.new() | |
| local sock, err = tcp() | |
| if not sock then | |
| return nil, err | |
| end | |
| return setmetatable({ sock = sock }, mt) | |
| end | |
| function _M.set_timeout(self, timeout) | |
| local sock = self.sock | |
| if not sock then | |
| return nil, "not initialized" | |
| end | |
| return sock:settimeout(timeout) | |
| end | |
| function _M.connect(self, opts) | |
| local sock = self.sock | |
| if not sock then | |
| return nil, "not initialized" | |
| end | |
| local max_packet_size = opts.max_packet_size | |
| if not max_packet_size then | |
| max_packet_size = 1024 * 1024 -- default 1 MB | |
| end | |
| self._max_packet_size = max_packet_size | |
| local ok, err | |
| self.compact = opts.compact_arrays | |
| self.database = opts.database or "" | |
| self.user = opts.user or "" | |
| self.charset = CHARSET_MAP[opts.charset or "_default"] | |
| if not self.charset then | |
| return nil, "charset '" .. opts.charset .. "' is not supported" | |
| end | |
| local pool = opts.pool | |
| self.ssl_verify = opts.ssl_verify | |
| self.use_ssl = opts.ssl or opts.ssl_verify | |
| self.password = opts.password or "" | |
| local host = opts.host | |
| if host then | |
| local port = opts.port or 3306 | |
| if not pool then | |
| pool = self.user .. ":" .. self.database .. ":" .. host .. ":" | |
| .. port | |
| end | |
| ok, err = sock:connect(host, port, { pool = pool, | |
| pool_size = opts.pool_size, | |
| backlog = opts.backlog }) | |
| else | |
| local path = opts.path | |
| if not path then | |
| return nil, 'neither "host" nor "path" options are specified' | |
| end | |
| if not pool then | |
| pool = self.user .. ":" .. self.database .. ":" .. path | |
| end | |
| self.is_unix = true | |
| ok, err = sock:connect("unix:" .. path, { pool = pool, | |
| pool_size = opts.pool_size, | |
| backlog = opts.backlog }) | |
| end | |
| if not ok then | |
| return nil, 'failed to connect: ' .. err | |
| end | |
| local reused = sock:getreusedtimes() | |
| if reused and reused > 0 then | |
| self.state = STATE_CONNECTED | |
| return 1 | |
| end | |
| self.DEFAULT_CLIENT_FLAGS = bor(DEFAULT_CLIENT_FLAGS, CLIENT_PLUGIN_AUTH) | |
| local auth_data, plugin, err, errno, sqlstate | |
| = _read_hand_shake_packet(self) | |
| if err ~= nil then | |
| return nil, err | |
| end | |
| local auth_resp, err = _auth(self, auth_data, plugin) | |
| if not auth_resp then | |
| return nil, err | |
| end | |
| err = _write_hand_shake_response(self, auth_resp, plugin) | |
| if err ~= nil then | |
| return nil, err | |
| end | |
| local err, errno, sqlstate = _handle_auth_result(self, auth_data, plugin) | |
| if err ~= nil then | |
| return nil, err, errno, sqlstate | |
| end | |
| self.state = STATE_CONNECTED | |
| return 1 | |
| end | |
| function _M.set_keepalive(self, ...) | |
| local sock = self.sock | |
| if not sock then | |
| return nil, "not initialized" | |
| end | |
| if self.state ~= STATE_CONNECTED then | |
| return nil, "cannot be reused in the current connection state: " | |
| .. (self.state or "nil") | |
| end | |
| self.state = nil | |
| return sock:setkeepalive(...) | |
| end | |
| function _M.get_reused_times(self) | |
| local sock = self.sock | |
| if not sock then | |
| return nil, "not initialized" | |
| end | |
| return sock:getreusedtimes() | |
| end | |
| function _M.close(self) | |
| local sock = self.sock | |
| if not sock then | |
| return nil, "not initialized" | |
| end | |
| self.state = nil | |
| local bytes, err = _send_packet(self, strchar(COM_QUIT), 1) | |
| if not bytes then | |
| return nil, err | |
| end | |
| return sock:close() | |
| end | |
| function _M.server_ver(self) | |
| return self._server_ver | |
| end | |
| local function send_query(self, query) | |
| if self.state ~= STATE_CONNECTED then | |
| return nil, "cannot send query in the current context: " | |
| .. (self.state or "nil") | |
| end | |
| local sock = self.sock | |
| if not sock then | |
| return nil, "not initialized" | |
| end | |
| self.packet_no = -1 | |
| local cmd_packet = strchar(COM_QUERY) .. query | |
| local packet_len = 1 + #query | |
| local bytes, err = _send_packet(self, cmd_packet, packet_len) | |
| if not bytes then | |
| return nil, err | |
| end | |
| self.state = STATE_COMMAND_SENT | |
| --print("packet sent ", bytes, " bytes") | |
| return bytes | |
| end | |
| _M.send_query = send_query | |
| local function read_result(self, est_nrows) | |
| if self.state ~= STATE_COMMAND_SENT then | |
| return nil, "cannot read result in the current context: " | |
| .. (self.state or "nil") | |
| end | |
| local sock = self.sock | |
| if not sock then | |
| return nil, "not initialized" | |
| end | |
| local packet, typ, err = _recv_packet(self) | |
| if not packet then | |
| return nil, err | |
| end | |
| if typ == RESP_ERR then | |
| self.state = STATE_CONNECTED | |
| local errno, msg, sqlstate = _parse_err_packet(packet) | |
| return nil, msg, errno, sqlstate | |
| end | |
| if typ == RESP_OK then | |
| local res = _parse_ok_packet(packet) | |
| if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then | |
| return res, "again" | |
| end | |
| self.state = STATE_CONNECTED | |
| return res | |
| end | |
| if typ == RESP_LOCALINFILE then | |
| self.state = STATE_CONNECTED | |
| return nil, "packet type " .. typ .. " not supported" | |
| end | |
| -- typ == RESP_DATA or RESP_AUTHMOREDATA(also mean RESP_DATA here) | |
| --print("read the result set header packet") | |
| local field_count, extra = _parse_result_set_header_packet(packet) | |
| --print("field count: ", field_count) | |
| local cols = new_tab(field_count, 0) | |
| for i = 1, field_count do | |
| local col, err, errno, sqlstate = _recv_field_packet(self) | |
| if not col then | |
| return nil, err, errno, sqlstate | |
| end | |
| cols[i] = col | |
| end | |
| local packet, typ, err = _recv_packet(self) | |
| if not packet then | |
| return nil, err | |
| end | |
| if typ ~= RESP_EOF then | |
| return nil, "unexpected packet type " .. typ .. " while eof packet is " | |
| .. "expected" | |
| end | |
| -- typ == RESP_EOF | |
| local compact = self.compact | |
| local rows = new_tab(est_nrows or 4, 0) | |
| local i = 0 | |
| while true do | |
| --print("reading a row") | |
| packet, typ, err = _recv_packet(self) | |
| if not packet then | |
| return nil, err | |
| end | |
| if typ == RESP_EOF then | |
| local warning_count, status_flags = _parse_eof_packet(packet) | |
| --print("status flags: ", status_flags) | |
| if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then | |
| return rows, "again" | |
| end | |
| break | |
| end | |
| local row = _parse_row_data_packet(packet, cols, compact) | |
| i = i + 1 | |
| rows[i] = row | |
| end | |
| self.state = STATE_CONNECTED | |
| return rows | |
| end | |
| _M.read_result = read_result | |
| function _M.query(self, query, est_nrows) | |
| local bytes, err = send_query(self, query) | |
| if not bytes then | |
| return nil, "failed to send query: " .. err | |
| end | |
| return read_result(self, est_nrows) | |
| end | |
| function _M.set_compact_arrays(self, value) | |
| self.compact = value | |
| end | |
| return _M |