In [1]:
# !wget https://raw.githubusercontent.com/huseinzol05/Malaya-Dataset/master/entities/entities-data-v3.json
# !wget https://raw.githubusercontent.com/huseinzol05/Malaya-Dataset/master/entities/entities-augmentation.json

In [2]:
import json

with open('entities-data-v3.json') as fopen:
    ori = json.load(fopen)

In [3]:
len(ori['text'])

(2923, 58447)

In [4]:
import malaya

embedded_wiki = malaya.word2vec.load_wiki()
word_vector_wiki = malaya.word2vec.word2vec(embedded_wiki['nce_weights'],
                                            embedded_wiki['dictionary'])

W0730 15:15:41.557207 140072063063872 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/malaya/word2vec.py:257: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0730 15:15:41.585768 140072063063872 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/malaya/word2vec.py:268: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.



In [11]:
from unidecode import unidecode
import re
import inspect
import random

rejected = ['spacetoontv','estalak','perlawananby','itzpapalotl',
           'intrabenua', 'normalina','eshkhal','mariang','ingvaeonik','strahl',
           'yeolmae','sheat','deglet','muzheiko','hassanandani']

def _check_digit(string):
    return any(i.isdigit() for i in string)

def simple_textcleaning(string, lowering = True):
    """
    use by topic modelling
    only accept A-Z, a-z
    """
    string = unidecode(string)
    string = re.sub('[^A-Za-z0-9\-\\/ ]+', ' ', string)
    return re.sub(r'[ ]+', ' ', string.lower() if lowering else string).strip()

def make_upper(p, o):
    p_split = p.split()
    o_split = o.split()
    return ' '.join([s.title() if o_split[no][0].isupper() else s for no, s in enumerate(p_split)])

def w2v_augmentation(
    string,
    w2v,
    threshold = 0.5,
    soft = False,
    random_select = True,
    augment_counts = 1,
    top_n = 5,
    cleaning_function = simple_textcleaning,
):
    """
    augmenting a string using word2vec

    Parameters
    ----------
    string: str
    w2v: object
        word2vec interface object.
    threshold: float, optional (default=0.5)
        random selection for a word.
    soft: bool, optional (default=False)
        if True, a word not in the dictionary will be replaced with nearest fuzzywuzzy ratio.
        if False, it will throw an exception if a word not in the dictionary.
    random_select: bool, (default=True)
        if True, a word randomly selected in the pool.
        if False, based on the index
    augment_counts: int, (default=1)
        augmentation count for a string.
    top_n: int, (default=5)
        number of nearest neighbors returned.
    cleaning_function: function, (default=simple_textcleaning)


    Returns
    -------
    result: list
    """
    if not isinstance(string, str):
        raise ValueError('string must be a string')
    if not isinstance(threshold, float):
        raise ValueError('threshold must be a float')
    if not (threshold > 0 and threshold < 1):
        raise ValueError('threshold must be bigger than 0 and less than 1')
    if not isinstance(soft, bool):
        raise ValueError('soft must be a boolean')
    if not hasattr(w2v, 'batch_n_closest'):
        raise ValueError('word2vec must has `batch_n_closest` method')
    if not hasattr(w2v, '_dictionary'):
        raise ValueError('word2vec must has `_dictionary` attribute')
    if not isinstance(random_select, bool):
        raise ValueError('random_select must be a boolean')
    if not isinstance(top_n, int):
        raise ValueError('top_n must be an integer')
    if not isinstance(augment_counts, int):
        raise ValueError('augment_counts must be an integer')
    if not random_select:
        if augment_counts > top_n:
            raise ValueError(
                'if random_select is False, augment_counts need to be less than or equal to top_n'
            )
    original_string = string
    if cleaning_function:
        string = cleaning_function(string)
    string = string.split()
    selected = []
    while not len(selected):
        if soft:
            selected = [
                (no, w)
                for no, w in enumerate(string)
                if random.random() > threshold and not _check_digit(w)
            ]
        else:
            selected = [
                (no, w)
                for no, w in enumerate(string)
                if random.random() > threshold and not _check_digit(w) and w in w2v._dictionary
            ]
    indices, words = [i[0] for i in selected], [i[1] for i in selected]
    batch_parameters = list(
        inspect.signature(w2v.batch_n_closest).parameters.keys()
    )
    if 'soft' in batch_parameters:
        results = w2v.batch_n_closest(words, num_closest = top_n, soft = soft)
    else:
        results = w2v.batch_n_closest(words, num_closest = top_n)
    augmented = []
    for i in range(augment_counts):
        string_ = string[:]
        for no in range(len(results)):
            if random_select:
                index = random.randint(0, len(results[no]) - 1)
            else:
                index = i
            string_[indices[no]] = results[no][index]
        if all([r not in string_] for r in rejected):
            augmented.append(make_upper(' '.join(string_), original_string))
    return augmented

In [12]:
from tqdm import tqdm

batch_size = 20
X, Y = [], []

for i in tqdm(range(0, len(ori['text']), batch_size)):
    index = min(i + batch_size, len(ori['text']))
    s = ' '.join(ori['text'][i: index])
    aug = w2v_augmentation(s, 
                           word_vector_wiki,
                           soft=False,
                           augment_counts=3, random_select = False)
    X.append(aug)
    Y.append(ori['label'][i: index])

100%|██████████| 2923/2923 [1:02:45<00:00,  1.24s/it]


In [21]:
len(X[-5][0].split())

20

In [23]:
len(Y[-5])

20