In [1]:
# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
# !unzip uncased_L-12_H-768_A-12.zip

In [2]:
import pickle
import json

In [3]:
with open('data.json') as fopen:
    output = json.load(fopen)

In [4]:
import bert
from bert import run_classifier
from bert import optimization
from bert import tokenization
from bert import modeling
from tqdm import tqdm

BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'
BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'
BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'

tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)
tokenizer = tokenization.FullTokenizer(
      vocab_file=BERT_VOCAB, do_lower_case=True)

MAX_SEQ_LENGTH = 100

for i in tqdm(range(len(output))):
    tokens_a = tokenizer.tokenize(output[i][0])
    if len(tokens_a) > MAX_SEQ_LENGTH - 1:
        tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 1)]
    tokens_a.append('|')
    output[i][0] = tokens_a

100%|██████████| 303625/303625 [00:55<00:00, 5462.11it/s]


In [5]:
phr2sg_id = pickle.load(open('phr2sg_id.pkl', 'rb'))
sg_id2phr = pickle.load(open('sg_id2phr.pkl', 'rb'))

In [6]:
from collections import Counter

n_classes = 100

sg_ids = []
for line in output:
    sg_id = line[1]
    if sg_id != 0:
        sg_ids.append(sg_id)
sg_id2cnt = Counter(sg_ids)
sg_ids = [sg_id for sg_id, cnt in sg_id2cnt.most_common(n_classes)]
idx2sg_id = {idx: sg_id for idx, sg_id in enumerate(sg_ids)}
sg_id2idx = {sg_id: idx for idx, sg_id in enumerate(sg_ids)}

In [7]:
phr2idx = dict()
for phr, sg_id in phr2sg_id.items():
    if sg_id in sg_id2idx:
        phr2idx[phr] = sg_id2idx[sg_id]
        
idx2phr = dict()
for idx, sg_id in idx2sg_id.items():
    if sg_id in sg_id2phr:
        idx2phr[idx] = sg_id2phr[sg_id]

In [8]:
contexts_li = [[] for _ in range(n_classes)]
for i, cols in tqdm(enumerate(output), total=len(output)):
    if i==0: continue
    sent, sg_id = cols
    if sg_id in sg_id2idx:
        idx = sg_id2idx[sg_id]
        ctx = []
        for l in output[:i][-10:]:
            ctx.append(l[0])
        contexts = contexts_li[idx]
        contexts.append(ctx)

100%|██████████| 303625/303625 [00:54<00:00, 5527.43it/s] 


In [9]:
train, dev = [], []
for contexts in contexts_li:
    if len(contexts) > 1:
        train.append(contexts[1:])
        dev.append(contexts[:1])
    else:
        train.append(contexts)
        dev.append([])

In [10]:
from itertools import chain
from keras.preprocessing.sequence import pad_sequences

def to_xy(data, max_span = 128):
    x, y = [], []
    for i in tqdm(range(n_classes)):
        contexts = data[i]
        for c in contexts:
            flatted = list(chain.from_iterable(c))
            for k in range(0, len(flatted), max_span):
                index = min(k + max_span, len(flatted))
                batch_x = flatted[k: index]
                batch_x = ["[CLS]"] + batch_x + ["[SEP]"]
                batch_x = tokenizer.convert_tokens_to_ids(batch_x)
                x.append(batch_x)
                y.append(i)
    return pad_sequences(x, padding='post'), y

Using TensorFlow backend.


In [11]:
train_X, train_Y = to_xy(train)

100%|██████████| 100/100 [00:01<00:00, 73.32it/s]


In [12]:
test_X, test_Y = to_xy(dev)

100%|██████████| 100/100 [00:00<00:00, 25063.07it/s]


In [13]:
train_X.shape, test_X.shape

((63022, 130), (165, 130))

In [14]:
epoch = 5
batch_size = 50
warmup_proportion = 0.1
num_train_steps = int(train_X.shape[0] / batch_size * epoch)
num_warmup_steps = int(num_train_steps * warmup_proportion)

In [15]:
class Model:
    def __init__(
        self,
        dimension_output,
        learning_rate = 2e-5,
    ):
        self.X = tf.placeholder(tf.int32, [None, None])
        self.Y = tf.placeholder(tf.int32, [None])
        
        model = modeling.BertModel(
            config=bert_config,
            is_training=True,
            input_ids=self.X,
            use_one_hot_embeddings=False)
        
        output_layer = model.get_pooled_output()
        self.logits = tf.layers.dense(output_layer, dimension_output)
        
        self.cost = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits = self.logits, labels = self.Y
            )
        )
        
        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, 
                                                       num_train_steps, num_warmup_steps, False)
        correct_pred = tf.equal(
            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y
        )
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [16]:
import tensorflow as tf
import time
from sklearn.utils import shuffle
import numpy as np

In [17]:
train_X, train_Y = shuffle(train_X, train_Y)
test_X, test_Y = shuffle(test_X, test_Y)

In [18]:
bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)

In [19]:
learning_rate = 2e-5

tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Model(
    n_classes,
    learning_rate
)

