Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mysql_ch_replicator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self):
self.http_port = 0
self.types_mapping = {}
self.target_databases = {}
self.initial_replication_threads = 0

def load(self, settings_file):
data = open(settings_file, 'r').read()
Expand All @@ -143,6 +144,7 @@ def load(self, settings_file):
self.http_host = data.pop('http_host', '')
self.http_port = data.pop('http_port', 0)
self.target_databases = data.pop('target_databases', {})
self.initial_replication_threads = data.pop('initial_replication_threads', 0)

indexes = data.pop('indexes', [])
for index in indexes:
Expand Down Expand Up @@ -202,3 +204,7 @@ def validate(self):
self.validate_log_level()
if not isinstance(self.target_databases, dict):
raise ValueError(f'wrong target databases {self.target_databases}')
if not isinstance(self.initial_replication_threads, int):
raise ValueError(f'initial_replication_threads should be an integer, not {type(self.initial_replication_threads)}')
if self.initial_replication_threads < 0:
raise ValueError(f'initial_replication_threads should be non-negative')
164 changes: 144 additions & 20 deletions mysql_ch_replicator/db_replicator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import json
import os.path
import random
import time
import pickle
import hashlib
from logging import getLogger
from enum import Enum
from dataclasses import dataclass
from collections import defaultdict
import sys
import subprocess
import select

from .config import Settings, MysqlSettings, ClickhouseSettings
from .mysql_api import MySQLApi
Expand Down Expand Up @@ -106,10 +111,15 @@ class DbReplicator:

READ_LOG_INTERVAL = 0.3

def __init__(self, config: Settings, database: str, target_database: str = None, initial_only: bool = False):
def __init__(self, config: Settings, database: str, target_database: str = None, initial_only: bool = False,
worker_id: int = None, total_workers: int = None, table: str = None):
self.config = config
self.database = database

self.worker_id = worker_id
self.total_workers = total_workers
self.settings_file = config.settings_file
self.single_table = table # Store the single table to process

# use same as source database by default
self.target_database = database

