From a6e476c3d0d7a8603db9f70f63b336cfa421ac1c Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Thu, 6 Mar 2025 13:12:43 -0700 Subject: [PATCH 01/16] Fix enum --- mysql_ch_replicator/converter.py | 8 ++++++-- pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mysql_ch_replicator/converter.py b/mysql_ch_replicator/converter.py index 9aecffe..ff227f0 100644 --- a/mysql_ch_replicator/converter.py +++ b/mysql_ch_replicator/converter.py @@ -284,8 +284,12 @@ def convert_type(self, mysql_type, parameters): for idx, value_name in enumerate(enum_values): ch_enum_values.append(f"'{value_name}' = {idx+1}") ch_enum_values = ', '.join(ch_enum_values) - # Enum8('red' = 1, 'green' = 2, 'black' = 3) - return f'Enum8({ch_enum_values})' + if len(enum_values) <= 127: + # Enum8('red' = 1, 'green' = 2, 'black' = 3) + return f'Enum8({ch_enum_values})' + else: + # Enum16('red' = 1, 'green' = 2, 'black' = 3) + return f'Enum16({ch_enum_values})' if 'text' in mysql_type: return 'String' if 'blob' in mysql_type: diff --git a/pyproject.toml b/pyproject.toml index f617f26..e098d96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mysql-ch-replicator" -version = "0.0.40" +version = "0.0.70" description = "Tool for replication of MySQL databases to ClickHouse" authors = ["Filipp Ozinov "] license = "MIT" From 40fba7159f2a1df1b345126462b233b33af7d961 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Thu, 6 Mar 2025 13:33:11 -0700 Subject: [PATCH 02/16] Fix enum parser --- mysql_ch_replicator/converter.py | 763 +++++++++++-------- mysql_ch_replicator/converter_enum_parser.py | 65 +- 2 files changed, 497 insertions(+), 331 deletions(-) diff --git a/mysql_ch_replicator/converter.py b/mysql_ch_replicator/converter.py index ff227f0..f534696 100644 --- a/mysql_ch_replicator/converter.py +++ b/mysql_ch_replicator/converter.py @@ -1,57 +1,61 @@ -import struct import json +import re +import struct import uuid +from typing import TYPE_CHECKING + import sqlparse -import re -from pyparsing import Suppress, CaselessKeyword, Word, alphas, alphanums, delimitedList +from pyparsing import CaselessKeyword, Suppress, Word, alphanums, alphas, delimitedList -from .table_structure import TableStructure, TableField from .converter_enum_parser import parse_mysql_enum +from .table_structure import TableField, TableStructure +if TYPE_CHECKING: + from .db_replicator import DbReplicator CHARSET_MYSQL_TO_PYTHON = { - 'armscii8': None, # ARMSCII-8 is not directly supported in Python - 'ascii': 'ascii', - 'big5': 'big5', - 'binary': 'latin1', # Treat binary data as Latin-1 in Python - 'cp1250': 'cp1250', - 'cp1251': 'cp1251', - 'cp1256': 'cp1256', - 'cp1257': 'cp1257', - 'cp850': 'cp850', - 'cp852': 'cp852', - 'cp866': 'cp866', - 'cp932': 'cp932', - 'dec8': 'latin1', # DEC8 is similar to Latin-1 - 'eucjpms': 'euc_jp', # Map to EUC-JP - 'euckr': 'euc_kr', - 'gb18030': 'gb18030', - 'gb2312': 'gb2312', - 'gbk': 'gbk', - 'geostd8': None, # GEOSTD8 is not directly supported in Python - 'greek': 'iso8859_7', - 'hebrew': 'iso8859_8', - 'hp8': None, # HP8 is not directly supported in Python - 'keybcs2': None, # KEYBCS2 is not directly supported in Python - 'koi8r': 'koi8_r', - 'koi8u': 'koi8_u', - 'latin1': 'cp1252', # MySQL's latin1 corresponds to Windows-1252 - 'latin2': 'iso8859_2', - 'latin5': 'iso8859_9', - 'latin7': 'iso8859_13', - 'macce': 'mac_latin2', - 'macroman': 'mac_roman', - 'sjis': 'shift_jis', - 'swe7': None, # SWE7 is not directly supported in Python - 'tis620': 'tis_620', - 'ucs2': 'utf_16', # UCS-2 can be mapped to UTF-16 - 'ujis': 'euc_jp', - 'utf16': 'utf_16', - 'utf16le': 'utf_16_le', - 'utf32': 'utf_32', - 'utf8mb3': 'utf_8', # Both utf8mb3 and utf8mb4 can be mapped to UTF-8 - 'utf8mb4': 'utf_8', - 'utf8': 'utf_8', + "armscii8": None, # ARMSCII-8 is not directly supported in Python + "ascii": "ascii", + "big5": "big5", + "binary": "latin1", # Treat binary data as Latin-1 in Python + "cp1250": "cp1250", + "cp1251": "cp1251", + "cp1256": "cp1256", + "cp1257": "cp1257", + "cp850": "cp850", + "cp852": "cp852", + "cp866": "cp866", + "cp932": "cp932", + "dec8": "latin1", # DEC8 is similar to Latin-1 + "eucjpms": "euc_jp", # Map to EUC-JP + "euckr": "euc_kr", + "gb18030": "gb18030", + "gb2312": "gb2312", + "gbk": "gbk", + "geostd8": None, # GEOSTD8 is not directly supported in Python + "greek": "iso8859_7", + "hebrew": "iso8859_8", + "hp8": None, # HP8 is not directly supported in Python + "keybcs2": None, # KEYBCS2 is not directly supported in Python + "koi8r": "koi8_r", + "koi8u": "koi8_u", + "latin1": "cp1252", # MySQL's latin1 corresponds to Windows-1252 + "latin2": "iso8859_2", + "latin5": "iso8859_9", + "latin7": "iso8859_13", + "macce": "mac_latin2", + "macroman": "mac_roman", + "sjis": "shift_jis", + "swe7": None, # SWE7 is not directly supported in Python + "tis620": "tis_620", + "ucs2": "utf_16", # UCS-2 can be mapped to UTF-16 + "ujis": "euc_jp", + "utf16": "utf_16", + "utf16le": "utf_16_le", + "utf32": "utf_32", + "utf8mb3": "utf_8", # Both utf8mb3 and utf8mb4 can be mapped to UTF-8 + "utf8mb4": "utf_8", + "utf8": "utf_8", } @@ -59,7 +63,7 @@ def convert_bytes(obj): if isinstance(obj, dict): new_obj = {} for k, v in obj.items(): - new_key = k.decode('utf-8') if isinstance(k, bytes) else k + new_key = k.decode("utf-8") if isinstance(k, bytes) else k new_value = convert_bytes(v) new_obj[new_key] = new_value return new_obj @@ -71,7 +75,7 @@ def convert_bytes(obj): return tuple(new_obj) return new_obj elif isinstance(obj, bytes): - return obj.decode('utf-8') + return obj.decode("utf-8") else: return obj @@ -92,37 +96,37 @@ def parse_mysql_point(binary): # Read the byte order byte_order = binary[0] if byte_order == 0: - endian = '>' + endian = ">" elif byte_order == 1: - endian = '<' + endian = "<" else: raise ValueError("Invalid byte order in WKB POINT") # Read the WKB Type - wkb_type = struct.unpack(endian + 'I', binary[1:5])[0] + wkb_type = struct.unpack(endian + "I", binary[1:5])[0] if wkb_type != 1: # WKB type 1 means POINT raise ValueError("Not a WKB POINT type") # Read X and Y coordinates - x = struct.unpack(endian + 'd', binary[5:13])[0] - y = struct.unpack(endian + 'd', binary[13:21])[0] + x = struct.unpack(endian + "d", binary[5:13])[0] + y = struct.unpack(endian + "d", binary[13:21])[0] elif len(binary) == 25: # With SRID included # First 4 bytes are the SRID - srid = struct.unpack('>I', binary[0:4])[0] # SRID is big-endian + srid = struct.unpack(">I", binary[0:4])[0] # SRID is big-endian # Next byte is byte order byte_order = binary[4] if byte_order == 0: - endian = '>' + endian = ">" elif byte_order == 1: - endian = '<' + endian = "<" else: raise ValueError("Invalid byte order in WKB POINT") # Read the WKB Type - wkb_type = struct.unpack(endian + 'I', binary[5:9])[0] + wkb_type = struct.unpack(endian + "I", binary[5:9])[0] if wkb_type != 1: # WKB type 1 means POINT raise ValueError("Not a WKB POINT type") # Read X and Y coordinates - x = struct.unpack(endian + 'd', binary[9:17])[0] - y = struct.unpack(endian + 'd', binary[17:25])[0] + x = struct.unpack(endian + "d", binary[9:17])[0] + y = struct.unpack(endian + "d", binary[17:25])[0] else: raise ValueError("Invalid binary length for WKB POINT") return (x, y) @@ -130,9 +134,9 @@ def parse_mysql_point(binary): def strip_sql_name(name): name = name.strip() - if name.startswith('`'): + if name.startswith("`"): name = name[1:] - if name.endswith('`'): + if name.endswith("`"): name = name[:-1] return name @@ -140,15 +144,15 @@ def strip_sql_name(name): def split_high_level(data, token): results = [] level = 0 - curr_data = '' + curr_data = "" for c in data: if c == token and level == 0: results.append(curr_data.strip()) - curr_data = '' + curr_data = "" continue - if c == '(': + if c == "(": level += 1 - if c == ')': + if c == ")": level -= 1 curr_data += c if curr_data: @@ -161,9 +165,8 @@ def strip_sql_comments(sql_statement): def convert_timestamp_to_datetime64(input_str): - # Define the regex pattern - pattern = r'^timestamp(?:\((\d+)\))?$' + pattern = r"^timestamp(?:\((\d+)\))?$" # Attempt to match the pattern match = re.match(pattern, input_str.strip(), re.IGNORECASE) @@ -172,34 +175,34 @@ def convert_timestamp_to_datetime64(input_str): # If a precision is provided, include it in the replacement precision = match.group(1) if precision is not None: - return f'DateTime64({precision})' + return f"DateTime64({precision})" else: - return 'DateTime64' + return "DateTime64" else: raise ValueError(f"Invalid input string format: '{input_str}'") class MysqlToClickhouseConverter: - def __init__(self, db_replicator: 'DbReplicator' = None): + def __init__(self, db_replicator: "DbReplicator" = None): self.db_replicator = db_replicator self.types_mapping = {} if self.db_replicator is not None: self.types_mapping = db_replicator.config.types_mapping def convert_type(self, mysql_type, parameters): - is_unsigned = 'unsigned' in parameters.lower() + is_unsigned = "unsigned" in parameters.lower() result_type = self.types_mapping.get(mysql_type) if result_type is not None: return result_type - if mysql_type == 'point': - return 'Tuple(x Float32, y Float32)' + if mysql_type == "point": + return "Tuple(x Float32, y Float32)" # Correctly handle numeric types - if mysql_type.startswith('numeric'): + if mysql_type.startswith("numeric"): # Determine if parameters are specified via parentheses: - if '(' in mysql_type and ')' in mysql_type: + if "(" in mysql_type and ")" in mysql_type: # Expecting a type definition like "numeric(precision, scale)" pattern = r"numeric\((\d+)\s*,\s*(\d+)\)" match = re.search(pattern, mysql_type) @@ -234,141 +237,160 @@ def convert_type(self, mysql_type, parameters): # For types with a defined fractional part, use a Decimal mapping. return f"Decimal({precision}, {scale})" - if mysql_type == 'int': + if mysql_type == "int": if is_unsigned: - return 'UInt32' - return 'Int32' - if mysql_type == 'integer': + return "UInt32" + return "Int32" + if mysql_type == "integer": if is_unsigned: - return 'UInt32' - return 'Int32' - if mysql_type == 'bigint': + return "UInt32" + return "Int32" + if mysql_type == "bigint": if is_unsigned: - return 'UInt64' - return 'Int64' - if mysql_type == 'double': - return 'Float64' - if mysql_type == 'real': - return 'Float64' - if mysql_type == 'float': - return 'Float32' - if mysql_type == 'date': - return 'Date32' - if mysql_type == 'tinyint(1)': - return 'Bool' - if mysql_type == 'bit(1)': - return 'Bool' - if mysql_type == 'bool': - return 'Bool' - if 'smallint' in mysql_type: + return "UInt64" + return "Int64" + if mysql_type == "double": + return "Float64" + if mysql_type == "real": + return "Float64" + if mysql_type == "float": + return "Float32" + if mysql_type == "date": + return "Date32" + if mysql_type == "tinyint(1)": + return "Bool" + if mysql_type == "bit(1)": + return "Bool" + if mysql_type == "bool": + return "Bool" + if "smallint" in mysql_type: if is_unsigned: - return 'UInt16' - return 'Int16' - if 'tinyint' in mysql_type: + return "UInt16" + return "Int16" + if "tinyint" in mysql_type: if is_unsigned: - return 'UInt8' - return 'Int8' - if 'mediumint' in mysql_type: + return "UInt8" + return "Int8" + if "mediumint" in mysql_type: if is_unsigned: - return 'UInt32' - return 'Int32' - if 'datetime' in mysql_type: - return mysql_type.replace('datetime', 'DateTime64') - if 'longtext' in mysql_type: - return 'String' - if 'varchar' in mysql_type: - return 'String' - if mysql_type.startswith('enum'): + return "UInt32" + return "Int32" + if "datetime" in mysql_type: + return mysql_type.replace("datetime", "DateTime64") + if "longtext" in mysql_type: + return "String" + if "varchar" in mysql_type: + return "String" + if mysql_type.startswith("enum"): enum_values = parse_mysql_enum(mysql_type) ch_enum_values = [] for idx, value_name in enumerate(enum_values): - ch_enum_values.append(f"'{value_name}' = {idx+1}") - ch_enum_values = ', '.join(ch_enum_values) + ch_enum_values.append(f"'{value_name}' = {idx + 1}") + ch_enum_values = ", ".join(ch_enum_values) if len(enum_values) <= 127: # Enum8('red' = 1, 'green' = 2, 'black' = 3) - return f'Enum8({ch_enum_values})' + return f"Enum8({ch_enum_values})" else: # Enum16('red' = 1, 'green' = 2, 'black' = 3) - return f'Enum16({ch_enum_values})' - if 'text' in mysql_type: - return 'String' - if 'blob' in mysql_type: - return 'String' - if 'char' in mysql_type: - return 'String' - if 'json' in mysql_type: - return 'String' - if 'decimal' in mysql_type: - return 'Float64' - if 'float' in mysql_type: - return 'Float32' - if 'double' in mysql_type: - return 'Float64' - if 'bigint' in mysql_type: + return f"Enum16({ch_enum_values})" + if "text" in mysql_type: + return "String" + if "blob" in mysql_type: + return "String" + if "char" in mysql_type: + return "String" + if "json" in mysql_type: + return "String" + if "decimal" in mysql_type: + return "Float64" + if "float" in mysql_type: + return "Float32" + if "double" in mysql_type: + return "Float64" + if "bigint" in mysql_type: if is_unsigned: - return 'UInt64' - return 'Int64' - if 'integer' in mysql_type or 'int(' in mysql_type: + return "UInt64" + return "Int64" + if "integer" in mysql_type or "int(" in mysql_type: if is_unsigned: - return 'UInt32' - return 'Int32' - if 'real' in mysql_type: - return 'Float64' - if mysql_type.startswith('timestamp'): + return "UInt32" + return "Int32" + if "real" in mysql_type: + return "Float64" + if mysql_type.startswith("timestamp"): return convert_timestamp_to_datetime64(mysql_type) - if mysql_type.startswith('time'): - return 'String' - if 'varbinary' in mysql_type: - return 'String' - if 'binary' in mysql_type: - return 'String' - if 'set(' in mysql_type: - return 'String' + if mysql_type.startswith("time"): + return "String" + if "varbinary" in mysql_type: + return "String" + if "binary" in mysql_type: + return "String" + if "set(" in mysql_type: + return "String" raise Exception(f'unknown mysql type "{mysql_type}"') def convert_field_type(self, mysql_type, mysql_parameters): mysql_type = mysql_type.lower() mysql_parameters = mysql_parameters.lower() - not_null = 'not null' in mysql_parameters + not_null = "not null" in mysql_parameters clickhouse_type = self.convert_type(mysql_type, mysql_parameters) - if 'Tuple' in clickhouse_type: + if "Tuple" in clickhouse_type: not_null = True if not not_null: - clickhouse_type = f'Nullable({clickhouse_type})' + clickhouse_type = f"Nullable({clickhouse_type})" return clickhouse_type - def convert_table_structure(self, mysql_structure: TableStructure) -> TableStructure: + def convert_table_structure( + self, mysql_structure: TableStructure + ) -> TableStructure: clickhouse_structure = TableStructure() clickhouse_structure.table_name = mysql_structure.table_name clickhouse_structure.if_not_exists = mysql_structure.if_not_exists for field in mysql_structure.fields: - clickhouse_field_type = self.convert_field_type(field.field_type, field.parameters) - clickhouse_structure.fields.append(TableField( - name=field.name, - field_type=clickhouse_field_type, - )) + clickhouse_field_type = self.convert_field_type( + field.field_type, field.parameters + ) + clickhouse_structure.fields.append( + TableField( + name=field.name, + field_type=clickhouse_field_type, + ) + ) clickhouse_structure.primary_keys = mysql_structure.primary_keys clickhouse_structure.preprocess() return clickhouse_structure def convert_records( - self, mysql_records, mysql_structure: TableStructure, clickhouse_structure: TableStructure, - only_primary: bool = False, + self, + mysql_records, + mysql_structure: TableStructure, + clickhouse_structure: TableStructure, + only_primary: bool = False, ): mysql_field_types = [field.field_type for field in mysql_structure.fields] - clickhouse_filed_types = [field.field_type for field in clickhouse_structure.fields] + clickhouse_filed_types = [ + field.field_type for field in clickhouse_structure.fields + ] clickhouse_records = [] for mysql_record in mysql_records: clickhouse_record = self.convert_record( - mysql_record, mysql_field_types, clickhouse_filed_types, mysql_structure, only_primary, + mysql_record, + mysql_field_types, + clickhouse_filed_types, + mysql_structure, + only_primary, ) clickhouse_records.append(clickhouse_record) return clickhouse_records def convert_record( - self, mysql_record, mysql_field_types, clickhouse_field_types, mysql_structure: TableStructure, - only_primary: bool, + self, + mysql_record, + mysql_field_types, + clickhouse_field_types, + mysql_structure: TableStructure, + only_primary: bool, ): clickhouse_record = [] for idx, mysql_field_value in enumerate(mysql_record): @@ -379,38 +401,50 @@ def convert_record( clickhouse_field_value = mysql_field_value mysql_field_type = mysql_field_types[idx] clickhouse_field_type = clickhouse_field_types[idx] - if mysql_field_type.startswith('time') and 'String' in clickhouse_field_type: + if ( + mysql_field_type.startswith("time") + and "String" in clickhouse_field_type + ): clickhouse_field_value = str(mysql_field_value) - if mysql_field_type == 'json' and 'String' in clickhouse_field_type: + if mysql_field_type == "json" and "String" in clickhouse_field_type: if not isinstance(clickhouse_field_value, str): - clickhouse_field_value = json.dumps(convert_bytes(clickhouse_field_value)) + clickhouse_field_value = json.dumps( + convert_bytes(clickhouse_field_value) + ) if clickhouse_field_value is not None: - if 'UUID' in clickhouse_field_type: + if "UUID" in clickhouse_field_type: if len(clickhouse_field_value) == 36: if isinstance(clickhouse_field_value, bytes): - clickhouse_field_value = clickhouse_field_value.decode('utf-8') + clickhouse_field_value = clickhouse_field_value.decode( + "utf-8" + ) clickhouse_field_value = uuid.UUID(clickhouse_field_value).bytes - if 'UInt16' in clickhouse_field_type and clickhouse_field_value < 0: + if "UInt16" in clickhouse_field_type and clickhouse_field_value < 0: clickhouse_field_value = 65536 + clickhouse_field_value - if 'UInt8' in clickhouse_field_type and clickhouse_field_value < 0: + if "UInt8" in clickhouse_field_type and clickhouse_field_value < 0: clickhouse_field_value = 256 + clickhouse_field_value - if 'mediumint' in mysql_field_type.lower() and clickhouse_field_value < 0: + if ( + "mediumint" in mysql_field_type.lower() + and clickhouse_field_value < 0 + ): clickhouse_field_value = 16777216 + clickhouse_field_value - if 'UInt32' in clickhouse_field_type and clickhouse_field_value < 0: + if "UInt32" in clickhouse_field_type and clickhouse_field_value < 0: clickhouse_field_value = 4294967296 + clickhouse_field_value - if 'UInt64' in clickhouse_field_type and clickhouse_field_value < 0: - clickhouse_field_value = 18446744073709551616 + clickhouse_field_value + if "UInt64" in clickhouse_field_type and clickhouse_field_value < 0: + clickhouse_field_value = ( + 18446744073709551616 + clickhouse_field_value + ) - if 'String' in clickhouse_field_type and ( - 'text' in mysql_field_type or 'char' in mysql_field_type + if "String" in clickhouse_field_type and ( + "text" in mysql_field_type or "char" in mysql_field_type ): if isinstance(clickhouse_field_value, bytes): - charset = mysql_structure.charset_python or 'utf-8' + charset = mysql_structure.charset_python or "utf-8" clickhouse_field_value = clickhouse_field_value.decode(charset) - if 'set(' in mysql_field_type: + if "set(" in mysql_field_type: set_values = mysql_structure.fields[idx].additional_data if isinstance(clickhouse_field_value, int): bit_mask = clickhouse_field_value @@ -423,29 +457,31 @@ def convert_record( clickhouse_field_value = [ v for v in set_values if v in clickhouse_field_value ] - clickhouse_field_value = ','.join(clickhouse_field_value) + clickhouse_field_value = ",".join(clickhouse_field_value) - if mysql_field_type.startswith('point'): + if mysql_field_type.startswith("point"): clickhouse_field_value = parse_mysql_point(clickhouse_field_value) - if mysql_field_type.startswith('enum(') and isinstance(clickhouse_field_value, int): + if mysql_field_type.startswith("enum(") and isinstance( + clickhouse_field_value, int + ): enum_values = mysql_structure.fields[idx].additional_data - clickhouse_field_value = enum_values[int(clickhouse_field_value)-1] + clickhouse_field_value = enum_values[int(clickhouse_field_value) - 1] clickhouse_record.append(clickhouse_field_value) return tuple(clickhouse_record) def __basic_validate_query(self, mysql_query): mysql_query = mysql_query.strip() - if mysql_query.endswith(';'): + if mysql_query.endswith(";"): mysql_query = mysql_query[:-1] - if mysql_query.find(';') != -1: - raise Exception('multi-query statement not supported') + if mysql_query.find(";") != -1: + raise Exception("multi-query statement not supported") return mysql_query - + def get_db_and_table_name(self, token, db_name): - if '.' in token: - db_name, table_name = token.split('.') + if "." in token: + db_name, table_name = token.split(".") else: table_name = token db_name = strip_sql_name(db_name) @@ -453,9 +489,9 @@ def get_db_and_table_name(self, token, db_name): if self.db_replicator: if db_name == self.db_replicator.database: db_name = self.db_replicator.target_database - matches_config = ( - self.db_replicator.config.is_database_matches(db_name) - and self.db_replicator.config.is_table_matches(table_name)) + matches_config = self.db_replicator.config.is_database_matches( + db_name + ) and self.db_replicator.config.is_table_matches(table_name) else: matches_config = True @@ -465,19 +501,21 @@ def convert_alter_query(self, mysql_query, db_name): mysql_query = self.__basic_validate_query(mysql_query) tokens = mysql_query.split() - if tokens[0].lower() != 'alter': - raise Exception('wrong query') + if tokens[0].lower() != "alter": + raise Exception("wrong query") - if tokens[1].lower() != 'table': - raise Exception('wrong query') + if tokens[1].lower() != "table": + raise Exception("wrong query") - db_name, table_name, matches_config = self.get_db_and_table_name(tokens[2], db_name) + db_name, table_name, matches_config = self.get_db_and_table_name( + tokens[2], db_name + ) if not matches_config: return - subqueries = ' '.join(tokens[3:]) - subqueries = split_high_level(subqueries, ',') + subqueries = " ".join(tokens[3:]) + subqueries = split_high_level(subqueries, ",") for subquery in subqueries: subquery = subquery.strip() @@ -486,33 +524,33 @@ def convert_alter_query(self, mysql_query, db_name): op_name = tokens[0].lower() tokens = tokens[1:] - if tokens[0].lower() == 'column': + if tokens[0].lower() == "column": tokens = tokens[1:] - if op_name == 'add': - if tokens[0].lower() in ('constraint', 'index', 'foreign', 'unique'): + if op_name == "add": + if tokens[0].lower() in ("constraint", "index", "foreign", "unique"): continue self.__convert_alter_table_add_column(db_name, table_name, tokens) continue - if op_name == 'drop': - if tokens[0].lower() in ('constraint', 'index', 'foreign', 'unique'): + if op_name == "drop": + if tokens[0].lower() in ("constraint", "index", "foreign", "unique"): continue self.__convert_alter_table_drop_column(db_name, table_name, tokens) continue - if op_name == 'modify': + if op_name == "modify": self.__convert_alter_table_modify_column(db_name, table_name, tokens) continue - if op_name == 'alter': + if op_name == "alter": continue - if op_name == 'change': + if op_name == "change": self.__convert_alter_table_change_column(db_name, table_name, tokens) continue - raise Exception(f'operation {op_name} not implement, query: {subquery}') + raise Exception(f"operation {op_name} not implement, query: {subquery}") @classmethod def _tokenize_alter_query(cls, sql_line): @@ -524,7 +562,8 @@ def _tokenize_alter_query(cls, sql_line): # # The order is important: for example, if a word is immediately followed by parentheses, # we want to grab it as a single token. - token_pattern = re.compile(r''' + token_pattern = re.compile( + r""" ( # start capture group for a token `[^`]+`(?:\([^)]*\))? | # backquoted identifier w/ optional parentheses \w+(?:\([^)]*\))? | # a word with optional parentheses @@ -532,7 +571,9 @@ def _tokenize_alter_query(cls, sql_line): "(?:\\"|[^"])*" | # a double-quoted string [^\s]+ # fallback: any sequence of non-whitespace characters ) - ''', re.VERBOSE) + """, + re.VERBOSE, + ) tokens = token_pattern.findall(sql_line) # Now, split the column definition into: @@ -543,10 +584,29 @@ def _tokenize_alter_query(cls, sql_line): # # We define a set of keywords that indicate the start of column options. constraint_keywords = { - "DEFAULT", "NOT", "NULL", "AUTO_INCREMENT", "PRIMARY", "UNIQUE", - "COMMENT", "COLLATE", "REFERENCES", "ON", "CHECK", "CONSTRAINT", - "AFTER", "BEFORE", "GENERATED", "VIRTUAL", "STORED", "FIRST", - "ALWAYS", "AS", "IDENTITY", "INVISIBLE", "PERSISTED", + "DEFAULT", + "NOT", + "NULL", + "AUTO_INCREMENT", + "PRIMARY", + "UNIQUE", + "COMMENT", + "COLLATE", + "REFERENCES", + "ON", + "CHECK", + "CONSTRAINT", + "AFTER", + "BEFORE", + "GENERATED", + "VIRTUAL", + "STORED", + "FIRST", + "ALWAYS", + "AS", + "IDENTITY", + "INVISIBLE", + "PERSISTED", } if not tokens: @@ -554,7 +614,7 @@ def _tokenize_alter_query(cls, sql_line): # The first token is always the column name. column_name = tokens[0] - # Now “merge” tokens after the column name that belong to the type. + # Now "merge" tokens after the column name that belong to the type. # (For many types the type is written as a single token already – # e.g. "VARCHAR(254)" or "NUMERIC(5, 2)", but for types like # "DOUBLE PRECISION" or "INT UNSIGNED" the .split() would produce two tokens.) @@ -575,26 +635,28 @@ def _tokenize_alter_query(cls, sql_line): return [column_name] + param_tokens def __convert_alter_table_add_column(self, db_name, table_name, tokens): - tokens = self._tokenize_alter_query(' '.join(tokens)) + tokens = self._tokenize_alter_query(" ".join(tokens)) if len(tokens) < 2: - raise Exception('wrong tokens count', tokens) + raise Exception("wrong tokens count", tokens) column_after = None column_first = False - if tokens[-2].lower() == 'after': + if tokens[-2].lower() == "after": column_after = strip_sql_name(tokens[-1]) tokens = tokens[:-2] if len(tokens) < 2: - raise Exception('wrong tokens count', tokens) - elif tokens[-1].lower() == 'first': + raise Exception("wrong tokens count", tokens) + elif tokens[-1].lower() == "first": column_first = True column_name = strip_sql_name(tokens[0]) column_type_mysql = tokens[1] - column_type_mysql_parameters = ' '.join(tokens[2:]) + column_type_mysql_parameters = " ".join(tokens[2:]) - column_type_ch = self.convert_field_type(column_type_mysql, column_type_mysql_parameters) + column_type_ch = self.convert_field_type( + column_type_mysql, column_type_mysql_parameters + ) # update table structure if self.db_replicator: @@ -606,7 +668,7 @@ def __convert_alter_table_add_column(self, db_name, table_name, tokens): mysql_table_structure.add_field_first( TableField(name=column_name, field_type=column_type_mysql) ) - + ch_table_structure.add_field_first( TableField(name=column_name, field_type=column_type_ch) ) @@ -624,18 +686,18 @@ def __convert_alter_table_add_column(self, db_name, table_name, tokens): column_after, ) - query = f'ALTER TABLE `{db_name}`.`{table_name}` ADD COLUMN `{column_name}` {column_type_ch}' + query = f"ALTER TABLE `{db_name}`.`{table_name}` ADD COLUMN `{column_name}` {column_type_ch}" if column_first: - query += ' FIRST' + query += " FIRST" else: - query += f' AFTER {column_after}' + query += f" AFTER {column_after}" if self.db_replicator: self.db_replicator.clickhouse_api.execute_command(query) def __convert_alter_table_drop_column(self, db_name, table_name, tokens): if len(tokens) != 1: - raise Exception('wrong tokens count', tokens) + raise Exception("wrong tokens count", tokens) column_name = strip_sql_name(tokens[0]) @@ -648,19 +710,21 @@ def __convert_alter_table_drop_column(self, db_name, table_name, tokens): mysql_table_structure.remove_field(field_name=column_name) ch_table_structure.remove_field(field_name=column_name) - query = f'ALTER TABLE `{db_name}`.`{table_name}` DROP COLUMN {column_name}' + query = f"ALTER TABLE `{db_name}`.`{table_name}` DROP COLUMN {column_name}" if self.db_replicator: self.db_replicator.clickhouse_api.execute_command(query) def __convert_alter_table_modify_column(self, db_name, table_name, tokens): if len(tokens) < 2: - raise Exception('wrong tokens count', tokens) + raise Exception("wrong tokens count", tokens) column_name = strip_sql_name(tokens[0]) column_type_mysql = tokens[1] - column_type_mysql_parameters = ' '.join(tokens[2:]) + column_type_mysql_parameters = " ".join(tokens[2:]) - column_type_ch = self.convert_field_type(column_type_mysql, column_type_mysql_parameters) + column_type_ch = self.convert_field_type( + column_type_mysql, column_type_mysql_parameters + ) # update table structure if self.db_replicator: @@ -676,20 +740,22 @@ def __convert_alter_table_modify_column(self, db_name, table_name, tokens): TableField(name=column_name, field_type=column_type_ch), ) - query = f'ALTER TABLE `{db_name}`.`{table_name}` MODIFY COLUMN `{column_name}` {column_type_ch}' + query = f"ALTER TABLE `{db_name}`.`{table_name}` MODIFY COLUMN `{column_name}` {column_type_ch}" if self.db_replicator: self.db_replicator.clickhouse_api.execute_command(query) def __convert_alter_table_change_column(self, db_name, table_name, tokens): if len(tokens) < 3: - raise Exception('wrong tokens count', tokens) + raise Exception("wrong tokens count", tokens) column_name = strip_sql_name(tokens[0]) new_column_name = strip_sql_name(tokens[1]) column_type_mysql = tokens[2] - column_type_mysql_parameters = ' '.join(tokens[3:]) + column_type_mysql_parameters = " ".join(tokens[3:]) - column_type_ch = self.convert_field_type(column_type_mysql, column_type_mysql_parameters) + column_type_ch = self.convert_field_type( + column_type_mysql, column_type_mysql_parameters + ) # update table structure if self.db_replicator: @@ -697,10 +763,11 @@ def __convert_alter_table_change_column(self, db_name, table_name, tokens): mysql_table_structure: TableStructure = table_structure[0] ch_table_structure: TableStructure = table_structure[1] - current_column_type_ch = ch_table_structure.get_field(column_name).field_type + current_column_type_ch = ch_table_structure.get_field( + column_name + ).field_type if current_column_type_ch != column_type_ch: - mysql_table_structure.update_field( TableField(name=column_name, field_type=column_type_mysql), ) @@ -709,7 +776,7 @@ def __convert_alter_table_change_column(self, db_name, table_name, tokens): TableField(name=column_name, field_type=column_type_ch), ) - query = f'ALTER TABLE `{db_name}`.`{table_name}` MODIFY COLUMN {column_name} {column_type_ch}' + query = f"ALTER TABLE `{db_name}`.`{table_name}` MODIFY COLUMN {column_name} {column_type_ch}" self.db_replicator.clickhouse_api.execute_command(query) if column_name != new_column_name: @@ -719,109 +786,119 @@ def __convert_alter_table_change_column(self, db_name, table_name, tokens): curr_field_mysql.name = new_column_name curr_field_clickhouse.name = new_column_name - query = f'ALTER TABLE `{db_name}`.`{table_name}` RENAME COLUMN {column_name} TO {new_column_name}' + query = f"ALTER TABLE `{db_name}`.`{table_name}` RENAME COLUMN {column_name} TO {new_column_name}" self.db_replicator.clickhouse_api.execute_command(query) - def parse_create_table_query(self, mysql_query) -> tuple[TableStructure, TableStructure]: + def parse_create_table_query( + self, mysql_query + ) -> tuple[TableStructure, TableStructure]: mysql_table_structure = self.parse_mysql_table_structure(mysql_query) ch_table_structure = self.convert_table_structure(mysql_table_structure) return mysql_table_structure, ch_table_structure def convert_drop_table_query(self, mysql_query): - raise Exception('not implement') + raise Exception("not implement") def _strip_comments(self, create_statement): pattern = r'\bCOMMENT(?:\s*=\s*|\s+)([\'"])(?:\\.|[^\\])*?\1' - return re.sub(pattern, '', create_statement, flags=re.IGNORECASE) + return re.sub(pattern, "", create_statement, flags=re.IGNORECASE) def parse_mysql_table_structure(self, create_statement, required_table_name=None): create_statement = self._strip_comments(create_statement) structure = TableStructure() - tokens = sqlparse.parse(create_statement.replace('\n', ' ').strip())[0].tokens + tokens = sqlparse.parse(create_statement.replace("\n", " ").strip())[0].tokens tokens = [t for t in tokens if not t.is_whitespace and not t.is_newline] # remove "IF NOT EXISTS" - if (len(tokens) > 5 and - tokens[0].normalized.upper() == 'CREATE' and - tokens[1].normalized.upper() == 'TABLE' and - tokens[2].normalized.upper() == 'IF' and - tokens[3].normalized.upper() == 'NOT' and - tokens[4].normalized.upper() == 'EXISTS'): + if ( + len(tokens) > 5 + and tokens[0].normalized.upper() == "CREATE" + and tokens[1].normalized.upper() == "TABLE" + and tokens[2].normalized.upper() == "IF" + and tokens[3].normalized.upper() == "NOT" + and tokens[4].normalized.upper() == "EXISTS" + ): del tokens[2:5] # Remove the 'IF', 'NOT', 'EXISTS' tokens structure.if_not_exists = True if tokens[0].ttype != sqlparse.tokens.DDL: - raise Exception('wrong create statement', create_statement) - if tokens[0].normalized.lower() != 'create': - raise Exception('wrong create statement', create_statement) + raise Exception("wrong create statement", create_statement) + if tokens[0].normalized.lower() != "create": + raise Exception("wrong create statement", create_statement) if tokens[1].ttype != sqlparse.tokens.Keyword: - raise Exception('wrong create statement', create_statement) + raise Exception("wrong create statement", create_statement) if not isinstance(tokens[2], sqlparse.sql.Identifier): - raise Exception('wrong create statement', create_statement) + raise Exception("wrong create statement", create_statement) # get_real_name() returns the table name if the token is in the # style `.` structure.table_name = strip_sql_name(tokens[2].get_real_name()) if not isinstance(tokens[3], sqlparse.sql.Parenthesis): - raise Exception('wrong create statement', create_statement) + raise Exception("wrong create statement", create_statement) - #print(' --- processing statement:\n', create_statement, '\n') + # print(' --- processing statement:\n', create_statement, '\n') inner_tokens = tokens[3].tokens - inner_tokens = ''.join([str(t) for t in inner_tokens[1:-1]]).strip() - inner_tokens = split_high_level(inner_tokens, ',') + inner_tokens = "".join([str(t) for t in inner_tokens[1:-1]]).strip() + inner_tokens = split_high_level(inner_tokens, ",") - prev_token = '' - prev_prev_token = '' + prev_token = "" + prev_prev_token = "" for line in tokens[4:]: curr_token = line.value - if prev_token == '=' and prev_prev_token.lower() == 'charset': + if prev_token == "=" and prev_prev_token.lower() == "charset": structure.charset = curr_token prev_prev_token = prev_token prev_token = curr_token - structure.charset_python = 'utf-8' + structure.charset_python = "utf-8" if structure.charset: structure.charset_python = CHARSET_MYSQL_TO_PYTHON[structure.charset] - prev_line = '' + prev_line = "" for line in inner_tokens: line = prev_line + line - q_count = line.count('`') + q_count = line.count("`") if q_count % 2 == 1: prev_line = line continue - prev_line = '' + prev_line = "" - if line.lower().startswith('unique key'): + if line.lower().startswith("unique key"): continue - if line.lower().startswith('key'): + if line.lower().startswith("key"): continue - if line.lower().startswith('constraint'): + if line.lower().startswith("constraint"): continue - if line.lower().startswith('fulltext'): + if line.lower().startswith("fulltext"): continue - if line.lower().startswith('spatial'): + if line.lower().startswith("spatial"): continue - if line.lower().startswith('primary key'): + if line.lower().startswith("primary key"): # Define identifier to match column names, handling backticks and unquoted names - identifier = (Suppress('`') + Word(alphas + alphanums + '_') + Suppress('`')) | Word( - alphas + alphanums + '_') + identifier = ( + Suppress("`") + Word(alphas + alphanums + "_") + Suppress("`") + ) | Word(alphas + alphanums + "_") # Build the parsing pattern - pattern = CaselessKeyword('PRIMARY') + CaselessKeyword('KEY') + Suppress('(') + delimitedList( - identifier)('column_names') + Suppress(')') + pattern = ( + CaselessKeyword("PRIMARY") + + CaselessKeyword("KEY") + + Suppress("(") + + delimitedList(identifier)("column_names") + + Suppress(")") + ) # Parse the line result = pattern.parseString(line) # Extract and process the primary key column names - primary_keys = [strip_sql_name(name) for name in result['column_names']] + primary_keys = [strip_sql_name(name) for name in result["column_names"]] structure.primary_keys = primary_keys @@ -830,59 +907,139 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None line = line.strip() # print(" === processing line", line) - if line.startswith('`'): - end_pos = line.find('`', 1) + if line.startswith("`"): + end_pos = line.find("`", 1) field_name = line[1:end_pos] - line = line[end_pos+1:].strip() - definition = line.split(' ') + line = line[end_pos + 1 :].strip() + # Don't split by space for enum and set types that might contain spaces + if line.lower().startswith("enum(") or line.lower().startswith("set("): + # Find the end of the enum/set definition (closing parenthesis) + open_parens = 0 + in_quotes = False + quote_char = None + end_pos = -1 + + for i, char in enumerate(line): + if char in "'\"" and (i == 0 or line[i - 1] != "\\"): + if not in_quotes: + in_quotes = True + quote_char = char + elif char == quote_char: + in_quotes = False + elif char == "(" and not in_quotes: + open_parens += 1 + elif char == ")" and not in_quotes: + open_parens -= 1 + if open_parens == 0: + end_pos = i + 1 + break + + if end_pos > 0: + field_type = line[:end_pos] + field_parameters = line[end_pos:].strip() + else: + # Fallback to original behavior if we can't find the end + definition = line.split(" ") + field_type = definition[0] + field_parameters = ( + " ".join(definition[1:]) if len(definition) > 1 else "" + ) + else: + definition = line.split(" ") + field_type = definition[0] + field_parameters = ( + " ".join(definition[1:]) if len(definition) > 1 else "" + ) else: - definition = line.split(' ') + definition = line.split(" ") field_name = strip_sql_name(definition[0]) definition = definition[1:] - field_type = definition[0] - field_parameters = '' - if len(definition) > 1: - field_parameters = ' '.join(definition[1:]) + # Handle enum and set types with spaces for non-backtick field names + if definition and ( + definition[0].lower().startswith("enum(") + or definition[0].lower().startswith("set(") + ): + line = " ".join(definition) + # Find the end of the enum/set definition (closing parenthesis) + open_parens = 0 + in_quotes = False + quote_char = None + end_pos = -1 + + for i, char in enumerate(line): + if char in "'\"" and (i == 0 or line[i - 1] != "\\"): + if not in_quotes: + in_quotes = True + quote_char = char + elif char == quote_char: + in_quotes = False + elif char == "(" and not in_quotes: + open_parens += 1 + elif char == ")" and not in_quotes: + open_parens -= 1 + if open_parens == 0: + end_pos = i + 1 + break + + if end_pos > 0: + field_type = line[:end_pos] + field_parameters = line[end_pos:].strip() + else: + # Fallback to original behavior + field_type = definition[0] + field_parameters = ( + " ".join(definition[1:]) if len(definition) > 1 else "" + ) + else: + field_type = definition[0] + field_parameters = ( + " ".join(definition[1:]) if len(definition) > 1 else "" + ) additional_data = None - if 'set(' in field_type.lower(): - vals = field_type[len('set('):] - close_pos = vals.find(')') + if "set(" in field_type.lower(): + vals = field_type[len("set(") :] + close_pos = vals.find(")") vals = vals[:close_pos] - vals = vals.split(',') + vals = vals.split(",") + def vstrip(e): if not e: return e - if e[0] in '"\'': + if e[0] in "\"'": return e[1:-1] return e + vals = [vstrip(v) for v in vals] additional_data = vals - if field_type.lower().startswith('enum('): + if field_type.lower().startswith("enum("): additional_data = parse_mysql_enum(field_type) - structure.fields.append(TableField( - name=field_name, - field_type=field_type, - parameters=field_parameters, - additional_data=additional_data, - )) - #print(' ---- params:', field_parameters) - + structure.fields.append( + TableField( + name=field_name, + field_type=field_type, + parameters=field_parameters, + additional_data=additional_data, + ) + ) + # print(' ---- params:', field_parameters) if not structure.primary_keys: for field in structure.fields: - if 'primary key' in field.parameters.lower(): + if "primary key" in field.parameters.lower(): structure.primary_keys.append(field.name) if not structure.primary_keys: - if structure.has_field('id'): - structure.primary_keys = ['id'] + if structure.has_field("id"): + structure.primary_keys = ["id"] if not structure.primary_keys: - raise Exception(f'No primary key for table {structure.table_name}, {create_statement}') + raise Exception( + f"No primary key for table {structure.table_name}, {create_statement}" + ) structure.preprocess() return structure diff --git a/mysql_ch_replicator/converter_enum_parser.py b/mysql_ch_replicator/converter_enum_parser.py index 92192ea..b4f497d 100644 --- a/mysql_ch_replicator/converter_enum_parser.py +++ b/mysql_ch_replicator/converter_enum_parser.py @@ -1,5 +1,3 @@ - - def parse_mysql_enum(enum_definition): """ Accepts a MySQL ENUM definition string (case–insensitive), @@ -22,7 +20,7 @@ def parse_mysql_enum(enum_definition): raise ValueError("String does not start with 'enum'") # Find the first opening parenthesis. - pos = s.find('(') + pos = s.find("(") if pos == -1: raise ValueError("Missing '(' in the enum definition") @@ -48,7 +46,8 @@ def _extract_parenthesized_content(s, start_index): ', " or `) and also to skip over escape sequences in single/double quotes. (Backticks do not process backslash escapes.) """ - if s[start_index] != '(': + + if s[start_index] != "(": raise ValueError("Expected '(' at position {}".format(start_index)) depth = 1 i = start_index + 1 @@ -56,14 +55,14 @@ def _extract_parenthesized_content(s, start_index): in_quote = None # will be set to a quoting character when inside a quoted literal # Allow these quote characters. - allowed_quotes = ("'", '"', '`') + allowed_quotes = ("'", '"', "`") while i < len(s): c = s[i] if in_quote: # Inside a quoted literal. if in_quote in ("'", '"'): - if c == '\\': + if c == "\\": # Skip the escape character and the next character. i += 2 continue @@ -87,21 +86,24 @@ def _extract_parenthesized_content(s, start_index): in_quote = c i += 1 continue - elif c == '(': + elif c == "(": depth += 1 i += 1 continue - elif c == ')': + elif c == ")": depth -= 1 i += 1 if depth == 0: # Return the substring inside (excluding the outer parentheses) - return s[content_start:i - 1], i + result = s[content_start : i - 1] + return result, i continue else: i += 1 - raise ValueError("Unbalanced parentheses in enum definition") + raise ValueError( + "Unbalanced parentheses in enum definition at position {} in {!r}".format(i, s) + ) def _parse_enum_values(content): @@ -116,7 +118,7 @@ def _parse_enum_values(content): """ values = [] i = 0 - allowed_quotes = ("'", '"', '`') + allowed_quotes = ("'", '"', "`") while i < len(content): # Skip any whitespace. while i < len(content) and content[i].isspace(): @@ -125,7 +127,11 @@ def _parse_enum_values(content): break # The next non–whitespace character must be one of the allowed quotes. if content[i] not in allowed_quotes: - raise ValueError("Expected starting quote for enum value at position {} in {!r}".format(i, content)) + raise ValueError( + "Expected starting quote for enum value at position {} in {!r}".format( + i, content + ) + ) quote = content[i] i += 1 # skip the opening quote @@ -133,26 +139,26 @@ def _parse_enum_values(content): while i < len(content): c = content[i] # For single- and double–quotes, process backslash escapes. - if quote in ("'", '"') and c == '\\': + if quote in ("'", '"') and c == "\\": if i + 1 < len(content): next_char = content[i + 1] # Mapping for common escapes. (For the quote character, map it to itself.) escapes = { - '0': '\0', - 'b': '\b', - 'n': '\n', - 'r': '\r', - 't': '\t', - 'Z': '\x1a', - '\\': '\\', - quote: quote + "0": "\0", + "b": "\b", + "n": "\n", + "r": "\r", + "t": "\t", + "Z": "\x1a", + "\\": "\\", + quote: quote, } literal_chars.append(escapes.get(next_char, next_char)) i += 2 continue else: # Trailing backslash – treat it as literal. - literal_chars.append('\\') + literal_chars.append("\\") i += 1 continue elif c == quote: @@ -169,24 +175,27 @@ def _parse_enum_values(content): literal_chars.append(c) i += 1 # Finished reading one literal; join the characters. - value = ''.join(literal_chars) + value = "".join(literal_chars) values.append(value) # Skip whitespace after the literal. while i < len(content) and content[i].isspace(): i += 1 - # If there’s a comma, skip it; otherwise, we must be at the end. + # If there's a comma, skip it; otherwise, we must be at the end. if i < len(content): - if content[i] == ',': + if content[i] == ",": i += 1 else: - raise ValueError("Expected comma between enum values at position {} in {!r}" - .format(i, content)) + raise ValueError( + "Expected comma between enum values at position {} in {!r}".format( + i, content + ) + ) return values # --- For testing purposes --- -if __name__ == '__main__': +if __name__ == "__main__": tests = [ "enum('point','qwe','def')", "ENUM('asd', 'qwe', 'def')", From aad93640b6a5ec1019d20d82bb1f43cec4697622 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Thu, 6 Mar 2025 13:49:23 -0700 Subject: [PATCH 03/16] Fix formatting --- mysql_ch_replicator/converter.py | 712 ++++++++----------- mysql_ch_replicator/converter_enum_parser.py | 10 +- 2 files changed, 319 insertions(+), 403 deletions(-) diff --git a/mysql_ch_replicator/converter.py b/mysql_ch_replicator/converter.py index f534696..cf9b1f6 100644 --- a/mysql_ch_replicator/converter.py +++ b/mysql_ch_replicator/converter.py @@ -1,61 +1,57 @@ -import json -import re import struct +import json import uuid -from typing import TYPE_CHECKING - import sqlparse -from pyparsing import CaselessKeyword, Suppress, Word, alphanums, alphas, delimitedList +import re +from pyparsing import Suppress, CaselessKeyword, Word, alphas, alphanums, delimitedList +from .table_structure import TableStructure, TableField from .converter_enum_parser import parse_mysql_enum -from .table_structure import TableField, TableStructure -if TYPE_CHECKING: - from .db_replicator import DbReplicator CHARSET_MYSQL_TO_PYTHON = { - "armscii8": None, # ARMSCII-8 is not directly supported in Python - "ascii": "ascii", - "big5": "big5", - "binary": "latin1", # Treat binary data as Latin-1 in Python - "cp1250": "cp1250", - "cp1251": "cp1251", - "cp1256": "cp1256", - "cp1257": "cp1257", - "cp850": "cp850", - "cp852": "cp852", - "cp866": "cp866", - "cp932": "cp932", - "dec8": "latin1", # DEC8 is similar to Latin-1 - "eucjpms": "euc_jp", # Map to EUC-JP - "euckr": "euc_kr", - "gb18030": "gb18030", - "gb2312": "gb2312", - "gbk": "gbk", - "geostd8": None, # GEOSTD8 is not directly supported in Python - "greek": "iso8859_7", - "hebrew": "iso8859_8", - "hp8": None, # HP8 is not directly supported in Python - "keybcs2": None, # KEYBCS2 is not directly supported in Python - "koi8r": "koi8_r", - "koi8u": "koi8_u", - "latin1": "cp1252", # MySQL's latin1 corresponds to Windows-1252 - "latin2": "iso8859_2", - "latin5": "iso8859_9", - "latin7": "iso8859_13", - "macce": "mac_latin2", - "macroman": "mac_roman", - "sjis": "shift_jis", - "swe7": None, # SWE7 is not directly supported in Python - "tis620": "tis_620", - "ucs2": "utf_16", # UCS-2 can be mapped to UTF-16 - "ujis": "euc_jp", - "utf16": "utf_16", - "utf16le": "utf_16_le", - "utf32": "utf_32", - "utf8mb3": "utf_8", # Both utf8mb3 and utf8mb4 can be mapped to UTF-8 - "utf8mb4": "utf_8", - "utf8": "utf_8", + 'armscii8': None, # ARMSCII-8 is not directly supported in Python + 'ascii': 'ascii', + 'big5': 'big5', + 'binary': 'latin1', # Treat binary data as Latin-1 in Python + 'cp1250': 'cp1250', + 'cp1251': 'cp1251', + 'cp1256': 'cp1256', + 'cp1257': 'cp1257', + 'cp850': 'cp850', + 'cp852': 'cp852', + 'cp866': 'cp866', + 'cp932': 'cp932', + 'dec8': 'latin1', # DEC8 is similar to Latin-1 + 'eucjpms': 'euc_jp', # Map to EUC-JP + 'euckr': 'euc_kr', + 'gb18030': 'gb18030', + 'gb2312': 'gb2312', + 'gbk': 'gbk', + 'geostd8': None, # GEOSTD8 is not directly supported in Python + 'greek': 'iso8859_7', + 'hebrew': 'iso8859_8', + 'hp8': None, # HP8 is not directly supported in Python + 'keybcs2': None, # KEYBCS2 is not directly supported in Python + 'koi8r': 'koi8_r', + 'koi8u': 'koi8_u', + 'latin1': 'cp1252', # MySQL's latin1 corresponds to Windows-1252 + 'latin2': 'iso8859_2', + 'latin5': 'iso8859_9', + 'latin7': 'iso8859_13', + 'macce': 'mac_latin2', + 'macroman': 'mac_roman', + 'sjis': 'shift_jis', + 'swe7': None, # SWE7 is not directly supported in Python + 'tis620': 'tis_620', + 'ucs2': 'utf_16', # UCS-2 can be mapped to UTF-16 + 'ujis': 'euc_jp', + 'utf16': 'utf_16', + 'utf16le': 'utf_16_le', + 'utf32': 'utf_32', + 'utf8mb3': 'utf_8', # Both utf8mb3 and utf8mb4 can be mapped to UTF-8 + 'utf8mb4': 'utf_8', + 'utf8': 'utf_8', } @@ -63,7 +59,7 @@ def convert_bytes(obj): if isinstance(obj, dict): new_obj = {} for k, v in obj.items(): - new_key = k.decode("utf-8") if isinstance(k, bytes) else k + new_key = k.decode('utf-8') if isinstance(k, bytes) else k new_value = convert_bytes(v) new_obj[new_key] = new_value return new_obj @@ -75,7 +71,7 @@ def convert_bytes(obj): return tuple(new_obj) return new_obj elif isinstance(obj, bytes): - return obj.decode("utf-8") + return obj.decode('utf-8') else: return obj @@ -96,37 +92,37 @@ def parse_mysql_point(binary): # Read the byte order byte_order = binary[0] if byte_order == 0: - endian = ">" + endian = '>' elif byte_order == 1: - endian = "<" + endian = '<' else: raise ValueError("Invalid byte order in WKB POINT") # Read the WKB Type - wkb_type = struct.unpack(endian + "I", binary[1:5])[0] + wkb_type = struct.unpack(endian + 'I', binary[1:5])[0] if wkb_type != 1: # WKB type 1 means POINT raise ValueError("Not a WKB POINT type") # Read X and Y coordinates - x = struct.unpack(endian + "d", binary[5:13])[0] - y = struct.unpack(endian + "d", binary[13:21])[0] + x = struct.unpack(endian + 'd', binary[5:13])[0] + y = struct.unpack(endian + 'd', binary[13:21])[0] elif len(binary) == 25: # With SRID included # First 4 bytes are the SRID - srid = struct.unpack(">I", binary[0:4])[0] # SRID is big-endian + srid = struct.unpack('>I', binary[0:4])[0] # SRID is big-endian # Next byte is byte order byte_order = binary[4] if byte_order == 0: - endian = ">" + endian = '>' elif byte_order == 1: - endian = "<" + endian = '<' else: raise ValueError("Invalid byte order in WKB POINT") # Read the WKB Type - wkb_type = struct.unpack(endian + "I", binary[5:9])[0] + wkb_type = struct.unpack(endian + 'I', binary[5:9])[0] if wkb_type != 1: # WKB type 1 means POINT raise ValueError("Not a WKB POINT type") # Read X and Y coordinates - x = struct.unpack(endian + "d", binary[9:17])[0] - y = struct.unpack(endian + "d", binary[17:25])[0] + x = struct.unpack(endian + 'd', binary[9:17])[0] + y = struct.unpack(endian + 'd', binary[17:25])[0] else: raise ValueError("Invalid binary length for WKB POINT") return (x, y) @@ -134,9 +130,9 @@ def parse_mysql_point(binary): def strip_sql_name(name): name = name.strip() - if name.startswith("`"): + if name.startswith('`'): name = name[1:] - if name.endswith("`"): + if name.endswith('`'): name = name[:-1] return name @@ -144,15 +140,15 @@ def strip_sql_name(name): def split_high_level(data, token): results = [] level = 0 - curr_data = "" + curr_data = '' for c in data: if c == token and level == 0: results.append(curr_data.strip()) - curr_data = "" + curr_data = '' continue - if c == "(": + if c == '(': level += 1 - if c == ")": + if c == ')': level -= 1 curr_data += c if curr_data: @@ -165,8 +161,9 @@ def strip_sql_comments(sql_statement): def convert_timestamp_to_datetime64(input_str): + # Define the regex pattern - pattern = r"^timestamp(?:\((\d+)\))?$" + pattern = r'^timestamp(?:\((\d+)\))?$' # Attempt to match the pattern match = re.match(pattern, input_str.strip(), re.IGNORECASE) @@ -175,34 +172,34 @@ def convert_timestamp_to_datetime64(input_str): # If a precision is provided, include it in the replacement precision = match.group(1) if precision is not None: - return f"DateTime64({precision})" + return f'DateTime64({precision})' else: - return "DateTime64" + return 'DateTime64' else: raise ValueError(f"Invalid input string format: '{input_str}'") class MysqlToClickhouseConverter: - def __init__(self, db_replicator: "DbReplicator" = None): + def __init__(self, db_replicator: 'DbReplicator' = None): self.db_replicator = db_replicator self.types_mapping = {} if self.db_replicator is not None: self.types_mapping = db_replicator.config.types_mapping def convert_type(self, mysql_type, parameters): - is_unsigned = "unsigned" in parameters.lower() + is_unsigned = 'unsigned' in parameters.lower() result_type = self.types_mapping.get(mysql_type) if result_type is not None: return result_type - if mysql_type == "point": - return "Tuple(x Float32, y Float32)" + if mysql_type == 'point': + return 'Tuple(x Float32, y Float32)' # Correctly handle numeric types - if mysql_type.startswith("numeric"): + if mysql_type.startswith('numeric'): # Determine if parameters are specified via parentheses: - if "(" in mysql_type and ")" in mysql_type: + if '(' in mysql_type and ')' in mysql_type: # Expecting a type definition like "numeric(precision, scale)" pattern = r"numeric\((\d+)\s*,\s*(\d+)\)" match = re.search(pattern, mysql_type) @@ -237,160 +234,141 @@ def convert_type(self, mysql_type, parameters): # For types with a defined fractional part, use a Decimal mapping. return f"Decimal({precision}, {scale})" - if mysql_type == "int": + if mysql_type == 'int': if is_unsigned: - return "UInt32" - return "Int32" - if mysql_type == "integer": + return 'UInt32' + return 'Int32' + if mysql_type == 'integer': if is_unsigned: - return "UInt32" - return "Int32" - if mysql_type == "bigint": + return 'UInt32' + return 'Int32' + if mysql_type == 'bigint': if is_unsigned: - return "UInt64" - return "Int64" - if mysql_type == "double": - return "Float64" - if mysql_type == "real": - return "Float64" - if mysql_type == "float": - return "Float32" - if mysql_type == "date": - return "Date32" - if mysql_type == "tinyint(1)": - return "Bool" - if mysql_type == "bit(1)": - return "Bool" - if mysql_type == "bool": - return "Bool" - if "smallint" in mysql_type: + return 'UInt64' + return 'Int64' + if mysql_type == 'double': + return 'Float64' + if mysql_type == 'real': + return 'Float64' + if mysql_type == 'float': + return 'Float32' + if mysql_type == 'date': + return 'Date32' + if mysql_type == 'tinyint(1)': + return 'Bool' + if mysql_type == 'bit(1)': + return 'Bool' + if mysql_type == 'bool': + return 'Bool' + if 'smallint' in mysql_type: if is_unsigned: - return "UInt16" - return "Int16" - if "tinyint" in mysql_type: + return 'UInt16' + return 'Int16' + if 'tinyint' in mysql_type: if is_unsigned: - return "UInt8" - return "Int8" - if "mediumint" in mysql_type: + return 'UInt8' + return 'Int8' + if 'mediumint' in mysql_type: if is_unsigned: - return "UInt32" - return "Int32" - if "datetime" in mysql_type: - return mysql_type.replace("datetime", "DateTime64") - if "longtext" in mysql_type: - return "String" - if "varchar" in mysql_type: - return "String" - if mysql_type.startswith("enum"): + return 'UInt32' + return 'Int32' + if 'datetime' in mysql_type: + return mysql_type.replace('datetime', 'DateTime64') + if 'longtext' in mysql_type: + return 'String' + if 'varchar' in mysql_type: + return 'String' + if mysql_type.startswith('enum'): enum_values = parse_mysql_enum(mysql_type) ch_enum_values = [] for idx, value_name in enumerate(enum_values): - ch_enum_values.append(f"'{value_name}' = {idx + 1}") - ch_enum_values = ", ".join(ch_enum_values) + ch_enum_values.append(f"'{value_name}' = {idx+1}") + ch_enum_values = ', '.join(ch_enum_values) if len(enum_values) <= 127: # Enum8('red' = 1, 'green' = 2, 'black' = 3) - return f"Enum8({ch_enum_values})" + return f'Enum8({ch_enum_values})' else: # Enum16('red' = 1, 'green' = 2, 'black' = 3) - return f"Enum16({ch_enum_values})" - if "text" in mysql_type: - return "String" - if "blob" in mysql_type: - return "String" - if "char" in mysql_type: - return "String" - if "json" in mysql_type: - return "String" - if "decimal" in mysql_type: - return "Float64" - if "float" in mysql_type: - return "Float32" - if "double" in mysql_type: - return "Float64" - if "bigint" in mysql_type: + return f'Enum16({ch_enum_values})' + if 'text' in mysql_type: + return 'String' + if 'blob' in mysql_type: + return 'String' + if 'char' in mysql_type: + return 'String' + if 'json' in mysql_type: + return 'String' + if 'decimal' in mysql_type: + return 'Float64' + if 'float' in mysql_type: + return 'Float32' + if 'double' in mysql_type: + return 'Float64' + if 'bigint' in mysql_type: if is_unsigned: - return "UInt64" - return "Int64" - if "integer" in mysql_type or "int(" in mysql_type: + return 'UInt64' + return 'Int64' + if 'integer' in mysql_type or 'int(' in mysql_type: if is_unsigned: - return "UInt32" - return "Int32" - if "real" in mysql_type: - return "Float64" - if mysql_type.startswith("timestamp"): + return 'UInt32' + return 'Int32' + if 'real' in mysql_type: + return 'Float64' + if mysql_type.startswith('timestamp'): return convert_timestamp_to_datetime64(mysql_type) - if mysql_type.startswith("time"): - return "String" - if "varbinary" in mysql_type: - return "String" - if "binary" in mysql_type: - return "String" - if "set(" in mysql_type: - return "String" + if mysql_type.startswith('time'): + return 'String' + if 'varbinary' in mysql_type: + return 'String' + if 'binary' in mysql_type: + return 'String' + if 'set(' in mysql_type: + return 'String' raise Exception(f'unknown mysql type "{mysql_type}"') def convert_field_type(self, mysql_type, mysql_parameters): mysql_type = mysql_type.lower() mysql_parameters = mysql_parameters.lower() - not_null = "not null" in mysql_parameters + not_null = 'not null' in mysql_parameters clickhouse_type = self.convert_type(mysql_type, mysql_parameters) - if "Tuple" in clickhouse_type: + if 'Tuple' in clickhouse_type: not_null = True if not not_null: - clickhouse_type = f"Nullable({clickhouse_type})" + clickhouse_type = f'Nullable({clickhouse_type})' return clickhouse_type - def convert_table_structure( - self, mysql_structure: TableStructure - ) -> TableStructure: + def convert_table_structure(self, mysql_structure: TableStructure) -> TableStructure: clickhouse_structure = TableStructure() clickhouse_structure.table_name = mysql_structure.table_name clickhouse_structure.if_not_exists = mysql_structure.if_not_exists for field in mysql_structure.fields: - clickhouse_field_type = self.convert_field_type( - field.field_type, field.parameters - ) - clickhouse_structure.fields.append( - TableField( - name=field.name, - field_type=clickhouse_field_type, - ) - ) + clickhouse_field_type = self.convert_field_type(field.field_type, field.parameters) + clickhouse_structure.fields.append(TableField( + name=field.name, + field_type=clickhouse_field_type, + )) clickhouse_structure.primary_keys = mysql_structure.primary_keys clickhouse_structure.preprocess() return clickhouse_structure def convert_records( - self, - mysql_records, - mysql_structure: TableStructure, - clickhouse_structure: TableStructure, - only_primary: bool = False, + self, mysql_records, mysql_structure: TableStructure, clickhouse_structure: TableStructure, + only_primary: bool = False, ): mysql_field_types = [field.field_type for field in mysql_structure.fields] - clickhouse_filed_types = [ - field.field_type for field in clickhouse_structure.fields - ] + clickhouse_filed_types = [field.field_type for field in clickhouse_structure.fields] clickhouse_records = [] for mysql_record in mysql_records: clickhouse_record = self.convert_record( - mysql_record, - mysql_field_types, - clickhouse_filed_types, - mysql_structure, - only_primary, + mysql_record, mysql_field_types, clickhouse_filed_types, mysql_structure, only_primary, ) clickhouse_records.append(clickhouse_record) return clickhouse_records def convert_record( - self, - mysql_record, - mysql_field_types, - clickhouse_field_types, - mysql_structure: TableStructure, - only_primary: bool, + self, mysql_record, mysql_field_types, clickhouse_field_types, mysql_structure: TableStructure, + only_primary: bool, ): clickhouse_record = [] for idx, mysql_field_value in enumerate(mysql_record): @@ -401,50 +379,38 @@ def convert_record( clickhouse_field_value = mysql_field_value mysql_field_type = mysql_field_types[idx] clickhouse_field_type = clickhouse_field_types[idx] - if ( - mysql_field_type.startswith("time") - and "String" in clickhouse_field_type - ): + if mysql_field_type.startswith('time') and 'String' in clickhouse_field_type: clickhouse_field_value = str(mysql_field_value) - if mysql_field_type == "json" and "String" in clickhouse_field_type: + if mysql_field_type == 'json' and 'String' in clickhouse_field_type: if not isinstance(clickhouse_field_value, str): - clickhouse_field_value = json.dumps( - convert_bytes(clickhouse_field_value) - ) + clickhouse_field_value = json.dumps(convert_bytes(clickhouse_field_value)) if clickhouse_field_value is not None: - if "UUID" in clickhouse_field_type: + if 'UUID' in clickhouse_field_type: if len(clickhouse_field_value) == 36: if isinstance(clickhouse_field_value, bytes): - clickhouse_field_value = clickhouse_field_value.decode( - "utf-8" - ) + clickhouse_field_value = clickhouse_field_value.decode('utf-8') clickhouse_field_value = uuid.UUID(clickhouse_field_value).bytes - if "UInt16" in clickhouse_field_type and clickhouse_field_value < 0: + if 'UInt16' in clickhouse_field_type and clickhouse_field_value < 0: clickhouse_field_value = 65536 + clickhouse_field_value - if "UInt8" in clickhouse_field_type and clickhouse_field_value < 0: + if 'UInt8' in clickhouse_field_type and clickhouse_field_value < 0: clickhouse_field_value = 256 + clickhouse_field_value - if ( - "mediumint" in mysql_field_type.lower() - and clickhouse_field_value < 0 - ): + if 'mediumint' in mysql_field_type.lower() and clickhouse_field_value < 0: clickhouse_field_value = 16777216 + clickhouse_field_value - if "UInt32" in clickhouse_field_type and clickhouse_field_value < 0: + if 'UInt32' in clickhouse_field_type and clickhouse_field_value < 0: clickhouse_field_value = 4294967296 + clickhouse_field_value - if "UInt64" in clickhouse_field_type and clickhouse_field_value < 0: - clickhouse_field_value = ( - 18446744073709551616 + clickhouse_field_value - ) + if 'UInt64' in clickhouse_field_type and clickhouse_field_value < 0: + clickhouse_field_value = 18446744073709551616 + clickhouse_field_value - if "String" in clickhouse_field_type and ( - "text" in mysql_field_type or "char" in mysql_field_type + if 'String' in clickhouse_field_type and ( + 'text' in mysql_field_type or 'char' in mysql_field_type ): if isinstance(clickhouse_field_value, bytes): - charset = mysql_structure.charset_python or "utf-8" + charset = mysql_structure.charset_python or 'utf-8' clickhouse_field_value = clickhouse_field_value.decode(charset) - if "set(" in mysql_field_type: + if 'set(' in mysql_field_type: set_values = mysql_structure.fields[idx].additional_data if isinstance(clickhouse_field_value, int): bit_mask = clickhouse_field_value @@ -457,31 +423,29 @@ def convert_record( clickhouse_field_value = [ v for v in set_values if v in clickhouse_field_value ] - clickhouse_field_value = ",".join(clickhouse_field_value) + clickhouse_field_value = ','.join(clickhouse_field_value) - if mysql_field_type.startswith("point"): + if mysql_field_type.startswith('point'): clickhouse_field_value = parse_mysql_point(clickhouse_field_value) - if mysql_field_type.startswith("enum(") and isinstance( - clickhouse_field_value, int - ): + if mysql_field_type.startswith('enum(') and isinstance(clickhouse_field_value, int): enum_values = mysql_structure.fields[idx].additional_data - clickhouse_field_value = enum_values[int(clickhouse_field_value) - 1] + clickhouse_field_value = enum_values[int(clickhouse_field_value)-1] clickhouse_record.append(clickhouse_field_value) return tuple(clickhouse_record) def __basic_validate_query(self, mysql_query): mysql_query = mysql_query.strip() - if mysql_query.endswith(";"): + if mysql_query.endswith(';'): mysql_query = mysql_query[:-1] - if mysql_query.find(";") != -1: - raise Exception("multi-query statement not supported") + if mysql_query.find(';') != -1: + raise Exception('multi-query statement not supported') return mysql_query - + def get_db_and_table_name(self, token, db_name): - if "." in token: - db_name, table_name = token.split(".") + if '.' in token: + db_name, table_name = token.split('.') else: table_name = token db_name = strip_sql_name(db_name) @@ -489,9 +453,9 @@ def get_db_and_table_name(self, token, db_name): if self.db_replicator: if db_name == self.db_replicator.database: db_name = self.db_replicator.target_database - matches_config = self.db_replicator.config.is_database_matches( - db_name - ) and self.db_replicator.config.is_table_matches(table_name) + matches_config = ( + self.db_replicator.config.is_database_matches(db_name) + and self.db_replicator.config.is_table_matches(table_name)) else: matches_config = True @@ -501,21 +465,19 @@ def convert_alter_query(self, mysql_query, db_name): mysql_query = self.__basic_validate_query(mysql_query) tokens = mysql_query.split() - if tokens[0].lower() != "alter": - raise Exception("wrong query") + if tokens[0].lower() != 'alter': + raise Exception('wrong query') - if tokens[1].lower() != "table": - raise Exception("wrong query") + if tokens[1].lower() != 'table': + raise Exception('wrong query') - db_name, table_name, matches_config = self.get_db_and_table_name( - tokens[2], db_name - ) + db_name, table_name, matches_config = self.get_db_and_table_name(tokens[2], db_name) if not matches_config: return - subqueries = " ".join(tokens[3:]) - subqueries = split_high_level(subqueries, ",") + subqueries = ' '.join(tokens[3:]) + subqueries = split_high_level(subqueries, ',') for subquery in subqueries: subquery = subquery.strip() @@ -524,33 +486,33 @@ def convert_alter_query(self, mysql_query, db_name): op_name = tokens[0].lower() tokens = tokens[1:] - if tokens[0].lower() == "column": + if tokens[0].lower() == 'column': tokens = tokens[1:] - if op_name == "add": - if tokens[0].lower() in ("constraint", "index", "foreign", "unique"): + if op_name == 'add': + if tokens[0].lower() in ('constraint', 'index', 'foreign', 'unique'): continue self.__convert_alter_table_add_column(db_name, table_name, tokens) continue - if op_name == "drop": - if tokens[0].lower() in ("constraint", "index", "foreign", "unique"): + if op_name == 'drop': + if tokens[0].lower() in ('constraint', 'index', 'foreign', 'unique'): continue self.__convert_alter_table_drop_column(db_name, table_name, tokens) continue - if op_name == "modify": + if op_name == 'modify': self.__convert_alter_table_modify_column(db_name, table_name, tokens) continue - if op_name == "alter": + if op_name == 'alter': continue - if op_name == "change": + if op_name == 'change': self.__convert_alter_table_change_column(db_name, table_name, tokens) continue - raise Exception(f"operation {op_name} not implement, query: {subquery}") + raise Exception(f'operation {op_name} not implement, query: {subquery}') @classmethod def _tokenize_alter_query(cls, sql_line): @@ -562,8 +524,7 @@ def _tokenize_alter_query(cls, sql_line): # # The order is important: for example, if a word is immediately followed by parentheses, # we want to grab it as a single token. - token_pattern = re.compile( - r""" + token_pattern = re.compile(r''' ( # start capture group for a token `[^`]+`(?:\([^)]*\))? | # backquoted identifier w/ optional parentheses \w+(?:\([^)]*\))? | # a word with optional parentheses @@ -571,9 +532,7 @@ def _tokenize_alter_query(cls, sql_line): "(?:\\"|[^"])*" | # a double-quoted string [^\s]+ # fallback: any sequence of non-whitespace characters ) - """, - re.VERBOSE, - ) + ''', re.VERBOSE) tokens = token_pattern.findall(sql_line) # Now, split the column definition into: @@ -584,29 +543,10 @@ def _tokenize_alter_query(cls, sql_line): # # We define a set of keywords that indicate the start of column options. constraint_keywords = { - "DEFAULT", - "NOT", - "NULL", - "AUTO_INCREMENT", - "PRIMARY", - "UNIQUE", - "COMMENT", - "COLLATE", - "REFERENCES", - "ON", - "CHECK", - "CONSTRAINT", - "AFTER", - "BEFORE", - "GENERATED", - "VIRTUAL", - "STORED", - "FIRST", - "ALWAYS", - "AS", - "IDENTITY", - "INVISIBLE", - "PERSISTED", + "DEFAULT", "NOT", "NULL", "AUTO_INCREMENT", "PRIMARY", "UNIQUE", + "COMMENT", "COLLATE", "REFERENCES", "ON", "CHECK", "CONSTRAINT", + "AFTER", "BEFORE", "GENERATED", "VIRTUAL", "STORED", "FIRST", + "ALWAYS", "AS", "IDENTITY", "INVISIBLE", "PERSISTED", } if not tokens: @@ -614,7 +554,7 @@ def _tokenize_alter_query(cls, sql_line): # The first token is always the column name. column_name = tokens[0] - # Now "merge" tokens after the column name that belong to the type. + # Now “merge” tokens after the column name that belong to the type. # (For many types the type is written as a single token already – # e.g. "VARCHAR(254)" or "NUMERIC(5, 2)", but for types like # "DOUBLE PRECISION" or "INT UNSIGNED" the .split() would produce two tokens.) @@ -635,28 +575,26 @@ def _tokenize_alter_query(cls, sql_line): return [column_name] + param_tokens def __convert_alter_table_add_column(self, db_name, table_name, tokens): - tokens = self._tokenize_alter_query(" ".join(tokens)) + tokens = self._tokenize_alter_query(' '.join(tokens)) if len(tokens) < 2: - raise Exception("wrong tokens count", tokens) + raise Exception('wrong tokens count', tokens) column_after = None column_first = False - if tokens[-2].lower() == "after": + if tokens[-2].lower() == 'after': column_after = strip_sql_name(tokens[-1]) tokens = tokens[:-2] if len(tokens) < 2: - raise Exception("wrong tokens count", tokens) - elif tokens[-1].lower() == "first": + raise Exception('wrong tokens count', tokens) + elif tokens[-1].lower() == 'first': column_first = True column_name = strip_sql_name(tokens[0]) column_type_mysql = tokens[1] - column_type_mysql_parameters = " ".join(tokens[2:]) + column_type_mysql_parameters = ' '.join(tokens[2:]) - column_type_ch = self.convert_field_type( - column_type_mysql, column_type_mysql_parameters - ) + column_type_ch = self.convert_field_type(column_type_mysql, column_type_mysql_parameters) # update table structure if self.db_replicator: @@ -668,7 +606,7 @@ def __convert_alter_table_add_column(self, db_name, table_name, tokens): mysql_table_structure.add_field_first( TableField(name=column_name, field_type=column_type_mysql) ) - + ch_table_structure.add_field_first( TableField(name=column_name, field_type=column_type_ch) ) @@ -686,18 +624,18 @@ def __convert_alter_table_add_column(self, db_name, table_name, tokens): column_after, ) - query = f"ALTER TABLE `{db_name}`.`{table_name}` ADD COLUMN `{column_name}` {column_type_ch}" + query = f'ALTER TABLE `{db_name}`.`{table_name}` ADD COLUMN `{column_name}` {column_type_ch}' if column_first: - query += " FIRST" + query += ' FIRST' else: - query += f" AFTER {column_after}" + query += f' AFTER {column_after}' if self.db_replicator: self.db_replicator.clickhouse_api.execute_command(query) def __convert_alter_table_drop_column(self, db_name, table_name, tokens): if len(tokens) != 1: - raise Exception("wrong tokens count", tokens) + raise Exception('wrong tokens count', tokens) column_name = strip_sql_name(tokens[0]) @@ -710,21 +648,19 @@ def __convert_alter_table_drop_column(self, db_name, table_name, tokens): mysql_table_structure.remove_field(field_name=column_name) ch_table_structure.remove_field(field_name=column_name) - query = f"ALTER TABLE `{db_name}`.`{table_name}` DROP COLUMN {column_name}" + query = f'ALTER TABLE `{db_name}`.`{table_name}` DROP COLUMN {column_name}' if self.db_replicator: self.db_replicator.clickhouse_api.execute_command(query) def __convert_alter_table_modify_column(self, db_name, table_name, tokens): if len(tokens) < 2: - raise Exception("wrong tokens count", tokens) + raise Exception('wrong tokens count', tokens) column_name = strip_sql_name(tokens[0]) column_type_mysql = tokens[1] - column_type_mysql_parameters = " ".join(tokens[2:]) + column_type_mysql_parameters = ' '.join(tokens[2:]) - column_type_ch = self.convert_field_type( - column_type_mysql, column_type_mysql_parameters - ) + column_type_ch = self.convert_field_type(column_type_mysql, column_type_mysql_parameters) # update table structure if self.db_replicator: @@ -740,22 +676,20 @@ def __convert_alter_table_modify_column(self, db_name, table_name, tokens): TableField(name=column_name, field_type=column_type_ch), ) - query = f"ALTER TABLE `{db_name}`.`{table_name}` MODIFY COLUMN `{column_name}` {column_type_ch}" + query = f'ALTER TABLE `{db_name}`.`{table_name}` MODIFY COLUMN `{column_name}` {column_type_ch}' if self.db_replicator: self.db_replicator.clickhouse_api.execute_command(query) def __convert_alter_table_change_column(self, db_name, table_name, tokens): if len(tokens) < 3: - raise Exception("wrong tokens count", tokens) + raise Exception('wrong tokens count', tokens) column_name = strip_sql_name(tokens[0]) new_column_name = strip_sql_name(tokens[1]) column_type_mysql = tokens[2] - column_type_mysql_parameters = " ".join(tokens[3:]) + column_type_mysql_parameters = ' '.join(tokens[3:]) - column_type_ch = self.convert_field_type( - column_type_mysql, column_type_mysql_parameters - ) + column_type_ch = self.convert_field_type(column_type_mysql, column_type_mysql_parameters) # update table structure if self.db_replicator: @@ -763,11 +697,10 @@ def __convert_alter_table_change_column(self, db_name, table_name, tokens): mysql_table_structure: TableStructure = table_structure[0] ch_table_structure: TableStructure = table_structure[1] - current_column_type_ch = ch_table_structure.get_field( - column_name - ).field_type + current_column_type_ch = ch_table_structure.get_field(column_name).field_type if current_column_type_ch != column_type_ch: + mysql_table_structure.update_field( TableField(name=column_name, field_type=column_type_mysql), ) @@ -776,7 +709,7 @@ def __convert_alter_table_change_column(self, db_name, table_name, tokens): TableField(name=column_name, field_type=column_type_ch), ) - query = f"ALTER TABLE `{db_name}`.`{table_name}` MODIFY COLUMN {column_name} {column_type_ch}" + query = f'ALTER TABLE `{db_name}`.`{table_name}` MODIFY COLUMN {column_name} {column_type_ch}' self.db_replicator.clickhouse_api.execute_command(query) if column_name != new_column_name: @@ -786,119 +719,109 @@ def __convert_alter_table_change_column(self, db_name, table_name, tokens): curr_field_mysql.name = new_column_name curr_field_clickhouse.name = new_column_name - query = f"ALTER TABLE `{db_name}`.`{table_name}` RENAME COLUMN {column_name} TO {new_column_name}" + query = f'ALTER TABLE `{db_name}`.`{table_name}` RENAME COLUMN {column_name} TO {new_column_name}' self.db_replicator.clickhouse_api.execute_command(query) - def parse_create_table_query( - self, mysql_query - ) -> tuple[TableStructure, TableStructure]: + def parse_create_table_query(self, mysql_query) -> tuple[TableStructure, TableStructure]: mysql_table_structure = self.parse_mysql_table_structure(mysql_query) ch_table_structure = self.convert_table_structure(mysql_table_structure) return mysql_table_structure, ch_table_structure def convert_drop_table_query(self, mysql_query): - raise Exception("not implement") + raise Exception('not implement') def _strip_comments(self, create_statement): pattern = r'\bCOMMENT(?:\s*=\s*|\s+)([\'"])(?:\\.|[^\\])*?\1' - return re.sub(pattern, "", create_statement, flags=re.IGNORECASE) + return re.sub(pattern, '', create_statement, flags=re.IGNORECASE) def parse_mysql_table_structure(self, create_statement, required_table_name=None): create_statement = self._strip_comments(create_statement) structure = TableStructure() - tokens = sqlparse.parse(create_statement.replace("\n", " ").strip())[0].tokens + tokens = sqlparse.parse(create_statement.replace('\n', ' ').strip())[0].tokens tokens = [t for t in tokens if not t.is_whitespace and not t.is_newline] # remove "IF NOT EXISTS" - if ( - len(tokens) > 5 - and tokens[0].normalized.upper() == "CREATE" - and tokens[1].normalized.upper() == "TABLE" - and tokens[2].normalized.upper() == "IF" - and tokens[3].normalized.upper() == "NOT" - and tokens[4].normalized.upper() == "EXISTS" - ): + if (len(tokens) > 5 and + tokens[0].normalized.upper() == 'CREATE' and + tokens[1].normalized.upper() == 'TABLE' and + tokens[2].normalized.upper() == 'IF' and + tokens[3].normalized.upper() == 'NOT' and + tokens[4].normalized.upper() == 'EXISTS'): del tokens[2:5] # Remove the 'IF', 'NOT', 'EXISTS' tokens structure.if_not_exists = True if tokens[0].ttype != sqlparse.tokens.DDL: - raise Exception("wrong create statement", create_statement) - if tokens[0].normalized.lower() != "create": - raise Exception("wrong create statement", create_statement) + raise Exception('wrong create statement', create_statement) + if tokens[0].normalized.lower() != 'create': + raise Exception('wrong create statement', create_statement) if tokens[1].ttype != sqlparse.tokens.Keyword: - raise Exception("wrong create statement", create_statement) + raise Exception('wrong create statement', create_statement) if not isinstance(tokens[2], sqlparse.sql.Identifier): - raise Exception("wrong create statement", create_statement) + raise Exception('wrong create statement', create_statement) # get_real_name() returns the table name if the token is in the # style `.` structure.table_name = strip_sql_name(tokens[2].get_real_name()) if not isinstance(tokens[3], sqlparse.sql.Parenthesis): - raise Exception("wrong create statement", create_statement) + raise Exception('wrong create statement', create_statement) - # print(' --- processing statement:\n', create_statement, '\n') + #print(' --- processing statement:\n', create_statement, '\n') inner_tokens = tokens[3].tokens - inner_tokens = "".join([str(t) for t in inner_tokens[1:-1]]).strip() - inner_tokens = split_high_level(inner_tokens, ",") + inner_tokens = ''.join([str(t) for t in inner_tokens[1:-1]]).strip() + inner_tokens = split_high_level(inner_tokens, ',') - prev_token = "" - prev_prev_token = "" + prev_token = '' + prev_prev_token = '' for line in tokens[4:]: curr_token = line.value - if prev_token == "=" and prev_prev_token.lower() == "charset": + if prev_token == '=' and prev_prev_token.lower() == 'charset': structure.charset = curr_token prev_prev_token = prev_token prev_token = curr_token - structure.charset_python = "utf-8" + structure.charset_python = 'utf-8' if structure.charset: structure.charset_python = CHARSET_MYSQL_TO_PYTHON[structure.charset] - prev_line = "" + prev_line = '' for line in inner_tokens: line = prev_line + line - q_count = line.count("`") + q_count = line.count('`') if q_count % 2 == 1: prev_line = line continue - prev_line = "" + prev_line = '' - if line.lower().startswith("unique key"): + if line.lower().startswith('unique key'): continue - if line.lower().startswith("key"): + if line.lower().startswith('key'): continue - if line.lower().startswith("constraint"): + if line.lower().startswith('constraint'): continue - if line.lower().startswith("fulltext"): + if line.lower().startswith('fulltext'): continue - if line.lower().startswith("spatial"): + if line.lower().startswith('spatial'): continue - if line.lower().startswith("primary key"): + if line.lower().startswith('primary key'): # Define identifier to match column names, handling backticks and unquoted names - identifier = ( - Suppress("`") + Word(alphas + alphanums + "_") + Suppress("`") - ) | Word(alphas + alphanums + "_") + identifier = (Suppress('`') + Word(alphas + alphanums + '_') + Suppress('`')) | Word( + alphas + alphanums + '_') # Build the parsing pattern - pattern = ( - CaselessKeyword("PRIMARY") - + CaselessKeyword("KEY") - + Suppress("(") - + delimitedList(identifier)("column_names") - + Suppress(")") - ) + pattern = CaselessKeyword('PRIMARY') + CaselessKeyword('KEY') + Suppress('(') + delimitedList( + identifier)('column_names') + Suppress(')') # Parse the line result = pattern.parseString(line) # Extract and process the primary key column names - primary_keys = [strip_sql_name(name) for name in result["column_names"]] + primary_keys = [strip_sql_name(name) for name in result['column_names']] structure.primary_keys = primary_keys @@ -907,12 +830,12 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None line = line.strip() # print(" === processing line", line) - if line.startswith("`"): - end_pos = line.find("`", 1) + if line.startswith('`'): + end_pos = line.find('`', 1) field_name = line[1:end_pos] line = line[end_pos + 1 :].strip() # Don't split by space for enum and set types that might contain spaces - if line.lower().startswith("enum(") or line.lower().startswith("set("): + if line.lower().startswith('enum(') or line.lower().startswith('set('): # Find the end of the enum/set definition (closing parenthesis) open_parens = 0 in_quotes = False @@ -926,9 +849,9 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None quote_char = char elif char == quote_char: in_quotes = False - elif char == "(" and not in_quotes: + elif char == '(' and not in_quotes: open_parens += 1 - elif char == ")" and not in_quotes: + elif char == ')' and not in_quotes: open_parens -= 1 if open_parens == 0: end_pos = i + 1 @@ -939,28 +862,26 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None field_parameters = line[end_pos:].strip() else: # Fallback to original behavior if we can't find the end - definition = line.split(" ") + definition = line.split(' ') field_type = definition[0] field_parameters = ( - " ".join(definition[1:]) if len(definition) > 1 else "" + ' '.join(definition[1:]) if len(definition) > 1 else '' ) else: - definition = line.split(" ") + definition = line.split(' ') field_type = definition[0] field_parameters = ( - " ".join(definition[1:]) if len(definition) > 1 else "" + ' '.join(definition[1:]) if len(definition) > 1 else '' ) else: - definition = line.split(" ") + definition = line.split(' ') field_name = strip_sql_name(definition[0]) definition = definition[1:] - - # Handle enum and set types with spaces for non-backtick field names if definition and ( - definition[0].lower().startswith("enum(") - or definition[0].lower().startswith("set(") + definition[0].lower().startswith('enum(') + or definition[0].lower().startswith('set(') ): - line = " ".join(definition) + line = ' '.join(definition) # Find the end of the enum/set definition (closing parenthesis) open_parens = 0 in_quotes = False @@ -974,9 +895,9 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None quote_char = char elif char == quote_char: in_quotes = False - elif char == "(" and not in_quotes: + elif char == '(' and not in_quotes: open_parens += 1 - elif char == ")" and not in_quotes: + elif char == ')' and not in_quotes: open_parens -= 1 if open_parens == 0: end_pos = i + 1 @@ -989,57 +910,56 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None # Fallback to original behavior field_type = definition[0] field_parameters = ( - " ".join(definition[1:]) if len(definition) > 1 else "" + ' '.join(definition[1:]) if len(definition) > 1 else '' ) else: field_type = definition[0] field_parameters = ( - " ".join(definition[1:]) if len(definition) > 1 else "" - ) + ' '.join(definition[1:]) if len(definition) > 1 else '' + ) + field_type = definition[0] + field_parameters = '' + if len(definition) > 1: + field_parameters = ' '.join(definition[1:]) additional_data = None - if "set(" in field_type.lower(): - vals = field_type[len("set(") :] - close_pos = vals.find(")") + if 'set(' in field_type.lower(): + vals = field_type[len('set('):] + close_pos = vals.find(')') vals = vals[:close_pos] - vals = vals.split(",") - + vals = vals.split(',') def vstrip(e): if not e: return e - if e[0] in "\"'": + if e[0] in '"\'': return e[1:-1] return e - vals = [vstrip(v) for v in vals] additional_data = vals - if field_type.lower().startswith("enum("): + if field_type.lower().startswith('enum('): additional_data = parse_mysql_enum(field_type) - structure.fields.append( - TableField( - name=field_name, - field_type=field_type, - parameters=field_parameters, - additional_data=additional_data, - ) - ) - # print(' ---- params:', field_parameters) + structure.fields.append(TableField( + name=field_name, + field_type=field_type, + parameters=field_parameters, + additional_data=additional_data, + )) + #print(' ---- params:', field_parameters) + if not structure.primary_keys: for field in structure.fields: - if "primary key" in field.parameters.lower(): + if 'primary key' in field.parameters.lower(): structure.primary_keys.append(field.name) if not structure.primary_keys: - if structure.has_field("id"): - structure.primary_keys = ["id"] + if structure.has_field('id'): + structure.primary_keys = ['id'] if not structure.primary_keys: - raise Exception( - f"No primary key for table {structure.table_name}, {create_statement}" - ) + raise Exception(f'No primary key for table {structure.table_name}, {create_statement}') structure.preprocess() return structure diff --git a/mysql_ch_replicator/converter_enum_parser.py b/mysql_ch_replicator/converter_enum_parser.py index b4f497d..8694d5f 100644 --- a/mysql_ch_replicator/converter_enum_parser.py +++ b/mysql_ch_replicator/converter_enum_parser.py @@ -46,7 +46,6 @@ def _extract_parenthesized_content(s, start_index): ', " or `) and also to skip over escape sequences in single/double quotes. (Backticks do not process backslash escapes.) """ - if s[start_index] != "(": raise ValueError("Expected '(' at position {}".format(start_index)) depth = 1 @@ -95,15 +94,12 @@ def _extract_parenthesized_content(s, start_index): i += 1 if depth == 0: # Return the substring inside (excluding the outer parentheses) - result = s[content_start : i - 1] - return result, i + return s[content_start : i - 1], i continue else: i += 1 - raise ValueError( - "Unbalanced parentheses in enum definition at position {} in {!r}".format(i, s) - ) + raise ValueError("Unbalanced parentheses in enum definition") def _parse_enum_values(content): @@ -181,7 +177,7 @@ def _parse_enum_values(content): # Skip whitespace after the literal. while i < len(content) and content[i].isspace(): i += 1 - # If there's a comma, skip it; otherwise, we must be at the end. + # If there’s a comma, skip it; otherwise, we must be at the end. if i < len(content): if content[i] == ",": i += 1 From 883f4268491c3b71bb0779177236c772121d6606 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Thu, 6 Mar 2025 13:55:36 -0700 Subject: [PATCH 04/16] Fix --- mysql_ch_replicator/converter_enum_parser.py | 59 +++++++++----------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/mysql_ch_replicator/converter_enum_parser.py b/mysql_ch_replicator/converter_enum_parser.py index 8694d5f..c8d29df 100644 --- a/mysql_ch_replicator/converter_enum_parser.py +++ b/mysql_ch_replicator/converter_enum_parser.py @@ -1,3 +1,5 @@ + + def parse_mysql_enum(enum_definition): """ Accepts a MySQL ENUM definition string (case–insensitive), @@ -20,7 +22,7 @@ def parse_mysql_enum(enum_definition): raise ValueError("String does not start with 'enum'") # Find the first opening parenthesis. - pos = s.find("(") + pos = s.find('(') if pos == -1: raise ValueError("Missing '(' in the enum definition") @@ -46,7 +48,7 @@ def _extract_parenthesized_content(s, start_index): ', " or `) and also to skip over escape sequences in single/double quotes. (Backticks do not process backslash escapes.) """ - if s[start_index] != "(": + if s[start_index] != '(': raise ValueError("Expected '(' at position {}".format(start_index)) depth = 1 i = start_index + 1 @@ -54,14 +56,14 @@ def _extract_parenthesized_content(s, start_index): in_quote = None # will be set to a quoting character when inside a quoted literal # Allow these quote characters. - allowed_quotes = ("'", '"', "`") + allowed_quotes = ("'", '"', '`') while i < len(s): c = s[i] if in_quote: # Inside a quoted literal. if in_quote in ("'", '"'): - if c == "\\": + if c == '\\': # Skip the escape character and the next character. i += 2 continue @@ -85,16 +87,16 @@ def _extract_parenthesized_content(s, start_index): in_quote = c i += 1 continue - elif c == "(": + elif c == '(': depth += 1 i += 1 continue - elif c == ")": + elif c == ')': depth -= 1 i += 1 if depth == 0: # Return the substring inside (excluding the outer parentheses) - return s[content_start : i - 1], i + return s[content_start:i - 1], i continue else: i += 1 @@ -114,7 +116,7 @@ def _parse_enum_values(content): """ values = [] i = 0 - allowed_quotes = ("'", '"', "`") + allowed_quotes = ("'", '"', '`') while i < len(content): # Skip any whitespace. while i < len(content) and content[i].isspace(): @@ -123,11 +125,7 @@ def _parse_enum_values(content): break # The next non–whitespace character must be one of the allowed quotes. if content[i] not in allowed_quotes: - raise ValueError( - "Expected starting quote for enum value at position {} in {!r}".format( - i, content - ) - ) + raise ValueError("Expected starting quote for enum value at position {} in {!r}".format(i, content)) quote = content[i] i += 1 # skip the opening quote @@ -135,26 +133,26 @@ def _parse_enum_values(content): while i < len(content): c = content[i] # For single- and double–quotes, process backslash escapes. - if quote in ("'", '"') and c == "\\": + if quote in ("'", '"') and c == '\\': if i + 1 < len(content): next_char = content[i + 1] # Mapping for common escapes. (For the quote character, map it to itself.) escapes = { - "0": "\0", - "b": "\b", - "n": "\n", - "r": "\r", - "t": "\t", - "Z": "\x1a", - "\\": "\\", - quote: quote, + '0': '\0', + 'b': '\b', + 'n': '\n', + 'r': '\r', + 't': '\t', + 'Z': '\x1a', + '\\': '\\', + quote: quote } literal_chars.append(escapes.get(next_char, next_char)) i += 2 continue else: # Trailing backslash – treat it as literal. - literal_chars.append("\\") + literal_chars.append('\\') i += 1 continue elif c == quote: @@ -171,7 +169,7 @@ def _parse_enum_values(content): literal_chars.append(c) i += 1 # Finished reading one literal; join the characters. - value = "".join(literal_chars) + value = ''.join(literal_chars) values.append(value) # Skip whitespace after the literal. @@ -179,19 +177,16 @@ def _parse_enum_values(content): i += 1 # If there’s a comma, skip it; otherwise, we must be at the end. if i < len(content): - if content[i] == ",": + if content[i] == ',': i += 1 else: - raise ValueError( - "Expected comma between enum values at position {} in {!r}".format( - i, content - ) - ) + raise ValueError("Expected comma between enum values at position {} in {!r}" + .format(i, content)) return values # --- For testing purposes --- -if __name__ == "__main__": +if __name__ == '__main__': tests = [ "enum('point','qwe','def')", "ENUM('asd', 'qwe', 'def')", @@ -208,4 +203,4 @@ def _parse_enum_values(content): result = parse_mysql_enum(t) print("Input: {}\nParsed: {}\n".format(t, result)) except Exception as e: - print("Error parsing {}: {}\n".format(t, e)) + print("Error parsing {}: {}\n".format(t, e)) \ No newline at end of file From 727c19f7e4797d86dbad8b5e997f5f9983938e57 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Thu, 6 Mar 2025 13:56:28 -0700 Subject: [PATCH 05/16] Fix --- mysql_ch_replicator/converter_enum_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mysql_ch_replicator/converter_enum_parser.py b/mysql_ch_replicator/converter_enum_parser.py index c8d29df..92192ea 100644 --- a/mysql_ch_replicator/converter_enum_parser.py +++ b/mysql_ch_replicator/converter_enum_parser.py @@ -203,4 +203,4 @@ def _parse_enum_values(content): result = parse_mysql_enum(t) print("Input: {}\nParsed: {}\n".format(t, result)) except Exception as e: - print("Error parsing {}: {}\n".format(t, e)) \ No newline at end of file + print("Error parsing {}: {}\n".format(t, e)) From 017e48e65abae0d0019e1b57a9952404c2505485 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Fri, 7 Mar 2025 10:57:09 -0700 Subject: [PATCH 06/16] Fix bug --- mysql_ch_replicator/converter.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mysql_ch_replicator/converter.py b/mysql_ch_replicator/converter.py index cf9b1f6..53cc1e7 100644 --- a/mysql_ch_replicator/converter.py +++ b/mysql_ch_replicator/converter.py @@ -554,7 +554,7 @@ def _tokenize_alter_query(cls, sql_line): # The first token is always the column name. column_name = tokens[0] - # Now “merge” tokens after the column name that belong to the type. + # Now "merge" tokens after the column name that belong to the type. # (For many types the type is written as a single token already – # e.g. "VARCHAR(254)" or "NUMERIC(5, 2)", but for types like # "DOUBLE PRECISION" or "INT UNSIGNED" the .split() would produce two tokens.) @@ -872,7 +872,7 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None field_type = definition[0] field_parameters = ( ' '.join(definition[1:]) if len(definition) > 1 else '' - ) + ) else: definition = line.split(' ') field_name = strip_sql_name(definition[0]) @@ -917,10 +917,6 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None field_parameters = ( ' '.join(definition[1:]) if len(definition) > 1 else '' ) - field_type = definition[0] - field_parameters = '' - if len(definition) > 1: - field_parameters = ' '.join(definition[1:]) additional_data = None if 'set(' in field_type.lower(): From 53e4018c6700bace02b9143a94153311c9b956ce Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Mon, 10 Mar 2025 12:42:42 -0600 Subject: [PATCH 07/16] Add in enum translation fix --- mysql_ch_replicator/clickhouse_api.py | 41 ++++++ mysql_ch_replicator/converter.py | 123 +++------------- mysql_ch_replicator/db_replicator.py | 11 +- mysql_ch_replicator/enum/__init__.py | 21 +++ mysql_ch_replicator/enum/converter.py | 72 ++++++++++ mysql_ch_replicator/enum/ddl_parser.py | 134 ++++++++++++++++++ .../parser.py} | 34 ++--- mysql_ch_replicator/enum/utils.py | 99 +++++++++++++ test_mysql_ch_replicator.py | 56 ++++++++ 9 files changed, 463 insertions(+), 128 deletions(-) create mode 100644 mysql_ch_replicator/enum/__init__.py create mode 100644 mysql_ch_replicator/enum/converter.py create mode 100644 mysql_ch_replicator/enum/ddl_parser.py rename mysql_ch_replicator/{converter_enum_parser.py => enum/parser.py} (90%) create mode 100644 mysql_ch_replicator/enum/utils.py diff --git a/mysql_ch_replicator/clickhouse_api.py b/mysql_ch_replicator/clickhouse_api.py index 825986d..2dca0bf 100644 --- a/mysql_ch_replicator/clickhouse_api.py +++ b/mysql_ch_replicator/clickhouse_api.py @@ -117,6 +117,37 @@ def get_databases(self): database_list = [row[0] for row in databases] return database_list + def table_exists(self, table_name): + """Check if a specific table exists in the current database""" + if self.database not in self.get_databases(): + return False + + query = f"EXISTS TABLE `{self.database}`.`{table_name}`" + result = self.client.query(query) + if result.result_rows and result.result_rows[0][0] == 1: + return True + return False + + def validate_database_schema(self, expected_tables): + """Validates that all expected tables exist in the database and returns missing tables""" + if self.database not in self.get_databases(): + logger.warning(f"Database {self.database} does not exist") + return False, expected_tables + + existing_tables = self.get_tables() + missing_tables = [table for table in expected_tables if table not in existing_tables] + + if missing_tables: + # Log with a count for large numbers of missing tables + if len(missing_tables) > 10: + sample_tables = ", ".join(missing_tables[:10]) + logger.warning(f"Missing {len(missing_tables)} tables in {self.database}. First 10: {sample_tables}...") + else: + logger.warning(f"Missing tables in {self.database}: {', '.join(missing_tables)}") + return False, missing_tables + + return True, [] + def execute_command(self, query): for attempt in range(ClickhouseApi.MAX_RETRIES): try: @@ -127,6 +158,16 @@ def execute_command(self, query): if attempt == ClickhouseApi.MAX_RETRIES - 1: raise e time.sleep(ClickhouseApi.RETRY_INTERVAL) + except clickhouse_connect.driver.exceptions.DatabaseError as e: + # Handle TABLE_ALREADY_EXISTS errors + if "Table already exists" in str(e) or "TABLE_ALREADY_EXISTS" in str(e): + logger.warning(f"Table already exists, continuing: {e}") + break + else: + logger.error(f'error executing command {query}: {e}', exc_info=e) + if attempt == ClickhouseApi.MAX_RETRIES - 1: + raise e + time.sleep(ClickhouseApi.RETRY_INTERVAL) def recreate_database(self): self.execute_command(f'DROP DATABASE IF EXISTS `{self.database}`') diff --git a/mysql_ch_replicator/converter.py b/mysql_ch_replicator/converter.py index 53cc1e7..1f17e72 100644 --- a/mysql_ch_replicator/converter.py +++ b/mysql_ch_replicator/converter.py @@ -6,7 +6,11 @@ from pyparsing import Suppress, CaselessKeyword, Word, alphas, alphanums, delimitedList from .table_structure import TableStructure, TableField -from .converter_enum_parser import parse_mysql_enum +from .enum import ( + parse_mysql_enum, EnumConverter, + parse_enum_or_set_field, + extract_enum_or_set_values +) CHARSET_MYSQL_TO_PYTHON = { @@ -282,7 +286,7 @@ def convert_type(self, mysql_type, parameters): enum_values = parse_mysql_enum(mysql_type) ch_enum_values = [] for idx, value_name in enumerate(enum_values): - ch_enum_values.append(f"'{value_name}' = {idx+1}") + ch_enum_values.append(f"'{value_name.lower()}' = {idx+1}") ch_enum_values = ', '.join(ch_enum_values) if len(enum_values) <= 127: # Enum8('red' = 1, 'green' = 2, 'black' = 3) @@ -428,9 +432,15 @@ def convert_record( if mysql_field_type.startswith('point'): clickhouse_field_value = parse_mysql_point(clickhouse_field_value) - if mysql_field_type.startswith('enum(') and isinstance(clickhouse_field_value, int): + if mysql_field_type.startswith('enum('): enum_values = mysql_structure.fields[idx].additional_data - clickhouse_field_value = enum_values[int(clickhouse_field_value)-1] + field_name = mysql_structure.fields[idx].name if idx < len(mysql_structure.fields) else "unknown" + + clickhouse_field_value = EnumConverter.convert_mysql_to_clickhouse_enum( + clickhouse_field_value, + enum_values, + field_name + ) clickhouse_record.append(clickhouse_field_value) return tuple(clickhouse_record) @@ -834,107 +844,16 @@ def parse_mysql_table_structure(self, create_statement, required_table_name=None end_pos = line.find('`', 1) field_name = line[1:end_pos] line = line[end_pos + 1 :].strip() - # Don't split by space for enum and set types that might contain spaces - if line.lower().startswith('enum(') or line.lower().startswith('set('): - # Find the end of the enum/set definition (closing parenthesis) - open_parens = 0 - in_quotes = False - quote_char = None - end_pos = -1 - - for i, char in enumerate(line): - if char in "'\"" and (i == 0 or line[i - 1] != "\\"): - if not in_quotes: - in_quotes = True - quote_char = char - elif char == quote_char: - in_quotes = False - elif char == '(' and not in_quotes: - open_parens += 1 - elif char == ')' and not in_quotes: - open_parens -= 1 - if open_parens == 0: - end_pos = i + 1 - break - - if end_pos > 0: - field_type = line[:end_pos] - field_parameters = line[end_pos:].strip() - else: - # Fallback to original behavior if we can't find the end - definition = line.split(' ') - field_type = definition[0] - field_parameters = ( - ' '.join(definition[1:]) if len(definition) > 1 else '' - ) - else: - definition = line.split(' ') - field_type = definition[0] - field_parameters = ( - ' '.join(definition[1:]) if len(definition) > 1 else '' - ) + # Use our new enum parsing utilities + field_name, field_type, field_parameters = parse_enum_or_set_field(line, field_name, is_backtick_quoted=True) else: definition = line.split(' ') field_name = strip_sql_name(definition[0]) - definition = definition[1:] - if definition and ( - definition[0].lower().startswith('enum(') - or definition[0].lower().startswith('set(') - ): - line = ' '.join(definition) - # Find the end of the enum/set definition (closing parenthesis) - open_parens = 0 - in_quotes = False - quote_char = None - end_pos = -1 - - for i, char in enumerate(line): - if char in "'\"" and (i == 0 or line[i - 1] != "\\"): - if not in_quotes: - in_quotes = True - quote_char = char - elif char == quote_char: - in_quotes = False - elif char == '(' and not in_quotes: - open_parens += 1 - elif char == ')' and not in_quotes: - open_parens -= 1 - if open_parens == 0: - end_pos = i + 1 - break - - if end_pos > 0: - field_type = line[:end_pos] - field_parameters = line[end_pos:].strip() - else: - # Fallback to original behavior - field_type = definition[0] - field_parameters = ( - ' '.join(definition[1:]) if len(definition) > 1 else '' - ) - else: - field_type = definition[0] - field_parameters = ( - ' '.join(definition[1:]) if len(definition) > 1 else '' - ) - - additional_data = None - if 'set(' in field_type.lower(): - vals = field_type[len('set('):] - close_pos = vals.find(')') - vals = vals[:close_pos] - vals = vals.split(',') - def vstrip(e): - if not e: - return e - if e[0] in '"\'': - return e[1:-1] - return e - vals = [vstrip(v) for v in vals] - additional_data = vals - - if field_type.lower().startswith('enum('): - additional_data = parse_mysql_enum(field_type) + # Use our new enum parsing utilities + field_name, field_type, field_parameters = parse_enum_or_set_field(line, field_name, is_backtick_quoted=False) + + # Extract additional data for enum and set types + additional_data = extract_enum_or_set_values(field_type, from_parser_func=parse_mysql_enum) structure.fields.append(TableField( name=field_name, diff --git a/mysql_ch_replicator/db_replicator.py b/mysql_ch_replicator/db_replicator.py index 9a1ac92..ec29430 100644 --- a/mysql_ch_replicator/db_replicator.py +++ b/mysql_ch_replicator/db_replicator.py @@ -182,10 +182,9 @@ def run(self): # ensure target database still exists if self.target_database not in self.clickhouse_api.get_databases(): logger.warning(f'database {self.target_database} missing in CH') - if self.initial_only: - logger.warning('will run replication from scratch') - self.state.remove() - self.state = self.create_state() + logger.warning('will run replication from scratch') + self.state.remove() + self.state = self.create_state() if self.state.status == Status.RUNNING_REALTIME_REPLICATION: self.run_realtime_replication() @@ -227,6 +226,10 @@ def create_initial_structure_table(self, table_name): ) self.validate_mysql_structure(mysql_structure) clickhouse_structure = self.converter.convert_table_structure(mysql_structure) + + # Always set if_not_exists to True to prevent errors when tables already exist + clickhouse_structure.if_not_exists = True + self.state.tables_structure[table_name] = (mysql_structure, clickhouse_structure) indexes = self.config.get_indexes(self.database, table_name) self.clickhouse_api.create_table(clickhouse_structure, additional_indexes=indexes) diff --git a/mysql_ch_replicator/enum/__init__.py b/mysql_ch_replicator/enum/__init__.py new file mode 100644 index 0000000..9c36c98 --- /dev/null +++ b/mysql_ch_replicator/enum/__init__.py @@ -0,0 +1,21 @@ +from .parser import parse_mysql_enum, is_enum_type +from .converter import EnumConverter +from .utils import find_enum_definition_end, extract_field_components +from .ddl_parser import ( + find_enum_or_set_definition_end, + parse_enum_or_set_field, + extract_enum_or_set_values, + strip_value +) + +__all__ = [ + 'parse_mysql_enum', + 'is_enum_type', + 'EnumConverter', + 'find_enum_definition_end', + 'extract_field_components', + 'find_enum_or_set_definition_end', + 'parse_enum_or_set_field', + 'extract_enum_or_set_values', + 'strip_value' +] diff --git a/mysql_ch_replicator/enum/converter.py b/mysql_ch_replicator/enum/converter.py new file mode 100644 index 0000000..51549b7 --- /dev/null +++ b/mysql_ch_replicator/enum/converter.py @@ -0,0 +1,72 @@ +from typing import List, Union, Optional, Any +from logging import getLogger + +# Create a single module-level logger +logger = getLogger(__name__) + +class EnumConverter: + """Class to handle conversion of enum values between MySQL and ClickHouse""" + + @staticmethod + def convert_mysql_to_clickhouse_enum( + value: Any, + enum_values: List[str], + field_name: str = "unknown" + ) -> Optional[Union[str, int]]: + """ + Convert a MySQL enum value to the appropriate ClickHouse representation + + Args: + value: The MySQL enum value (can be int, str, None) + enum_values: List of possible enum string values + field_name: Name of the field (for better error reporting) + + Returns: + The properly converted enum value for ClickHouse + """ + # Handle NULL values + if value is None: + return None + + # Handle integer values (index-based) + if isinstance(value, int): + # Check if the value is 0 + if value == 0: + # Return 0 as-is - let ClickHouse handle it according to the field's nullability + logger.debug(f"ENUM CONVERSION: Found enum index 0 for field '{field_name}'. Keeping as 0.") + return 0 + + # Validate that the enum index is within range + if value < 1 or value > len(enum_values): + # Log the issue + logger.error(f"ENUM CONVERSION: Invalid enum index {value} for field '{field_name}' " + f"with values {enum_values}") + # Return the value unchanged + return value + else: + # Convert to the string representation (lowercase to match our new convention) + return enum_values[int(value)-1].lower() + + # Handle string values + elif isinstance(value, str): + # Validate that the string value exists in enum values + # First check case-sensitive, then case-insensitive + if value in enum_values: + return value.lower() + + # Try case-insensitive match + lowercase_enum_values = [v.lower() for v in enum_values] + if value.lower() in lowercase_enum_values: + return value.lower() + + # Value not found in enum values + logger.error(f"ENUM CONVERSION: Invalid enum value '{value}' not in {enum_values} " + f"for field '{field_name}'") + # Return the value unchanged + return value + + # Handle any other unexpected types + else: + logger.error(f"ENUM CONVERSION: Unexpected type {type(value)} for enum field '{field_name}'") + # Return the value unchanged + return value \ No newline at end of file diff --git a/mysql_ch_replicator/enum/ddl_parser.py b/mysql_ch_replicator/enum/ddl_parser.py new file mode 100644 index 0000000..504efcf --- /dev/null +++ b/mysql_ch_replicator/enum/ddl_parser.py @@ -0,0 +1,134 @@ +from typing import List, Tuple, Optional, Dict, Any + +def find_enum_or_set_definition_end(line: str) -> Tuple[int, str, str]: + """ + Find the end of an enum or set definition in a DDL line + + Args: + line: The DDL line containing an enum or set definition + + Returns: + Tuple containing (end_position, field_type, field_parameters) + """ + open_parens = 0 + in_quotes = False + quote_char = None + end_pos = -1 + + for i, char in enumerate(line): + if char in "'\"" and (i == 0 or line[i - 1] != "\\"): + if not in_quotes: + in_quotes = True + quote_char = char + elif char == quote_char: + in_quotes = False + elif char == '(' and not in_quotes: + open_parens += 1 + elif char == ')' and not in_quotes: + open_parens -= 1 + if open_parens == 0: + end_pos = i + 1 + break + + if end_pos > 0: + field_type = line[:end_pos] + field_parameters = line[end_pos:].strip() + return end_pos, field_type, field_parameters + + # Fallback to splitting by space if we can't find the end + definition = line.split(' ') + field_type = definition[0] + field_parameters = ' '.join(definition[1:]) if len(definition) > 1 else '' + + return -1, field_type, field_parameters + + +def parse_enum_or_set_field(line: str, field_name: str, is_backtick_quoted: bool = False) -> Tuple[str, str, str]: + """ + Parse a field definition line containing an enum or set type + + Args: + line: The line to parse + field_name: The name of the field (already extracted) + is_backtick_quoted: Whether the field name was backtick quoted + + Returns: + Tuple containing (field_name, field_type, field_parameters) + """ + # If the field name was backtick quoted, it's already been extracted + if is_backtick_quoted: + line = line.strip() + # Don't split by space for enum and set types that might contain spaces + if line.lower().startswith('enum(') or line.lower().startswith('set('): + end_pos, field_type, field_parameters = find_enum_or_set_definition_end(line) + else: + definition = line.split(' ') + field_type = definition[0] + field_parameters = ' '.join(definition[1:]) if len(definition) > 1 else '' + else: + # For non-backtick quoted fields + definition = line.split(' ') + definition = definition[1:] # Skip the field name which was already extracted + + if definition and ( + definition[0].lower().startswith('enum(') + or definition[0].lower().startswith('set(') + ): + line = ' '.join(definition) + end_pos, field_type, field_parameters = find_enum_or_set_definition_end(line) + else: + field_type = definition[0] if definition else "" + field_parameters = ' '.join(definition[1:]) if len(definition) > 1 else '' + + return field_name, field_type, field_parameters + + +def extract_enum_or_set_values(field_type: str, from_parser_func=None) -> Optional[List[str]]: + """ + Extract values from an enum or set field type + + Args: + field_type: The field type string (e.g. "enum('a','b','c')") + from_parser_func: Optional function to use for parsing (defaults to simple string parsing) + + Returns: + List of extracted values or None if not an enum/set + """ + if field_type.lower().startswith('enum('): + # Use the provided parser function if available + if from_parser_func: + return from_parser_func(field_type) + + # Simple parsing fallback + vals = field_type[len('enum('):] + close_pos = vals.find(')') + vals = vals[:close_pos] + vals = vals.split(',') + return [strip_value(v) for v in vals] + + elif 'set(' in field_type.lower(): + vals = field_type[field_type.lower().find('set(') + len('set('):] + close_pos = vals.find(')') + vals = vals[:close_pos] + vals = vals.split(',') + return [strip_value(v) for v in vals] + + return None + + +def strip_value(value: str) -> str: + """ + Strip quotes from enum/set values + + Args: + value: The value to strip + + Returns: + Stripped value + """ + value = value.strip() + if not value: + return value + if value[0] in '"\'`': + return value[1:-1] + return value \ No newline at end of file diff --git a/mysql_ch_replicator/converter_enum_parser.py b/mysql_ch_replicator/enum/parser.py similarity index 90% rename from mysql_ch_replicator/converter_enum_parser.py rename to mysql_ch_replicator/enum/parser.py index 92192ea..f7b4b7b 100644 --- a/mysql_ch_replicator/converter_enum_parser.py +++ b/mysql_ch_replicator/enum/parser.py @@ -1,5 +1,3 @@ - - def parse_mysql_enum(enum_definition): """ Accepts a MySQL ENUM definition string (case–insensitive), @@ -175,7 +173,7 @@ def _parse_enum_values(content): # Skip whitespace after the literal. while i < len(content) and content[i].isspace(): i += 1 - # If there’s a comma, skip it; otherwise, we must be at the end. + # If there's a comma, skip it; otherwise, we must be at the end. if i < len(content): if content[i] == ',': i += 1 @@ -185,22 +183,14 @@ def _parse_enum_values(content): return values -# --- For testing purposes --- -if __name__ == '__main__': - tests = [ - "enum('point','qwe','def')", - "ENUM('asd', 'qwe', 'def')", - 'enum("first", \'second\', "Don""t stop")', - "enum('a\\'b','c\\\\d','Hello\\nWorld')", - # Now with backticks: - "enum(`point`,`qwe`,`def`)", - "enum('point',`qwe`,'def')", - "enum(`first`, `Don``t`, `third`)", - ] - - for t in tests: - try: - result = parse_mysql_enum(t) - print("Input: {}\nParsed: {}\n".format(t, result)) - except Exception as e: - print("Error parsing {}: {}\n".format(t, e)) +def is_enum_type(field_type): + """ + Check if a field type is an enum type + + Args: + field_type: The MySQL field type string + + Returns: + bool: True if it's an enum type, False otherwise + """ + return field_type.lower().startswith('enum(') \ No newline at end of file diff --git a/mysql_ch_replicator/enum/utils.py b/mysql_ch_replicator/enum/utils.py new file mode 100644 index 0000000..bfed4f1 --- /dev/null +++ b/mysql_ch_replicator/enum/utils.py @@ -0,0 +1,99 @@ +from typing import List, Optional, Tuple + +def find_enum_definition_end(text: str, start_pos: int) -> int: + """ + Find the end position of an enum definition in a string + + Args: + text: The input text containing the enum definition + start_pos: The starting position (after 'enum(') + + Returns: + int: The position of the closing parenthesis + """ + open_parens = 1 + in_quotes = False + quote_char = None + + for i in range(start_pos, len(text)): + char = text[i] + + # Handle quote state + if not in_quotes and char in ("'", '"', '`'): + in_quotes = True + quote_char = char + continue + elif in_quotes and char == quote_char: + # Check for escaped quotes + if i > 0 and text[i-1] == '\\': + # This is an escaped quote, not the end of the quoted string + continue + # End of quoted string + in_quotes = False + quote_char = None + continue + + # Only process parentheses when not in quotes + if not in_quotes: + if char == '(': + open_parens += 1 + elif char == ')': + open_parens -= 1 + if open_parens == 0: + return i + + # If we get here, the definition is malformed + raise ValueError("Unbalanced parentheses in enum definition") + + +def extract_field_components(line: str) -> Tuple[str, str, List[str]]: + """ + Extract field name, type, and parameters from a MySQL field definition line + + Args: + line: A line from a field definition + + Returns: + Tuple containing field_name, field_type, and parameters + """ + components = line.split(' ') + field_name = components[0].strip('`') + + # Handle special case for enum and set types that might contain spaces + if len(components) > 1 and ( + components[1].lower().startswith('enum(') or + components[1].lower().startswith('set(') + ): + field_type_start = components[1] + field_type_components = [field_type_start] + + # If the enum definition is not complete on this component + if not _is_complete_definition(field_type_start): + # Join subsequent components until we find the end of the definition + for component in components[2:]: + field_type_components.append(component) + if ')' in component: + break + + field_type = ' '.join(field_type_components) + parameters = components[len(field_type_components) + 1:] + else: + field_type = components[1] if len(components) > 1 else "" + parameters = components[2:] if len(components) > 2 else [] + + return field_name, field_type, parameters + + +def _is_complete_definition(text: str) -> bool: + """ + Check if a string contains a complete enum definition (balanced parentheses) + + Args: + text: The string to check + + Returns: + bool: True if the definition is complete + """ + open_count = text.count('(') + close_count = text.count(')') + return open_count > 0 and open_count == close_count \ No newline at end of file diff --git a/test_mysql_ch_replicator.py b/test_mysql_ch_replicator.py index 9b5d30e..9711b57 100644 --- a/test_mysql_ch_replicator.py +++ b/test_mysql_ch_replicator.py @@ -1535,3 +1535,59 @@ def test_alter_tokens_split(): print("Match? ", result == expected) print("-" * 60) assert result == expected + + +def test_enum_conversion(): + """ + Test that enum values are properly converted to lowercase in ClickHouse + and that zero values are preserved rather than converted to first enum value. + """ + config_file = CONFIG_FILE + cfg = config.Settings(config_file) + mysql_config = cfg.mysql + clickhouse_config = cfg.clickhouse + mysql = mysql_api.MySQLApi(mysql_config) + ch = clickhouse_api.ClickhouseApi(clickhouse_config) + + prepare_env(cfg, mysql, ch) + + mysql.execute(f''' + CREATE TABLE `{TEST_TABLE_NAME}` ( + id INT NOT NULL AUTO_INCREMENT, + status_mixed_case ENUM('Purchase','Sell','Transfer') NOT NULL, + status_empty ENUM('Yes','No','Maybe'), + PRIMARY KEY (id) + ) + ''') + + # Insert values with mixed case and NULL/empty values + mysql.execute(f''' + INSERT INTO `{TEST_TABLE_NAME}` (status_mixed_case, status_empty) VALUES + ('Purchase', 'Yes'), + ('Sell', NULL), + ('Transfer', ''); + ''', commit=True) + + run_all_runner = RunAllRunner(cfg_file=config_file) + run_all_runner.run() + + assert_wait(lambda: TEST_DB_NAME in ch.get_databases()) + ch.execute_command(f'USE `{TEST_DB_NAME}`') + assert_wait(lambda: TEST_TABLE_NAME in ch.get_tables()) + assert_wait(lambda: len(ch.select(TEST_TABLE_NAME)) == 3) + + # Get the ClickHouse data + results = ch.select(TEST_TABLE_NAME) + + # Verify all values are properly converted + assert results[0][1] == 'purchase' # First row, status_mixed_case is lowercase 'purchase' + assert results[1][1] == 'sell' # Second row, status_mixed_case is lowercase 'sell' + assert results[2][1] == 'transfer' # Third row, status_mixed_case is lowercase 'transfer' + + # Status_empty should now keep 0s as 0s instead of converting to first enum value + assert results[1][2] is None # NULL should remain NULL + assert results[2][2] == 0 # Empty string should be stored as 0, not converted to 'yes' + + run_all_runner.stop() + assert_wait(lambda: 'stopping db_replicator' in read_logs(TEST_DB_NAME)) + assert('Traceback' not in read_logs(TEST_DB_NAME)) From 2fbec2e4baa23b0944cc8ac3dce8cedf90ca9b1c Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Mon, 10 Mar 2025 12:45:01 -0600 Subject: [PATCH 08/16] [skip ci] --- mysql_ch_replicator/enum/parser.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/mysql_ch_replicator/enum/parser.py b/mysql_ch_replicator/enum/parser.py index f7b4b7b..888f3a9 100644 --- a/mysql_ch_replicator/enum/parser.py +++ b/mysql_ch_replicator/enum/parser.py @@ -193,4 +193,23 @@ def is_enum_type(field_type): Returns: bool: True if it's an enum type, False otherwise """ - return field_type.lower().startswith('enum(') \ No newline at end of file + return field_type.lower().startswith('enum(') + +if __name__ == '__main__': + tests = [ + "enum('point','qwe','def')", + "ENUM('asd', 'qwe', 'def')", + 'enum("first", \'second\', "Don""t stop")', + "enum('a\\'b','c\\\\d','Hello\\nWorld')", + # Now with backticks: + "enum(`point`,`qwe`,`def`)", + "enum('point',`qwe`,'def')", + "enum(`first`, `Don``t`, `third`)", + ] + + for t in tests: + try: + result = parse_mysql_enum(t) + print("Input: {}\nParsed: {}\n".format(t, result)) + except Exception as e: + print("Error parsing {}: {}\n".format(t, e)) \ No newline at end of file From 202c8c01bba40fff88ddbbca8d347638c7e5a574 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Tue, 11 Mar 2025 07:30:11 -0600 Subject: [PATCH 09/16] Check for temp as well --- mysql_ch_replicator/clickhouse_api.py | 11 ----------- mysql_ch_replicator/db_replicator.py | 2 +- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/mysql_ch_replicator/clickhouse_api.py b/mysql_ch_replicator/clickhouse_api.py index 2dca0bf..0e82aa4 100644 --- a/mysql_ch_replicator/clickhouse_api.py +++ b/mysql_ch_replicator/clickhouse_api.py @@ -117,17 +117,6 @@ def get_databases(self): database_list = [row[0] for row in databases] return database_list - def table_exists(self, table_name): - """Check if a specific table exists in the current database""" - if self.database not in self.get_databases(): - return False - - query = f"EXISTS TABLE `{self.database}`.`{table_name}`" - result = self.client.query(query) - if result.result_rows and result.result_rows[0][0] == 1: - return True - return False - def validate_database_schema(self, expected_tables): """Validates that all expected tables exist in the database and returns missing tables""" if self.database not in self.get_databases(): diff --git a/mysql_ch_replicator/db_replicator.py b/mysql_ch_replicator/db_replicator.py index ec29430..87fd94d 100644 --- a/mysql_ch_replicator/db_replicator.py +++ b/mysql_ch_replicator/db_replicator.py @@ -180,7 +180,7 @@ def run(self): if self.state.status != Status.NONE: # ensure target database still exists - if self.target_database not in self.clickhouse_api.get_databases(): + if self.target_database not in self.clickhouse_api.get_databases() and f"{self.target_database}_tmp" not in self.clickhouse_api.get_databases(): logger.warning(f'database {self.target_database} missing in CH') logger.warning('will run replication from scratch') self.state.remove() From f458ffdf00bc27ed4eca17dfb48e9edd47af9a52 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Tue, 11 Mar 2025 07:31:47 -0600 Subject: [PATCH 10/16] Remove unsafe --- mysql_ch_replicator/clickhouse_api.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/mysql_ch_replicator/clickhouse_api.py b/mysql_ch_replicator/clickhouse_api.py index 0e82aa4..ca4670b 100644 --- a/mysql_ch_replicator/clickhouse_api.py +++ b/mysql_ch_replicator/clickhouse_api.py @@ -147,16 +147,6 @@ def execute_command(self, query): if attempt == ClickhouseApi.MAX_RETRIES - 1: raise e time.sleep(ClickhouseApi.RETRY_INTERVAL) - except clickhouse_connect.driver.exceptions.DatabaseError as e: - # Handle TABLE_ALREADY_EXISTS errors - if "Table already exists" in str(e) or "TABLE_ALREADY_EXISTS" in str(e): - logger.warning(f"Table already exists, continuing: {e}") - break - else: - logger.error(f'error executing command {query}: {e}', exc_info=e) - if attempt == ClickhouseApi.MAX_RETRIES - 1: - raise e - time.sleep(ClickhouseApi.RETRY_INTERVAL) def recreate_database(self): self.execute_command(f'DROP DATABASE IF EXISTS `{self.database}`') From e9e9a2d42772ab6518f66b9f6f4873ce5e9e5b5a Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Tue, 11 Mar 2025 07:32:27 -0600 Subject: [PATCH 11/16] Remove unused --- mysql_ch_replicator/clickhouse_api.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/mysql_ch_replicator/clickhouse_api.py b/mysql_ch_replicator/clickhouse_api.py index ca4670b..825986d 100644 --- a/mysql_ch_replicator/clickhouse_api.py +++ b/mysql_ch_replicator/clickhouse_api.py @@ -117,26 +117,6 @@ def get_databases(self): database_list = [row[0] for row in databases] return database_list - def validate_database_schema(self, expected_tables): - """Validates that all expected tables exist in the database and returns missing tables""" - if self.database not in self.get_databases(): - logger.warning(f"Database {self.database} does not exist") - return False, expected_tables - - existing_tables = self.get_tables() - missing_tables = [table for table in expected_tables if table not in existing_tables] - - if missing_tables: - # Log with a count for large numbers of missing tables - if len(missing_tables) > 10: - sample_tables = ", ".join(missing_tables[:10]) - logger.warning(f"Missing {len(missing_tables)} tables in {self.database}. First 10: {sample_tables}...") - else: - logger.warning(f"Missing tables in {self.database}: {', '.join(missing_tables)}") - return False, missing_tables - - return True, [] - def execute_command(self, query): for attempt in range(ClickhouseApi.MAX_RETRIES): try: From d83f2dea86316d0f655ff89decc464de59184cb5 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Tue, 11 Mar 2025 07:48:03 -0600 Subject: [PATCH 12/16] Fix test --- test_mysql_ch_replicator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test_mysql_ch_replicator.py b/test_mysql_ch_replicator.py index 7568c3c..63ed30c 100644 --- a/test_mysql_ch_replicator.py +++ b/test_mysql_ch_replicator.py @@ -1560,7 +1560,8 @@ def test_enum_conversion(): and that zero values are preserved rather than converted to first enum value. """ config_file = CONFIG_FILE - cfg = config.Settings(config_file) + cfg = config.Settings() + cfg.load(config_file) mysql_config = cfg.mysql clickhouse_config = cfg.clickhouse mysql = mysql_api.MySQLApi(mysql_config) From a7d4ea7bd1db758aa5c4a4259ec92d37ad9dc223 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Tue, 11 Mar 2025 07:57:37 -0600 Subject: [PATCH 13/16] Fix test --- test_mysql_ch_replicator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test_mysql_ch_replicator.py b/test_mysql_ch_replicator.py index 63ed30c..c48ff2a 100644 --- a/test_mysql_ch_replicator.py +++ b/test_mysql_ch_replicator.py @@ -1292,6 +1292,9 @@ def test_percona_migration(monkeypatch): mysql.execute( f"DROP TABLE IF EXISTS `{TEST_DB_NAME}`.`_{TEST_TABLE_NAME}_old`;") + # Wait for table to be recreated in ClickHouse after rename + assert_wait(lambda: TEST_TABLE_NAME in ch.get_tables()) + mysql.execute( f"INSERT INTO `{TEST_TABLE_NAME}` (id, c1) VALUES (43, 1)", commit=True, @@ -1604,7 +1607,7 @@ def test_enum_conversion(): # Status_empty should now keep 0s as 0s instead of converting to first enum value assert results[1][2] is None # NULL should remain NULL - assert results[2][2] == 0 # Empty string should be stored as 0, not converted to 'yes' + assert results[2][2] == 0 # Empty string should be stored as 0, not converted to first enum value run_all_runner.stop() assert_wait(lambda: 'stopping db_replicator' in read_logs(TEST_DB_NAME)) From 8ac290040ddb30e728897d59fcee3f80fbdf1788 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Tue, 11 Mar 2025 08:06:52 -0600 Subject: [PATCH 14/16] Fix test --- test_mysql_ch_replicator.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test_mysql_ch_replicator.py b/test_mysql_ch_replicator.py index c48ff2a..653e90f 100644 --- a/test_mysql_ch_replicator.py +++ b/test_mysql_ch_replicator.py @@ -1567,8 +1567,14 @@ def test_enum_conversion(): cfg.load(config_file) mysql_config = cfg.mysql clickhouse_config = cfg.clickhouse - mysql = mysql_api.MySQLApi(mysql_config) - ch = clickhouse_api.ClickhouseApi(clickhouse_config) + mysql = mysql_api.MySQLApi( + database=None, + mysql_settings=mysql_config + ) + ch = clickhouse_api.ClickhouseApi( + database=TEST_DB_NAME, + clickhouse_settings=clickhouse_config + ) prepare_env(cfg, mysql, ch) From b3223cdb90fe2c32bc087942e0c45e7868fddfb3 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Tue, 11 Mar 2025 08:53:03 -0600 Subject: [PATCH 15/16] Fix test --- test_mysql_ch_replicator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test_mysql_ch_replicator.py b/test_mysql_ch_replicator.py index 653e90f..68d965f 100644 --- a/test_mysql_ch_replicator.py +++ b/test_mysql_ch_replicator.py @@ -1587,12 +1587,12 @@ def test_enum_conversion(): ) ''') - # Insert values with mixed case and NULL/empty values + # Insert values with mixed case and NULL values mysql.execute(f''' INSERT INTO `{TEST_TABLE_NAME}` (status_mixed_case, status_empty) VALUES ('Purchase', 'Yes'), ('Sell', NULL), - ('Transfer', ''); + ('Transfer', NULL); ''', commit=True) run_all_runner = RunAllRunner(cfg_file=config_file) @@ -1611,9 +1611,10 @@ def test_enum_conversion(): assert results[1][1] == 'sell' # Second row, status_mixed_case is lowercase 'sell' assert results[2][1] == 'transfer' # Third row, status_mixed_case is lowercase 'transfer' - # Status_empty should now keep 0s as 0s instead of converting to first enum value - assert results[1][2] is None # NULL should remain NULL - assert results[2][2] == 0 # Empty string should be stored as 0, not converted to first enum value + # Status_empty should handle NULL values correctly + assert results[0][2] == 'yes' # First row has explicit 'Yes' value + assert results[1][2] is None # Second row is NULL + assert results[2][2] is None # Third row is NULL run_all_runner.stop() assert_wait(lambda: 'stopping db_replicator' in read_logs(TEST_DB_NAME)) From 037bdb08a32860346ecf66d8441c6591dfa71388 Mon Sep 17 00:00:00 2001 From: Jared Dobson Date: Tue, 11 Mar 2025 09:20:31 -0600 Subject: [PATCH 16/16] Fix spec --- test_mysql_ch_replicator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test_mysql_ch_replicator.py b/test_mysql_ch_replicator.py index 68d965f..8c35333 100644 --- a/test_mysql_ch_replicator.py +++ b/test_mysql_ch_replicator.py @@ -1607,14 +1607,14 @@ def test_enum_conversion(): results = ch.select(TEST_TABLE_NAME) # Verify all values are properly converted - assert results[0][1] == 'purchase' # First row, status_mixed_case is lowercase 'purchase' - assert results[1][1] == 'sell' # Second row, status_mixed_case is lowercase 'sell' - assert results[2][1] == 'transfer' # Third row, status_mixed_case is lowercase 'transfer' + assert results[0]['status_mixed_case'] == 'purchase' + assert results[1]['status_mixed_case'] == 'sell' + assert results[2]['status_mixed_case'] == 'transfer' # Status_empty should handle NULL values correctly - assert results[0][2] == 'yes' # First row has explicit 'Yes' value - assert results[1][2] is None # Second row is NULL - assert results[2][2] is None # Third row is NULL + assert results[0]['status_empty'] == 'yes' + assert results[1]['status_empty'] is None + assert results[2]['status_empty'] is None run_all_runner.stop() assert_wait(lambda: 'stopping db_replicator' in read_logs(TEST_DB_NAME))