Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dropout in MASTER #349

Merged
merged 4 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
num_heads: int = 8,
num_layers: int = 3,
max_length: int = 50,
dropout: float = 0.2,
input_shape: Tuple[int, int, int] = (3, 48, 160),
cfg: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -190,6 +191,7 @@ def __init__(
dff=dff,
vocab_size=self.vocab_size,
maximum_position_encoding=max_length,
dropout=dropout,
)
self.feature_pe = positional_encoding(input_shape[1] * input_shape[2], d_model)
self.linear = nn.Linear(d_model, self.vocab_size + 3)
Expand Down
10 changes: 5 additions & 5 deletions doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
'master': {
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
'input_shape': (48, 160, 3),
'input_shape': (32, 128, 3),
'vocab': VOCABS['french'],
'url': None,
},
Expand Down Expand Up @@ -81,7 +81,6 @@ def __init__(
name='transform'
)

@tf.function
def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor:
b, h, w, c = (tf.shape(inputs)[i] for i in range(4))

Expand Down Expand Up @@ -193,6 +192,7 @@ def __init__(
num_heads: int = 8,
num_layers: int = 3,
max_length: int = 50,
dropout: float = 0.2,
input_shape: Tuple[int, int, int] = (48, 160, 3),
cfg: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -213,13 +213,13 @@ def __init__(
dff=dff,
vocab_size=self.vocab_size,
maximum_position_encoding=max_length,
dropout=dropout,
)
self.feature_pe = positional_encoding(input_shape[0] * input_shape[1], d_model)
self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())

self.postprocessor = MASTERPostProcessor(vocab=self.vocab)

@tf.function
def make_mask(self, target: tf.Tensor) -> tf.Tensor:
look_ahead_mask = create_look_ahead_mask(tf.shape(target)[1])
target_padding_mask = create_padding_mask(target, self.vocab_size + 2) # Pad symbol
Expand Down Expand Up @@ -395,8 +395,8 @@ def _master(arch: str, pretrained: bool, input_shape: Tuple[int, int, int] = Non
# Build the model
model = MASTER(cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]['url'])
# if pretrained:
# load_pretrained_params(model, default_cfgs[arch]['url'])
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved

return model

Expand Down
3 changes: 2 additions & 1 deletion doctr/models/recognition/transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
dff: int = 2048,
vocab_size: int = 120,
maximum_position_encoding: int = 50,
dropout: float = 0.2,
) -> None:
super(Decoder, self).__init__()

Expand All @@ -54,7 +55,7 @@ def __init__(
d_model=d_model,
nhead=num_heads,
dim_feedforward=dff,
dropout=0.1,
dropout=dropout,
activation='relu',
) for _ in range(num_layers)
]
Expand Down
22 changes: 20 additions & 2 deletions doctr/models/recognition/transformer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ def point_wise_feed_forward_network(d_model: int = 512, dff: int = 2048) -> tf.k

class DecoderLayer(tf.keras.layers.Layer):

def __init__(self, d_model: int = 512, num_heads: int = 8, dff: int = 2048) -> None:
def __init__(
self,
d_model: int = 512,
num_heads: int = 8,
dff: int = 2048,
dropout: float = 0.2,
) -> None:
super(DecoderLayer, self).__init__()

self.mha1 = MultiHeadAttention(d_model, num_heads)
Expand All @@ -180,6 +186,10 @@ def __init__(self, d_model: int = 512, num_heads: int = 8, dff: int = 2048) -> N
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

self.dropout1 = tf.keras.layers.Dropout(dropout)
self.dropout2 = tf.keras.layers.Dropout(dropout)
self.dropout3 = tf.keras.layers.Dropout(dropout)

def call(
self,
x: tf.Tensor,
Expand All @@ -191,12 +201,15 @@ def call(
# enc_output.shape == (batch_size, input_seq_len, d_model)

attn1 = self.mha1(x, x, x, look_ahead_mask, **kwargs) # (batch_size, target_seq_len, d_model)
attn1 = self.dropout1(attn1, **kwargs)
out1 = self.layernorm1(attn1 + x, **kwargs)

attn2 = self.mha2(enc_output, enc_output, out1, padding_mask, **kwargs) # (batch_size, target_seq_len, d_model)
attn2 = self.dropout2(attn2, **kwargs)
out2 = self.layernorm2(attn2 + out1, **kwargs) # (batch_size, target_seq_len, d_model)

ffn_output = self.ffn(out2, **kwargs) # (batch_size, target_seq_len, d_model)
ffn_output = self.dropout3(ffn_output, **kwargs)
out3 = self.layernorm3(ffn_output + out2, **kwargs) # (batch_size, target_seq_len, d_model)

return out3
Expand All @@ -212,6 +225,7 @@ def __init__(
dff: int = 2048,
vocab_size: int = 120,
maximum_position_encoding: int = 50,
dropout: float = 0.2,
) -> None:
super(Decoder, self).__init__()

Expand All @@ -221,9 +235,11 @@ def __init__(
self.embedding = tf.keras.layers.Embedding(vocab_size + 3, d_model) # 3 more classes EOS/SOS/PAD
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

self.dec_layers = [DecoderLayer(d_model, num_heads, dff)
self.dec_layers = [DecoderLayer(d_model, num_heads, dff, dropout)
for _ in range(num_layers)]

self.dropout = tf.keras.layers.Dropout(dropout)

def call(
self,
x: tf.Tensor,
Expand All @@ -239,6 +255,8 @@ def call(
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]

x = self.dropout(x, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmh, we already have dropout in the DecoderLayer, are you positive we're supposed to put one here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://www.tensorflow.org/text/tutorials/transformer#decoder Here they put another one here, but we can remove it

Copy link
Collaborator Author

@charlesmindee charlesmindee Jul 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In pytorch it is also used here after the postional_embedding: https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model


for i in range(self.num_layers):
x = self.dec_layers[i](
x, enc_output, look_ahead_mask, padding_mask, **kwargs
Expand Down