Permalink
Fetching contributors…
Cannot retrieve contributors at this time
327 lines (270 sloc) 10.3 KB
"""Repair data on subscriber.
Walks tables by primary key and searches for missing inserts/updates/deletes.
"""
import sys, os, skytools, subprocess
from londiste.syncer import Syncer
__all__ = ['Repairer']
def unescape(s):
"""Remove copy escapes."""
return skytools.unescape_copy(s)
class Repairer(Syncer):
"""Walks tables in primary key order and checks if data matches."""
cnt_insert = 0
cnt_update = 0
cnt_delete = 0
total_src = 0
total_dst = 0
pkey_list = []
common_fields = []
apply_curs = None
def init_optparse(self, p=None):
"""Initialize cmdline switches."""
p = super(Repairer, self).init_optparse(p)
p.add_option("--apply", action="store_true", help="apply fixes")
return p
def process_sync(self, t1, t2, src_db, dst_db):
"""Actual comparison."""
apply_db = None
if self.options.apply:
apply_db = self.get_database('db', cache='applydb', autocommit=1)
self.apply_curs = apply_db.cursor()
self.apply_curs.execute("set session_replication_role = 'replica'")
src_tbl = t1.dest_table
dst_tbl = t2.dest_table
src_curs = src_db.cursor()
dst_curs = dst_db.cursor()
self.log.info('Checking %s', dst_tbl)
self.common_fields = []
self.fq_common_fields = []
self.pkey_list = []
self.load_common_columns(src_tbl, dst_tbl, src_curs, dst_curs)
dump_src = dst_tbl + ".src"
dump_dst = dst_tbl + ".dst"
dump_src_sorted = dump_src + ".sorted"
dump_dst_sorted = dump_dst + ".sorted"
dst_where = t2.plugin.get_copy_condition(src_curs, dst_curs)
src_where = dst_where
self.log.info("Dumping src table: %s", src_tbl)
self.dump_table(src_tbl, src_curs, dump_src, src_where)
src_db.commit()
self.log.info("Dumping dst table: %s", dst_tbl)
self.dump_table(dst_tbl, dst_curs, dump_dst, dst_where)
dst_db.commit()
self.log.info("Sorting src table: %s", dump_src)
self.do_sort(dump_src, dump_src_sorted)
self.log.info("Sorting dst table: %s", dump_dst)
self.do_sort(dump_dst, dump_dst_sorted)
self.dump_compare(dst_tbl, dump_src_sorted, dump_dst_sorted)
os.unlink(dump_src)
os.unlink(dump_dst)
os.unlink(dump_src_sorted)
os.unlink(dump_dst_sorted)
def do_sort(self, src, dst):
""" Sort contents of src file, write them to dst file. """
p = subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
s_ver = p.communicate()[0]
del p
xenv = os.environ.copy()
xenv['LANG'] = 'C'
xenv['LC_ALL'] = 'C'
cmdline = ['sort', '-T', '.']
if s_ver.find("coreutils") > 0:
cmdline.append('-S')
cmdline.append('30%')
cmdline.append('-o')
cmdline.append(dst)
cmdline.append(src)
p = subprocess.Popen(cmdline, env = xenv)
if p.wait() != 0:
raise Exception('sort failed')
def load_common_columns(self, src_tbl, dst_tbl, src_curs, dst_curs):
"""Get common fields, put pkeys in start."""
self.pkey_list = skytools.get_table_pkeys(src_curs, src_tbl)
dst_pkey = skytools.get_table_pkeys(dst_curs, dst_tbl)
if dst_pkey != self.pkey_list:
self.log.error('pkeys do not match')
sys.exit(1)
src_cols = skytools.get_table_columns(src_curs, src_tbl)
dst_cols = skytools.get_table_columns(dst_curs, dst_tbl)
field_list = []
for f in self.pkey_list:
field_list.append(f)
for f in src_cols:
if f in self.pkey_list:
continue
if f in dst_cols:
field_list.append(f)
self.common_fields = field_list
fqlist = [skytools.quote_ident(col) for col in field_list]
self.fq_common_fields = fqlist
cols = ",".join(fqlist)
self.log.debug("using columns: %s", cols)
def dump_table(self, tbl, curs, fn, whr):
"""Dump table to disk."""
cols = ','.join(self.fq_common_fields)
if len(whr) == 0:
whr = 'true'
q = "copy (SELECT %s FROM %s WHERE %s) to stdout" % (cols, skytools.quote_fqident(tbl), whr)
self.log.debug("Query: %s", q)
f = open(fn, "w", 64*1024)
curs.copy_expert(q, f)
size = f.tell()
f.close()
self.log.info('%s: Got %d bytes', tbl, size)
def get_row(self, ln):
"""Parse a row into dict."""
if not ln:
return None
t = ln[:-1].split('\t')
row = {}
for i in range(len(self.common_fields)):
row[self.common_fields[i]] = t[i]
return row
def dump_compare(self, tbl, src_fn, dst_fn):
""" Compare two table dumps, create sql file to fix target table
or apply changes to target table directly.
"""
self.log.info("Comparing dumps: %s", tbl)
self.cnt_insert = 0
self.cnt_update = 0
self.cnt_delete = 0
self.total_src = 0
self.total_dst = 0
f1 = open(src_fn, "r", 64*1024)
f2 = open(dst_fn, "r", 64*1024)
src_ln = f1.readline()
dst_ln = f2.readline()
if src_ln: self.total_src += 1
if dst_ln: self.total_dst += 1
fix = "fix.%s.sql" % tbl
if os.path.isfile(fix):
os.unlink(fix)
while src_ln or dst_ln:
keep_src = keep_dst = 0
if src_ln != dst_ln:
src_row = self.get_row(src_ln)
dst_row = self.get_row(dst_ln)
diff = self.cmp_keys(src_row, dst_row)
if diff > 0:
# src > dst
self.got_missed_delete(tbl, dst_row)
keep_src = 1
elif diff < 0:
# src < dst
self.got_missed_insert(tbl, src_row)
keep_dst = 1
else:
if self.cmp_data(src_row, dst_row) != 0:
self.got_missed_update(tbl, src_row, dst_row)
if not keep_src:
src_ln = f1.readline()
if src_ln: self.total_src += 1
if not keep_dst:
dst_ln = f2.readline()
if dst_ln: self.total_dst += 1
self.log.info("finished %s: src: %d rows, dst: %d rows,"
" missed: %d inserts, %d updates, %d deletes",
tbl, self.total_src, self.total_dst,
self.cnt_insert, self.cnt_update, self.cnt_delete)
def got_missed_insert(self, tbl, src_row):
"""Create sql for missed insert."""
self.cnt_insert += 1
fld_list = self.common_fields
fq_list = []
val_list = []
for f in fld_list:
fq_list.append(skytools.quote_ident(f))
v = unescape(src_row[f])
val_list.append(skytools.quote_literal(v))
q = "insert into %s (%s) values (%s);" % (
tbl, ", ".join(fq_list), ", ".join(val_list))
self.show_fix(tbl, q, 'insert')
def got_missed_update(self, tbl, src_row, dst_row):
"""Create sql for missed update."""
self.cnt_update += 1
fld_list = self.common_fields
set_list = []
whe_list = []
for f in self.pkey_list:
self.addcmp(whe_list, skytools.quote_ident(f), unescape(src_row[f]))
for f in fld_list:
v1 = src_row[f]
v2 = dst_row[f]
if self.cmp_value(v1, v2) == 0:
continue
self.addeq(set_list, skytools.quote_ident(f), unescape(v1))
self.addcmp(whe_list, skytools.quote_ident(f), unescape(v2))
q = "update only %s set %s where %s;" % (
tbl, ", ".join(set_list), " and ".join(whe_list))
self.show_fix(tbl, q, 'update')
def got_missed_delete(self, tbl, dst_row):
"""Create sql for missed delete."""
self.cnt_delete += 1
whe_list = []
for f in self.pkey_list:
self.addcmp(whe_list, skytools.quote_ident(f), unescape(dst_row[f]))
q = "delete from only %s where %s;" % (skytools.quote_fqident(tbl), " and ".join(whe_list))
self.show_fix(tbl, q, 'delete')
def show_fix(self, tbl, q, desc):
"""Print/write/apply repair sql."""
self.log.debug("missed %s: %s", desc, q)
if self.apply_curs:
self.apply_curs.execute(q)
else:
fn = "fix.%s.sql" % tbl
open(fn, "a").write("%s\n" % q)
def addeq(self, list, f, v):
"""Add quoted SET."""
vq = skytools.quote_literal(v)
s = "%s = %s" % (f, vq)
list.append(s)
def addcmp(self, list, f, v):
"""Add quoted comparison."""
if v is None:
s = "%s is null" % f
else:
vq = skytools.quote_literal(v)
s = "%s = %s" % (f, vq)
list.append(s)
def cmp_data(self, src_row, dst_row):
"""Compare data field-by-field."""
for k in self.common_fields:
v1 = src_row[k]
v2 = dst_row[k]
if self.cmp_value(v1, v2) != 0:
return -1
return 0
def cmp_value(self, v1, v2):
"""Compare single field, tolerates tz vs notz dates."""
if v1 == v2:
return 0
# try to work around tz vs. notz
z1 = len(v1)
z2 = len(v2)
if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+':
v1 = v1[:-3]
if v1 == v2:
return 0
elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+':
v2 = v2[:-3]
if v1 == v2:
return 0
return -1
def cmp_keys(self, src_row, dst_row):
"""Compare primary keys of the rows.
Returns 1 if src > dst, -1 if src < dst and 0 if src == dst"""
# None means table is done. tag it larger than any existing row.
if src_row is None:
if dst_row is None:
return 0
return 1
elif dst_row is None:
return -1
for k in self.pkey_list:
v1 = src_row[k]
v2 = dst_row[k]
if v1 < v2:
return -1
elif v1 > v2:
return 1
return 0