From Subword Neural Machine Translation https://github.com/rsennrich/subword-nmt

In [93]:
import os
import sys
import inspect
import codecs
import re
import copy
import argparse
import warnings
import tempfile
from multiprocessing import Pool, cpu_count
from collections import defaultdict, Counter
from utils import get_vocabulary  , get_pair_statistics, prune_stats, replace_pair, update_pair_statistics, BPE, get_vocab
from tqdm import tqdm


In [48]:
DATASET_PATH = os.environ['DL_DATASET']
DL_PATH = os.environ['DL_PATH']

In [49]:
DATASET_PATH

'/media/mtb/nas/datasets/'

In [50]:
eng_file, french_file = os.path.join(DATASET_PATH, "training-parallel-commoncrawl", "commoncrawl.fr-en.en"), \
    os.path.join(DATASET_PATH,  "training-parallel-commoncrawl", "commoncrawl.fr-en.fr")

In [51]:
eng_file

'/media/mtb/nas/datasets/training-parallel-commoncrawl/commoncrawl.fr-en.en'

In [52]:
infile = codecs.open(eng_file, encoding='utf-8')

In [53]:
infile

<codecs.StreamReaderWriter at 0x7fa0d2375f10>

In [54]:
for k, s in enumerate(infile):
    print(s)
    if k == 1:
        break

* - Main goods are marked with red color.

* - Main servises are marked with red color.



In [55]:
next(infile)

'Services of language translation of the...\n'

In [56]:
outfile= codecs.open(os.path.join(DL_PATH, "bpe", "eng_file.txt"), "w", encoding="utf-8")

In [57]:
with open(infile.name, encoding="utf-8") as f:
    print(os.fstat(f.fileno()).st_size)

434655470


In [58]:
with open(infile.name, encoding="utf8") as f:
    line = f.readline()
    print(line)
    #     

* - Main goods are marked with red color.



In [59]:
with open(infile.name, encoding="utf8") as f:
    f.seek(5)
    line = f.readline()
    print(line)

ain goods are marked with red color.



In [60]:
with open(infile.name, encoding="utf8") as f:
    f.seek(176)
    line = f.readline()
    print(line)

Goods and services advancement through the P.O.Box system is NOT ALLOWED.



In [61]:
with open(infile.name, encoding="utf8") as f:
    f.seek(175)
    line = f.readline()
    assert line == "\n"

In [62]:
with open(infile.name, encoding="utf8") as f:
    f.seek(180)
    line = f.readline()
    print(line)

s and services advancement through the P.O.Box system is NOT ALLOWED.



In [63]:
with open(infile.name, encoding="utf8") as f:
    f.seek(180)
    print(f.tell())
    line = f.readline()
    print(line)

180
s and services advancement through the P.O.Box system is NOT ALLOWED.



In [64]:
with open(infile.name, encoding="utf8") as f:
    size = os.fstat(f.fileno()).st_size
    num_workers = 8
    chunk_size = int(size / num_workers)
    offsets = [0 for _ in range(num_workers + 1)]
    for i in range(1, num_workers):
        f.seek(chunk_size * i)
        pos = f.tell()
        while True:
            try:
                line = f.readline()
                print(line)
                break
            except UnicodeDecodeError:
                pos -= 1
                f.seek(pos)
        print(chunk_size * i, f.tell())
        offsets[i] = f.tell()
        print("Cursor ", f.readline())
        assert 0 <= offsets[i] < 1e20, "Bad new line separator, e.g. '\\r'"


 is located in the Kirchstraße 54, in the south-west district of Bergheim. The Cologne-Bonn airport is only a 73-minute drive away and is 45 km away from the hotel.

54331933 54332099
Cursor  The district town of Bergheim, with its cultural sights, downtown and ideal location, is the ideal base for your stay.The lovingly arranged rooms exude cosiness and lightness even in the absence of summer and sun. Each of our 16 rooms is barrier-free and appointed exclusively.

ges in national central banks' relative income positions, the amount of income to be allocated pursuant to Article 32 shall be reduced by a uniform percentage which shall not exceed 60% in the first financial year after the start of the third stage and which shall decrease by at least 12 percentage points in each subsequent financial year.

108663866 108664208
Cursor  Article 51.1 shall be applicable for not more than five financial years after the start of the third stage.

 the guidelines to use TABS rather than spaces.



In [65]:
offsets

[0,
 54332099,
 108664208,
 162995847,
 217327849,
 271659715,
 325991719,
 380323634,
 0]

