In [1]:
import sys
if "../" not in sys.path: sys.path.insert(0,"../");

## [EXPERIMENTAL] Building Pair Matching for Morphpiece

Core code to build the pairs in the code

### Building reliable byte map

In [2]:
from typing import List, Tuple
from itertools import chain

# for typing
byte = str

START_WORD_MARK = "<"
END_WORD_MARK = ">"

def prepare_word_to_bytes(text: str, sw: byte = START_WORD_MARK, ew: byte = END_WORD_MARK) -> List[byte]:
    """Converts word into a sequence of bytes"""
    text = text.strip() # trim logic must be abstracted outside
    assert len(text) >= 1, "Text length should be more than 1"
    
    if len(text) >= 2:
        return list(chain([sw + text[0], *text[1:-1], text[-1] + ew]))
                           
    return [(sw + text + ew)]

def generate_byte_pair(word_bytes: List[bytes]) -> List[Tuple[byte, byte]]:
    """Generates pairs of words form a word"""
    return zip(word_bytes[:-1], word_bytes[1:])

In [6]:
bs = prepare_word_to_bytes("anaenda")
list(generate_byte_pair(bs))

[('<a', 'n'), ('n', 'a'), ('a', 'e'), ('e', 'n'), ('n', 'd'), ('d', 'a>')]

In [3]:
from collections import Counter, abc
import itertools as it
# from functools import reduce
from typing import List, Tuple, Union, Iterable

from marynlp import funcutils as f


@f.apply(Counter)
def get_byte_level_counter(word: str) -> Counter:
    return prepare_word_to_bytes(word)

@f.apply(Counter)
def get_byte_pair_counter_from_word(word: str, word_to_bytes: abc.Callable) -> Counter:
    fn = f.apply(generate_byte_pair)(word_to_bytes)
    return Counter(fn(word)) + get_byte_level_counter(word)


def get_byte_pair_counter_from_words(words: Iterable[str], word_to_byte_counter: abc.Callable) -> Counter:
    # TODO: Multiple iterations should be done when constructing
    #  the byte pairs
    c = None
#     assert len(words) != 0, "You MUST pass words as iterables"
    for word in words:
        if c is None:
            c = word_to_byte_counter(word)
        else:
            c += word_to_byte_counter(word)
            
    return c


def merge_bytes(wbp: Iterable[Union[byte, Tuple[byte, byte]]]) -> List[Union[byte, str]]:
    return ["".join(i) if isinstance(i, tuple) else i for i in wbp]

# For inference
def bytes_merger(wb: List[byte], byte_pair_counter: Counter) -> List[Union[byte, Tuple[byte, byte]]]:
    var_1 = list(chain([wb[0], *zip(wb[:-1], wb[1:]), wb[-1]]))
    f1 = [var_1[i] for i in range(0, len(var_1), 2)]
    f2 = [var_1[i] for i in range(1, len(var_1), 2)]
    
#     print((f1, f2))
    
    # check the total pairs
    f1_matches = sum(map(lambda b: byte_pair_counter[b], f1))
    f2_matches = sum(map(lambda b: byte_pair_counter[b], f2))
#     print((f1_matches, f2_matches))

    chosen_pair = f1
    if f2_matches > f1_matches:
        chosen_pair = f2
        
    wbp = []
    for p in chosen_pair:
        if byte_pair_counter[p] == 0:
            if isinstance(p, tuple):
                # separate these
                wbp.append(p[0])
                wbp.append(p[1])
            else:
                wbp.append(p)
        else:
            wbp.append(p)
        
    return wbp


In [4]:
import re

def break_word_from_byte(byte_: byte, word: str):
    # temporary
    sw = word
    
    # holding the byte sequence
    flow = []

    it_ = list(re.finditer(byte_, word))
    
    # To check if the word is broken
    broken = False
    
    for i in it_[::-1]:
        # NOTE: might want to deal with this some how,
        # i.e. encode?
        target = 2 # sw[i.start(0):i.end(0)]

        sw, right = sw[:i.start(0)], sw[i.end(0):]
        broken = True

        # update the temp value
        flow.append((target, right))

    return broken, chain.from_iterable([(sw,), *flow[::-1]])


In [5]:
b, d = break_word_from_byte("cha", "chadema"); list(d)

['', 2, 'dema']

In [81]:
from typing import List, Iterable, Tuple
from collections import defaultdict

# When selecting this value, make sure it is not in
#  the range(0, len(tokens)) range of numbers
INDEX_FOR_OOV = -1
OOV_VAL = '<UNK>'

