In [1]:
import json
import os
import random
import sys

import tensorflow as tf
import numpy as np
import pickle
from bert_serving.client import BertClient, ConcurrentBertClient
from tensorflow.estimator import BaselineClassifier
from tensorflow.python.estimator.canned.dnn import DNNClassifier
from tensorflow.python.estimator.run_config import RunConfig
from tensorflow.python.estimator.training import TrainSpec, EvalSpec, train_and_evaluate

tf.logging.set_verbosity(tf.logging.INFO)

In [2]:
batch_size = 128
num_parallel_calls = 1
bc = BertClient()

In [10]:
def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]

In [11]:
encode_count = 0
def encode(chunk):
    global encode_count
    print('Chunk {}'.format(encode_count))
    encode_count += 1
    return bc.encode(chunk)

In [23]:
def get_encodes(data):
    # x is `batch_size` of lines, each of which is a json object
    features = np.array([])
    text = [x[0] for x in data]
    text = text[len(text)-50:]                                                         # take the last 50 words
    features = np.concatenate([encode(chunk) for chunk in chunks(text, 256)])
    # randomly choose a label
    labels = [x[1] for x in data]
    return features, labels

In [19]:
def cache_data(data_dir, dest_dir, start_chunk, end_chunk):
    pos_files = os.listdir(os.path.join(data_dir, 'pos'))
    neg_files = os.listdir(os.path.join(data_dir, 'neg'))
        
    data = []
    for pos_file, neg_file in zip(pos_files, neg_files):
        with open(os.path.join(data_dir, 'pos', pos_file)) as f:
            review = f.readlines()[0].strip()
            data.append((review, 1))
        with open(os.path.join(data_dir, 'neg', neg_file)) as f:
            review = f.readlines()[0].strip()
            data.append((review, 0))
    chunk_num = -1
    chunk_size = 2048
    for chunk in chunks(data, chunk_size):
        chunk_num += 1
        if chunk_num < start_chunk:
            continue
        if chunk_num > end_chunk:
            break
        features, output = get_encodes(chunk)
        print('Wrote data_{:03d}.p'.format(chunk_num))
        with open(os.path.join(dest_dir, 'data_{:03d}.p'.format(chunk_num)), 'wb') as f:
            pickle.dump((features, output), f)

In [20]:
# !bert-serving-start -model_dir ./uncased_L-12_H-768_A-12 -num_worker=4 -max_seq_len=50

In [None]:
%%time
input_fn_train = cache_data('/Users/jlc/Google Drive/_code/MIDS_W266/BERT_Imdb/aclImdb/train', '/Users/jlc/Google Drive/_code/MIDS_W266/BERT_Imdb/uncased_L-12_H-768_A-12/cache/train_last50words_tokens50', 0, 14)
input_fn_eval = cache_data('/Users/jlc/Google Drive/_code/MIDS_W266/BERT_Imdb/aclImdb/test', '/Users/jlc/Google Drive/_code/MIDS_W266/BERT_Imdb/uncased_L-12_H-768_A-12/cache/test_last50words_tokens50', 0, 14)


Chunk 55
Wrote data_000.p
Chunk 56
Wrote data_001.p
Chunk 57
Wrote data_002.p
Chunk 58
Wrote data_003.p
Chunk 59
Wrote data_004.p
Chunk 60
Wrote data_005.p
Chunk 61
Wrote data_006.p
Chunk 62
Wrote data_007.p
Chunk 63
Wrote data_008.p
Chunk 64
Wrote data_009.p
Chunk 65
Wrote data_010.p
Chunk 66
Wrote data_011.p
Chunk 67
Wrote data_012.p
Chunk 68
Wrote data_000.p
Chunk 69
Wrote data_001.p
Chunk 70
Wrote data_002.p
Chunk 71
Wrote data_003.p
Chunk 72
Wrote data_004.p
Chunk 73
Wrote data_005.p
Chunk 74
Wrote data_006.p
Chunk 75
Wrote data_007.p
Chunk 76
Wrote data_008.p
Chunk 77
Wrote data_009.p
Chunk 78
