Skip to content

Commit

Permalink
fix(mysql): avoid creating any tables when using .sql() (#9363)
Browse files Browse the repository at this point in the history
## Description of changes

This PR fixes an issue with using `.sql()` with MySQL in environments
where the
user doesn't have permission to create tables.

Previously we created a temporary table with zero rows and then used
that to
discover types of a query.

In this PR, I use the given zero-row `SELECT` query to get the types of
the columns
using `cursor.description` along with some low level details of the
MySQL protocol.

Inferring types this way doesn't allow getting type information from
ad-hoc
types that use plugins, like uuid or inet, so their physical type is
used
(string in both those cases).

This seems like an acceptable tradeoff given that without this change
`.sql()`
cannot be used in a whole class of deployments/environments.

## Issues closed

Closes #9354.
  • Loading branch information
cpcloud authored Jun 12, 2024
1 parent d8fc4f5 commit d2d5251
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 10 deletions.
16 changes: 8 additions & 8 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ibis import util
from ibis.backends import CanCreateDatabase
from ibis.backends.mysql.compiler import MySQLCompiler
from ibis.backends.mysql.datatypes import _type_from_cursor_info
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import TRUE, C

Expand Down Expand Up @@ -189,16 +190,15 @@ def list_databases(self, like: str | None = None) -> list[str]:
return self._filter_with_like(databases, like)

def _get_schema_using_query(self, query: str) -> sch.Schema:
table = util.gen_name(f"{self.name}_metadata")

with self.begin() as cur:
cur.execute(
f"CREATE TEMPORARY TABLE {table} AS SELECT * FROM ({query}) AS tmp LIMIT 0"
cur.execute(f"SELECT * FROM ({query}) AS tmp LIMIT 0")

return sch.Schema(
{
field.name: _type_from_cursor_info(descr, field)
for descr, field in zip(cur.description, cur._result.fields)
}
)
try:
return self.get_schema(table)
finally:
cur.execute(f"DROP TABLE {table}")

def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
Expand Down
142 changes: 142 additions & 0 deletions ibis/backends/mysql/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

import inspect
from functools import partial

from pymysql.constants import FIELD_TYPE

import ibis.expr.datatypes as dt

# binary character set
# used to distinguish blob binary vs blob text
MY_CHARSET_BIN = 63


def _type_from_cursor_info(descr, field) -> dt.DataType:
"""Construct an ibis type from MySQL field descr and field result metadata.
This method is complex because the MySQL protocol is complex.
Types are not encoded in a self contained way, meaning you need
multiple pieces of information coming from the result set metadata to
determine the most precise type for a field. Even then, the decoding is
not high fidelity in some cases: UUIDs for example are decoded as
strings, because the protocol does not appear to preserve the logical
type, only the physical type.
"""
from pymysql.connections import TEXT_TYPES

_, type_code, _, _, field_length, scale, _ = descr
flags = _FieldFlags(field.flags)
typename = _type_codes.get(type_code)
if typename is None:
raise NotImplementedError(f"MySQL type code {type_code:d} is not supported")

if typename in ("DECIMAL", "NEWDECIMAL"):
precision = _decimal_length_to_precision(
length=field_length,
scale=scale,
is_unsigned=flags.is_unsigned,
)
typ = partial(_type_mapping[typename], precision=precision, scale=scale)
elif typename == "BIT":
if field_length <= 8:
typ = dt.int8
elif field_length <= 16:
typ = dt.int16
elif field_length <= 32:
typ = dt.int32
elif field_length <= 64:
typ = dt.int64
else:
raise AssertionError("invalid field length for BIT type")
elif flags.is_set:
# sets are limited to strings
typ = dt.Array(dt.string)
elif type_code in TEXT_TYPES:
# binary text
if field.charsetnr == MY_CHARSET_BIN:
typ = dt.Binary
else:
typ = dt.String
elif flags.is_timestamp or typename == "TIMESTAMP":
typ = partial(dt.Timestamp, timezone="UTC", scale=scale or None)
elif typename == "DATETIME":
typ = partial(dt.Timestamp, scale=scale or None)
else:
typ = _type_mapping[typename]

# projection columns are always nullable
return typ(nullable=True)


# ported from my_decimal.h:my_decimal_length_to_precision in mariadb
def _decimal_length_to_precision(*, length: int, scale: int, is_unsigned: bool) -> int:
return length - (scale > 0) - (not (is_unsigned or not length))


_type_codes = {v: k for k, v in inspect.getmembers(FIELD_TYPE) if not k.startswith("_")}


_type_mapping = {
"DECIMAL": dt.Decimal,
"TINY": dt.Int8,
"SHORT": dt.Int16,
"LONG": dt.Int32,
"FLOAT": dt.Float32,
"DOUBLE": dt.Float64,
"NULL": dt.Null,
"LONGLONG": dt.Int64,
"INT24": dt.Int32,
"DATE": dt.Date,
"TIME": dt.Time,
"DATETIME": dt.Timestamp,
"YEAR": dt.Int8,
"VARCHAR": dt.String,
"JSON": dt.JSON,
"NEWDECIMAL": dt.Decimal,
"ENUM": dt.String,
"SET": partial(dt.Array, dt.string),
"TINY_BLOB": dt.Binary,
"MEDIUM_BLOB": dt.Binary,
"LONG_BLOB": dt.Binary,
"BLOB": dt.Binary,
"VAR_STRING": dt.String,
"STRING": dt.String,
"GEOMETRY": dt.Geometry,
}


class _FieldFlags:
"""Flags used to disambiguate field types.
Gaps in the flag numbers are because we do not map in flags that are
of no use in determining the field's type, such as whether the field
is a primary key or not.
"""

UNSIGNED = 1 << 5
TIMESTAMP = 1 << 10
SET = 1 << 11
NUM = 1 << 15

__slots__ = ("value",)

def __init__(self, value: int) -> None:
self.value = value

@property
def is_unsigned(self) -> bool:
return (self.UNSIGNED & self.value) != 0

@property
def is_timestamp(self) -> bool:
return (self.TIMESTAMP & self.value) != 0

@property
def is_set(self) -> bool:
return (self.SET & self.value) != 0

@property
def is_num(self) -> bool:
return (self.NUM & self.value) != 0
29 changes: 27 additions & 2 deletions ibis/backends/mysql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@
# mariadb doesn't have a distinct json type
param("json", dt.string, id="json"),
param("enum('small', 'medium', 'large')", dt.string, id="enum"),
param("inet6", dt.inet, id="inet"),
param("set('a', 'b', 'c', 'd')", dt.Array(dt.string), id="set"),
param("mediumblob", dt.binary, id="mediumblob"),
param("blob", dt.binary, id="blob"),
param("uuid", dt.uuid, id="uuid"),
] + [
param(
f"datetime({scale:d})",
Expand Down Expand Up @@ -85,6 +83,33 @@ def test_get_schema_from_query(con, mysql_type, expected_type):
assert t.schema() == expected_schema


@pytest.mark.parametrize(
("mysql_type", "get_schema_expected_type", "table_expected_type"),
[
param("inet6", dt.string, dt.inet, id="inet"),
param("uuid", dt.string, dt.uuid, id="uuid"),
],
)
def test_get_schema_from_query_special_cases(
con, mysql_type, get_schema_expected_type, table_expected_type
):
raw_name = ibis.util.guid()
name = sg.to_identifier(raw_name, quoted=True).sql("mysql")
get_schema_expected_schema = ibis.schema(dict(x=get_schema_expected_type))
table_expected_schema = ibis.schema(dict(x=table_expected_type))

# temporary tables get cleaned up by the db when the session ends, so we
# don't need to explicitly drop the table
with con.begin() as c:
c.execute(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})")

result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
assert result_schema == get_schema_expected_schema

t = con.table(raw_name)
assert t.schema() == table_expected_schema


@pytest.mark.parametrize("coltype", ["TINYBLOB", "MEDIUMBLOB", "BLOB", "LONGBLOB"])
def test_blob_type(con, coltype):
tmp = f"tmp_{ibis.util.guid()}"
Expand Down

0 comments on commit d2d5251

Please sign in to comment.