class Tokenizer(object):
    def __init__(self, tokens: List[byte]):
        self.tokens = sorted(tokens, key=lambda c: -len(c))
        self.enc_map = dict(zip(self.tokens, range(0, len(list(self.tokens)))))
        self.dec_map = { v: k for k, v in self.enc_map.items()}
        
    def decode(self, index: int) -> byte:
        try:
            return self.dec_map[index]
        except KeyError as err:
            if index == INDEX_FOR_OOV:
                return OOV_VAL
            raise err

    def encode(self, token: byte):
        """Encode word"""
        try:
            return self.enc_map[token]
        except KeyError:
            return INDEX_FOR_OOV
        
    def tokenize(self, word: str):
        # NOTE: make sure tokens are sorted from widest
        return self.demolish_word(word, self.tokens)
    
    def break_word_from_byte(self, byte_: byte, word: str) -> Tuple[bool, Iterable[Union[str, byte]]]:
        # temporary
        sw = word

        # holding the byte sequence
        flow = []

        # To check if the word is broken
        it_ = list(re.finditer(byte_, word))
        broken = False

        for i in it_[::-1]:
            target = self.encode(sw[i.start(0):i.end(0)])

            sw, right = sw[:i.start(0)], sw[i.end(0):]
            broken = True

            # update the temp value
            flow.append((target, right))

        return broken, chain.from_iterable([(sw,), *flow[::-1]])

    def demolish_word(self, word: str, sorted_subwords: List[byte]):
        word_in_list = [word]
        continue_break = True

        while continue_break:
            for sbw in sorted_subwords:
                item_should_break = False
    #             print("-->", sbw)
                collate = []

    #             print("Word in list:", list(word_in_list))
                for ix, w_tb in enumerate(word_in_list):
    #                 print('\tbreaking:', w_tb)
                    if not isinstance(w_tb, int):      
                        if len(w_tb.strip()) >= 0: 
                            is_broken, out = self.break_word_from_byte(sbw, w_tb)
                            item_should_break = item_should_break or is_broken  # update the broken status
                        else:
                            out = [w_tb]
                    else:
                        out = [w_tb]

                    # re-merge ?
    #                 if not is_broken:
    #                     print('not')
    #                     collate = [word_in_list]
    #                 else:
                    collate.append(list(out))


    #             print("Collate:", collate)
                word_in_list = list(chain.from_iterable(collate))
    #             print("AFTER COLLATE:", word_in_list, end="\n\n")

                # continue only of the list is broken
            continue_break = not item_should_break
            
        return list(filter(lambda s: isinstance(s, int) or len(str(s).strip()) > 0, word_in_list))

In [82]:
characters = 'abcdefghijklmnoprstuvwyz'
char = str

# derive the possible use for the characters
# TODO: Change API for `@f.flowBy`
@f.flowBy(set) 
def possible_character_use(characters: List[char], sc: byte = START_WORD_MARK, ec: byte = END_WORD_MARK):
    for c in characters:
        yield (sc + c, c, c + ec)
    
# chars = set(possible_character_use(characters))
chars = set([*list(characters), 'ch'])


# breaking word
word = "cheza"
sort_bytes = sorted(chars, key=lambda c: -len(c))

# check if 

tkr = WordBreaker(sort_bytes)
tkr.break_word("cheza")


KeyboardInterrupt: 

#### Using the created functions

In [92]:
from functools import partial
from pathlib import Path

data_path = Path("../resources/data")
helsinki_na_path = data_path / Path("./hcs-na-v2")

# File to test out the concept
sample_file = helsinki_na_path / Path("./new-mat/bunge/han1-2004.shu")

# prepared during the 'fitting' phase
# might want to sample and change to get_unique_words

In [152]:
from tqdm import tqdm
from collections import Counter

from marynlp import funcutils as f

def read_file(file):
    with open(file, "r") as f:
        return f.readlines()

def not_have_html(text: str) -> bool:
    return not text.find('<text') > -1 and not text.find('</text>') > -1

@f.flowBy(lambda t: t.split())
@f.forEach(lambda t: t.lower())
@f.filterBy(not_have_html)
def get_all_words_from_file(file_path: str):
    return read_file(file_path)

# building functions
get_unique_words_from_file =  f.apply(set)(get_all_words_from_file)
get_all_words_from_files = f.apply(list)(get_all_words_from_file)

get_bytes_from_word = partial(get_byte_pair_counter_from_word, word_to_bytes=prepare_word_to_bytes)
build_byte_pair_possible_combination_counter = partial(get_byte_pair_counter_from_words, word_to_byte_counter=get_bytes_from_word)

In [94]:
best_byte_pair_list = bytes_merger(prepare_word_to_bytes("anaenda"), byte_pair_counter); best_byte_pair_list

