<a href="https://colab.research.google.com/github/chenqinkai/bert/blob/master/colab/generate_embedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import sys
import os
import tensorflow as tf
import pprint
import json

assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print('TPU address is', TPU_ADDRESS)

from google.colab import auth
auth.authenticate_user()
with tf.Session(TPU_ADDRESS) as session:
  print('TPU devices:')
  pprint.pprint(session.list_devices())

  # Upload credentials to TPU.
  with open('/content/adc.json', 'r') as f:
    auth_info = json.load(f)
  tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
  # Now credentials are set for all future sessions on this TPU.

TPU address is grpc://10.59.171.186:8470

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

TPU devices:
[_DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:CPU:0, CPU, -1, 4511277199797020198),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 10375007403089140325),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 8744274357333971060),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 3293099263645622010),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 7186576120902363608),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 6248051809526301553),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:4, TPU, 171

In [0]:
!rm -rf bert_repo
!test -d bert_repo || git clone https://github.com/chenqinkai/bert.git bert_repo
if not 'bert_repo' in sys.path:
  sys.path += ['bert_repo']

Cloning into 'bert_repo'...
remote: Enumerating objects: 3, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 385 (delta 0), reused 1 (delta 0), pack-reused 382[K
Receiving objects: 100% (385/385), 272.36 KiB | 3.24 MiB/s, done.
Resolving deltas: 100% (221/221), done.


In [0]:
BERT_MODEL = 'uncased_L-12_H-768_A-12' #@param {type:"string"}
BERT_PRETRAINED_DIR = 'gs://cloud-tpu-checkpoints/bert/' + BERT_MODEL
print('***** BERT pretrained directory: {} *****'.format(BERT_PRETRAINED_DIR))
!gsutil ls $BERT_PRETRAINED_DIR

***** BERT pretrained directory: gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12 *****
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_config.json
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.index
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/bert_model.ckpt.meta
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/checkpoint
gs://cloud-tpu-checkpoints/bert/uncased_L-12_H-768_A-12/vocab.txt


get embedding matrix for training data

In [0]:
MAX_LEN = 64 #@param {type:"integer"}
LAYER = -1 #@param {type:"integer"}
REMOVE_SW = False #@param {type:"boolean"}

USE_FINE_TUNED = True #@param {type:"boolean"}
FINE_TUNED_PATH = "gs://bert-news-sentiment/bert/models/reuters/horizon-3_percentile-10_epoch-4_batch-32_lr-2_maxlen-64/model.ckpt-6250"  #@param {type:"string"}

if not USE_FINE_TUNED:
    FINE_TUNED_PATH = BERT_PRETRAINED_DIR + "/bert_model.ckpt"
print("Model path: %s" % FINE_TUNED_PATH)

import nltk
nltk.download('stopwords')

Model path: gs://bert-news-sentiment/bert/models/reuters/horizon-3_percentile-10_epoch-4_batch-32_lr-2_maxlen-64/model.ckpt-6250
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

Generate training embedding

In [0]:
INPUT_FILE = "gs://bert-news-sentiment/reuters/horizon_3/training_horizon_3_percentile_10_headlines.txt"
OUTPUT_EMBEDDING_PATH = "gs://bert-news-sentiment/rnn/embedding/training_horizon_3_percentile_10_%s_layer_%d_maxlen_%d%s.npy" % ("tuned" if USE_FINE_TUNED else "base", -LAYER, MAX_LEN, "_nostop" if REMOVE_SW else "")

!python bert_repo/extract_features.py \
  --input_file=$INPUT_FILE \
  --output_file=$OUTPUT_EMBEDDING_PATH \
  --vocab_file=$BERT_PRETRAINED_DIR/vocab.txt \
  --bert_config_file=$BERT_PRETRAINED_DIR/bert_config.json \
  --init_checkpoint=$FINE_TUNED_PATH \
  --layers=$LAYER \
  --max_seq_length=$MAX_LEN \
  --batch_size=32 \
  --use_tpu=True \
  --use_one_hot_embeddings=True \
  --master=$TPU_ADDRESS \
  --remove_stopwords=$REMOVE_SW


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 0
INFO:tensorflow:tokens: [CLS] * rig working 46 miles off ghana for en ##i for $ 450 000 / day [SEP]
INFO:tensorflow:input_ids: 101 1008 19838 2551 4805 2661 2125 9701 2005 4372 2072 2005 1002 10332 2199 1013 2154 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:input_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 1
INFO:tensorflow:t

get embedding matrix for test data

!!! Use it only when MAX_LEN <= 32  !!!

In [0]:
INPUT_FILE = "gs://bert-news-sentiment/reuters/horizon_3/test_horizon_3_headlines.txt"
OUTPUT_EMBEDDING_PATH = "gs://bert-news-sentiment/rnn/embedding/test_horizon_3_%s_layer_%d_maxlen_%d%s.npy" % ("tuned" if USE_FINE_TUNED else "base", -LAYER, MAX_LEN, "_nostop" if REMOVE_SW else "")

!python bert_repo/extract_features.py \
  --input_file=$INPUT_FILE \
  --output_file=$OUTPUT_EMBEDDING_PATH \
  --vocab_file=$BERT_PRETRAINED_DIR/vocab.txt \
  --bert_config_file=$BERT_PRETRAINED_DIR/bert_config.json \
  --init_checkpoint=$FINE_TUNED_PATH \
  --layers=$LAYER \
  --max_seq_length=$MAX_LEN \
  --batch_size=32 \
  --use_tpu=True \
  --use_one_hot_embeddings=True \
  --master=$TPU_ADDRESS \
  --remove_stopwords=$REMOVE_SW


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 0
INFO:tensorflow:tokens: [CLS] president donald proposal sell half strategic petroleum reserves likely little impact efforts reduce global oil g ##lu ##t goldman sachs said [SEP]
INFO:tensorflow:input_ids: 101 2343 6221 6378 5271 2431 6143 11540 8269 3497 2210 4254 4073 5547 3795 3514 1043 7630 2102 17765 22818 2056 102 0 0 0 0 0 0 0 0 0
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0
INFO:tensorflow:input_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 1
INFO:tensorflow:tokens: [CLS] sets regular quarterly divide ##nd source text further company [SEP]
INFO:tensorflow:

Run this cell if MAX_LEN is larger than 64.

It will separate data into 2 parts to avoid memeory exceed error

In [0]:
INPUT_FILE = "gs://bert-news-sentiment/reuters/horizon_3/test_horizon_3_headlines_part1.txt"
OUTPUT_EMBEDDING_PATH = "gs://bert-news-sentiment/rnn/embedding/test_horizon_3_%s_layer_%d_maxlen_%d%s_part1.npy" % ("tuned" if USE_FINE_TUNED else "base", -LAYER, MAX_LEN, "_nostop" if REMOVE_SW else "")

!python bert_repo/extract_features.py \
  --input_file=$INPUT_FILE \
  --output_file=$OUTPUT_EMBEDDING_PATH \
  --vocab_file=$BERT_PRETRAINED_DIR/vocab.txt \
  --bert_config_file=$BERT_PRETRAINED_DIR/bert_config.json \
  --init_checkpoint=$FINE_TUNED_PATH \
  --layers=$LAYER \
  --max_seq_length=$MAX_LEN \
  --batch_size=32 \
  --use_tpu=True \
  --use_one_hot_embeddings=True \
  --master=$TPU_ADDRESS \
  --remove_stopwords=$REMOVE_SW

INPUT_FILE = "gs://bert-news-sentiment/reuters/horizon_3/test_horizon_3_headlines_part2.txt"
OUTPUT_EMBEDDING_PATH = "gs://bert-news-sentiment/rnn/embedding/test_horizon_3_%s_layer_%d_maxlen_%d%s_part2.npy" % ("tuned" if USE_FINE_TUNED else "base", -LAYER, MAX_LEN, "_nostop" if REMOVE_SW else "")

!python bert_repo/extract_features.py \
  --input_file=$INPUT_FILE \
  --output_file=$OUTPUT_EMBEDDING_PATH \
  --vocab_file=$BERT_PRETRAINED_DIR/vocab.txt \
  --bert_config_file=$BERT_PRETRAINED_DIR/bert_config.json \
  --init_checkpoint=$FINE_TUNED_PATH \
  --layers=$LAYER \
  --max_seq_length=$MAX_LEN \
  --batch_size=32 \
  --use_tpu=True \
  --use_one_hot_embeddings=True \
  --master=$TPU_ADDRESS \
  --remove_stopwords=$REMOVE_SW


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 0
INFO:tensorflow:tokens: [CLS] president donald trump ' s proposal to sell half of the u . s . strategic petroleum reserves ( sp ##r ) will likely have little impact on op ##ec ' s efforts to reduce a global oil g ##lu ##t goldman sachs said on tuesday . [SEP]
INFO:tensorflow:input_ids: 101 2343 6221 8398 1005 1055 6378 2000 5271 2431 1997 1996 1057 1012 1055 1012 6143 11540 8269 1006 11867 2099 1007 2097 3497 2031 2210 4254 2006 6728 8586 1005 1055 4073 2000 5547 1037 3795 3514 1043 7630 2102 17765 22818 2056 2006 9857 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0