Skip to content

Commit

Permalink
feat: add dropout in MASTER (#349)
Browse files Browse the repository at this point in the history
* feat: add dropout

* fix: cfg

* fix: dropout torch
  • Loading branch information
charlesmindee committed Jul 6, 2021
1 parent 6a558df commit 9f7edcd
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
2 changes: 2 additions & 0 deletions doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
num_heads: int = 8, # number of heads in the transformer decoder
num_layers: int = 3,
max_length: int = 50,
dropout: float = 0.2,
input_shape: Tuple[int, int, int] = (3, 48, 160),
cfg: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -192,6 +193,7 @@ def __init__(
dff=dff,
vocab_size=self.vocab_size,
maximum_position_encoding=max_length,
dropout=dropout,
)
self.feature_pe = positional_encoding(input_shape[1] * input_shape[2], d_model)
self.linear = nn.Linear(d_model, self.vocab_size + 3)
Expand Down
6 changes: 3 additions & 3 deletions doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
'master': {
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
'input_shape': (48, 160, 3),
'input_shape': (32, 128, 3),
'vocab': VOCABS['french'],
'url': None,
},
Expand Down Expand Up @@ -84,7 +84,6 @@ def __init__(
name='transform'
)

@tf.function
def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor:
b, h, w, c = (tf.shape(inputs)[i] for i in range(4))

Expand Down Expand Up @@ -196,6 +195,7 @@ def __init__(
num_heads: int = 8, # number of heads in the transformer decoder
num_layers: int = 3,
max_length: int = 50,
dropout: float = 0.2,
input_shape: Tuple[int, int, int] = (48, 160, 3),
cfg: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -216,13 +216,13 @@ def __init__(
dff=dff,
vocab_size=self.vocab_size,
maximum_position_encoding=max_length,
dropout=dropout,
)
self.feature_pe = positional_encoding(input_shape[0] * input_shape[1], d_model)
self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())

self.postprocessor = MASTERPostProcessor(vocab=self.vocab)

@tf.function
def make_mask(self, target: tf.Tensor) -> tf.Tensor:
look_ahead_mask = create_look_ahead_mask(tf.shape(target)[1])
target_padding_mask = create_padding_mask(target, self.vocab_size + 2) # Pad symbol
Expand Down
6 changes: 5 additions & 1 deletion doctr/models/recognition/transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
dff: int = 2048,
vocab_size: int = 120,
maximum_position_encoding: int = 50,
dropout: float = 0.2,
) -> None:
super(Decoder, self).__init__()

Expand All @@ -54,11 +55,13 @@ def __init__(
d_model=d_model,
nhead=num_heads,
dim_feedforward=dff,
dropout=0.1,
dropout=dropout,
activation='relu',
) for _ in range(num_layers)
]

self.dropout = nn.Dropout(dropout)

def forward(
self,
x: torch.Tensor,
Expand All @@ -72,6 +75,7 @@ def forward(
x = self.embedding(x) # (batch_size, target_seq_len, d_model)
x *= math.sqrt(self.d_model)
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x)

# Batch first = False in decoder
x = x.permute(1, 0, 2)
Expand Down
22 changes: 20 additions & 2 deletions doctr/models/recognition/transformer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ def point_wise_feed_forward_network(d_model: int = 512, dff: int = 2048) -> tf.k

class DecoderLayer(tf.keras.layers.Layer):

def __init__(self, d_model: int = 512, num_heads: int = 8, dff: int = 2048) -> None:
def __init__(
self,
d_model: int = 512,
num_heads: int = 8,
dff: int = 2048,
dropout: float = 0.2,
) -> None:
super(DecoderLayer, self).__init__()

self.mha1 = MultiHeadAttention(d_model, num_heads)
Expand All @@ -180,6 +186,10 @@ def __init__(self, d_model: int = 512, num_heads: int = 8, dff: int = 2048) -> N
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

self.dropout1 = tf.keras.layers.Dropout(dropout)
self.dropout2 = tf.keras.layers.Dropout(dropout)
self.dropout3 = tf.keras.layers.Dropout(dropout)

def call(
self,
x: tf.Tensor,
Expand All @@ -191,12 +201,15 @@ def call(
# enc_output.shape == (batch_size, input_seq_len, d_model)

attn1 = self.mha1(x, x, x, look_ahead_mask, **kwargs) # (batch_size, target_seq_len, d_model)
attn1 = self.dropout1(attn1, **kwargs)
out1 = self.layernorm1(attn1 + x, **kwargs)

attn2 = self.mha2(enc_output, enc_output, out1, padding_mask, **kwargs) # (batch_size, target_seq_len, d_model)
attn2 = self.dropout2(attn2, **kwargs)
out2 = self.layernorm2(attn2 + out1, **kwargs) # (batch_size, target_seq_len, d_model)

ffn_output = self.ffn(out2, **kwargs) # (batch_size, target_seq_len, d_model)
ffn_output = self.dropout3(ffn_output, **kwargs)
out3 = self.layernorm3(ffn_output + out2, **kwargs) # (batch_size, target_seq_len, d_model)

return out3
Expand All @@ -212,6 +225,7 @@ def __init__(
dff: int = 2048,
vocab_size: int = 120,
maximum_position_encoding: int = 50,
dropout: float = 0.2,
) -> None:
super(Decoder, self).__init__()

Expand All @@ -221,9 +235,11 @@ def __init__(
self.embedding = tf.keras.layers.Embedding(vocab_size + 3, d_model) # 3 more classes EOS/SOS/PAD
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

self.dec_layers = [DecoderLayer(d_model, num_heads, dff)
self.dec_layers = [DecoderLayer(d_model, num_heads, dff, dropout)
for _ in range(num_layers)]

self.dropout = tf.keras.layers.Dropout(dropout)

def call(
self,
x: tf.Tensor,
Expand All @@ -239,6 +255,8 @@ def call(
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]

x = self.dropout(x, **kwargs)

for i in range(self.num_layers):
x = self.dec_layers[i](
x, enc_output, look_ahead_mask, padding_mask, **kwargs
Expand Down

0 comments on commit 9f7edcd

Please sign in to comment.