In [1]:
# from transformers import AutoTokenizer, AutoModelForMaskedLM
# tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
# model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")

from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained("bert-base-multilingual-cased")

In [2]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from my_ot_algorithm import *
from other_ot_algorithms import *
from align_utils import *

src = 'de'
tgt = 'en'

num_samples = len(open(f'PMI-Align/data/{src}-{tgt}/text.{src}', 'r').readlines())

results = {}


gold_aligen_list = []

with tqdm(total=num_samples, desc='Processing') as pbar:
# if True:
#     test_line_id = 508
    for test_line_id in range(num_samples):        
        gold_alignment = open(f'PMI-Align/data/{src}-{tgt}/gold.{src}-{tgt}.aligned', 'r').readlines()[test_line_id]
        src_text = open(f'PMI-Align/data/{src}-{tgt}/text.{src}', 'r').readlines()[test_line_id].strip()
        tgt_text = open(f'PMI-Align/data/{src}-{tgt}/text.{tgt}', 'r',).readlines()[test_line_id].strip()
        
        if len(src_text.split()) == 0:
            print('Finished @', test_line_id)
            print('Processed', len(results), 'samples')
            break
        
        try:
            src_emb_ = get_word_embeddings(src_text, model, tokenizer)
            tgt_emb_ = get_word_embeddings(tgt_text, model, tokenizer)
        except Exception as e:
            print('error emb @', test_line_id)
            continue
        
        gold_aligen_list.append(gold_alignment)
        
        """ Config Format
        config: {
          'method': 'emd' | 'sink'| 'p_sink' | 'p_sink_fwd' | 'p_sink_rev' | 'unk_sink_fwd' | 'unk_sink_rev' | 'unk_p_sink_fwd' | 'unk_p_sink_rev',
          'epsilon': 0.1,
          'relax_ratio': 1.5,
          'distance_metric': 'cosine' | 'l2' # default: cosine
          }
        """

        P_dict = {}
        

        # NOTE: emd is for hard full-to-full mapping
        # config_emd = {'method': 'emd'}
        # P_emd, C = get_ot_map(src_emb_, tgt_emb_, config_emd)
        # P_dict['hard_f2f'] = P_emd
        
        # NOTE: sinkhorn is for soft full-to-full mapping
        config_sink = {'method': 'sink', 'epsilon': 0.1}
        P_sink, _ = get_ot_map(src_emb_, tgt_emb_, config_sink)
        P_dict['f2f'] = P_sink
        
        # NOTE: p_sink is for soft partial-to-partial mapping
        config_p_sink = {'method': 'p_sink', 'epsilon': 0.1, 'relax_ratio': 1.5}
        P_p_sink, _ = get_ot_map(src_emb_, tgt_emb_, config_p_sink)
        P_dict['p2p'] = P_p_sink
        
        # NOTE: p_sink_f2p is for full-to-partial mapping
        config_p_sink_f2p = {'method': 'p_sink_f2p', 'epsilon': 0.1, 'relax_ratio': 1.5}
        P_p_sink_f2p, _ = get_ot_map(src_emb_, tgt_emb_, config_p_sink_f2p)
        P_dict['f2p'] = P_p_sink_f2p
        
        # NOTE: p_sink_p2f is for soft partial-to-full mapping
        config_p_sink_p2f = {'method': 'p_sink_p2f', 'epsilon': 0.1, 'relax_ratio': 1.5}
        P_p_sink_p2f, _ = get_ot_map(src_emb_, tgt_emb_, config_p_sink_p2f)
        P_dict['p2f'] = P_p_sink_p2f
        
        # NOTE: p_sink_f2f_fwd is for foward align with full-to-full sinkhorn
        config_p_sink_f2f_fwd = {'method': 'p_sink_f2f_fwd', 'epsilon': 0.1}
        P_p_sink_f2f_fwd, _ = get_ot_map(src_emb_, tgt_emb_, config_p_sink_f2f_fwd)
        P_dict['f2f_fwd'] = P_p_sink_f2f_fwd
        
        # NOTE: p_sink_f2f_rev is for reverse align with full-to-full sinkhorn
        config_p_sink_f2f_rev = {'method': 'p_sink_f2f_rev', 'epsilon': 0.1}     
        P_p_sink_f2f_rev, _ = get_ot_map(src_emb_, tgt_emb_, config_p_sink_f2f_rev)
        P_dict['f2f_rev'] = P_p_sink_f2f_rev
        
        # # NOTE: p_sink_p2f_fwd is for forward align with partial-to-full sinkhorn
        # config_p_sink_p2f_fwd = {'method': 'p_sink_p2f_fwd', 'epsilon': 0.1, 'relax_ratio': 1.5}
        # P_p_sink_p2f_fwd, _ = get_ot_map(src_emb_, tgt_emb_, config_p_sink_p2f_fwd)
        # P_dict['p2f_fwd'] = P_p_sink_p2f_fwd
        
        # # NOTE: p_sink_f2p_rev is for reverse align with full-to-partial sinkhorn
        # config_p_sink_f2p_rev = {'method': 'p_sink_f2p_rev', 'epsilon': 0.1, 'relax_ratio': 1.5}
        # P_p_sink_f2p_rev, _ = get_ot_map(src_emb_, tgt_emb_, config_p_sink_f2p_rev)
        # P_dict['f2p_rev'] = P_p_sink_f2p_rev
        
        # # # ===== not recommended =====
        # # NOTE: p_sink_f2p_fwd is for forward align with partial-to-full sinkhorn
        # p_sink_f2p_fwd = {'method': 'p_sink_f2p_fwd', 'epsilon': 0.1, 'relax_ratio': 1.5}
        # P_p_sink_f2p_fwd, _ = get_ot_map(src_emb_, tgt_emb_, p_sink_f2p_fwd)
        # P_dict['f2p_fwd'] = P_p_sink_f2p_fwd
        
        # # ===== not recommended =====
        # # NOTE: p_sink_p2f_rev is for reverse align with partial-to-full sinkhorn
        # p_sink_p2f_rev = {'method': 'p_sink_p2f_rev', 'epsilon': 0.1, 'relax_ratio': 1.5}
        # P_p_sink_p2f_rev, _ = get_ot_map(src_emb_, tgt_emb_, p_sink_p2f_rev)
        # P_dict['p2f_rev'] = P_p_sink_p2f_rev
        
        
        fwd_str, rev_str = ot_mbr(P_dict, C)
        
        if 'maj_vote_fwd' not in results:
            results['maj_vote_fwd'] = [fwd_str]
        else:
            results['maj_vote_fwd'].append(fwd_str)
        
        if 'maj_vote_rev' not in results:
            results['maj_vote_rev'] = [rev_str]
        else:
            results['maj_vote_rev'].append(rev_str)        
        
        if 'union' not in results:
            results['union'] = [get_joint_alignments(fwd_str, rev_str, method='union')]
        else:
            results['union'].append(get_joint_alignments(fwd_str, rev_str,  method='union'))
        
        if 'intersection' not in results:
            results['intersection'] = [get_joint_alignments(fwd_str, rev_str, method='intersection')]
        else:
            results['intersection'].append(get_joint_alignments(fwd_str, rev_str))
            
        
        tmp = 0
        # ====== evaluation ======
        eval_result = evaluate_corpus_level(results['maj_vote_rev'], gold_aligen_list)
        f1 = eval_result['F1 Score']
        aer = eval_result['AER']
        precision = eval_result['Precision']
        recall = eval_result['Recall']
        
        avg_f1_str = f'{f1:.4f}'
        avg_aer_str = f'{aer:.4f}'
        avg_pre_str = f'{precision:.4f}'
        avg_rec_str = f'{recall:.4f}'
        
        pbar.set_postfix({'F1': avg_f1_str, 'AER': avg_aer_str, 'Precision': avg_pre_str, 'Recall': avg_rec_str})
        pbar.update(1)

