In [1]:
from bigbird import modeling
from bigbird import utils
import tensorflow as tf
import numpy as np

In [2]:
bert_config = {
  # transformer basic configs
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": 'gelu',
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 4096,
  "max_encoder_length": 1024,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "scope": 'bert',
  "use_bias": True,
  "rescale_embedding": False,
  "vocab_model_file": None,
  # sparse mask configs
  "attention_type": "block_sparse",
  "norm_type": 'postnorm',
  "block_size": 16,
  "num_rand_blocks": 3,
  "vocab_size": 32000,
}

In [3]:
model = modeling.BertModel(bert_config)

In [4]:
X = tf.placeholder(tf.int32, [None, None])

In [5]:
sequence_output, pooled_output = model(X)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [6]:
sequence_output, pooled_output

(<tf.Tensor 'bert/encoder/layer_11/output/LayerNorm/batchnorm/add_1:0' shape=(?, 1024, 768) dtype=float32>,
 <tf.Tensor 'bert/pooler/dense/Tanh:0' shape=(?, 768) dtype=float32>)

In [7]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [8]:
import collections
import re

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
    """Compute the union of the current variables and checkpoint variables."""
    assignment_map = {}
    initialized_variable_names = {}

    name_to_variable = collections.OrderedDict()
    for var in tvars:
        name = var.name
        m = re.match('^(.*):\\d+$', name)
        if m is not None:
            name = m.group(1)
        name_to_variable[name] = var

    init_vars = tf.train.list_variables(init_checkpoint)

    assignment_map = collections.OrderedDict()
    for x in init_vars:
        (name, var) = (x[0], x[1])
        name_r = name.replace('bert/embeddings/LayerNorm', 'bert/encoder/LayerNorm')
        if name_r not in name_to_variable:
            continue
        if 'embeddings/position_embeddings' in name_r:
            continue
        assignment_map[name] = name_to_variable[name_r]
        initialized_variable_names[name_r] = 1
        initialized_variable_names[name_r + ':0'] = 1

    return (assignment_map, initialized_variable_names)

In [12]:
# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/bert-base-2020-10-08.tar.gz
# !tar -zxf bert-base-2020-10-08.tar.gz

In [10]:
tvars = tf.trainable_variables()
checkpoint = 'bert-base/model.ckpt-1000000'
assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, 
                                                                                checkpoint)

In [11]:
saver = tf.train.Saver(var_list = assignment_map)
saver.restore(sess, checkpoint)

INFO:tensorflow:Restoring parameters from bert-base/model.ckpt-1000000


INFO:tensorflow:Restoring parameters from bert-base/model.ckpt-1000000


In [14]:
%%time

o = sess.run([sequence_output, pooled_output], feed_dict = {X: [[1] * 1024]})

CPU times: user 6.43 s, sys: 1.16 s, total: 7.6 s
Wall time: 1.28 s