In [66]:
with open(infile.name, encoding="utf8") as f:
    for i in range(len(offsets) - 1):
        start = f.seek(offsets[i])
        print(offsets[i], offsets[i+1])
        print("Start ", f.readline())
        end = f.seek(offsets[i+1])
        print("End   ", f.readline())        

0 54332099
Start  * - Main goods are marked with red color.

End    The district town of Bergheim, with its cultural sights, downtown and ideal location, is the ideal base for your stay.The lovingly arranged rooms exude cosiness and lightness even in the absence of summer and sun. Each of our 16 rooms is barrier-free and appointed exclusively.

54332099 108664208
Start  The district town of Bergheim, with its cultural sights, downtown and ideal location, is the ideal base for your stay.The lovingly arranged rooms exude cosiness and lightness even in the absence of summer and sun. Each of our 16 rooms is barrier-free and appointed exclusively.

End    Article 51.1 shall be applicable for not more than five financial years after the start of the third stage.

108664208 162995847
Start  Article 51.1 shall be applicable for not more than five financial years after the start of the third stage.

End    I'm often guilty of this. Remember to run repoman over your ebuilds so it can tell you if

In [67]:
vocab = get_vocabulary(infile, num_workers=10)

0 43465562

43465562 86931099

86931099 130396695

130396695 173862295

173862295 217327849

217327849 260793349

260793349 304258912

304258912 347724605

347724605 391189932

391189932 0

/tmp/tmpes3maj3u
/tmp/tmp91vt48ww
/tmp/tmpojn9n57p
/tmp/tmpcoe1ujpy
/tmp/tmp6negxf2b
/tmp/tmpvbadtsa8
/tmp/tmp41zre10b
/tmp/tmpg76z7juf
/tmp/tmpqvfv9qh4
/tmp/tmp5z1js_xp


In [68]:
len(vocab)

1918160

In [69]:
vocab = dict([(tuple(x[:-1])+(x[-1]+'</w>',) ,y) for (x,y) in vocab.items()])

In [70]:
list(vocab.keys())[:10]

[('*</w>',),
 ('-</w>',),
 ('M', 'a', 'i', 'n</w>'),
 ('g', 'o', 'o', 'd', 's</w>'),
 ('a', 'r', 'e</w>'),
 ('m', 'a', 'r', 'k', 'e', 'd</w>'),
 ('w', 'i', 't', 'h</w>'),
 ('r', 'e', 'd</w>'),
 ('c', 'o', 'l', 'o', 'r', '.</w>'),
 ('s', 'e', 'r', 'v', 'i', 's', 'e', 's</w>')]

In [71]:
vocab[('M', 'a', 'i', 'n</w>')]

1544

In [72]:
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)

In [73]:
sorted_vocab[0]

(('t', 'h', 'e</w>'), 3967456)

In [74]:
stats, indices = get_pair_statistics(sorted_vocab)

In [75]:
stats