for merge_method, results_list in results.items():
    eval_result = evaluate_corpus_level(results_list, gold_aligen_list)
    f1 = eval_result['F1 Score']
    aer = eval_result['AER']
    precision = eval_result['Precision']
    recall = eval_result['Recall']
    print(f'{merge_method}: F1: {f1:.4f}, AER: {aer:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}') 

        
    

Processing: 100%|██████████| 508/508 [08:32<00:00,  1.01s/it, F1=0.7463, AER=0.2540, Precision=0.7298, Recall=0.7637]  

maj_vote_fwd: F1: 0.7475, AER: 0.2525, Precision: 0.7477, Recall: 0.7473
maj_vote_rev: F1: 0.7463, AER: 0.2540, Precision: 0.7298, Recall: 0.7637
union: F1: 0.7287, AER: 0.2720, Precision: 0.6894, Recall: 0.7728
intersection: F1: 0.7288, AER: 0.2720, Precision: 0.6896, Recall: 0.7728





In [19]:
for merge_method, results_list in results.items():
    eval_result = evaluate_corpus_level(results_list, gold_aligen_list)
    f1 = eval_result['F1 Score']
    aer = eval_result['AER']
    precision = eval_result['Precision']
    recall = eval_result['Recall']
    print(f'{merge_method}: F1: {f1:.4f}, AER: {aer:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}') 