['<a', ('n', 'a'), ('e', 'n'), ('d', 'a>')]

In [95]:
build_byte_pair_possible_combination_counter(['anaenda', 'alienda'])

Counter({('<a', 'n'): 1,
         ('n', 'a'): 1,
         ('a', 'e'): 1,
         ('e', 'n'): 2,
         ('n', 'd'): 2,
         ('d', 'a>'): 2,
         '<a': 2,
         'n': 3,
         'a': 1,
         'e': 2,
         'd': 2,
         'a>': 2,
         ('<a', 'l'): 1,
         ('l', 'i'): 1,
         ('i', 'e'): 1,
         'l': 1,
         'i': 1})

In [181]:
import operator
from functools import reduce
from itertools import chain

## FRESH START
words = ['anaenda', 'alienda']

# Fitting all the possible combinations
all_possible_counter = build_byte_pair_possible_combination_counter(words)

# To store the new combinations
fresh_counter_state_list = []
discorage_counter_state_list = []
fn_word_to_bytes = prepare_word_to_bytes

for word in words:
    w2b = fn_word_to_bytes(word)
    print(w2b)
    best_pair = bytes_merger(w2b, all_possible_counter)
    
    # merge
    mb = merge_bytes(best_pair)
    
    fresh_counter_state_list.append(Counter(mb))
    discorage_counter_state_list.append(Counter(best_pair) + Counter(w2b))
    
print("\n\n")
print("Before update:")
print(all_possible_counter)
print("\n")
    
# Update the byte_pair_combinations list
for word, nc, ob in zip(words, fresh_counter_state_list, discorage_counter_state_list):
    print(word +":", nc, ob)
    all_possible_counter -= ob
    all_possible_counter += nc
# selected_combinations_counter  = reduce(operator.add, fresh_counter_state_list)
# discoraged_combinations_counter  = reduce(operator.add, discorage_counter_state_list)

# # Update the counter with infromation
# all_possible_counter = (all_possible_counter - discoraged_combinations_counter) + selected_combinations_counter

# def get_best_pair_from_word(word: str):
#     b = prepare_word_to_bytes(word)
#     b = bytes_merger(b, all_possible_counter)
#     return (b)
    
print("\n")
# # get_best_pair_from_word(word)
print("After update:")
all_possible_counter
# discoraged_combinations_counter
# w2b = greedy_byte_pair(word)
# bytes_merger(w2b, all_possible_counter)



['<a', 'n', 'a', 'e', 'n', 'd', 'a>']
['<a', 'l', 'i', 'e', 'n', 'd', 'a>']



Before update:
Counter({'n': 3, ('e', 'n'): 2, ('n', 'd'): 2, ('d', 'a>'): 2, '<a': 2, 'e': 2, 'd': 2, 'a>': 2, ('<a', 'n'): 1, ('n', 'a'): 1, ('a', 'e'): 1, 'a': 1, ('<a', 'l'): 1, ('l', 'i'): 1, ('i', 'e'): 1, 'l': 1, 'i': 1})


anaenda: Counter({'<a': 1, 'na': 1, 'en': 1, 'da>': 1}) Counter({'<a': 2, 'n': 2, ('n', 'a'): 1, ('e', 'n'): 1, ('d', 'a>'): 1, 'a': 1, 'e': 1, 'd': 1, 'a>': 1})
alienda: Counter({'<a': 1, 'li': 1, 'en': 1, 'da>': 1}) Counter({'<a': 2, ('l', 'i'): 1, ('e', 'n'): 1, ('d', 'a>'): 1, 'l': 1, 'i': 1, 'e': 1, 'n': 1, 'd': 1, 'a>': 1})


After update:


Counter({('<a', 'n'): 1,
         ('a', 'e'): 1,
         ('n', 'd'): 2,
         ('<a', 'l'): 1,
         ('i', 'e'): 1,
         'na': 1,
         'en': 2,
         'da>': 2,
         '<a': 1,
         'li': 1})

In [182]:

fn_word_to_bytes = f.apply(partial(bytes_merger, byte_pair_counter=all_possible_counter))(prepare_word_to_bytes)

merge_bytes()

['<k', 'e', 'v', 'i', 'n>']

In [179]:
import operator
operator.add

ModuleNotFoundError: No module named 'ooperator'

In [24]:
bytes_merger(prepare_word_to_bytes("ataenda"), byte_pair_counter)

['<a', 't', 'a', ('e', 'n'), ('d', 'a>')]

In [46]:
# words = get_all_words_from_files(sample_file)
words = ['anaenda', 'alienda']

