In [1]:
from mpi4py import MPI
from collections import OrderedDict
import re
import copy

In [243]:
def encode(text, sorted_tokens, unknown_token_id=0):
    words = text.strip().split()
    words = [word + end_token for word in words]
    words_string = ''.join(words)
    
    def _encode(string, token2id):
        token2id = token2id.copy()
        
        if string == '':
            return []
        if len(token2id) == 0:
            return [unknown_token_id]
        
        token, id = next(iter(token2id.items()))
        token2id.popitem(last=False)
        token_reg = re.escape(token)

        string_tokens = []
        matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
        
        #print(matched_positions)
        if len(matched_positions) == 0:
            return _encode(string, token2id)
        substring_end_positions = [matched_position[0] for matched_position in matched_positions]

        substring_start_position = 0
        for substring_end_position in substring_end_positions:
            substring = string[substring_start_position:substring_end_position]
            ts = _encode(substring, token2id)
            string_tokens += ts
            string_tokens += [id]
            substring_start_position = substring_end_position + len(token)
        remaining_substring = string[substring_start_position:]
        string_tokens += _encode(remaining_substring, token2id)
        
        return string_tokens
    
    return _encode(words_string, sorted_tokens)

In [244]:
def decode(token_ids, id2token, unknown_token_id=0):
    string = ''
    for id in token_ids:
        if id == unknown_token_id:
            string += '<UNK>'
        else:
            token = id2token[id]
            if token.endswith(end_token):
                string += token[:-1] + ' '
            else:
                string += token
    return string.strip()

In [245]:
token2id = OrderedDict({'a': 1, 's': 2, 's@': 3, 'a@': 4, '@': 5})
token2id = OrderedDict(sorted(token2id.items(), reverse=True, key=lambda x: len(x[0])))
id2token = {v: k for (k, v) in token2id.items()}
token2id

OrderedDict([('s@', 3), ('a@', 4), ('a', 1), ('s', 2), ('@', 5)])

In [247]:
s = 'd a asdsa d'

In [248]:
encode(s, token2id)

[0, 5, 4, 1, 2, 0, 2, 4, 0, 5]

In [249]:
decode(encode(s, token2id), id2token)

'<UNK> a as<UNK>sa <UNK>'

In [250]:
%%writefile tokenize.py

from collections import OrderedDict
import re
import copy
import argparse
import pickle


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-file', type=str)
    parser.add_argument('-dict-file', type=str)
    parser.add_argument('-unk-id', type=int, default=0)

    args = parser.parse_args()

    with open(args.dict_file, 'rb') as f:
        token2id = pickle.load(f)
    with open(args.file, 'r') as f:
        words = '\n'.join(f.readlines()).strip().split()

    words_string = ''.join(words)
    unknown_token_id = args.unk_id 

    def encode(string, token2id):
        token2id = token2id.copy()

        if string == '':
            return []
        if len(token2id) == 0:
            return [unknown_token_id]

        token, id = next(iter(token2id.items()))
        token2id.popitem(last=False)
        token_reg = re.escape(token)

        string_tokens = []
        matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
        if len(matched_positions) == 0:
            return encode(string, token2id)
        substring_end_positions = [matched_position[0] for matched_position in matched_positions]

        substring_start_position = 0
        for substring_end_position in substring_end_positions:
            substring = string[substring_start_position:substring_end_position]
            string_tokens += encode(substring, token2id)
            string_tokens += [id]
            substring_start_position = substring_end_position + len(token)
        remaining_substring = string[substring_start_position:]
        string_tokens += encode(remaining_substring, token2id)

        return string_tokens

    ids = encode(words_string, token2id.copy())
    print(ids)

    
if __name__ == '__main__':
    main()

Overwriting tokenize.py


In [253]:
%%writefile tokenize_parallel.py

from mpi4py import MPI
from collections import OrderedDict
import re
import copy
import argparse
import pickle


def main():
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    parser = argparse.ArgumentParser()
    parser.add_argument('-file', type=str)
    parser.add_argument('-dict-file', type=str)
    parser.add_argument('-unk-id', type=int, default=0)

    args = parser.parse_args()

    with open(args.dict_file, 'rb') as f:
        token2id = pickle.load(f)
    with open(args.file, 'r') as f:
        words = '\n'.join(f.readlines()).strip().split()

    words_per_process = len(words) // size

    if rank == size - 1:
        words = words[rank * words_per_process:]
    else:
        words = words[rank * words_per_process:(rank + 1) * words_per_process]

    words_string = ''.join(words)
    unknown_token_id = args.unk_id 

    def encode(string, token2id):
        token2id = token2id.copy()

        if string == '':
            return []
        if len(token2id) == 0:
            return [unknown_token_id]

        token, id = next(iter(token2id.items()))
        token2id.popitem(last=False)
        token_reg = re.escape(token)

        string_tokens = []
        matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
        if len(matched_positions) == 0:
            return encode(string, token2id)
        substring_end_positions = [matched_position[0] for matched_position in matched_positions]

        substring_start_position = 0
        for substring_end_position in substring_end_positions:
            substring = string[substring_start_position:substring_end_position]
            string_tokens += encode(substring, token2id)
            string_tokens += [id]
            substring_start_position = substring_end_position + len(token)
        remaining_substring = string[substring_start_position:]
        string_tokens += encode(remaining_substring, token2id)

        return string_tokens

    ids = encode(words_string, token2id.copy())
    ids = comm.gather(ids, root=0)
    if rank == 0:
        print(ids)

    
if __name__ == '__main__':
    main()

Overwriting tokenize_parallel.py
