Skip to content
Permalink
1 contributor

Users who have contributed to this file

import argparse
import os
from difflib import SequenceMatcher
import Levenshtein
import numpy as np
from tqdm import tqdm
from helpers import write_lines, read_parallel_lines, encode_verb_form, \
apply_reverse_transformation, SEQ_DELIMETERS, START_TOKEN
def perfect_align(t, T, insertions_allowed=0,
cost_function=Levenshtein.distance):
# dp[i, j, k] is a minimal cost of matching first `i` tokens of `t` with
# first `j` tokens of `T`, after making `k` insertions after last match of
# token from `t`. In other words t[:i] aligned with T[:j].
# Initialize with INFINITY (unknown)
shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1)
dp = np.ones(shape, dtype=int) * int(1e9)
come_from = np.ones(shape, dtype=int) * int(1e9)
come_from_ins = np.ones(shape, dtype=int) * int(1e9)
dp[0, 0, 0] = 0 # The only known starting point. Nothing matched to nothing.
for i in range(len(t) + 1): # Go inclusive
for j in range(len(T) + 1): # Go inclusive
for q in range(insertions_allowed + 1): # Go inclusive
if i < len(t):
# Given matched sequence of t[:i] and T[:j], match token
# t[i] with following tokens T[j:k].
for k in range(j, len(T) + 1):
transform = \
apply_transformation(t[i], ' '.join(T[j:k]))
if transform:
cost = 0
else:
cost = cost_function(t[i], ' '.join(T[j:k]))
current = dp[i, j, q] + cost
if dp[i + 1, k, 0] > current:
dp[i + 1, k, 0] = current
come_from[i + 1, k, 0] = j
come_from_ins[i + 1, k, 0] = q
if q < insertions_allowed:
# Given matched sequence of t[:i] and T[:j], create
# insertion with following tokens T[j:k].
for k in range(j, len(T) + 1):
cost = len(' '.join(T[j:k]))
current = dp[i, j, q] + cost
if dp[i, k, q + 1] > current:
dp[i, k, q + 1] = current
come_from[i, k, q + 1] = j
come_from_ins[i, k, q + 1] = q
# Solution is in the dp[len(t), len(T), *]. Backtracking from there.
alignment = []
i = len(t)
j = len(T)
q = dp[i, j, :].argmin()
while i > 0 or q > 0:
is_insert = (come_from_ins[i, j, q] != q) and (q != 0)
j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q]
if not is_insert:
i -= 1
if is_insert:
alignment.append(['INSERT', T[j:k], (i, i)])
else:
alignment.append([f'REPLACE_{t[i]}', T[j:k], (i, i + 1)])
assert j == 0
return dp[len(t), len(T)].min(), list(reversed(alignment))
def _split(token):
if not token:
return []
parts = token.split()
return parts or [token]
def apply_merge_transformation(source_tokens, target_words, shift_idx):
edits = []
if len(source_tokens) > 1 and len(target_words) == 1:
# check merge
transform = check_merge(source_tokens, target_words)
if transform:
for i in range(len(source_tokens) - 1):
edits.append([(shift_idx + i, shift_idx + i + 1), transform])
return edits
if len(source_tokens) == len(target_words) == 2:
# check swap
transform = check_swap(source_tokens, target_words)
if transform:
edits.append([(shift_idx, shift_idx + 1), transform])
return edits
def is_sent_ok(sent, delimeters=SEQ_DELIMETERS):
for del_val in delimeters.values():
if del_val in sent and del_val != delimeters["tokens"]:
return False
return True
def check_casetype(source_token, target_token):
if source_token.lower() != target_token.lower():
return None
if source_token.lower() == target_token:
return "$TRANSFORM_CASE_LOWER"
elif source_token.capitalize() == target_token:
return "$TRANSFORM_CASE_CAPITAL"
elif source_token.upper() == target_token:
return "$TRANSFORM_CASE_UPPER"
elif source_token[1:].capitalize() == target_token[1:] and source_token[0] == target_token[0]:
return "$TRANSFORM_CASE_CAPITAL_1"
elif source_token[:-1].upper() == target_token[:-1] and source_token[-1] == target_token[-1]:
return "$TRANSFORM_CASE_UPPER_-1"
else:
return None
def check_equal(source_token, target_token):
if source_token == target_token:
return "$KEEP"
else:
return None
def check_split(source_token, target_tokens):
if source_token.split("-") == target_tokens:
return "$TRANSFORM_SPLIT_HYPHEN"
else:
return None
def check_merge(source_tokens, target_tokens):
if "".join(source_tokens) == "".join(target_tokens):
return "$MERGE_SPACE"
elif "-".join(source_tokens) == "-".join(target_tokens):
return "$MERGE_HYPHEN"
else:
return None
def check_swap(source_tokens, target_tokens):
if source_tokens == [x for x in reversed(target_tokens)]:
return "$MERGE_SWAP"
else:
return None
def check_plural(source_token, target_token):
if source_token.endswith("s") and source_token[:-1] == target_token:
return "$TRANSFORM_AGREEMENT_SINGULAR"
elif target_token.endswith("s") and source_token == target_token[:-1]:
return "$TRANSFORM_AGREEMENT_PLURAL"
else:
return None
def check_verb(source_token, target_token):
encoding = encode_verb_form(source_token, target_token)
if encoding:
return f"$TRANSFORM_VERB_{encoding}"
else:
return None
def apply_transformation(source_token, target_token):
target_tokens = target_token.split()
if len(target_tokens) > 1:
# check split
transform = check_split(source_token, target_tokens)
if transform:
return transform
checks = [check_equal, check_casetype, check_verb, check_plural]
for check in checks:
transform = check(source_token, target_token)
if transform:
return transform
return None
def align_sequences(source_sent, target_sent):
# check if sent is OK
if not is_sent_ok(source_sent) or not is_sent_ok(target_sent):
return None
source_tokens = source_sent.split()
target_tokens = target_sent.split()
matcher = SequenceMatcher(None, source_tokens, target_tokens)
diffs = list(matcher.get_opcodes())
all_edits = []
for diff in diffs:
tag, i1, i2, j1, j2 = diff
source_part = _split(" ".join(source_tokens[i1:i2]))
target_part = _split(" ".join(target_tokens[j1:j2]))
if tag == 'equal':
continue
elif tag == 'delete':
# delete all words separatly
for j in range(i2 - i1):
edit = [(i1 + j, i1 + j + 1), '$DELETE']
all_edits.append(edit)
elif tag == 'insert':
# append to the previous word
for target_token in target_part:
edit = ((i1 - 1, i1), f"$APPEND_{target_token}")
all_edits.append(edit)
else:
# check merge first of all
edits = apply_merge_transformation(source_part, target_part,
shift_idx=i1)
if edits:
all_edits.extend(edits)
continue
# normalize alignments if need (make them singleton)
_, alignments = perfect_align(source_part, target_part,
insertions_allowed=0)
for alignment in alignments:
new_shift = alignment[2][0]
edits = convert_alignments_into_edits(alignment,
shift_idx=i1 + new_shift)
all_edits.extend(edits)
# get labels
labels = convert_edits_into_labels(source_tokens, all_edits)
# match tags to source tokens
sent_with_tags = add_labels_to_the_tokens(source_tokens, labels)
return sent_with_tags
def convert_edits_into_labels(source_tokens, all_edits):
# make sure that edits are flat
flat_edits = []
for edit in all_edits:
(start, end), edit_operations = edit
if isinstance(edit_operations, list):
for operation in edit_operations:
new_edit = [(start, end), operation]
flat_edits.append(new_edit)
elif isinstance(edit_operations, str):
flat_edits.append(edit)
else:
raise Exception("Unknown operation type")
all_edits = flat_edits[:]
labels = []
total_labels = len(source_tokens) + 1
if not all_edits:
labels = [["$KEEP"] for x in range(total_labels)]
else:
for i in range(total_labels):
edit_operations = [x[1] for x in all_edits if x[0][0] == i - 1
and x[0][1] == i]
if not edit_operations:
labels.append(["$KEEP"])
else:
labels.append(edit_operations)
return labels
def convert_alignments_into_edits(alignment, shift_idx):
edits = []
action, target_tokens, new_idx = alignment
source_token = action.replace("REPLACE_", "")
# check if delete
if not target_tokens:
edit = [(shift_idx, 1 + shift_idx), "$DELETE"]
return [edit]
# check splits
for i in range(1, len(target_tokens)):
target_token = " ".join(target_tokens[:i + 1])
transform = apply_transformation(source_token, target_token)
if transform:
edit = [(shift_idx, shift_idx + 1), transform]
edits.append(edit)
target_tokens = target_tokens[i + 1:]
for target in target_tokens:
edits.append([(shift_idx, shift_idx + 1), f"$APPEND_{target}"])
return edits
transform_costs = []
transforms = []
for target_token in target_tokens:
transform = apply_transformation(source_token, target_token)
if transform:
cost = 0
transforms.append(transform)
else:
cost = Levenshtein.distance(source_token, target_token)
transforms.append(None)
transform_costs.append(cost)
min_cost_idx = transform_costs.index(min(transform_costs))
# append to the previous word
for i in range(0, min_cost_idx):
target = target_tokens[i]
edit = [(shift_idx - 1, shift_idx), f"$APPEND_{target}"]
edits.append(edit)
# replace/transform target word
transform = transforms[min_cost_idx]
target = transform if transform is not None \
else f"$REPLACE_{target_tokens[min_cost_idx]}"
edit = [(shift_idx, 1 + shift_idx), target]
edits.append(edit)
# append to this word
for i in range(min_cost_idx + 1, len(target_tokens)):
target = target_tokens[i]
edit = [(shift_idx, 1 + shift_idx), f"$APPEND_{target}"]
edits.append(edit)
return edits
def add_labels_to_the_tokens(source_tokens, labels, delimeters=SEQ_DELIMETERS):
tokens_with_all_tags = []
source_tokens_with_start = [START_TOKEN] + source_tokens
for token, label_list in zip(source_tokens_with_start, labels):
all_tags = delimeters['operations'].join(label_list)
comb_record = token + delimeters['labels'] + all_tags
tokens_with_all_tags.append(comb_record)
return delimeters['tokens'].join(tokens_with_all_tags)
def convert_data_from_raw_files(source_file, target_file, output_file, chunk_size):
tagged = []
source_data, target_data = read_parallel_lines(source_file, target_file)
print(f"The size of raw dataset is {len(source_data)}")
cnt_total, cnt_all, cnt_tp = 0, 0, 0
for source_sent, target_sent in tqdm(zip(source_data, target_data)):
try:
aligned_sent = align_sequences(source_sent, target_sent)
except Exception:
aligned_sent = align_sequences(source_sent, target_sent)
if source_sent != target_sent:
cnt_tp += 1
alignments = [aligned_sent]
cnt_all += len(alignments)
try:
check_sent = convert_tagged_line(aligned_sent)
except Exception:
# debug mode
aligned_sent = align_sequences(source_sent, target_sent)
check_sent = convert_tagged_line(aligned_sent)
if "".join(check_sent.split()) != "".join(
target_sent.split()):
# do it again for debugging
aligned_sent = align_sequences(source_sent, target_sent)
check_sent = convert_tagged_line(aligned_sent)
print(f"Incorrect pair: \n{target_sent}\n{check_sent}")
continue
if alignments:
cnt_total += len(alignments)
tagged.extend(alignments)
if len(tagged) > chunk_size:
write_lines(output_file, tagged, mode='a')
tagged = []
print(f"Overall extracted {cnt_total}. "
f"Original TP {cnt_tp}."
f" Original TN {cnt_all - cnt_tp}")
if tagged:
write_lines(output_file, tagged, 'a')
def convert_labels_into_edits(labels):
all_edits = []
for i, label_list in enumerate(labels):
if label_list == ["$KEEP"]:
continue
else:
edit = [(i - 1, i), label_list]
all_edits.append(edit)
return all_edits
def get_target_sent_by_levels(source_tokens, labels):
relevant_edits = convert_labels_into_edits(labels)
target_tokens = source_tokens[:]
leveled_target_tokens = {}
if not relevant_edits:
target_sentence = " ".join(target_tokens)
return leveled_target_tokens, target_sentence
max_level = max([len(x[1]) for x in relevant_edits])
for level in range(max_level):
rest_edits = []
shift_idx = 0
for edits in relevant_edits:
(start, end), label_list = edits
label = label_list[0]
target_pos = start + shift_idx
source_token = target_tokens[target_pos] if target_pos >= 0 else START_TOKEN
if label == "$DELETE":
del target_tokens[target_pos]
shift_idx -= 1
elif label.startswith("$APPEND_"):
word = label.replace("$APPEND_", "")
target_tokens[target_pos + 1: target_pos + 1] = [word]
shift_idx += 1
elif label.startswith("$REPLACE_"):
word = label.replace("$REPLACE_", "")
target_tokens[target_pos] = word
elif label.startswith("$TRANSFORM"):
word = apply_reverse_transformation(source_token, label)
if word is None:
word = source_token
target_tokens[target_pos] = word
elif label.startswith("$MERGE_"):
# apply merge only on last stage
if level == (max_level - 1):
target_tokens[target_pos + 1: target_pos + 1] = [label]
shift_idx += 1
else:
rest_edit = [(start + shift_idx, end + shift_idx), [label]]
rest_edits.append(rest_edit)
rest_labels = label_list[1:]
if rest_labels:
rest_edit = [(start + shift_idx, end + shift_idx), rest_labels]
rest_edits.append(rest_edit)
leveled_tokens = target_tokens[:]
# update next step
relevant_edits = rest_edits[:]
if level == (max_level - 1):
leveled_tokens = replace_merge_transforms(leveled_tokens)
leveled_labels = convert_edits_into_labels(leveled_tokens,
relevant_edits)
leveled_target_tokens[level + 1] = {"tokens": leveled_tokens,
"labels": leveled_labels}
target_sentence = " ".join(leveled_target_tokens[max_level]["tokens"])
return leveled_target_tokens, target_sentence
def replace_merge_transforms(tokens):
if all(not x.startswith("$MERGE_") for x in tokens):
return tokens
target_tokens = tokens[:]
allowed_range = (1, len(tokens) - 1)
for i in range(len(tokens)):
target_token = tokens[i]
if target_token.startswith("$MERGE"):
if target_token.startswith("$MERGE_SWAP") and i in allowed_range:
target_tokens[i - 1] = tokens[i + 1]
target_tokens[i + 1] = tokens[i - 1]
target_tokens[i: i + 1] = []
target_line = " ".join(target_tokens)
target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
target_line = target_line.replace(" $MERGE_SPACE ", "")
return target_line.split()
def convert_tagged_line(line, delimeters=SEQ_DELIMETERS):
label_del = delimeters['labels']
source_tokens = [x.split(label_del)[0]
for x in line.split(delimeters['tokens'])][1:]
labels = [x.split(label_del)[1].split(delimeters['operations'])
for x in line.split(delimeters['tokens'])]
assert len(source_tokens) + 1 == len(labels)
levels_dict, target_line = get_target_sent_by_levels(source_tokens, labels)
return target_line
def main(args):
convert_data_from_raw_files(args.source, args.target, args.output_file, args.chunk_size)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--source',
help='Path to the source file',
required=True)
parser.add_argument('-t', '--target',
help='Path to the target file',
required=True)
parser.add_argument('-o', '--output_file',
help='Path to the output file',
required=True)
parser.add_argument('--chunk_size',
type=int,
help='Dump each chunk size.',
default=1000000)
args = parser.parse_args()
main(args)