maj_vote_fwd: F1: 0.7475, AER: 0.2525, Precision: 0.7477, Recall: 0.7473
maj_vote_rev: F1: 0.7463, AER: 0.2540, Precision: 0.7298, Recall: 0.7637


[(13, 13), (14, 13), (15, 14), (16, 8), (17, 10), (17, 9), (17, 12), (17, 11), (18, 15)]
[(3, 1), (4, 4), (5, 10), (6, 7), (7, 9), (8, 2), (9, 8), (10, 6), (11, 5), (12, 12), (13, 13), (15, 14), (17, 11), (18, 15)]
F1 Score: 0.34782608695652173
AER: 0.5652173913043479


In [15]:
def map_original_to_sentencepiece(x1, x2):
    mapping = {}
    tokenized_index = 0
    original_index = 0

    # Replace with the actual underscore character used in your SentencePiece tokenization
    sentence_piece_underscore = '▁'

    while original_index < len(x2) and tokenized_index < len(x1):
        original_word = x2[original_index]
        subword_sequence = ''

        indices = []
        while tokenized_index < len(x1) and (subword_sequence != original_word):
            subword = x1[tokenized_index].lstrip(sentence_piece_underscore)
            if subword and (subword_sequence + subword == original_word[:len(subword_sequence) + len(subword)]):
                subword_sequence += subword
                indices.append(tokenized_index)
            tokenized_index += 1

        # Special handling for tokens that are just the SentencePiece underscore or punctuation
        if not indices and x1[tokenized_index - 1].strip(sentence_piece_underscore) == '':
            indices.append(tokenized_index - 1)

        if indices:
            mapping[original_index] = indices

        original_index += 1

    return mapping

# Example usage with SentencePiece tokenized and original sentences
x1_sp = ['▁Wir', '▁glauben', '▁nicht', '▁', ',', '▁daß', '▁wir', '▁nur', '▁Ros', 'inen', '▁heraus', 'pi', 'cken', '▁sollten', '▁', '.']
x2 = ['Wir', 'glauben', 'nicht', ',', 'daß', 'wir', 'nur', 'Rosinen', 'herauspicken', 'sollten', '.']
mapping = map_original_to_sentencepiece(x1_sp, x2)
print(mapping)


{0: [0], 1: [1], 2: [2], 3: [4], 4: [5], 5: [6], 6: [7], 7: [8, 9], 8: [10, 11, 12], 9: [13], 10: [15]}


In [4]:
def parse_alignment_string(alignment_str):
    return set(tuple(map(int, pair.split('-'))) for pair in alignment_str.split())

def grow_diag_final(forward_str, reverse_str, src_len, tgt_len):
    forward = parse_alignment_string(forward_str)
    reverse = parse_alignment_string(reverse_str)

    # Intersection
    alignment = forward & reverse

    # Define neighbors (8 surrounding points)
    neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]

    # Grow step
    added = True
    while added:
        added = False
        new_points = set()
        for s, t in alignment:
            for dx, dy in neighbors:
                ns, nt = s + dx, t + dy
                if 0 <= ns < src_len and 0 <= nt < tgt_len and (ns, nt) not in alignment:
                    if (ns, nt) in forward or (ns, nt) in reverse:
                        new_points.add((ns, nt))
                        added = True
        alignment.update(new_points)

    # Final step
    for s in range(src_len):
        for t in range(tgt_len):
            if all((s, t_) not in alignment for t_ in range(tgt_len)) and (s, t) in forward:
                alignment.add((s, t))
            if all((s_, t) not in alignment for s_ in range(src_len)) and (s, t) in reverse:
                alignment.add((s, t))

    return alignment

# Example usage
forward_str = '0-0 1-2'
reverse_str = '0-0 2-1'
src_len = 3  # Length of source sentence
tgt_len = 3  # Length of target sentence

result = grow_diag_final(forward_str, reverse_str, src_len, tgt_len)
print(result)


{(1, 2), (2, 1), (0, 0)}


Balanced Grow-diag-final output: {(0, 0), (2, 2)}
