Skip to content

Commit

Permalink
feat: add bottleneck ratio in MASTER (#350)
Browse files Browse the repository at this point in the history
* feat: add ratio bottleneck

* feat: add ratio bottleneck

* fix: tests
  • Loading branch information
charlesmindee committed Jul 5, 2021
1 parent d354011 commit b2ded17
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
16 changes: 9 additions & 7 deletions doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,27 @@ class MAGC(nn.Module):
def __init__(
self,
inplanes: int,
headers: int = 1,
headers: int = 8,
att_scale: bool = False,
ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper
) -> None:
super().__init__()

self.headers = headers # h
self.inplanes = inplanes # C
self.att_scale = att_scale
self.planes = int(inplanes * ratio)

self.single_header_inplanes = int(inplanes / headers) # C / h

self.conv_mask = nn.Conv2d(self.single_header_inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)

self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.inplanes, self.inplanes, kernel_size=1),
nn.LayerNorm([self.inplanes, 1, 1]),
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True),
nn.Conv2d(self.inplanes, self.inplanes, kernel_size=1)
nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -115,7 +117,7 @@ class MAGCResnet(nn.Sequential):

def __init__(
self,
headers: int = 1,
headers: int = 8,
) -> None:
_layers = [
# conv_1x
Expand Down Expand Up @@ -164,9 +166,9 @@ def __init__(
self,
vocab: str,
d_model: int = 512,
headers: int = 1,
headers: int = 8, # number of multi-aspect context
dff: int = 2048,
num_heads: int = 8,
num_heads: int = 8, # number of heads in the transformer decoder
num_layers: int = 3,
max_length: int = 50,
input_shape: Tuple[int, int, int] = (3, 48, 160),
Expand Down
15 changes: 9 additions & 6 deletions doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import math
import tensorflow as tf
from tensorflow.keras import layers, Sequential, Model
from typing import Tuple, List, Dict, Any, Optional
Expand Down Expand Up @@ -45,15 +46,17 @@ class MAGC(layers.Layer):
def __init__(
self,
inplanes: int,
headers: int = 1,
headers: int = 8,
att_scale: bool = False,
ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper
**kwargs
) -> None:
super().__init__(**kwargs)

self.headers = headers # h
self.inplanes = inplanes # C
self.att_scale = att_scale
self.planes = int(inplanes * ratio)

self.single_header_inplanes = int(inplanes / headers) # C / h

Expand All @@ -66,7 +69,7 @@ def __init__(
self.transform = tf.keras.Sequential(
[
tf.keras.layers.Conv2D(
filters=self.inplanes,
filters=self.planes,
kernel_size=1,
kernel_initializer=tf.initializers.he_normal()
),
Expand Down Expand Up @@ -104,7 +107,7 @@ def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor:
context_mask = tf.reshape(context_mask, shape=(b * self.headers, 1, h * w, 1))
# scale variance
if self.att_scale and self.headers > 1:
context_mask = context_mask / tf.sqrt(self.single_header_inplanes)
context_mask = context_mask / math.sqrt(self.single_header_inplanes)
# B*h, 1, H*W, 1
context_mask = tf.keras.activations.softmax(context_mask, axis=2)

Expand Down Expand Up @@ -138,7 +141,7 @@ class MAGCResnet(Sequential):

def __init__(
self,
headers: int = 1,
headers: int = 8,
input_shape: Tuple[int, int, int] = (48, 160, 3),
) -> None:
_layers = [
Expand Down Expand Up @@ -188,9 +191,9 @@ def __init__(
self,
vocab: str,
d_model: int = 512,
headers: int = 1,
headers: int = 8, # number of multi-aspect context
dff: int = 2048,
num_heads: int = 8,
num_heads: int = 8, # number of heads in the transformer decoder
num_layers: int = 3,
max_length: int = 50,
input_shape: Tuple[int, int, int] = (48, 160, 3),
Expand Down

0 comments on commit b2ded17

Please sign in to comment.