Skip to content

Commit

Permalink
Fix MySQL reset packet with prepared statements (#611)
Browse files Browse the repository at this point in the history
Fixed bug with mysql PS processing/add reset statement handler
  • Loading branch information
Zhaars committed Dec 19, 2022
1 parent 38dc19e commit 0b47fca
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ jobs:

mysql-ssl:
docker:
- image: cossacklabs/ci-py-go-themis:0.93.0
- image: cossacklabs/ci-py-go-themis:0.94.1
# use the same credentials for mysql db as for postgresql (which support was added first)
# has latest tag on 2018.03.29
- image: cossacklabs/mysql-ssl:5.7.31-1
Expand Down
9 changes: 8 additions & 1 deletion decryptor/mysql/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (packet *Packet) GetData() []byte {

// GetBindParameters returns packet Bind parameters
func (packet *Packet) GetBindParameters(paramNum int) ([]base.BoundValue, error) {
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html#packet-COM_STMT_EXECUTE
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
// 1 - packet header
// 4 - stmt-id
// 1 - flags
Expand Down Expand Up @@ -285,6 +285,13 @@ func (packet *Packet) ReadPacket(connection net.Conn) error {
return err
}

// IsOK return true if packet is OkPacket
func (packet *Packet) IsOK() bool {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html
isOkPacket := packet.data[0] == OkPacket && packet.GetPacketPayloadLength() >= 7
return isOkPacket
}

// IsEOF return true if packet is OkPacket or EOFPacket
func (packet *Packet) IsEOF() bool {
// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
Expand Down
66 changes: 44 additions & 22 deletions decryptor/mysql/response_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,11 @@ func (handler *Handler) ProxyClientConnection(ctx context.Context, errCh chan<-

handler.setQueryHandler(handler.QueryResponseHandler)
break
case CommandStatementClose, CommandStatementSendLongData, CommandStatementReset:
clientLog.Debugln("Close|SendLongData|Reset command")
case CommandStatementClose, CommandStatementSendLongData:
clientLog.Debugln("Close|SendLongData command")
case CommandStatementReset:
clientLog.Debugln("Reset Request Statement")
handler.setQueryHandler(handler.ResetStatementResponseHandler)
default:
clientLog.Debugf("Command %d not supported now", cmd)
}
Expand All @@ -406,32 +409,34 @@ func (handler *Handler) handleStatementExecute(ctx context.Context, packet *Pack
return nil
}

parameters, err := packet.GetBindParameters(statement.ParamsNum())
if err != nil {
log.WithError(err).Error("Can't parse OnBind parameters")
return nil
}

newParameters, changed, err := handler.queryObserverManager.OnBind(ctx, statement.Query(), parameters)
if err != nil {
// Security: here we should interrupt proxying in case of any keys read related errors
// in other cases we just stop the processing to let db protocol handle the error.
if filesystem.IsKeyReadError(err) {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
// we expect list of parameters if the paramsNum > 0
if paramsNum := statement.ParamsNum(); paramsNum > 0 {
parameters, err := packet.GetBindParameters(paramsNum)
if err != nil {
log.WithError(err).Error("Can't parse OnBind parameters")
return err
}

log.WithError(err).Error("Failed to handle Bind packet")
return nil
}

// Finally, if the parameter values have been changed, update the packet.
// If that fails, send the packet unchanged, as usual.
if changed {
err := packet.SetParameters(newParameters)
newParameters, changed, err := handler.queryObserverManager.OnBind(ctx, statement.Query(), parameters)
if err != nil {
log.WithError(err).Error("Failed to update Bind packet")
// Security: here we should interrupt proxying in case of any keys read related errors
// in other cases we just stop the processing to let db protocol handle the error.
if filesystem.IsKeyReadError(err) {
return err
}

log.WithError(err).Error("Failed to handle Bind packet")
return nil
}

// Finally, if the parameter values have been changed, update the packet.
if changed {
if err := packet.SetParameters(newParameters); err != nil {
log.WithError(err).Error("Failed to update Bind packet")
return err
}
}
}

return nil
Expand Down Expand Up @@ -582,6 +587,23 @@ func (handler *Handler) isPreparedStatementResult() bool {
return handler.currentCommand == CommandStatementExecute
}

// ResetStatementResponseHandler handle response for Reset Request Statement
func (handler *Handler) ResetStatementResponseHandler(ctx context.Context, packet *Packet, dbConnection, clientConnection net.Conn) (err error) {
if packet.IsOK() {
handler.logger.Debugln("OK Packet on Reset Request Statement")
} else {
handler.logger.Debugln("Err Packet on Reset Request Statement")
}

handler.resetQueryHandler()

if _, err := clientConnection.Write(packet.Dump()); err != nil {
handler.logger.WithError(err).WithField(logging.FieldKeyEventCode, logging.EventCodeErrorNetworkWrite).
Debugln("Can't proxy output")
}
return nil
}

// QueryResponseHandler parses data from database response
func (handler *Handler) QueryResponseHandler(ctx context.Context, packet *Packet, dbConnection, clientConnection net.Conn) (err error) {
handler.resetQueryHandler()
Expand Down
100 changes: 97 additions & 3 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,82 @@ def execute_prepared_statement(self, query, args=None):
return cursor.fetchall()


# MysqlConnectorCExecutor uses CMySQLConnection type, which sends the client packets different than standard MySQLConnection
# the difference is packets order in PreparedStatements processing
# after sending CommandStatementPrepare, ConnectorC send the CommandStatementReset and expect StatementResetResult(OK|Err)
# then send the empty CommandStatementExecute without params, params come in next packets
# to handle such behaviour properly, MySQL proxy should have StatementReset handler
# and skip params parsing if the params number is 0 in CommandStatementExecute handler
class MysqlConnectorCExecutor(QueryExecutor):
def _result_to_dict(self, description, data):
"""convert list of tuples of rows to list of dicts"""
columns_name = [i[0] for i in description]
result = []
for row in data:
row_data = {column_name: value
for column_name, value in zip(columns_name, row)}
result.append(row_data)
return result

def execute(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.connect(
use_unicode=True, raw=self.connection_args.raw, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
password=self.connection_args.password,
database=self.connection_args.dbname,
ssl_ca=self.connection_args.ssl_ca,
ssl_cert=self.connection_args.ssl_cert,
ssl_key=self.connection_args.ssl_key,
ssl_disabled=not TEST_WITH_TLS)) as connection:

with contextlib.closing(connection.cursor()) as cursor:
cursor.execute(query, args)
data = cursor.fetchall()
result = self._result_to_dict(cursor.description, data)
return result

def execute_prepared_statement(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.connect(
use_unicode=True, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
password=self.connection_args.password,
database=self.connection_args.dbname,
ssl_ca=self.connection_args.ssl_ca,
ssl_cert=self.connection_args.ssl_cert,
ssl_key=self.connection_args.ssl_key,
ssl_disabled=not TEST_WITH_TLS)) as connection:

with contextlib.closing(connection.cursor(prepared=True)) as cursor:
cursor.execute(query, args)
data = cursor.fetchall()
result = self._result_to_dict(cursor.description, data)
return result

def execute_prepared_statement_no_result(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.connect(
use_unicode=True, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
password=self.connection_args.password,
database=self.connection_args.dbname,
ssl_ca=self.connection_args.ssl_ca,
ssl_cert=self.connection_args.ssl_cert,
ssl_key=self.connection_args.ssl_key,
ssl_disabled=not TEST_WITH_TLS)) as connection:

with contextlib.closing(connection.cursor(prepared=True)) as cursor:
cursor.execute(query, args)
connection.commit()


class MysqlExecutor(QueryExecutor):
def _result_to_dict(self, description, data):
"""convert list of tuples of rows to list of dicts"""
Expand All @@ -899,7 +975,7 @@ def _result_to_dict(self, description, data):
def execute(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.Connect(
with contextlib.closing(mysql.connector.connection.MySQLConnection(
use_unicode=False, raw=self.connection_args.raw, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
Expand All @@ -919,7 +995,7 @@ def execute(self, query, args=None):
def execute_prepared_statement(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.Connect(
with contextlib.closing(mysql.connector.connection.MySQLConnection(
use_unicode=False, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
Expand All @@ -939,7 +1015,7 @@ def execute_prepared_statement(self, query, args=None):
def execute_prepared_statement_no_result(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.Connect(
with contextlib.closing(mysql.connector.connection.MySQLConnection(
use_unicode=False, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
Expand Down Expand Up @@ -4264,6 +4340,24 @@ def executePreparedStatement(self, query, args=None):
).execute_prepared_statement(query, args=args)


class TestMysqlConnectorCBinaryPreparedStatement(BasePrepareStatementMixin, BaseTestCase):
def checkSkip(self):
if not TEST_MYSQL:
self.skipTest("run test only for mysql")
elif not TEST_WITH_TLS:
self.skipTest("running tests only with TLS")

def executePreparedStatement(self, query, args=None):
return MysqlConnectorCExecutor(
ConnectionArgs(host='localhost', port=self.ACRASERVER_PORT,
user=DB_USER, password=DB_USER_PASSWORD,
dbname=DB_NAME, ssl_ca=TEST_TLS_CA,
ssl_key=TEST_TLS_CLIENT_KEY,
raw=True,
ssl_cert=TEST_TLS_CLIENT_CERT)
).execute_prepared_statement(query, args=args)


class TestMysqlBinaryPreparedStatementWholeCell(TestMysqlBinaryPreparedStatement):
WHOLECELL_MODE = True

Expand Down

0 comments on commit 0b47fca

Please sign in to comment.