## Fine-Tuning the Mueller Report

This notebook pretrains BERT on text from the Mueller report. Be sure to change your runtime to TPU to speed up training.



In [3]:
# Install SentencePiece and and download all BERT modules.

!pip install sentencepiece
!git clone https://github.com/google-research/bert

Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/00/95/7f357995d5eb1131aa2092096dca14a6fc1b1d2860bd99c22a612e1d1019/sentencepiece-0.1.82-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 3.5MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.82
Cloning into 'bert'...
remote: Enumerating objects: 325, done.[K
remote: Total 325 (delta 0), reused 0 (delta 0), pack-reused 325[K
Receiving objects: 100% (325/325), 232.46 KiB | 3.18 MiB/s, done.
Resolving deltas: 100% (186/186), done.


**Next we import all of our dependencies including libraries from BERT library.**

In [0]:
import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm

from glob import glob
from google.colab import auth, drive, files
from tensorflow.keras.utils import Progbar

sys.path.append("bert")

from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder


**Authenticate to be able to use TPUs.**

In [0]:
auth.authenticate_user()
  
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)

# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s :  %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]

if 'COLAB_TPU_ADDR' in os.environ:
  log.info("Using TPU runtime")
  USE_TPU = True
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']

  with tf.Session(TPU_ADDRESS) as session:
    log.info('TPU address is ' + TPU_ADDRESS)
    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    
else:
  log.warning('Not connected to TPU runtime')
  USE_TPU = False

**Next we upload our *clean_sents.txt* file for pre-processing.**

In [1]:
files.upload()

NameError: ignored

In [0]:
with open('./clean_sents.txt', 'rb') as f:
    clean_sents = [line[:-1].decode('utf-8') for line in f] 

In [0]:
clean_sents

['U. S. Department of Justice Attorney Werk Preduet// May Contain Material Protected Under Fed.',
 'R. Crim.',
 'P. 6 (e)',
 'Report On The Investigation Into',
 'Russian Interference In The 2016 Presidential Election',
 'Volume I of II',
 'Special Counsel Robert S. Mueller, III',
 'Submitted Pursuant to 28 C. F. R. $600.8 (c)',
 'Washington, D. C.',
 'March 2019',
 'U. S. Department of Justice Attorney Worle Product// May Contain Material Protected Underted R. Erim Pfe)',
 'U. S. Department of Justice Attorney Worle Prodret// May Contain Material Proteeted Under Fed.',
 'R. Erim.',
 'P - 6 (e)',
 'TABLE OF CONTENTS - VOLUME I',
 'INTRODUCTION TO VOLUMEI... EXECUTIVE SUMMARY TO VOLUME I.........................................',
 'I.',
 '<REDACTED>',
 '<REDACTED>',
 'A.',
 'Structure of the Internet Research Agency.................................................... B.',
 'Funding and Oversight from Concord and Prigozhin C. The IRA Targets U. S. Elections........',
 '1.',
 'The IRA Ram

In [0]:
# This will be used to filter out redacted sentences and sents before and after and those that are 
# 2 characters or less. 

# no_redact = clean_sents.copy()

# for sent in no_redact:
#   sent_index = no_redact.index(sent)
#   if sent == '<REDACTED>':
#     del no_redact[sent_index]
#   elif len(sent) <= 2:
#     del no_redact[sent_index]

# vocab_size = 0

# with open('no_redact.txt', 'w') as f:
#   for line in no_redact:
#     f.write(line+'\n')
#     for i in line:
#       vocab_size += 1


**Use NLTK's** [RegexpTokenizer](https://www.nltk.org/_modules/nltk/tokenize/regexp.html) **to remove punctuation uppercase letters and non-UTF symbols.**

In [0]:
regex_tokenizer = nltk.RegexpTokenizer("\w+")

def normalize_text(text):
  # lowercase text
  text = str(text).lower()
  # remove non-UTF
  text = text.encode("utf-8", "ignore").decode()
  # remove punktuation symbols
  text = " ".join(regex_tokenizer.tokenize(text))
  return text

def count_lines(filename):
  count = 0
  with open(filename) as fi:
    for line in fi:
      count += 1
  return count

In [0]:
RAW_DATA_FPATH = "clean_sents.txt" #@param {type: "string"}
PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}

# apply normalization to the dataset
# this will take a minute or two

total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)

