diff --git a/tests/test_adapter_freeze.py b/tests/test_adapter_freeze.py index 475ed7c..7ca4280 100644 --- a/tests/test_adapter_freeze.py +++ b/tests/test_adapter_freeze.py @@ -8,12 +8,20 @@ import unittest +import os +import tempfile + +import numpy as np +import tensorflow as tf from tensorflow import keras import bert +from .test_common import AbstractBertTest, MiniBertFactory + +tf.enable_eager_execution() -class AdapterFreezeTest(unittest.TestCase): +class AdapterFreezeTest(AbstractBertTest): def test_adapter_freezing(self): bert_params = bert.BertModelLayer.Params(hidden_size=32, @@ -53,5 +61,61 @@ def to_model(bert_params): print(weight.name, weight.shape) + def test_freeze(self): + model_dir = tempfile.TemporaryDirectory().name + os.makedirs(model_dir) + save_path = MiniBertFactory.create_mini_bert_weights(model_dir) + tokenizer = bert.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True) + + # prepare input + max_seq_len = 24 + input_str_batch = ["hello, bert!", "how are you doing!"] + + input_ids, token_type_ids = self.prepare_input_batch(input_str_batch, tokenizer, max_seq_len) + + bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt") + + bert_params = bert.params_from_pretrained_ckpt(model_dir) + bert_params.adapter_size = 4 + l_bert = bert.BertModelLayer.from_params(bert_params) + + model = keras.models.Sequential([ + l_bert, + ]) + + + model.build(input_shape=(None, max_seq_len)) + + model.summary() + l_bert.apply_adapter_freeze() + model.summary() + + bert.load_stock_weights(l_bert, bert_ckpt_file) + #l_bert.embeddings_layer.trainable = False + + model.summary() + + orig_weight_values = [] + for weight in l_bert.weights: + orig_weight_values.append(weight.numpy()) + + model.compile(optimizer=keras.optimizers.Adam(), + loss=keras.losses.mean_squared_error) + + orig_pred = model.predict(input_ids) + model.fit(x=input_ids, y=np.zeros_like(orig_pred), + batch_size=2, + epochs=4) + + for ndx, weight in enumerate(l_bert.weights): + print("{}: {}".format( + np.array_equal(weight.numpy(), orig_weight_values[ndx]), + weight.name)) + + model.summary() + + + + diff --git a/tests/test_common.py b/tests/test_common.py index 98944cd..dab1717 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -97,3 +97,24 @@ def create_mini_bert_weights(): print("mini_bert save_path", save_path) print("\n\t".join([""] + os.listdir(model_dir))) return model_dir + + def prepare_input_batch(self, input_str_batch, tokenizer, max_seq_len): + input_ids_batch = [] + token_type_ids_batch = [] + for input_str in input_str_batch: + input_tokens = tokenizer.tokenize(input_str) + input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"] + + print("input_tokens len:", len(input_tokens)) + + input_ids = tokenizer.convert_tokens_to_ids(input_tokens) + input_ids = input_ids + [0]*(max_seq_len - len(input_tokens)) + token_type_ids = [0]*len(input_tokens) + [0]*(max_seq_len - len(input_tokens)) + + input_ids_batch.append(input_ids) + token_type_ids_batch.append(token_type_ids) + + input_ids = np.array(input_ids_batch, dtype=np.int32) + token_type_ids = np.array(token_type_ids_batch, dtype=np.int32) + + return input_ids, token_type_ids \ No newline at end of file