Skip to content

Commit

Permalink
补充attention_dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone committed Sep 13, 2021
1 parent 4de1dba commit 48626f0
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions bert4keras/models.py
Expand Up @@ -24,6 +24,7 @@ def __init__(
intermediate_size, # FeedForward的隐层维度
hidden_act, # FeedForward隐层的激活函数
dropout_rate=None, # Dropout比例
attention_dropout_rate=None, # Attention矩阵的Dropout比例
embedding_size=None, # 是否指定embedding_size
attention_head_size=None, # Attention中V的head_size
attention_key_size=None, # Attention中Q,K的head_size
Expand All @@ -50,6 +51,7 @@ def __init__(
self.attention_key_size = attention_key_size or self.attention_head_size
self.intermediate_size = intermediate_size
self.dropout_rate = dropout_rate or 0
self.attention_dropout_rate = attention_dropout_rate or 0
self.hidden_act = hidden_act
self.embedding_size = embedding_size or hidden_size
self.sequence_length = sequence_length
Expand Down Expand Up @@ -559,6 +561,7 @@ def apply_main_layers(self, inputs, index):
head_size=self.attention_head_size,
out_dim=self.hidden_size,
key_size=self.attention_key_size,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=attention_name
)
Expand Down Expand Up @@ -809,6 +812,7 @@ def apply_main_layers(self, inputs, index):
head_size=self.attention_head_size,
out_dim=self.hidden_size,
key_size=self.attention_key_size,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=attention_name
)
Expand Down Expand Up @@ -1029,6 +1033,7 @@ def apply_main_layers(self, inputs, index):
head_size=self.attention_head_size,
out_dim=self.hidden_size,
key_size=self.attention_key_size,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=attention_name
)
Expand Down Expand Up @@ -1132,6 +1137,7 @@ def apply_main_layers(self, inputs, index):
head_size=self.attention_head_size,
out_dim=self.hidden_size,
key_size=self.attention_key_size,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=attention_name
)
Expand Down Expand Up @@ -1469,6 +1475,7 @@ def apply_main_layers(self, inputs, index):
head_size=self.attention_head_size,
out_dim=self.hidden_size,
key_size=self.attention_key_size,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=attention_name
)
Expand Down Expand Up @@ -1637,6 +1644,7 @@ def apply_main_layers(self, inputs, index):
head_size=self.attention_head_size,
out_dim=self.hidden_size,
key_size=self.attention_key_size,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=attention_name
)
Expand Down Expand Up @@ -1946,6 +1954,7 @@ def apply_main_layers(self, inputs, index):
key_size=self.attention_key_size,
use_bias=False,
attention_scale=False,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=attention_name
)
Expand Down Expand Up @@ -2141,6 +2150,7 @@ def apply_main_layers(self, inputs, index):
key_size=self.attention_key_size,
use_bias=False,
attention_scale=False,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=self_attention_name
)
Expand Down Expand Up @@ -2180,6 +2190,7 @@ def apply_main_layers(self, inputs, index):
key_size=self.attention_key_size,
use_bias=False,
attention_scale=False,
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=cross_attention_name
)
Expand Down Expand Up @@ -2406,6 +2417,10 @@ def build_transformer_model(
configs['max_position'] = configs.get('max_position_embeddings', 512)
if 'dropout_rate' not in configs:
configs['dropout_rate'] = configs.get('hidden_dropout_prob')
if 'attention_dropout_rate' not in configs:
configs['attention_dropout_rate'] = configs.get(
'attention_probs_dropout_prob'
)
if 'segment_vocab_size' not in configs:
configs['segment_vocab_size'] = configs.get('type_vocab_size', 2)

Expand Down

0 comments on commit 48626f0

Please sign in to comment.