# pre-training BERT
# 아래 링크의 코드를 그대로 가져옴
- https://colab.research.google.com/drive/1nVn6AFpQSzXBt8_ywfx6XR8ZfQXlKGAz#scrollTo=myjxQe5awo1v

# Install packages

In [None]:
!pip install tensorflow==1.15
!pip install nltk

In [None]:
!pip install sentencepiece
!git clone https://github.com/google-research/bert

# import & set logging

In [1]:
import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm

from glob import glob
from tensorflow.keras.utils import Progbar

sys.path.append("bert")

from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder




In [2]:
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s :  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]

In [3]:
USE_TPU = False

# Download dataset

In [None]:
!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.en.gz -O dataset.txt.gz
!gzip -d dataset.txt.gz
!tail dataset.txt

DEMO_MODE = True #@param {type:"boolean"}

if DEMO_MODE:
  CORPUS_SIZE = 10000
else:
  CORPUS_SIZE = 100000000 #@param {type: "integer"}
  
!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt
!mv subdataset.txt dataset.txt

# preprocess text
Remove punctuation, uppercase letters and non-UTF symbols

In [4]:
regex_tokenizer = nltk.RegexpTokenizer("\w+")

def normalize_text(text):
  # lowercase text
  text = str(text).lower()
  # remove non-UTF
  text = text.encode("utf-8", "ignore").decode()
  # remove punktuation symbols
  text = " ".join(regex_tokenizer.tokenize(text))
  return text

def count_lines(filename):
  count = 0
  with open(filename) as fi:
    for line in fi:
      count += 1
  return count

In [5]:
normalize_text('Thanks to the advance, they have succeeded in getting over their adversaries.')

'thanks to the advance they have succeeded in getting over their adversaries'

In [11]:
!cp ~/dataset.txt ./

In [12]:
RAW_DATA_FPATH = "dataset.txt" #@param {type: "string"}
PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}

# apply normalization to the dataset
# this will take a minute or two

total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)

with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
  with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
    for l in fi:
      fo.write(normalize_text(l)+"\n")
      bar.add(1)



In [13]:
!tail proc_dataset.txt

that s nice
sounds like the left bank s running lean
the service department s over there
can i talk to you
talk to me
yeah sure talk
well thing of it is i came back because
do you know what time it is
it s 5 30
i came back because


# building the vocabulary
BERT는 WordPiece tokenizer를 사용했으나, open source가 아님. 그래서 대신 SentencePiece tokenizer를 unigram mode로 사용하려 한다. 이건 BERT에 바로 적용이 안되고 몇가지 트릭을 사용해야 한다.
SentencePiece는 RAM을 엄청 많이 사용하므로, 바로 돌리면 crash 된다. 따라서 randomly subsample을 돌리기로 한다.
그리고 SentencePiece는 BOS, EOS symbol을 자동으로 더해주기 때문에 이를 막기 위해 저 symbol들의 index를 -1로 둔다.
NUM_PLACEHOLDERS는 fine-tune을 위해 예비로 남겨두는 자리이다.

In [15]:
MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE = 32000 #@param {type:"integer"}
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}
NUM_PLACEHOLDERS = 256 #@param {type:"integer"}

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} --input_sentence_size={} '
               '--shuffle_input_sentence=true ' 
               '--bos_id=-1 --eos_id=-1').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)

spm.SentencePieceTrainer.Train(SPM_COMMAND)

True

In [16]:
testcase = "Colorless geothermal substations are generating furiously"

```
>>> wordpiece.tokenize("Colorless geothermal substations are generating furiously")
 
['color',
 '##less',
 'geo',
 '##thermal',
 'sub',
 '##station',
 '##s',
 'are',
 'generating',
 'furiously']
 ```

위에 보는 대로, wordpiece tokenizer는 subword 들 중에서 중간에 오는 단어들에 '##'을 붙여준다. 

In [19]:
!ls

README.md                      requirements.txt
[1m[36mbert[m[m                           setup_preprocessing_dataset.sh
bert_preprocessing.py          tokenizer.model
dataset.txt                    tokenizer.vocab
preprocess_bert.ipynb          [1m[36mvenv[m[m
proc_dataset.txt


SentencePiece는 두개의 파일을 남긴다.
- tokenizer.model 
- tokenizer.vocab

In [20]:
!head -n 10 tokenizer.vocab

<unk>	0
▁you	-3.2342
▁i	-3.2821
▁the	-3.56375
▁s	-3.84955
▁to	-3.87601
▁a	-3.9102
▁it	-3.97593
▁t	-4.25729
▁and	-4.32686


In [21]:
def read_sentencepiece_vocab(filepath):
  voc = []
  with open(filepath, encoding='utf-8') as fi:
    for line in fi:
      voc.append(line.split("\t")[0])
  # skip the first <unk> token
  voc = voc[1:]
  return voc

snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

Learnt vocab size: 31743
Sample tokens: ['▁irwin', '▁tarde', '▁breathe', '▁treating', '▁blaine', 'estigation', 'illus', '▁nathan', '▁orange', 'implicit']


SentencePiece는 WordPiece와 반대로 동작한다는 것을 알 수 있다. SentencePiece는 whitespace를 아래와 같이 "▁" (U+2581)로 변경한다.
```
Hello▁World.
```

그리고 문장을 쪼갠다.
```
[Hello] [▁Wor] [ld] [.]
```

따라서 "▁"가 있으면 없애고 아니면 "##"을 붙여 주어야 한다.

In [22]:
def parse_sentencepiece_token(token):
    if token.startswith("▁"):
        return token[1:]
    else:
        return "##" + token

In [23]:
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))

In [24]:
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab

BERT에 사용되는 문자들과 placeholder token들을 vocab에 더해준다.

In [25]:
bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))

32000


In [26]:
VOC_FNAME = "vocab.txt" #@param {type:"string"}

with open(VOC_FNAME, "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

In [27]:
bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
bert_tokenizer.tokenize(testcase)

2019-12-27 07:35:05,914 :  From /Users/kmryu/code/deep/bert_code_review/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.



['color',
 '##less',
 'geo',
 '##ther',
 '##mal',
 'subs',
 '##tation',
 '##s',
 'are',
 'generat',
 '##ing',
 'furious',
 '##ly']

# generating pre-training data

In [28]:
!mkdir ./shards
!split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
!ls ./shards/

split: illegal option -- d
usage: split [-a sufflen] [-b byte_count] [-l line_count] [-p pattern]
             [file [prefix]]


In [32]:
!ls ./shards/

shard_0000 shard_0001 shard_0002 shard_0003


In [33]:
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}
PROCESSES = 2 #@param {type:"integer"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}

In [34]:
XARGS_CMD = ("ls ./shards/ | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=./shards/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', 
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)