defaultdict(int,
            {('t', 'h'): 6345851,
             ('h', 'e</w>'): 4656730,
             ('o', 'f</w>'): 2275816,
             ('a', 'n'): 4732890,
             ('n', 'd</w>'): 2670997,
             ('t', 'o</w>'): 1916214,
             ('i', 'n</w>'): 1604122,
             ('i', 's</w>'): 1427088,
             ('f', 'o'): 1389272,
             ('o', 'r</w>'): 1378733,
             ('w', 'i'): 1182446,
             ('i', 't'): 2695836,
             ('t', 'h</w>'): 820613,
             ('T', 'h'): 859608,
             ('o', 'n</w>'): 1771216,
             ('h', 'a'): 1759247,
             ('a', 't</w>'): 913141,
             ('a', 'r'): 2800787,
             ('r', 'e</w>'): 1344472,
             ('y', 'o'): 848980,
             ('o', 'u</w>'): 506580,
             ('b', 'y</w>'): 385953,
             ('b', 'e</w>'): 386404,
             ('a', 's</w>'): 842021,
             ('f', 'r'): 558271,
             ('r', 'o'): 2209515,
             ('o', 'm</w>'): 469607,
           

In [76]:
indices[('y', 'y')]

defaultdict(int,
            {139265: 1,
             145225: 1,
             148126: 1,
             183363: 1,
             221409: 1,
             229226: 1,
             230071: 1,
             230711: 1,
             265511: 1,
             266599: 3,
             271752: 1,
             289694: 1,
             296273: 1,
             307182: 1,
             334689: 1,
             341728: 3,
             341730: 2,
             361468: 1,
             373300: 1,
             390204: 1,
             402288: 1,
             404983: 1,
             422631: 1,
             427785: 1,
             456637: 1,
             457915: 1,
             502112: 1,
             503339: 3,
             504554: 2,
             504625: 1,
             514965: 1,
             517033: 1,
             518460: 1,
             523285: 1,
             526994: 1,
             531202: 1,
             532445: 1,
             535937: 1,
             536121: 1,
             540597: 1,
             551985: 1,

In [77]:
big_stats = copy.deepcopy(stats)

In [78]:
total_symbols = True

In [79]:
num_symbols = 10000

In [80]:
num_symbols

10000

In [81]:
if total_symbols:
    uniq_char_internal = set()
    uniq_char_final = set()
    for word in vocab:
        for char in word[:-1]:
            uniq_char_internal.add(char)
        uniq_char_final.add(word[-1])
    sys.stderr.write('Number of word-internal characters: {0}\n'.format(len(uniq_char_internal)))
    sys.stderr.write('Number of word-final characters: {0}\n'.format(len(uniq_char_final)))
    sys.stderr.write('Reducing number of merge operations by {0}\n'.format(len(uniq_char_internal) + len(uniq_char_final)))
    num_symbols -= len(uniq_char_internal) + len(uniq_char_final)

Number of word-internal characters: 2091
Number of word-final characters: 1029
Reducing number of merge operations by 3120


In [82]:
num_symbols

6880

In [83]:
uniq_char_final

{'이</w>',
 'ử</w>',
 '風</w>',
 '\x91</w>',
 'Q</w>',
 'Ю</w>',
 '画</w>',
 '𠄲</w>',
 'も</w>',
 'ה</w>',
 'こ</w>',
 'Ì</w>',
 'ب</w>',
 'l</w>',
 'Ê</w>',
 'つ</w>',
 'ờ</w>',
 'Љ</w>',
 '℣</w>',
 'х</w>',
 'ủ</w>',
 'ך</w>',
 'お</w>',
 'ć</w>',
 'Ь</w>',
 ']</w>',
 'á</w>',
 'м</w>',
 '物</w>',
 'া</w>',
 'ר</w>',
 'ν</w>',
 'ľ</w>',
 'ぶ</w>',
 '菜</w>',
 'Ш</w>',
 'プ</w>',
 '優</w>',
 '器</w>',
 'i</w>',
 '5</w>',
 '编</w>',
 '셔</w>',
 'ː</w>',
 'ồ</w>',
 '谷</w>',
 'მ</w>',
 'บ</w>',
 '.</w>',
 'ś</w>',
 '家</w>',
 '게</w>',
 'ф</w>',
 'ể</w>',
 '♫</w>',
 '銘</w>',
 'λ</w>',
 'Т</w>',
 'Ç</w>',
 'ส</w>',
 ',</w>',
 'グ</w>',
 'რ</w>',
 'T</w>',
 'ơ</w>',
 'O</w>',
 'ổ</w>',
 'ه</w>',
 'Þ</w>',
 'ī</w>',
 'げ</w>',
 '¹</w>',
 '鏡</w>',
 'ء</w>',
 'ت</w>',
 '1</w>',
 '장</w>',
 'ㅎ</w>',
 'œ</w>',
 '『</w>',
 '所</w>',
 '雷</w>',
 '飼</w>',
 '|</w>',
 '}</w>',
 'ฯ</w>',
 'י</w>',
 'Ⅱ</w>',
 'Ќ</w>',
 'w</w>',
 'Ο</w>',
 '▪</w>',
 '队</w>',
 'ک</w>',
 'ะ</w>',
 'j</w>',
 'º</w>',
 'ώ</w>',
 'ŭ</w>',
 'น</w>

In [84]:
sorted(uniq_char_internal)

['!',
 '"',
 '#',
 '$',
 '%',
 '&',
 "'",
 '(',
 ')',
 '*',
 '+',
 ',',
 '-',
 '.',
 '/',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 ':',
 ';',
 '<',
 '=',
 '>',
 '?',
 '@',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 '[',
 '\\',
 ']',
 '^',
 '_',
 '`',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z',
 '{',
 '|',
 '}',
 '~',
 '\x7f',
 '\x80',
 '\x81',
 '\x83',
 '\x84',
 '\x86',
 '\x8a',
 '\x8b',
 '\x8c',
 '\x8d',
 '\x8f',
 '\x90',
 '\x91',
 '\x92',
 '\x93',
 '\x94',
 '\x95',
 '\x96',
 '\x97',
 '\x99',
 '\x9b',
 '\x9c',
 '\x9d',
 '\x9e',
 '¡',
 '¢',
 '£',
 '¤',
 '¥',
 '¦',
 '§',
 '¨',
 '©',
 'ª',
 '«',
 '¬',
 '\xad',
 '®',
 '¯',
 '°',
 '±',
 '²',
 '³',
 '´',
 'µ',
 '¶',
 '·',
 '¸',
 '¹',
 'º',
 '»',
 '¼',
 '½',
 '¾',
 '¿',
 'À',
 'Á',
 'Â',
 'Ã',
 'Ä',


In [85]:
threshold = max(stats.values()) / 10

In [86]:
threshold

634585.1

In [87]:
sorted_vocab[:15]

[(('t', 'h', 'e</w>'), 3967456),
 (('o', 'f</w>'), 2265654),
 (('a', 'n', 'd</w>'), 2253612),
 (('t', 'o</w>'), 1799868),
 (('a</w>',), 1350178),
 (('i', 'n</w>'), 1332512),
 (('i', 's</w>'), 860408),
 (('f', 'o', 'r</w>'), 709763),
 (('w', 'i', 't', 'h</w>'), 557776),
 (('T', 'h', 'e</w>'), 546532),
 (('o', 'n</w>'), 496228),
 (('t', 'h', 'a', 't</w>'), 458174),
 (('a', 'r', 'e</w>'), 432169),
 (('y', 'o', 'u</w>'), 430459),
 (('o', 'r</w>'), 397593)]

In [88]:
verbose = True
min_frequency = 2

In [89]:
outfile.write('#version: 0.2\n')
for i in tqdm(range(num_symbols)):
    if stats:
        most_frequent = max(stats, key=lambda x: (stats[x], x))
    # we probably missed the best pair because of pruning; go back to full statistics
    if not stats or (i and stats[most_frequent] < threshold):
        prune_stats(stats, big_stats, threshold)
        stats = copy.deepcopy(big_stats)
        most_frequent = max(stats, key=lambda x: (stats[x], x))
        # threshold is inspired by Zipfian assumption, but should only affect speed
        threshold = stats[most_frequent] * i/(i+10000.0)
        prune_stats(stats, big_stats, threshold)

    if stats[most_frequent] < min_frequency:
        sys.stderr.write('no pair has frequency >= {0}. Stopping\n'.format(min_frequency))
        break

    if verbose:
        sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent]))
    outfile.write('{0} {1}\n'.format(*most_frequent))
    changes = replace_pair(most_frequent, sorted_vocab, indices)
    update_pair_statistics(most_frequent, changes, stats, indices)
    stats[most_frequent] = 0
    if not i % 100:
        prune_stats(stats, big_stats, threshold)
    
    

  0%|          | 0/6880 [00:00<?, ?it/s]pair 0: t h -> th (frequency 6345851)
  0%|          | 1/6880 [00:00<36:14,  3.16it/s]pair 1: i n -> in (frequency 4757165)
  0%|          | 2/6880 [00:02<2:40:53,  1.40s/it]pair 2: a n -> an (frequency 4732890)
  0%|          | 3/6880 [00:03<2:20:05,  1.22s/it]pair 3: th e</w> -> the</w> (frequency 3983179)
pair 4: e r -> er (frequency 3758573)
  0%|          | 5/6880 [00:04<1:47:03,  1.07it/s]pair 5: t i -> ti (frequency 3313240)
  0%|          | 6/6880 [00:06<2:13:28,  1.17s/it]pair 6: r e -> re (frequency 3257206)
  0%|          | 7/6880 [00:07<2:04:05,  1.08s/it]pair 7: o n -> on (frequency 3043644)
  0%|          | 8/6880 [00:08<2:04:01,  1.08s/it]pair 8: e n -> en (frequency 2886404)
  0%|          | 9/6880 [00:10<2:33:54,  1.34s/it]pair 9: a r -> ar (frequency 2612173)
  0%|          | 10/6880 [00:11<2:18:27,  1.21s/it]pair 10: an d</w> -> and</w> (frequency 2357364)
pair 11: o u -> ou (frequency 2279098)
  0%|          | 12/6880 [00:11<1

In [90]:
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, 
              total_symbols=False, num_workers=8):
    
    # outfile.write('#version: 0.2\n')
    pass
    

learn_bpe(infile, outfile, 10000)

In [None]:
# def replace_pair(pair, vocab, indices):
#     """Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'"""
#     first, second = pair
#     pair_str = ''.join(pair)
#     pair_str = pair_str.replace('\\','\\\\')
#     changes = []
#     pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
#     if sys.version_info < (3, 0):
#         iterator = indices[pair].iteritems()
#     else:
#         iterator = indices[pair].items()
#     for j, freq in iterator:
#         if freq < 1:
#             continue
#         word, freq = vocab[j]
#         new_word = ' '.join(word)
#         new_word = pattern.sub(pair_str, new_word)
#         new_word = tuple(new_word.split(' '))

#         vocab[j] = (new_word, freq)
#         changes.append((j, new_word, word, freq))

#     return changes


In [None]:
# replace_pair(('t', 'h'), copy.deepcopy(sorted_vocab), indices)


In [None]:
# "o\op".replace('\\','\\\\')

In [None]:
first, second = 'AB', 'C'

In [None]:
pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')

In [None]:
pattern

In [None]:
re.match(pattern, 'Po AB C o i m')

In [None]:
pattern.sub('ABC', 'Po sAB C o i m')

In [None]:
pattern.sub('ABC', 'Po AB C o i m')

In [None]:
sorted_vocab[:10]

In [None]:
stats[('t', 'h')]

In [None]:
stats[('t', 'h')]

In [None]:
sorted_vocab[107]

Apply bpe

In [None]:
codes = codecs.open(outfile.name, encoding='utf-8')

In [None]:
merges = -1
separator = "@@"
vocabulary = None
glossaries = None

In [None]:
codes.seek(0)
firstline = codes.readline()
bpe_codes = [tuple(item.strip('\r\n ').split(' ')) for (n, item) in enumerate(codes.read().rstrip('\n').split('\n')) if (n < merges or merges == -1)]

codes.close()

In [None]:
bpe_codes[:10]

In [None]:
bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(bpe_codes)))])

