# pre-training BERT
# 아래 링크의 코드를 그대로 가져옴
- https://colab.research.google.com/drive/1nVn6AFpQSzXBt8_ywfx6XR8ZfQXlKGAz#scrollTo=myjxQe5awo1v

# Install packages

In [None]:
!pip install tensorflow==1.15
!pip install nltk

In [None]:
!pip install sentencepiece
!git clone https://github.com/google-research/bert

# import & set logging

In [5]:
import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm

from glob import glob
from tensorflow.keras.utils import Progbar

sys.path.append("bert")

from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder




In [6]:
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s :  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]

In [7]:
USE_TPU = False

# Download dataset

In [None]:
!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.en.gz -O dataset.txt.gz
!gzip -d dataset.txt.gz
!tail dataset.txt

DEMO_MODE = True #@param {type:"boolean"}

if DEMO_MODE:
  CORPUS_SIZE = 10000
else:
  CORPUS_SIZE = 100000000 #@param {type: "integer"}
  
!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt
!mv subdataset.txt dataset.txt

# preprocess text
Remove punctuation, uppercase letters and non-UTF symbols

In [8]:
regex_tokenizer = nltk.RegexpTokenizer("\w+")

def normalize_text(text):
  # lowercase text
  text = str(text).lower()
  # remove non-UTF
  text = text.encode("utf-8", "ignore").decode()
  # remove punktuation symbols
  text = " ".join(regex_tokenizer.tokenize(text))
  return text

def count_lines(filename):
  count = 0
  with open(filename) as fi:
    for line in fi:
      count += 1
  return count

In [9]:
normalize_text('Thanks to the advance, they have succeeded in getting over their adversaries.')

'thanks to the advance they have succeeded in getting over their adversaries'

In [10]:
RAW_DATA_FPATH = "dataset.txt" #@param {type: "string"}
PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}

# apply normalization to the dataset
# this will take a minute or two

total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)

with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
  with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
    for l in fi:
      fo.write(normalize_text(l)+"\n")
      bar.add(1)



In [11]:
!tail proc_dataset.txt

that s nice
sounds like the left bank s running lean
the service department s over there
can i talk to you
talk to me
yeah sure talk
well thing of it is i came back because
do you know what time it is
it s 5 30
i came back because


# building the vocabulary
BERT는 WordPiece tokenizer를 사용했으나, open source가 아님. 그래서 대신 SentencePiece tokenizer를 unigram mode로 사용하려 한다. 이건 BERT에 바로 적용이 안되고 몇가지 트릭을 사용해야 한다.
SentencePiece는 RAM을 엄청 많이 사용하므로, 바로 돌리면 crash 된다. 따라서 randomly subsample을 돌리기로 한다.
그리고 SentencePiece는 BOS, EOS symbol을 자동으로 더해주기 때문에 이를 막기 위해 저 symbol들의 index를 -1로 둔다.
NUM_PLACEHOLDERS는 fine-tune을 위해 예비로 남겨두는 자리이다.

In [12]:
MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE = 32000 #@param {type:"integer"}
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}
NUM_PLACEHOLDERS = 256 #@param {type:"integer"}

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} --input_sentence_size={} '
               '--shuffle_input_sentence=true ' 
               '--bos_id=-1 --eos_id=-1').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)

spm.SentencePieceTrainer.Train(SPM_COMMAND)

True

In [13]:
testcase = "Colorless geothermal substations are generating furiously"

```
>>> wordpiece.tokenize("Colorless geothermal substations are generating furiously")
 
['color',
 '##less',
 'geo',
 '##thermal',
 'sub',
 '##station',
 '##s',
 'are',
 'generating',
 'furiously']
 ```

위에 보는 대로, wordpiece tokenizer는 subword 들 중에서 중간에 오는 단어들에 '##'을 붙여준다. 

In [14]:
!ls