# # contains the proper byte pair list
# # NOTE: should change to pure function
global_byte_counter = Counter()

# Build the 
byte_pair_counter = build_byte_pair_possible_combination_counter(words)

# """Function that computes the most relevant byte pairs in a word from a given byte pair collection"""
greedy_byte_pair = f.apply(partial(bytes_merger, byte_pair_counter=byte_pair_counter))(prepare_word_to_bytes)

test_words = ['alienda', 'ameenda']

# # in each word, outputs a byte_pair_counter
for word in tqdm(test_words):
    # initial word break down
    b = greedy_byte_pair(word)
    
    # word byte counter for the word
    wbc = Counter(b)
    global_byte_counter += wbc
    mb = merge_bytes(b)
    print(word+ ":", b,"->", mb)    
    
    
#     break

100%|██████████| 2/2 [00:00<00:00, 8783.88it/s]

alienda: ['<a', ('l', 'i'), ('e', 'n'), ('d', 'a>')] -> ['<a', 'li', 'en', 'da>']
ameenda: ['<a', 'm', 'e', ('e', 'n'), ('d', 'a>')] -> ['<a', 'm', 'e', 'en', 'da>']





In [47]:
infer_word_to_bytes = f.apply(merge_bytes)(greedy_byte_pair)
infer_word_to_bytes("ameenda")

['<a', 'm', 'e', 'en', 'da>']

In [54]:
# """Function that computes the most relevant byte pairs in a word from a given byte pair collection"""
greedy_byte_pair_ = f.apply(partial(bytes_merger, byte_pair_counter=global_byte_counter))(infer_word_to_bytes)

greedy_byte_pair("ameenda")
# test_words = ['alienda', 'ameenda']

# # # in each word, outputs a byte_pair_counter
# for word in tqdm(test_words):
#     # initial word break down
#     b = greedy_byte_pair_(word)
    
#     # word byte counter for the word
#     wbc = Counter(b)
#     global_byte_counter += wbc
#     mb = merge_bytes(b)
#     print(word+ ":", b,"->", mb)    
    
    
#     break

['<a', 'm', 'e', ('e', 'n'), ('d', 'a>')]

In [44]:
prepare_word_to_bytes("amenda"), merge_bytes(greedy_byte_pair("ameenda"))

(['<a', 'm', 'e', 'n', 'd', 'a>'], ['<a', 'm', 'e', 'en', 'da>'])

In [45]:
# second iteration
obp = f.apply(partial(bytes_merger, byte_pair_counter=global_byte_counter))(infer_word_to_bytes)

for word in tqdm(test_words):
    # initial word break down
    b_ = obp(word)
#     print(b_)
    # word byte counter for the word
#     wbc = Counter(b)
    mb = merge_bytes(b_)
    print(word+ ":", b_,"->", mb)

100%|██████████| 2/2 [00:00<00:00, 9776.93it/s]

alienda: ['<a', 'li', 'en', 'da>'] -> ['<a', 'li', 'en', 'da>']
ameenda: ['<a', 'm', 'e', 'en', 'da>'] -> ['<a', 'm', 'e', 'en', 'da>']





In [77]:

clone_bc = Counter(global_byte_counter)
uc = Counter(generate_byte_pair(wbl))

bcc = clone_bc + uc + Counter([('en', 'da>')])

wbl = infer_word_to_bytes("alienda")
bytes_merger(wbl, bcc), bcc

(['<a', ('li', 'en'), 'da>'], [('<a', 'li'), ('en', 'da>')])


(['<a', ('li', 'en'), 'da>'],
 Counter({'<a': 6,
          ('l', 'i'): 1,
          ('e', 'n'): 2,
          ('d', 'a>'): 2,
          'm': 3,
          'e': 3,
          'li': 2,
          'en': 4,
          'da>': 4,
          ('<a', 'li'): 1,
          ('li', 'en'): 1,
          ('en', 'da>'): 2}))

In [43]:
greedy_byte_pair("hapana")

['<h', ('a', 'p'), ('a', 'n'), 'a>']

In [None]:
# Build the byte pair of the word from 
optimistic_byte_pair = f.apply(partial(bytes_merger, byte_pair_counter=global_byte_counter))(greedy_byte_pair)

In [41]:

break_word = f.apply(convert_to_word)(greedy_byte_pair)
break_word("hapana")

['<h', 'ap', 'an', 'a>']

In [29]:
from functools import partial

text = "miswada"
ls_byte = break_word(text)
list(generate_byte_pair(ls_byte)), greedy_byte_pair(text)
# 
# get_byte_pair_counter_from_word("miswada", break_word)