In [None]:
bpe_codes[('t', 'h')]

In [None]:
bpe_codes[('a', 'n')]

In [None]:
bpe_codes[('i', 'n')]

In [None]:
bpe_codes_reverse = dict([(pair[0] + pair[1], pair) for pair,i in bpe_codes.items()])

In [None]:
bpe_codes_reverse

In [None]:
infile

In [None]:
ref_outfile = codecs.open(os.path.join(DL_PATH, "bpe", "eng_file.ref.txt"), "w", encoding="utf-8")

In [None]:
codes = codecs.open(outfile.name, encoding='utf-8')

In [None]:
bpe = BPE(codes, merges, separator, vocabulary, glossaries)

In [None]:
bpe

In [None]:
line = open(infile.name, encoding="utf-8").readline()

In [None]:
line

In [None]:
id(bpe_codes_reverse)

In [None]:
bpe.cache

In [None]:
bpe.process_line(line)

In [None]:
# bpe.process_line(line)

In [None]:
line.lstrip('\\n')

In [None]:
bpe_codes_reverse['ain']

In [None]:
bpe_codes_reverse['col']

In [None]:
bpe_codes_reverse['Ma']

In [None]:
bpe_codes_reverse['in']

In [None]:
bpe_codes_reverse['goods</w>']

In [91]:
codes = codecs.open(outfile.name, encoding='utf-8')
merges = -1
separator = "@@"
vocabulary = None
glossaries = None
bpe = BPE(codes, merges, separator, vocabulary, glossaries)

In [92]:
bpe.process_lines(infile.name, ref_outfile, num_workers=8)

/media/mtb/nas/datasets/training-parallel-commoncrawl/commoncrawl.fr-en.en <codecs.StreamReaderWriter object at 0x7fa1042e8850>
Proceesing 
Proceesing 
Proceesing 
Proceesing 
Proceesing 
Proceesing 
Proceesing 
Proceesing 


In [None]:
ref_outfile.name

In [None]:
outfile.name

In [94]:
vocab_outfile = codecs.open(os.path.join(DL_PATH, "bpe", "vocab_file.ref.txt"), "w", encoding="utf-8")

In [None]:
get_vocab(codecs.open(ref_outfile.name, encoding="utf-8"), 
            vocab_outfile)

In [1]:
chr(9601)

'▁'