Skip to content

Commit

Permalink
minor - adapter_freeze test case added
Browse files Browse the repository at this point in the history
  • Loading branch information
kpe committed Aug 11, 2019
1 parent 90f04ec commit c1afb9f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
66 changes: 65 additions & 1 deletion tests/test_adapter_freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@

import unittest

import os
import tempfile

import numpy as np
import tensorflow as tf
from tensorflow import keras

import bert

from .test_common import AbstractBertTest, MiniBertFactory

tf.enable_eager_execution()

class AdapterFreezeTest(unittest.TestCase):
class AdapterFreezeTest(AbstractBertTest):

def test_adapter_freezing(self):
bert_params = bert.BertModelLayer.Params(hidden_size=32,
Expand Down Expand Up @@ -53,5 +61,61 @@ def to_model(bert_params):
print(weight.name, weight.shape)


def test_freeze(self):
model_dir = tempfile.TemporaryDirectory().name
os.makedirs(model_dir)
save_path = MiniBertFactory.create_mini_bert_weights(model_dir)
tokenizer = bert.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True)

# prepare input
max_seq_len = 24
input_str_batch = ["hello, bert!", "how are you doing!"]

input_ids, token_type_ids = self.prepare_input_batch(input_str_batch, tokenizer, max_seq_len)

bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt")

bert_params = bert.params_from_pretrained_ckpt(model_dir)
bert_params.adapter_size = 4
l_bert = bert.BertModelLayer.from_params(bert_params)

model = keras.models.Sequential([
l_bert,
])


model.build(input_shape=(None, max_seq_len))

model.summary()
l_bert.apply_adapter_freeze()
model.summary()

bert.load_stock_weights(l_bert, bert_ckpt_file)
#l_bert.embeddings_layer.trainable = False

model.summary()

orig_weight_values = []
for weight in l_bert.weights:
orig_weight_values.append(weight.numpy())

model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.mean_squared_error)

orig_pred = model.predict(input_ids)
model.fit(x=input_ids, y=np.zeros_like(orig_pred),
batch_size=2,
epochs=4)

for ndx, weight in enumerate(l_bert.weights):
print("{}: {}".format(
np.array_equal(weight.numpy(), orig_weight_values[ndx]),
weight.name))

model.summary()






21 changes: 21 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,24 @@ def create_mini_bert_weights():
print("mini_bert save_path", save_path)
print("\n\t".join([""] + os.listdir(model_dir)))
return model_dir

def prepare_input_batch(self, input_str_batch, tokenizer, max_seq_len):
input_ids_batch = []
token_type_ids_batch = []
for input_str in input_str_batch:
input_tokens = tokenizer.tokenize(input_str)
input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"]

print("input_tokens len:", len(input_tokens))

input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
input_ids = input_ids + [0]*(max_seq_len - len(input_tokens))
token_type_ids = [0]*len(input_tokens) + [0]*(max_seq_len - len(input_tokens))

input_ids_batch.append(input_ids)
token_type_ids_batch.append(token_type_ids)

input_ids = np.array(input_ids_batch, dtype=np.int32)
token_type_ids = np.array(token_type_ids_batch, dtype=np.int32)

return input_ids, token_type_ids

0 comments on commit c1afb9f

Please sign in to comment.