README.md             proc_dataset.txt      tokenizer.vocab
[1m[36mbert[m[m                  requirements.txt      [1m[36mvenv[m[m
dataset.txt           [1m[36mshards[m[m                vocab.txt
preprocess_bert.ipynb tokenizer.model


SentencePiece는 두개의 파일을 남긴다.
- tokenizer.model 
- tokenizer.vocab

In [15]:
!head -n 10 tokenizer.vocab

<unk>	0
▁you	-3.2342
▁i	-3.2821
▁the	-3.56375
▁s	-3.84955
▁to	-3.87601
▁a	-3.9102
▁it	-3.97593
▁t	-4.25729
▁and	-4.32686


In [16]:
def read_sentencepiece_vocab(filepath):
  voc = []
  with open(filepath, encoding='utf-8') as fi:
    for line in fi:
      voc.append(line.split("\t")[0])
  # skip the first <unk> token
  voc = voc[1:]
  return voc

snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

Learnt vocab size: 31743
Sample tokens: ['▁alana', 'potatoes', '▁defend', 'sailing', 'artman', '▁head', 'pb', 'sation', 'nette', '▁serpent']


SentencePiece는 WordPiece와 반대로 동작한다는 것을 알 수 있다. SentencePiece는 whitespace를 아래와 같이 "▁" (U+2581)로 변경한다.
```
Hello▁World.
```

그리고 문장을 쪼갠다.
```
[Hello] [▁Wor] [ld] [.]
```

따라서 "▁"가 있으면 없애고 아니면 "##"을 붙여 주어야 한다.

In [17]:
def parse_sentencepiece_token(token):
    if token.startswith("▁"):
        return token[1:]
    else:
        return "##" + token

In [18]:
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))

In [19]:
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab

BERT에 사용되는 문자들과 placeholder token들을 vocab에 더해준다.

In [20]:
bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))

32000


In [21]:
VOC_FNAME = "vocab.txt" #@param {type:"string"}

with open(VOC_FNAME, "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

In [22]:
bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
bert_tokenizer.tokenize(testcase)

2019-12-28 18:19:05,423 :  From /Users/kmryu/code/deep/bert_code_review/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.



['color',
 '##less',
 'geo',
 '##ther',
 '##mal',
 'subs',
 '##tation',
 '##s',
 'are',
 'generat',
 '##ing',
 'furious',
 '##ly']

# generating pre-training data

In [28]:
!mkdir ./shards
!split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
!ls ./shards/

split: illegal option -- d
usage: split [-a sufflen] [-b byte_count] [-l line_count] [-p pattern]
             [file [prefix]]


In [32]:
!ls ./shards/

shard_0000 shard_0001 shard_0002 shard_0003


In [2]:
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}
PROCESSES = 2 #@param {type:"integer"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}

In [23]:
XARGS_CMD = ("ls ./shards/ | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=./shards/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', 
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)

In [24]:
tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD





W1228 18:19:22.213583 4766645696 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

W1228 18:19:22.213608 4493614528 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.



W1228 18:19:22.213773 4766645696 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.

W1228 18:19:22.213793 4493614528 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W1228 18:19:22.213932 4766645696 module_wrapper.py:139] From /Users/kmryu/code/deep/bert_code_review/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.


W1228 18:19

INFO:tensorflow:*** Writing to output files ***
I1228 18:20:35.434767 4493614528 create_pretraining_data.py:457] *** Writing to output files ***
INFO:tensorflow:  pretraining_data/shard_0001.tfrecord
I1228 18:20:35.434947 4493614528 create_pretraining_data.py:459]   pretraining_data/shard_0001.tfrecord

W1228 18:20:35.435140 4493614528 module_wrapper.py:139] From bert/create_pretraining_data.py:101: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.

INFO:tensorflow:*** Example ***
I1228 18:20:35.435830 4493614528 create_pretraining_data.py:149] *** Example ***
INFO:tensorflow:tokens: [CLS] s [MASK] what has happened vissing call an ambulance now vissing wake up it s bjarne vissing bi there is an ambulance on its way his [MASK] [MASK] jesper [MASK] is your son arthur knew that jesper is [MASK] [MASK] that s why it was arthur who [UNUSED_222] her gotcha morphine it was of course you wrote the story about the drug addicts in ny [MASK] ##

INFO:tensorflow:Wrote 91978 total instances
I1228 18:20:54.001211 4766645696 create_pretraining_data.py:166] Wrote 91978 total instances


W1228 18:20:57.004986 4480216512 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.


W1228 18:20:57.005164 4480216512 module_wrapper.py:139] From bert/create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W1228 18:20:57.005317 4480216512 module_wrapper.py:139] From /Users/kmryu/code/deep/bert_code_review/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.


W1228 18:20:57.093976 4480216512 module_wrapper.py:139] From bert/create_pretraining_data.py:444: The name tf.gfile.Glob is deprecated. Please use tf.io.gfile.glob instead.


W1228 18:20:57.096812 4480216512 module_wrapper.py:139] From bert/create_pretraining_data.

INFO:tensorflow:*** Writing to output files ***
I1228 18:22:04.587583 4492705216 create_pretraining_data.py:457] *** Writing to output files ***
INFO:tensorflow:  pretraining_data/shard_0003.tfrecord
I1228 18:22:04.587751 4492705216 create_pretraining_data.py:459]   pretraining_data/shard_0003.tfrecord

W1228 18:22:04.587924 4492705216 module_wrapper.py:139] From bert/create_pretraining_data.py:101: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.

INFO:tensorflow:*** Example ***
I1228 18:22:04.588566 4492705216 create_pretraining_data.py:149] *** Example ***
INFO:tensorflow:tokens: [CLS] what good s having [MASK] unless you spend it huh this is this is so extravagant i i m embarrassed don midnight take up smoking but if it s good enough for ##toilet duke of windsor [MASK] well i m really [MASK] and my little bernard shaw is so [MASK] no [SEP] you to [MASK] how unbelievabl ##y proud i am of you david this is [MASK] insignificant mayb