sess.run(tf.global_variables_initializer())
var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')
saver = tf.train.Saver(var_list = var_lists)
saver.restore(sess, BERT_INIT_CHKPNT)

Instructions for updating:
Colocations handled automatically by placer.

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from uncased_L-12_H-768_A-12/bert_model.ckpt


In [20]:
EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 3, 0, 0, 0

while True:
    lasttime = time.time()
    if CURRENT_CHECKPOINT == EARLY_STOPPING:
        print('break epoch:%d\n' % (EPOCH))
        break

    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0
    pbar = tqdm(
        range(0, len(train_X), batch_size), desc = 'train minibatch loop'
    )
    for i in pbar:
        index = min(i + batch_size, len(train_X))
        batch_x = train_X[i: index]
        batch_y = train_Y[i: index]
        acc, cost, _ = sess.run(
            [model.accuracy, model.cost, model.optimizer],
            feed_dict = {
                model.Y: batch_y,
                model.X: batch_x,
            },
        )
        assert not np.isnan(cost)
        train_loss += cost
        train_acc += acc
        pbar.set_postfix(cost = cost, accuracy = acc)
        
    pbar = tqdm(range(0, len(test_X), batch_size), desc = 'test minibatch loop')
    for i in pbar:
        index = min(i + batch_size, len(test_X))
        batch_x = test_X[i: index]
        batch_y = test_Y[i: index]
        acc, cost = sess.run(
            [model.accuracy, model.cost],
            feed_dict = {
                model.Y: batch_y,
                model.X: batch_x,
            },
        )
        test_loss += cost
        test_acc += acc
        pbar.set_postfix(cost = cost, accuracy = acc)

    train_loss /= len(train_X) / batch_size
    train_acc /= len(train_X) / batch_size
    test_loss /= len(test_X) / batch_size
    test_acc /= len(test_X) / batch_size

    if test_acc > CURRENT_ACC:
        print(
            'epoch: %d, pass acc: %f, current acc: %f'
            % (EPOCH, CURRENT_ACC, test_acc)
        )
        CURRENT_ACC = test_acc
        CURRENT_CHECKPOINT = 0
    else:
        CURRENT_CHECKPOINT += 1
        
    print('time taken:', time.time() - lasttime)
    print(
        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\n'
        % (EPOCH, train_loss, train_acc, test_loss, test_acc)
    )
    EPOCH += 1

train minibatch loop: 100%|██████████| 1261/1261 [09:27<00:00,  2.59it/s, accuracy=0.136, cost=4.09]
test minibatch loop: 100%|██████████| 4/4 [00:00<00:00,  4.72it/s, accuracy=0, cost=4.82]   
train minibatch loop:   0%|          | 0/1261 [00:00<?, ?it/s]

epoch: 0, pass acc: 0.000000, current acc: 0.012121
time taken: 568.682032585144
epoch: 0, training loss: 3.864423, training acc: 0.145216, valid loss: 6.014607, valid acc: 0.012121



train minibatch loop: 100%|██████████| 1261/1261 [09:28<00:00,  2.59it/s, accuracy=0.182, cost=3.94]
test minibatch loop: 100%|██████████| 4/4 [00:00<00:00,  7.73it/s, accuracy=0, cost=4.81]   
train minibatch loop:   0%|          | 0/1261 [00:00<?, ?it/s]

epoch: 1, pass acc: 0.012121, current acc: 0.024242
time taken: 569.0105755329132
epoch: 1, training loss: 3.600702, training acc: 0.192252, valid loss: 5.924614, valid acc: 0.024242



train minibatch loop: 100%|██████████| 1261/1261 [09:28<00:00,  2.60it/s, accuracy=0.227, cost=3.65]
test minibatch loop: 100%|██████████| 4/4 [00:00<00:00,  7.78it/s, accuracy=0, cost=4.77]   
train minibatch loop:   0%|          | 0/1261 [00:00<?, ?it/s]

time taken: 569.0057277679443
epoch: 2, training loss: 3.432430, training acc: 0.222849, valid loss: 5.912193, valid acc: 0.024242



train minibatch loop: 100%|██████████| 1261/1261 [09:28<00:00,  2.59it/s, accuracy=0.227, cost=3.45]
test minibatch loop: 100%|██████████| 4/4 [00:00<00:00,  7.75it/s, accuracy=0, cost=4.91]  
train minibatch loop:   0%|          | 0/1261 [00:00<?, ?it/s]

time taken: 568.930771112442
epoch: 3, training loss: 3.292074, training acc: 0.248348, valid loss: 6.023255, valid acc: 0.012121



train minibatch loop: 100%|██████████| 1261/1261 [09:28<00:00,  2.59it/s, accuracy=0.318, cost=3.29]
test minibatch loop: 100%|██████████| 4/4 [00:00<00:00,  7.74it/s, accuracy=0, cost=4.86]   

time taken: 568.9259750843048
epoch: 4, training loss: 3.198256, training acc: 0.267873, valid loss: 5.987054, valid acc: 0.012121

break epoch:5




