Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MySQL reset packet with prepared statements #611

Merged
merged 3 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ jobs:

mysql-ssl:
docker:
- image: cossacklabs/ci-py-go-themis:0.93.0
- image: cossacklabs/ci-py-go-themis:0.94.1
# use the same credentials for mysql db as for postgresql (which support was added first)
# has latest tag on 2018.03.29
- image: cossacklabs/mysql-ssl:5.7.31-1
Expand Down
9 changes: 8 additions & 1 deletion decryptor/mysql/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (packet *Packet) GetData() []byte {

// GetBindParameters returns packet Bind parameters
func (packet *Packet) GetBindParameters(paramNum int) ([]base.BoundValue, error) {
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html#packet-COM_STMT_EXECUTE
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
// 1 - packet header
// 4 - stmt-id
// 1 - flags
Expand Down Expand Up @@ -285,6 +285,13 @@ func (packet *Packet) ReadPacket(connection net.Conn) error {
return err
}

// IsOK return true if packet is OkPacket
func (packet *Packet) IsOK() bool {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html
isOkPacket := packet.data[0] == OkPacket && packet.GetPacketPayloadLength() >= 7
return isOkPacket
}

// IsEOF return true if packet is OkPacket or EOFPacket
func (packet *Packet) IsEOF() bool {
// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
Expand Down
66 changes: 44 additions & 22 deletions decryptor/mysql/response_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,11 @@ func (handler *Handler) ProxyClientConnection(ctx context.Context, errCh chan<-

handler.setQueryHandler(handler.QueryResponseHandler)
break
case CommandStatementClose, CommandStatementSendLongData, CommandStatementReset:
clientLog.Debugln("Close|SendLongData|Reset command")
case CommandStatementClose, CommandStatementSendLongData:
clientLog.Debugln("Close|SendLongData command")
case CommandStatementReset:
clientLog.Debugln("Reset Request Statement")
handler.setQueryHandler(handler.ResetStatementResponseHandler)
default:
clientLog.Debugf("Command %d not supported now", cmd)
}
Expand All @@ -406,32 +409,34 @@ func (handler *Handler) handleStatementExecute(ctx context.Context, packet *Pack
return nil
}

parameters, err := packet.GetBindParameters(statement.ParamsNum())
if err != nil {
log.WithError(err).Error("Can't parse OnBind parameters")
return nil
}

newParameters, changed, err := handler.queryObserverManager.OnBind(ctx, statement.Query(), parameters)
if err != nil {
// Security: here we should interrupt proxying in case of any keys read related errors
// in other cases we just stop the processing to let db protocol handle the error.
if filesystem.IsKeyReadError(err) {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
// we expect list of parameters if the paramsNum > 0
if paramsNum := statement.ParamsNum(); paramsNum > 0 {
parameters, err := packet.GetBindParameters(paramsNum)
if err != nil {
log.WithError(err).Error("Can't parse OnBind parameters")
return err
}

log.WithError(err).Error("Failed to handle Bind packet")
return nil
}

// Finally, if the parameter values have been changed, update the packet.
// If that fails, send the packet unchanged, as usual.
if changed {
err := packet.SetParameters(newParameters)
newParameters, changed, err := handler.queryObserverManager.OnBind(ctx, statement.Query(), parameters)
if err != nil {
log.WithError(err).Error("Failed to update Bind packet")
// Security: here we should interrupt proxying in case of any keys read related errors
// in other cases we just stop the processing to let db protocol handle the error.
if filesystem.IsKeyReadError(err) {
return err
}

log.WithError(err).Error("Failed to handle Bind packet")
return nil
}

// Finally, if the parameter values have been changed, update the packet.
if changed {
if err := packet.SetParameters(newParameters); err != nil {
log.WithError(err).Error("Failed to update Bind packet")
return err
}
}
}

return nil
Expand Down Expand Up @@ -582,6 +587,23 @@ func (handler *Handler) isPreparedStatementResult() bool {
return handler.currentCommand == CommandStatementExecute
}

// ResetStatementResponseHandler handle response for Reset Request Statement
func (handler *Handler) ResetStatementResponseHandler(ctx context.Context, packet *Packet, dbConnection, clientConnection net.Conn) (err error) {
if packet.IsOK() {
handler.logger.Debugln("OK Packet on Reset Request Statement")
} else {
handler.logger.Debugln("Err Packet on Reset Request Statement")
}

handler.resetQueryHandler()

if _, err := clientConnection.Write(packet.Dump()); err != nil {
handler.logger.WithError(err).WithField(logging.FieldKeyEventCode, logging.EventCodeErrorNetworkWrite).
Debugln("Can't proxy output")
}
return nil
}

// QueryResponseHandler parses data from database response
func (handler *Handler) QueryResponseHandler(ctx context.Context, packet *Packet, dbConnection, clientConnection net.Conn) (err error) {
handler.resetQueryHandler()
Expand Down
100 changes: 97 additions & 3 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,82 @@ def execute_prepared_statement(self, query, args=None):
return cursor.fetchall()


# MysqlConnectorCExecutor uses CMySQLConnection type, which sends the client packets different than standard MySQLConnection
# the difference is packets order in PreparedStatements processing
# after sending CommandStatementPrepare, ConnectorC send the CommandStatementReset and expect StatementResetResult(OK|Err)
# then send the empty CommandStatementExecute without params, params come in next packets
# to handle such behaviour properly, MySQL proxy should have StatementReset handler
# and skip params parsing if the params number is 0 in CommandStatementExecute handler
class MysqlConnectorCExecutor(QueryExecutor):
def _result_to_dict(self, description, data):
"""convert list of tuples of rows to list of dicts"""
columns_name = [i[0] for i in description]
result = []
for row in data:
row_data = {column_name: value
for column_name, value in zip(columns_name, row)}
result.append(row_data)
return result

def execute(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.connect(
use_unicode=True, raw=self.connection_args.raw, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
password=self.connection_args.password,
database=self.connection_args.dbname,
ssl_ca=self.connection_args.ssl_ca,
ssl_cert=self.connection_args.ssl_cert,
ssl_key=self.connection_args.ssl_key,
ssl_disabled=not TEST_WITH_TLS)) as connection:

with contextlib.closing(connection.cursor()) as cursor:
cursor.execute(query, args)
data = cursor.fetchall()
result = self._result_to_dict(cursor.description, data)
return result

def execute_prepared_statement(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.connect(
use_unicode=True, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
password=self.connection_args.password,
database=self.connection_args.dbname,
ssl_ca=self.connection_args.ssl_ca,
ssl_cert=self.connection_args.ssl_cert,
ssl_key=self.connection_args.ssl_key,
ssl_disabled=not TEST_WITH_TLS)) as connection:

with contextlib.closing(connection.cursor(prepared=True)) as cursor:
cursor.execute(query, args)
data = cursor.fetchall()
result = self._result_to_dict(cursor.description, data)
return result

def execute_prepared_statement_no_result(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.connect(
use_unicode=True, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
password=self.connection_args.password,
database=self.connection_args.dbname,
ssl_ca=self.connection_args.ssl_ca,
ssl_cert=self.connection_args.ssl_cert,
ssl_key=self.connection_args.ssl_key,
ssl_disabled=not TEST_WITH_TLS)) as connection:

with contextlib.closing(connection.cursor(prepared=True)) as cursor:
cursor.execute(query, args)
connection.commit()


class MysqlExecutor(QueryExecutor):
def _result_to_dict(self, description, data):
"""convert list of tuples of rows to list of dicts"""
Expand All @@ -919,7 +995,7 @@ def _result_to_dict(self, description, data):
def execute(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.Connect(
with contextlib.closing(mysql.connector.connection.MySQLConnection(
use_unicode=False, raw=self.connection_args.raw, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
Expand All @@ -939,7 +1015,7 @@ def execute(self, query, args=None):
def execute_prepared_statement(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.Connect(
with contextlib.closing(mysql.connector.connection.MySQLConnection(
use_unicode=False, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
Expand All @@ -959,7 +1035,7 @@ def execute_prepared_statement(self, query, args=None):
def execute_prepared_statement_no_result(self, query, args=None):
if args is None:
args = []
with contextlib.closing(mysql.connector.Connect(
with contextlib.closing(mysql.connector.connection.MySQLConnection(
use_unicode=False, charset='ascii',
host=self.connection_args.host, port=self.connection_args.port,
user=self.connection_args.user,
Expand Down Expand Up @@ -4689,6 +4765,24 @@ def executePreparedStatement(self, query, args=None):
).execute_prepared_statement(query, args=args)


class TestMysqlConnectorCBinaryPreparedStatement(BasePrepareStatementMixin, BaseTestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some comment with description what we expected this connector will do? it will help to understand at the future if this test will fail after updating driver or something else. and explain what we expected and should reproduce. because after a year if this test will fail, we will not remember what behavior we expected and what to do to back it again

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

def checkSkip(self):
if not TEST_MYSQL:
self.skipTest("run test only for mysql")
elif not TEST_WITH_TLS:
self.skipTest("running tests only with TLS")

def executePreparedStatement(self, query, args=None):
return MysqlConnectorCExecutor(
ConnectionArgs(host='localhost', port=self.ACRASERVER_PORT,
user=DB_USER, password=DB_USER_PASSWORD,
dbname=DB_NAME, ssl_ca=TEST_TLS_CA,
ssl_key=TEST_TLS_CLIENT_KEY,
raw=True,
ssl_cert=TEST_TLS_CLIENT_CERT)
).execute_prepared_statement(query, args=args)


class TestMysqlBinaryPreparedStatementWholeCell(TestMysqlBinaryPreparedStatement):
WHOLECELL_MODE = True

Expand Down