In [1]:
# !wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip
# !unzip cased_L-12_H-768_A-12.zip
# !wget https://raw.githubusercontent.com/zihangdai/xlnet/master/xlnet.py
# !wget https://raw.githubusercontent.com/zihangdai/xlnet/master/modeling.py
# !wget https://raw.githubusercontent.com/zihangdai/xlnet/master/prepro_utils.py
# !wget https://raw.githubusercontent.com/zihangdai/xlnet/master/model_utils.py

In [2]:
# !pip3 install sentencepiece

In [3]:
import xlnet
import numpy as np
from tqdm import tqdm

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [4]:
import sentencepiece as spm
from prepro_utils import preprocess_text, encode_ids

sp_model = spm.SentencePieceProcessor()
sp_model.Load('xlnet_cased_L-12_H-768_A-12/spiece.model')

def tokenize_fn(text):
    text = preprocess_text(text, lower= False)
    return encode_ids(sp_model, text)

In [5]:
MAX_SEQ_LENGTH = 150

SEG_ID_A   = 0
SEG_ID_B   = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4

special_symbols = {
    "<unk>"  : 0,
    "<s>"    : 1,
    "</s>"   : 2,
    "<cls>"  : 3,
    "<sep>"  : 4,
    "<pad>"  : 5,
    "<mask>" : 6,
    "<eod>"  : 7,
    "<eop>"  : 8,
}

VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]

In [37]:
def get_tokenize(texts):
    input_ids, input_masks, segment_ids = [], [], []

    for text in tqdm(texts):
        tokens_a = tokenize_fn(text)
        if len(tokens_a) > MAX_SEQ_LENGTH - 2:
            tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]

        tokens = []
        segment_id = []
        for token in tokens_a:
            tokens.append(token)
            segment_id.append(SEG_ID_A)
        tokens.append(SEP_ID)
        segment_id.append(SEG_ID_A)
        tokens.append(CLS_ID)
        segment_id.append(SEG_ID_CLS)
        input_id = tokens
        input_mask = [0] * len(input_id)
    
        input_ids.append(input_id)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    
    return input_ids, input_masks, segment_ids

In [7]:
with open('train.en') as fopen:
    train_english = fopen.read().split('\n')[:-1]
    
len(train_english)

133317

In [8]:
with open('tst2012.en') as fopen:
    test_english_2012 = fopen.read().split('\n')[:-1]
    
with open('tst2013.en') as fopen:
    test_english_2013 = fopen.read().split('\n')[:-1]

In [38]:
train_X = get_tokenize(train_english)

100%|██████████| 133317/133317 [00:17<00:00, 7716.30it/s]


In [40]:
test_X = get_tokenize(test_english_2012 + test_english_2013)

100%|██████████| 2821/2821 [00:00<00:00, 7842.78it/s]


In [41]:
import pickle

with open('train-test-xlnet.pkl', 'wb') as fopen:
    pickle.dump({'train_X': train_X,
                 'test_X': test_X}, fopen)