NameError: name 'break_word' is not defined

In [None]:

break_word = f.apply(convert_to_word)(greedy_byte_pair)

In [75]:
for word in get_words():
    b = greedy_byte_pair(word, global_byte_counter)
    print(b)
    break

[('<m', 'i'), ('s', 'w'), ('a', 'd'), 'a>']


In [46]:
text = "miswada"
greedy_byte_pair(text, global_byte_counter), greedy_byte_pair(text, byte_pair_counter)

(['<m', ('i', 's'), ('w', 'a'), ('d', 'a>')],
 [('<m', 'i'), ('s', 'w'), ('a', 'd'), 'a>'])

In [100]:
global_byte_counter

Counter()

In [None]:
global_byte_counter

In [25]:
get_byte_pair_counter_from_word('alienda')

Counter({('<a', 'l'): 1,
         ('l', 'i'): 1,
         ('i', 'e'): 1,
         ('e', 'n'): 1,
         ('n', 'd'): 1,
         ('d', 'a>'): 1,
         '<a': 1,
         'l': 1,
         'i': 1,
         'e': 1,
         'n': 1,
         'd': 1,
         'a>': 1})

In [23]:
byte_pair_counter

Counter({('<m', 'i'): 348,
         ('i', 's'): 671,
         ('s', 'w'): 29,
         ('w', 'a'): 862,
         ('a', 'd'): 562,
         ('d', 'a>'): 316,
         '<m': 2237,
         'i': 0,
         's': 0,
         'w': 0,
         'a': 0,
         'd': 0,
         'a>': 11796,
         ('<s', 'a'): 383,
         ('a', 'b'): 520,
         ('b', 'a>'): 402,
         '<s': 1058,
         'b': 0,
         ('<k', 'a'): 735,
         ('a', 't'): 1560,
         ('t', 'i>'): 242,
         '<k': 2691,
         't': 0,
         'i>': 4231,
         ('<y', 'a>'): 1068,
         '<y': 1446,
         ('<h', 'i'): 633,
         ('i', 'y'): 258,
         ('y', 'o>'): 342,
         '<h': 1169,
         'y': 0,
         'o>': 1668,
         ('<i', 'm'): 83,
         ('m', 'e'): 390,
         ('e', 'k'): 351,
         ('k', 'w'): 98,
         ('w', 'i'): 104,
         ('s', 'h'): 1071,
         ('h', 'a'): 787,
         ('a', 'p'): 299,
         ('p', 'a'): 292,
         ('t', 'a>'): 255,
       

In [13]:
words = ['anaenda', 'alienda', 'ataenda', 'nimekuja', 'shuka']
for word in words:
    print(break_word(word, byte_pair_counter))
    
    
# second iteration


[('<a', 'n'), ('a', 'e'), ('n', 'd'), 'a>']
[('<a', 'l'), ('i', 'e'), ('n', 'd'), 'a>']
[('<a', 't'), ('a', 'e'), ('n', 'd'), 'a>']
['<n', ('i', 'm'), ('e', 'k'), ('u', 'j'), 'a>']
[('<s', 'h'), ('u', 'k'), 'a>']


In [22]:
c = Counter([('a', 'n'), ('n', 'b'), ('a', 'n')])
c[('a', 'n')]

2

In [86]:
byte_pair_counter[('a', 'n')]

0

In [150]:
byte_pair_counter.most_common()

[('a>', 5),
 ('e', 4),
 ('n', 4),
 (('e', 'n'), 3),
 (('n', 'd'), 3),
 (('d', 'a>'), 3),
 ('<a', 3),
 ('d', 3),
 (('a', 'e'), 2),
 ('a', 2),
 ('i', 2),
 ('u', 2),
 ('k', 2),
 (('<a', 't'), 1),
 (('t', 'a'), 1),
 ('t', 1),
 (('<a', 'l'), 1),
 (('l', 'i'), 1),
 (('i', 'e'), 1),
 ('l', 1),
 (('<a', 'n'), 1),
 (('n', 'a'), 1),
 (('<s', 'h'), 1),
 (('h', 'u'), 1),
 (('u', 'k'), 1),
 (('k', 'a>'), 1),
 ('<s', 1),
 ('h', 1),
 (('<n', 'i'), 1),
 (('i', 'm'), 1),
 (('m', 'e'), 1),
 (('e', 'k'), 1),
 (('k', 'u'), 1),
 (('u', 'j'), 1),
 (('j', 'a>'), 1),
 ('<n', 1),
 ('m', 1),
 ('j', 1)]

In [95]:
import re

list(re.finditer(b're', b'brsom'))

[]