Expand All @@ -122,9 +132,42 @@ def __init__(self, config: Settings, database: str, target_database: str = None,
if target_database:
self.target_database = target_database

self.target_database_tmp = self.target_database + '_tmp'
self.initial_only = initial_only

# Handle state file differently for parallel workers
if self.worker_id is not None and self.total_workers is not None:
# For worker processes in parallel mode, use a different state file with a deterministic name
self.is_parallel_worker = True

# Determine table name for the state file
table_identifier = self.single_table if self.single_table else "all_tables"

# Create a hash of the table name to ensure it's filesystem-safe
if self.single_table:
# Use a hex digest of the table name to ensure it's filesystem-safe
table_identifier = hashlib.sha256(self.single_table.encode('utf-8')).hexdigest()[:16]
else:
table_identifier = "all_tables"

# Create a deterministic state file path that includes worker_id, total_workers, and table hash
self.state_path = os.path.join(
self.config.binlog_replicator.data_dir,
self.database,
f'state_worker_{self.worker_id}_of_{self.total_workers}_{table_identifier}.pckl'
)

logger.info(f"Worker {self.worker_id}/{self.total_workers} using state file: {self.state_path}")

if self.single_table:
logger.info(f"Worker {self.worker_id} focusing only on table: {self.single_table}")
else:
self.state_path = os.path.join(self.config.binlog_replicator.data_dir, self.database, 'state.pckl')
self.is_parallel_worker = False

self.target_database_tmp = self.target_database + '_tmp'
if self.is_parallel_worker:
self.target_database_tmp = self.target_database

self.mysql_api = MySQLApi(
database=self.database,
mysql_settings=config.mysql,
Expand All @@ -148,7 +191,7 @@ def __init__(self, config: Settings, database: str, target_database: str = None,
self.start_time = time.time()

def create_state(self):
return State(os.path.join(self.config.binlog_replicator.data_dir, self.database, 'state.pckl'))
return State(self.state_path)

def validate_database_settings(self):
if not self.initial_only:
Expand Down Expand Up @@ -196,7 +239,9 @@ def run(self):

logger.info('recreating database')
self.clickhouse_api.database = self.target_database_tmp
self.clickhouse_api.recreate_database()
if not self.is_parallel_worker:
self.clickhouse_api.recreate_database()

self.state.tables = self.mysql_api.get_tables()
self.state.tables = [
table for table in self.state.tables if self.config.is_table_matches(table)
Expand All @@ -220,6 +265,10 @@ def create_initial_structure(self):
def create_initial_structure_table(self, table_name):
if not self.config.is_table_matches(table_name):
return

if self.single_table and self.single_table != table_name:
return

mysql_create_statement = self.mysql_api.get_table_create_statement(table_name)
mysql_structure = self.converter.parse_mysql_table_structure(
mysql_create_statement, required_table_name=table_name,
Expand All @@ -232,7 +281,9 @@ def create_initial_structure_table(self, table_name):

self.state.tables_structure[table_name] = (mysql_structure, clickhouse_structure)
indexes = self.config.get_indexes(self.database, table_name)
self.clickhouse_api.create_table(clickhouse_structure, additional_indexes=indexes)

if not self.is_parallel_worker:
self.clickhouse_api.create_table(clickhouse_structure, additional_indexes=indexes)

def prevent_binlog_removal(self):
if time.time() - self.last_touch_time < self.BINLOG_TOUCH_INTERVAL:
Expand All @@ -253,22 +304,26 @@ def perform_initial_replication(self):
for table in self.state.tables:
if start_table and table != start_table:
continue
if self.single_table and self.single_table != table:
continue
self.perform_initial_replication_table(table)
start_table = None
logger.info(f'initial replication - swapping database')
if self.target_database in self.clickhouse_api.get_databases():
self.clickhouse_api.execute_command(
f'RENAME DATABASE `{self.target_database}` TO `{self.target_database}_old`',
)
self.clickhouse_api.execute_command(
f'RENAME DATABASE `{self.target_database_tmp}` TO `{self.target_database}`',
)
self.clickhouse_api.drop_database(f'{self.target_database}_old')
else:
self.clickhouse_api.execute_command(
f'RENAME DATABASE `{self.target_database_tmp}` TO `{self.target_database}`',
)
self.clickhouse_api.database = self.target_database

if not self.is_parallel_worker:
logger.info(f'initial replication - swapping database')
if self.target_database in self.clickhouse_api.get_databases():
self.clickhouse_api.execute_command(
f'RENAME DATABASE `{self.target_database}` TO `{self.target_database}_old`',
)
self.clickhouse_api.execute_command(
f'RENAME DATABASE `{self.target_database_tmp}` TO `{self.target_database}`',
)
self.clickhouse_api.drop_database(f'{self.target_database}_old')
else:
self.clickhouse_api.execute_command(
f'RENAME DATABASE `{self.target_database_tmp}` TO `{self.target_database}`',
)
self.clickhouse_api.database = self.target_database
logger.info(f'initial replication - done')

def perform_initial_replication_table(self, table_name):
Expand All @@ -278,6 +333,13 @@ def perform_initial_replication_table(self, table_name):
logger.info(f'skip table {table_name} - not matching any allowed table')
return

if not self.is_parallel_worker and self.config.initial_replication_threads > 1:
self.state.initial_replication_table = table_name
self.state.initial_replication_max_primary_key = None
self.state.save()
self.perform_initial_replication_table_parallel(table_name)
return

max_primary_key = None
if self.state.initial_replication_table == table_name:
# continue replication from saved position
Expand Down Expand Up @@ -322,6 +384,8 @@ def perform_initial_replication_table(self, table_name):
order_by=primary_keys,
limit=DbReplicator.INITIAL_REPLICATION_BATCH_SIZE,
start_value=query_start_values,
worker_id=self.worker_id,
total_workers=self.total_workers,
)
logger.debug(f'extracted {len(records)} records from mysql')

Expand Down Expand Up @@ -360,6 +424,66 @@ def perform_initial_replication_table(self, table_name):
f'primary key: {max_primary_key}',
)

def perform_initial_replication_table_parallel(self, table_name):
"""
Execute initial replication for a table using multiple parallel worker processes.
Each worker will handle a portion of the table based on its worker_id and total_workers.
"""
logger.info(f"Starting parallel replication for table {table_name} with {self.config.initial_replication_threads} workers")

# Create and launch worker processes
processes = []
for worker_id in range(self.config.initial_replication_threads):
# Prepare command to launch a worker process
cmd = [
sys.executable, "-m", "mysql_ch_replicator.main",
"db_replicator", # Required positional mode argument
"--config", self.settings_file,
"--db", self.database,
"--worker_id", str(worker_id),
"--total_workers", str(self.config.initial_replication_threads),
"--table", table_name,
"--target_db", self.target_database_tmp,
"--initial_only=True",
]

logger.info(f"Launching worker {worker_id}: {' '.join(cmd)}")
process = subprocess.Popen(cmd)
processes.append(process)

# Wait for all worker processes to complete
logger.info(f"Waiting for {len(processes)} workers to complete replication of {table_name}")

try:
while processes:
for i, process in enumerate(processes[:]):
# Check if process is still running
if process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info(f"Worker process {i} completed successfully")
else:
logger.error(f"Worker process {i} failed with exit code {exit_code}")
# Optional: can raise an exception here to abort the entire operation
raise Exception(f"Worker process failed with exit code {exit_code}")

processes.remove(process)

if processes:
# Wait a bit before checking again
time.sleep(0.1)

# Every 30 seconds, log progress
if int(time.time()) % 30 == 0:
logger.info(f"Still waiting for {len(processes)} workers to complete")
except KeyboardInterrupt:
logger.warning("Received interrupt, terminating worker processes")
for process in processes:
process.terminate()
raise

logger.info(f"All workers completed replication of table {table_name}")

def run_realtime_replication(self):
if self.initial_only:
logger.info('skip running realtime replication, only initial replication was requested')
Expand Down
29 changes: 28 additions & 1 deletion mysql_ch_replicator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,28 @@ def run_db_replicator(args, config: Settings):
'db_replicator.log',
)

set_logging_config(f'dbrepl {args.db}', log_file=log_file, log_level_str=config.log_level)
# Set log tag according to whether this is a worker or main process
if args.worker_id is not None:
if args.table:
log_tag = f'dbrepl {db_name} worker_{args.worker_id} table_{args.table}'
else:
log_tag = f'dbrepl {db_name} worker_{args.worker_id}'
else:
log_tag = f'dbrepl {db_name}'

set_logging_config(log_tag, log_file=log_file, log_level_str=config.log_level)

if args.table:
logging.info(f"Processing specific table: {args.table}")

db_replicator = DbReplicator(
config=config,
database=db_name,
target_database=getattr(args, 'target_db', None),
initial_only=args.initial_only,
worker_id=args.worker_id,
total_workers=args.total_workers,
table=args.table,
)
db_replicator.run()

Expand Down Expand Up @@ -142,6 +157,18 @@ def main():
"--initial_only", type=bool, default=False,
help="don't run realtime replication, run initial replication only",
)
parser.add_argument(
"--worker_id", type=int, default=None,
help="Worker ID for parallel initial replication (0-based)",
)
parser.add_argument(
"--total_workers", type=int, default=None,
help="Total number of workers for parallel initial replication",
)
parser.add_argument(
"--table", type=str, default=None,
help="Specific table to process (used with --worker_id for parallel processing of a single table)",
)
args = parser.parse_args()

config = Settings()
Expand Down
23 changes: 19 additions & 4 deletions mysql_ch_replicator/mysql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,29 @@ def get_table_create_statement(self, table_name) -> str:
create_statement = res[0][1].strip()
return create_statement

def get_records(self, table_name, order_by, limit, start_value=None):
def get_records(self, table_name, order_by, limit, start_value=None, worker_id=None, total_workers=None):
self.reconnect_if_required()
order_by = ','.join(order_by)
order_by_str = ','.join(order_by)
where = ''
if start_value is not None:
start_value = ','.join(map(str, start_value))
where = f'WHERE ({order_by}) > ({start_value}) '
query = f'SELECT * FROM `{table_name}` {where}ORDER BY {order_by} LIMIT {limit}'
where = f'WHERE ({order_by_str}) > ({start_value}) '

# Add partitioning filter for parallel processing if needed
if worker_id is not None and total_workers is not None and total_workers > 1:
# Use a list comprehension to build the COALESCE expressions with proper quoting
coalesce_expressions = [f"COALESCE({key}, '')" for key in order_by]
concat_keys = f"CONCAT_WS('|', {', '.join(coalesce_expressions)})"
hash_condition = f"CRC32({concat_keys}) % {total_workers} = {worker_id}"
if where:
where += f'AND {hash_condition} '
else:
where = f'WHERE {hash_condition} '

query = f'SELECT * FROM `{table_name}` {where}ORDER BY {order_by_str} LIMIT {limit}'
print("query:", query)

# Execute the actual query
self.cursor.execute(query)
res = self.cursor.fetchall()
records = [x for x in res]
Expand Down
15 changes: 13 additions & 2 deletions mysql_ch_replicator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,19 @@ def __init__(self, config_file):


class DbReplicatorRunner(ProcessRunner):
def __init__(self, db_name, config_file):
super().__init__(f'{sys.argv[0]} --config {config_file} --db {db_name} db_replicator')
def __init__(self, db_name, config_file, worker_id=None, total_workers=None, initial_only=False):
cmd = f'{sys.argv[0]} --config {config_file} --db {db_name} db_replicator'

if worker_id is not None:
cmd += f' --worker_id={worker_id}'

if total_workers is not None:
cmd += f' --total_workers={total_workers}'

if initial_only:
cmd += ' --initial_only=True'

super().__init__(cmd)


class DbOptimizerRunner(ProcessRunner):
Expand Down
Loading