Skip to content

Commit

Permalink
quantized mobilenet
Browse files Browse the repository at this point in the history
  • Loading branch information
eladhoffer committed Jul 30, 2018
1 parent 30827ba commit 3fd37e7
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions models/mobilenet_quantized.py
Expand Up @@ -13,6 +13,7 @@
NUM_BITS_GRAD = 8
BIPRECISION = True


def nearby_int(n):
return int(round(n))

Expand All @@ -25,6 +26,8 @@ def init_model(model):
elif isinstance(m, RangeBN):
m.weight.data.fill_(1)
m.bias.data.zero_()
model.fc.weight.data.normal_(0, 0.01)
model.fc.bias.data.zero_()


class DepthwiseSeparableFusedConv2d(nn.Module):
Expand All @@ -34,12 +37,15 @@ def __init__(self, in_channels, out_channels, kernel_size,
super(DepthwiseSeparableFusedConv2d, self).__init__()
self.components = nn.Sequential(
QConv2d(in_channels, in_channels, kernel_size,
stride=stride, padding=padding, groups=in_channels),
RangeBN(in_channels),
stride=stride, padding=padding, groups=in_channels, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION),
RangeBN(in_channels, num_bits=NUM_BITS,
num_bits_grad=NUM_BITS_GRAD),
nn.ReLU(),

QConv2d(in_channels, out_channels, 1, bias=False),
RangeBN(out_channels),
QConv2d(in_channels, out_channels, 1, bias=False, num_bits=NUM_BITS,
num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION),
RangeBN(out_channels, num_bits=NUM_BITS,
num_bits_grad=NUM_BITS_GRAD),
nn.ReLU()
)

Expand All @@ -55,8 +61,10 @@ def __init__(self, width=1., shallow=False, num_classes=1000):
width = width or 1.
layers = [
QConv2d(3, nearby_int(width * 32),
kernel_size=3, stride=2, padding=1, bias=False),
RangeBN(nearby_int(width * 32)),
kernel_size=3, stride=2, padding=1, bias=False, num_bits=NUM_BITS,
num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION),
RangeBN(nearby_int(width * 32), num_bits=NUM_BITS,
num_bits_grad=NUM_BITS_GRAD),
nn.ReLU(inplace=True),

DepthwiseSeparableFusedConv2d(
Expand Down Expand Up @@ -109,7 +117,8 @@ def __init__(self, width=1., shallow=False, num_classes=1000):
]
self.features = nn.Sequential(*layers)
self.avg_pool = nn.AvgPool2d(7)
self.fc = QLinear(nearby_int(width * 1024), num_classes)
self.fc = QLinear(nearby_int(width * 1024), num_classes, num_bits=NUM_BITS,
num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.input_transform = {
Expand All @@ -133,7 +142,6 @@ def __init__(self, width=1., shallow=False, num_classes=1000):
{'epoch': 80, 'lr': 1e-4}
]


@staticmethod
def regularization(model, weight_decay=4e-5):
l2_params = 0
Expand Down

0 comments on commit 3fd37e7

Please sign in to comment.