In [2]:
import tensorflow as tf
import numpy as np
from bert_serving.server.graph import optimize_graph
from bert_serving.server.helper import get_args_parser
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.keras.utils import Progbar
from bert_serving.server.bert.tokenization import FullTokenizer
from bert_serving.server.bert.extract_features import convert_lst_to_features

In [3]:
import logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)
log.handlers = []

In [4]:
sesh = tf.InteractiveSession()

In [5]:
MODEL_DIR = './bert/uncased_L-12_H-768_A-12/'
GRAPH_DIR = './bert/graph/'
GRAPH_OUT = 'extractor.pbtxt'

POOL_STRAT = 'REDUCE_MEAN'
POOL_LAYER = '-2'
SEQ_LEN = '256'

tf.gfile.MkDir(GRAPH_DIR)

In [6]:
import os

In [7]:
%%time
parser = get_args_parser()
carg = parser.parse_args(args=['-model_dir', MODEL_DIR,
                               '-graph_tmp_dir', GRAPH_DIR,
                               '-max_seq_len', str(SEQ_LEN),
                               '-pooling_layer', str(POOL_LAYER),
                               '-pooling_strategy', POOL_STRAT])

tmp_name, config = optimize_graph(carg)
graph_fout = os.path.join(GRAPH_DIR, GRAPH_OUT)

tf.gfile.Rename(
    tmp_name,
    graph_fout,
    overwrite=True
)
print("\nSerialized graph to {}".format(graph_fout))

From C:\Users\sutar\Anaconda3\lib\site-packages\bert_serving\server\helper.py:186: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

From C:\Users\sutar\Anaconda3\lib\site-packages\bert_serving\server\helper.py:186: The name tf.logging.ERROR is deprecated. Please use tf.compat.v1.logging.ERROR instead.



I:[36mGRAPHOPT[0m:model config: ./bert/uncased_L-12_H-768_A-12/bert_config.json
I:[36mGRAPHOPT[0m:checkpoint: ./bert/uncased_L-12_H-768_A-12/bert_model.ckpt
I:[36mGRAPHOPT[0m:build graph...
I:[36mGRAPHOPT[0m:load parameters from checkpoint...
I:[36mGRAPHOPT[0m:optimize...
I:[36mGRAPHOPT[0m:freeze...
I:[36mGRAPHOPT[0m:write graph to a tmp file: E:\ML\BE Project\bert\graph\tmpyj5lb1b9

Serialized graph to ./bert/graph/extractor.pbtxt
Wall time: 11.7 s


In [8]:
GRAPH_PATH = "./bert/graph/extractor.pbtxt"
VOCAB_PATH = "./bert/uncased_L-12_H-768_A-12/vocab.txt"
SEQ_LEN = 256

In [9]:
INPUT_NAMES = ['input_ids', 'input_mask', 'input_type_ids']
bert_tokenizer = FullTokenizer(VOCAB_PATH)

In [10]:
def build_feed_dict(texts):
    
    text_features = list(convert_lst_to_features(
        texts, SEQ_LEN, SEQ_LEN, 
        bert_tokenizer, log, False, False))

    target_shape = (len(texts), -1)

    feed_dict = {}
    for iname in INPUT_NAMES:
        features_i = np.array([getattr(f, iname) for f in text_features])
        features_i = features_i.reshape(target_shape).astype("int32")
        feed_dict[iname] = features_i

    return feed_dict

In [11]:
def build_input_fn(container):
    
    def gen():
        while True:
            try:
                yield build_feed_dict(container.get())
            except:
                yield build_feed_dict(container.get())

    def input_fn():
        return tf.data.Dataset.from_generator(
            gen,
            output_types={iname: tf.int32 for iname in INPUT_NAMES},
            output_shapes={iname: (None, None) for iname in INPUT_NAMES})
    return input_fn

In [12]:
class DataContainer:
    def __init__(self):
        self._texts = None
  
    def set(self, texts):
        if type(texts) is str:
              texts = [texts]
        self._texts = texts
    
    def get(self):
        return self._texts

In [13]:
def model_fn(features, mode):
    with tf.gfile.GFile(GRAPH_PATH, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        
    output = tf.import_graph_def(graph_def,
                                 input_map={k + ':0': features[k] for k in INPUT_NAMES},
                                 return_elements=['final_encodes:0'])

    return EstimatorSpec(mode=mode, predictions={'output': output[0]})
  
estimator = Estimator(model_fn=model_fn)

In [14]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def build_vectorizer(_estimator, _input_fn_builder, batch_size=128):
    container = DataContainer()
    predict_fn = _estimator.predict(_input_fn_builder(container), yield_single_examples=False)
  
    def vectorize(text, verbose=False):
        x = []
        bar = Progbar(len(text))
        for text_batch in batch(text, batch_size):
            container.set(text_batch)
            x.append(next(predict_fn)['output'])
            if verbose:
                bar.add(len(text_batch))
      
        r = np.vstack(x)
        return r
  
    return vectorize

In [15]:
bert_vectorizer = build_vectorizer(estimator, build_input_fn)

In [16]:
%%time
bert_vectorizer(64*['sample text']).shape

Wall time: 46.8 s


(64, 768)

In [17]:
import pandas as pd

In [18]:
df = pd.read_csv('./datasets/keywords.csv')

In [19]:
df.head()

Unnamed: 0.1,Unnamed: 0,project,keyword,file
0,0,integrated spice modeling/simulation of circui...,circuit element,./reports_doc/1.pdf
1,1,integrated spice modeling/simulation of circui...,circuit simulation,./reports_doc/1.pdf
2,2,integrated spice modeling/simulation of circui...,circuit spice,./reports_doc/1.pdf
3,3,integrated spice modeling/simulation of circui...,component library,./reports_doc/1.pdf
4,4,integrated spice modeling/simulation of circui...,current source,./reports_doc/1.pdf


In [20]:
df.drop('Unnamed: 0', axis=1, inplace=True)

In [21]:
df.head()

Unnamed: 0,project,keyword,file
0,integrated spice modeling/simulation of circui...,circuit element,./reports_doc/1.pdf
1,integrated spice modeling/simulation of circui...,circuit simulation,./reports_doc/1.pdf
2,integrated spice modeling/simulation of circui...,circuit spice,./reports_doc/1.pdf
3,integrated spice modeling/simulation of circui...,component library,./reports_doc/1.pdf
4,integrated spice modeling/simulation of circui...,current source,./reports_doc/1.pdf


In [None]:
%%time
bert_feature_matrix = pd.DataFrame(bert_vectorizer(list(df['keyword'])))