INFO:tensorflow:masked_lm_ids: 3614 9 223 23 262 79 5 38 1447 9 1802 583 5 628 4101 143 79 6 628 0
I1228 18:22:04.602460 4492705216 create_pretraining_data.py:161] masked_lm_ids: 3614 9 223 23 262 79 5 38 1447 9 1802 583 5 628 4101 143 79 6 628 0
INFO:tensorflow:masked_lm_weights: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0
I1228 18:22:04.602529 4492705216 create_pretraining_data.py:161] masked_lm_weights: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0
INFO:tensorflow:next_sentence_labels: 0
I1228 18:22:04.602587 4492705216 create_pretraining_data.py:161] next_sentence_labels: 0
INFO:tensorflow:*** Example ***
I1228 18:22:04.603002 4492705216 create_pretraining_data.py:149] *** Example ***
INFO:tensorflow:tokens: [CLS] all right all right lt [MASK] me he s he s here ln the red snappy ##y nova uh that s [MASK] sir [MASK] over all right as soon as he leaves you call me with the owners addresses spirits [MASK] i 

INFO:tensorflow:Wrote 88775 total instances
I1228 18:22:24.664735 4492705216 create_pretraining_data.py:166] Wrote 88775 total instances
INFO:tensorflow:Wrote 94522 total instances
I1228 18:22:24.705079 4480216512 create_pretraining_data.py:166] Wrote 94522 total instances


---
# create_pretraining_data 코드 리뷰

In [25]:
XARGS_CMD

'ls ./shards/ | xargs -n 1 -P 2 -I{} python3 bert/create_pretraining_data.py --input_file=./shards/{} --output_file=pretraining_data/{}.tfrecord --vocab_file=vocab.txt --do_lower_case=True --max_predictions_per_seq=20 --max_seq_length=128 --masked_lm_prob=0.15 --random_seed=34 --dupe_factor=5'

In [None]:
flags = tf.flags

FLAGS = flags.FLAGS

flags.DEFINE_string("input_file", None,
                    "Input raw text file (or comma-separated list of files).")

flags.DEFINE_string(
    "output_file", None,
    "Output TF example file (or comma-separated list of files).")

flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_bool(
    "do_whole_word_mask", False,
    "Whether to use whole word masking rather than per-WordPiece masking.")

flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")

flags.DEFINE_integer("max_predictions_per_seq", 20,
                     "Maximum number of masked LM predictions per sequence.")

flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")

flags.DEFINE_integer(
    "dupe_factor", 10,
    "Number of times to duplicate the input data (with different masks).")

flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")

flags.DEFINE_float(
    "short_seq_prob", 0.1,
    "Probability of creating sequences which are shorter than the "
    "maximum length.")

FLAGS는 tf의 command line options를 담아놓는 객체로 보인다.
debugging으로 확인해보면 위의 옵션들은 다음과 같다.
```
input_file = {str} './shards/shard_0000'
output_file = {str} 'pretraining_data/shard_0000.tfrecord'
vocab_file = {str} 'vocab.txt'

do_lower_case = {bool} True
do_whole_word_mask = {bool} False
max_seq_length = {int} 128
max_predictions_per_seq = {int} 20

dupe_factor = {int} 5
masked_lm_prob = {float} 0.15
short_seq_prob = {float} 0.1
```

main 함수는 로깅을 제외하면 4가지로 볼 수 있다.
- tokenizer 선언
- input_files 선언
- create_training_instances
- write_instances_to_example_files

In [27]:
def main(_):
  tokenizer = tokenization.FullTokenizer(
    vocab_file=FLAGS.vocab_file,
    do_lower_case=FLAGS.do_lower_case
  )
  
  input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.gfile.Glob(input_pattern))
    
  instances = create_training_instances(
      input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
      FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
      rng)
  
  write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
                                  FLAGS.max_predictions_per_seq, output_files)

여기서 tokenization은 bert의 tokenization.py 이다.
위에서 이 tokenization을 사용해서 test를 했었다.

In [31]:
tokenizer = tokenization.FullTokenizer(VOC_FNAME)
tokenizer.tokenize(testcase)

['color',
 '##less',
 'geo',
 '##ther',
 '##mal',
 'subs',
 '##tation',
 '##s',
 'are',
 'generat',
 '##ing',
 'furious',
 '##ly']

input_files는 command line option 에서 받은 그 파일들을 말한다.
아래 코드의 결과물은 다음과 같다.
```
input_files = <class 'list'>: ['../shards/shard_0000']
```

In [None]:
input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.gfile.Glob(input_pattern))

create_training_instances 에서는 input_files 에서 한줄씩 읽어서 training instance로 변환한다.

In [None]:
instances = create_training_instances(
      input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
      FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
      rng)