Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
leaderj1001 committed May 24, 2019
1 parent f1e4c8b commit 2500981
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions model.py
Expand Up @@ -77,12 +77,11 @@ def forward(self, x):


class MobileBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernal_size, stride, nonLinear, SE, exp_size, dropout_rate=1.0):
def __init__(self, in_channels, out_channels, kernal_size, stride, nonLinear, SE, exp_size):
super(MobileBlock, self).__init__()
self.out_channels = out_channels
self.nonLinear = nonLinear
self.SE = SE
self.dropout_rate = dropout_rate
padding = (kernal_size - 1) // 2

self.use_connect = stride == 1 and in_channels == out_channels
Expand Down Expand Up @@ -131,7 +130,7 @@ def forward(self, x):


class MobileNetV3(nn.Module):
def __init__(self, model_mode="LARGE", num_classes=1000, multiplier=1.0):
def __init__(self, model_mode="LARGE", num_classes=1000, multiplier=1.0, dropout_rate=0.0):
super(MobileNetV3, self).__init__()
self.num_classes = num_classes

Expand Down Expand Up @@ -183,6 +182,7 @@ def __init__(self, model_mode="LARGE", num_classes=1000, multiplier=1.0):
self.out_conv2 = nn.Sequential(
nn.Conv2d(out_conv2_in, out_conv2_out, kernel_size=1, stride=1),
h_swish(inplace=True),
nn.Dropout(dropout_rate),
nn.Conv2d(out_conv2_out, self.num_classes, kernel_size=1, stride=1),
)

Expand Down Expand Up @@ -230,6 +230,7 @@ def __init__(self, model_mode="LARGE", num_classes=1000, multiplier=1.0):
self.out_conv2 = nn.Sequential(
nn.Conv2d(out_conv2_in, out_conv2_out, kernel_size=1, stride=1),
h_swish(inplace=True),
nn.Dropout(dropout_rate),
nn.Conv2d(out_conv2_out, self.num_classes, kernel_size=1, stride=1),
)

Expand Down

0 comments on commit 2500981

Please sign in to comment.