with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
  with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
    for l in fi:
      fo.write(normalize_text(l)+"\n")
      bar.add(1)



In [2]:
!head -n 5 proc_dataset.txt

head: cannot open 'proc_dataset.txt' for reading: No such file or directory


In [0]:
MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE = 6850  #@param {type:"integer"}
SUBSAMPLE_SIZE = 500000 #@param {type:"integer"}
NUM_PLACEHOLDERS = -1 #@param {type:"integer"}

SPM_COMMAND = ('--input={} --model_prefix={} '
               '--vocab_size={} --input_sentence_size={} '
               '--shuffle_input_sentence=false ' 
               '--bos_id=-1 --eos_id=-1').format(
               PRC_DATA_FPATH, MODEL_PREFIX, 
               VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)

spm.SentencePieceTrainer.Train(SPM_COMMAND)

True

Since the WordPiece tokenizer used in the paper is not open source, instead the SentencePice tokenizer will be used in it's place.

In [0]:
def read_sentencepiece_vocab(filepath):
  voc = []
  with open(filepath, encoding='utf-8') as fi:
    for line in fi:
      voc.append(line.split("\t")[0])
  # skip the first <unk> token
  voc = voc[1:]
  return voc

snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

Learnt vocab size: 6850
Sample tokens: ['aw', '▁tillerson', '▁omit', '▁waiting', '▁meg', 'nolog', '▁televis', '▁date', '▁1512', '▁event']


**Here we use a small hack to replace the '_' which is not compatible with the BERT model with "#" which is what it expects.**


In [0]:
def parse_sentencepiece_token(token):
    if token.startswith("▁"):
        return token[1:]
    else:
        return "##" + token
        
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))

ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab

bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))

6855


**We now rewrite the vocabulary to a file.**

In [0]:
VOC_FNAME = "vocab.txt" #@param {type:"string"}

with open(VOC_FNAME, "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

**Let's test it out to make sure it works as expected.**

In [0]:
testcase = "Colorless geothermal substations are generating furiously"
bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
bert_tokenizer.tokenize(testcase)

['col',
 '##or',
 '##less',
 'ge',
 '##other',
 '##ma',
 '##l',
 'sub',
 '##sta',
 '##tion',
 '##s',
 'are',
 'generat',
 '##ing',
 'fur',
 '##iously']

**To account for the size of our dataset we will split it into shards.**

In [0]:
!mkdir ./shards
!split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_

**For each shard we call the creat_pretraining_data.py which can be found in the [BERT](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) GitHub repository.**

In [0]:
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}

PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
# controls how many parallel processes xargs can create
PROCESSES = 2 #@param {type:"integer"}

In [0]:
XARGS_CMD = ("ls ./shards/ | "
             "xargs -n 1 -P {} -I{} "
             "python3 bert/create_pretraining_data.py "
             "--input_file=./shards/{} "
             "--output_file={}/{}.tfrecord "
             "--vocab_file={} "
             "--do_lower_case={} "
             "--max_predictions_per_seq={} "
             "--max_seq_length={} "
             "--masked_lm_prob={} "
             "--random_seed=34 "
             "--dupe_factor=5")

XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}', 
                             VOC_FNAME, DO_LOWER_CASE, 
                             MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)
                             
tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD

INFO:tensorflow:*** Reading from input files ***
INFO:tensorflow:  ./shards/shard_0000
INFO:tensorflow:*** Writing to output files ***
INFO:tensorflow:  pretraining_data/shard_0000.tfrecord
INFO:tensorflow:*** Example ***
INFO:tensorflow:tokens: [CLS] the wash ##ing [MASK] post publishe ##d an [MASK] holl ##y ##w ##ood video that capture ##d comments by candidate trump some years earlie ##r and that [MASK] expected to adverse ##ly ##ities the campaign 239 less than an [MASK] after [MASK] video s publication [MASK] released the first set of emails stolen [MASK] the [MASK] from the account of clinton campaign chairman john [MASK] [SEP] redacted redacted redacted redacted redacted [MASK] [MASK] that beca ##use [MASK] had no direct means of communicat ##ing with wikileaks he told [MASK] of [MASK] news site wnd who were participati ##ng [MASK] a conference [MASK] with him that day to reach assange immediately [MASK] corsi claim ##ed that the pressure was [SEP]
INFO:tensorflow:input_ids: 2 5

