Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

finetune BERT with custom dataset #20

Closed
BerenLuthien opened this issue Feb 27, 2019 · 2 comments
Closed

finetune BERT with custom dataset #20

BerenLuthien opened this issue Feb 27, 2019 · 2 comments
Assignees
Labels
enhancement New feature or request wontfix This will not be worked on

Comments

@BerenLuthien
Copy link

BerenLuthien commented Feb 27, 2019

Is your feature request related to a problem? Please describe.
Wish to finetune BERT (MLM, PairSentence) with customer dataset, e.g. text exacted from a book.

Describe the solution you'd like

Describe alternatives you've considered
Which function wherein we can feed a customer dataset, for example, a text file from a book ?
Do we need write a function to format the text file so that it can be taken by BERT ?

Additional context

@BerenLuthien BerenLuthien added the enhancement New feature or request label Feb 27, 2019
@CyberZHG
Copy link
Owner

There is a test case that trains the model:

def test_fit(self):
current_path = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_path, 'test_bert_fit.h5')
sentence_pairs = [
[['all', 'work', 'and', 'no', 'play'], ['makes', 'jack', 'a', 'dull', 'boy']],
[['from', 'the', 'day', 'forth'], ['my', 'arm', 'changed']],
[['and', 'a', 'voice', 'echoed'], ['power', 'give', 'me', 'more', 'power']],
]
token_dict = get_base_dict()
for pairs in sentence_pairs:
for token in pairs[0] + pairs[1]:
if token not in token_dict:
token_dict[token] = len(token_dict)
token_list = list(token_dict.keys())
if os.path.exists(model_path):
model = keras.models.load_model(
model_path,
custom_objects=get_custom_objects(),
)
else:
model = get_model(
token_num=len(token_dict),
head_num=5,
transformer_num=12,
embed_dim=25,
feed_forward_dim=100,
seq_len=20,
pos_num=20,
dropout_rate=0.05,
attention_activation=gelu,
lr=1e-3,
)
model.summary()
def _generator():
while True:
yield gen_batch_inputs(
sentence_pairs,
token_dict,
token_list,
seq_len=20,
mask_rate=0.3,
swap_sentence_rate=1.0,
)
model.fit_generator(
generator=_generator(),
steps_per_epoch=1000,
epochs=1,
validation_data=_generator(),
validation_steps=100,
callbacks=[
keras.callbacks.ReduceLROnPlateau(monitor='val_MLM_loss', factor=0.5, patience=3),
keras.callbacks.EarlyStopping(monitor='val_MLM_loss', patience=5)
],
)
# model.save(model_path)
for inputs, outputs in _generator():
predicts = model.predict(inputs)
outputs = list(map(lambda x: np.squeeze(x, axis=-1), outputs))
predicts = list(map(lambda x: np.argmax(x, axis=-1), predicts))
batch_size, seq_len = inputs[-1].shape
for i in range(batch_size):
for j in range(seq_len):
if inputs[-1][i][j]:
self.assertEqual(outputs[0][i][j], predicts[0][i][j])
self.assertTrue(np.allclose(outputs[1], predicts[1]))
break

However, I recommend training with the official implementation then load the checkpoint (since the optimizer and the creation of sentence pairs are different).

@stale
Copy link

stale bot commented Mar 5, 2019

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants