In [1]:
import sys
if "../" not in sys.path: sys.path.insert(0,"../");

# Building word breaking

In [119]:
class token(str):
    def __new__(cls, o, *args, **kwargs):
        _str = str.__new__(cls,  o, *args, **kwargs)
        _str._o = o
        return _str

    def __repr__(self):
        return "t'%s'" % (self._o)
    
class byte(token):
    """Unit representation of an value"""
    def __repr__(self):
        return "b'%s'" % (self._o)

class morph(token):
    """token  object that represent subword"""
    def __repr__(self):
        return "m'%s'" % (self._o)

class word(token):
    """token object that prepresentes a word"""
    def __repr__(self):
        return "w'%s'" % (self._o)

In [120]:
from typing import List, Iterable, Union, Sequence, Tuple
from itertools import chain

from functools import partial
from marynlp import funcutils as f

def add_start_end_mark_on_byte_sequence(byte_seq: Sequence[byte], sw: byte, ew: byte) -> Sequence[byte]:
    """Adding the start and end marker"""
    if len(byte_seq) >= 2:
        return list(chain([sw + byte_seq[0], *byte_seq[1:-1], byte_seq[-1] + ew]))
                           
    return [(sw + byte_seq[0] + ew)]

# function to break the rules
def simple_break_word(w: Union[word, str]) -> Iterable[byte]:
    # convert to word
    return list(word(w))

def get_all_possible_byte_pair_in_sequence(byte_seq: Sequence[byte], sw: byte, ew: byte) -> Iterable[Tuple[byte, byte]]:
    byte_seq = tuple(byte_seq)
    return chain([(sw, byte_seq[0]), *zip(byte_seq[:-1], byte_seq[1:]), (byte_seq[-1], ew)])

def get_possible_byte_pair_sequence(byte_seq: Iterable[byte], sw: byte, ew: byte) -> Iterable[Union[byte, Tuple[byte, byte]]]:
    bps = tuple(get_all_possible_byte_pair_in_sequence(byte_seq, sw, ew))
    f1, f2 = [bps[i] for i in range(0, len(bps), 2)], [bps[i] for i in range(1, len(bps), 2)]
    return tuple(f1), tuple(f2)

START_BYTE_MARKER: byte = "<"
END_BYTE_MARKER: byte = ">"
    
prepare_word = f.apply(f.partial(get_all_possible_byte_pair_in_sequence, sw=START_BYTE_MARKER, ew=END_BYTE_MARKER))(simple_break_word)

In [121]:

out = simple_break_word("anaenda")
bps = get_all_possible_byte_pair_in_sequence(out, sw=START_BYTE_MARKER, ew=END_BYTE_MARKER)
# bps = tuple(bps)
# f1, f2 = tuple([bps[i] for i in range(0, len(bps), 2)]), tuple([bps[i] for i in range(1, len(bps), 2)])

get_possible_byte_pair_sequence(out, sw=START_BYTE_MARKER, ew=END_BYTE_MARKER)

((('<', 'a'), ('n', 'a'), ('e', 'n'), ('d', 'a')),
 (('a', 'n'), ('a', 'e'), ('n', 'd'), ('a', '>')))

In [122]:
from collections import Counter

words = ['kaenda', 'haendi', 'anaenda', 'rudi']

# Byte counter to use
byte_counter = Counter(list())
byte_pair_counter = Counter()

# add all the pairs
for w in words:
    byte_seq = simple_break_word(w)
    all_possible = tuple(get_all_possible_byte_pair_in_sequence(byte_seq, sw=START_BYTE_MARKER, ew=END_BYTE_MARKER))
    byte_pair_counter = byte_pair_counter + Counter(all_possible)


In [134]:
def break_word(wo: word, counter_object: Counter):
    print(wo)
    
break_word("haendi", Counter())

haendi


In [132]:
"""
Get Pairs
"""

def score_function(byte_pair_sequence, byte_pair_counter):
    return sum(map(lambda b: byte_pair_counter[b], byte_pair_sequence))

# @f.apply(token)
def merge_bytes(byte_pair: Sequence[byte]) -> List[token]:
    return "".join(byte_pair)

score_fn = f.partial(score_function, byte_pair_counter=byte_pair_counter)
MERGE_SCORE = 1

# Can start with different one
group_counter = Counter()

for w in words:
    byte_seq = simple_break_word(w)
    all_possible_forms = tuple(get_possible_byte_pair_sequence(byte_seq, sw=START_BYTE_MARKER, ew=END_BYTE_MARKER))
    left, right = all_possible_forms
    
    left_score, right_score = score_fn(left), score_fn(right)
    
    chosen_pair = left
    if left_score > right_score:
        chosen_pair = right

    for cp in chosen_pair:
        b_score = byte_pair_counter[cp]
        if b_score > MERGE_SCORE:
            # merge to new 
            group_counter[merge_bytes(cp)] += 1            
        else:
            for byt in filter(lambda b: b != START_BYTE_MARKER and b != END_BYTE_MARKER, cp):
                group_counter[byt] += 1
    
#     print(chosen_pair)
#     print("-"*50)
group_counter.most_common()

[('a', 4),
 ('en', 3),
 ('da', 2),
 ('di', 2),
 ('k', 1),
 ('h', 1),
 ('n', 1),
 ('r', 1),
 ('u', 1)]