**Provide a GCS bucket name where these files will be stored.**

In [0]:
BUCKET_NAME = "manceps-mueller" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
tf.gfile.MkDir(MODEL_DIR)

if not BUCKET_NAME:
  log.warning("WARNING: BUCKET_NAME is not set. "
              "You will not be able to train the model.")

**Define hyperparameters for BERT-base.**

In [0]:
# use this for BERT-base

bert_base_config = {
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": VOC_SIZE
}

with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo:
  json.dump(bert_base_config, fo, indent=2)
  
with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo:
  for token in bert_vocab:
    fo.write(token+"\n")

In [0]:
if BUCKET_NAME:
  !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME

Copying file://pretraining_data/shard_0000.tfrecord [Content-Type=application/octet-stream]...
Copying file://bert_model/bert_config.json [Content-Type=application/json]...
Copying file://bert_model/vocab.txt [Content-Type=text/plain]...
/ [3/3 files][  9.0 MiB/  9.0 MiB] 100% Done                                    
Operation completed over 3 objects/9.0 MiB.                                      


**Some parameters used here are from previous steps so make sure these parameters are set exactly the same across this notebook.**

In [0]:
BUCKET_NAME = "manceps-mueller" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
VOC_FNAME = "vocab.txt" #@param {type:"string"}

# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param

# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 100000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}
NUM_TPU_CORES = 8

if BUCKET_NAME:
  BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
  BUCKET_PATH = "."

BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)

VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")

INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)

bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))

log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))

2019-05-30 02:39:43,093 :  Using checkpoint: gs://manceps-mueller/bert_model/model.ckpt-2500
2019-05-30 02:39:43,094 :  Using 1 data shards


**Prepare the training run configuration, build the extimator and input function.**

In [0]:
model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=INIT_CHECKPOINT,
      learning_rate=LEARNING_RATE,
      num_train_steps=TRAIN_STEPS,
      num_warmup_steps=10,
      use_tpu=USE_TPU,
      use_one_hot_embeddings=True)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=BERT_GCS_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=USE_TPU,
    model_fn=model_fn,
    config=run_config,
    train_batch_size=TRAIN_BATCH_SIZE,
    eval_batch_size=EVAL_BATCH_SIZE)
  
train_input_fn = input_fn_builder(
        input_files=input_files,
        max_seq_length=MAX_SEQ_LENGTH,
        max_predictions_per_seq=MAX_PREDICTIONS,
        is_training=True)

2019-05-30 02:39:55,560 :  Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x7f6d09b9c840>) includes params argument, but params are not passed to Estimator.
2019-05-30 02:39:55,563 :  Using config: {'_model_dir': 'gs://manceps-mueller/bert_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 2500, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.75.18.10:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f6d0fa747b8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.75.18.10:8470', '_evaluati

**Begin training the model. This will take quite some time so be patient. If the kernel restarts you can restart from the last checkpoint.**

In [0]:
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)


2019-05-30 02:41:15,277 :  Querying Tensorflow master (grpc://10.75.18.10:8470) for TPU system metadata.
2019-05-30 02:41:15,299 :  Found TPU system:
2019-05-30 02:41:15,300 :  *** Num TPU Cores: 8
2019-05-30 02:41:15,301 :  *** Num TPU Workers: 1
2019-05-30 02:41:15,310 :  *** Num TPU Cores Per Worker: 8
2019-05-30 02:41:15,312 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 7822140648737772533)
2019-05-30 02:41:15,314 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 5236194508695992152)
2019-05-30 02:41:15,316 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 6119358150869997627)
2019-05-30 02:41:15,319 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 10956677226285470421)
2019-05-30 02:41:15,321 :  *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/d

KeyboardInterrupt: ignored