From 8063cbd80acf103e1b66b7237996326df9cd987e Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 8 Apr 2022 19:51:47 -0400 Subject: [PATCH 1/3] Revert ResNet definition to not quantize input to add op in residual branches. --- .../pytorch/models/classification/resnet.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index cd8b979c3ad..e4c79278afd 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -151,7 +151,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 1} else: self.functional = ReLU(num_channels=num_channels, inplace=True) @@ -185,7 +185,11 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + # self.add_relu = _AddReLU(out_channels) + if FloatFunctional: + self.self.add_relu = FloatFunctional() + else: + self.self.add_relu = ReLU(num_channels=out_channels, inplace=True) self.initialize() @@ -198,9 +202,13 @@ def forward(self, inp: Tensor): out = self.bn2(out) identity_val = self.identity(inp) if self.identity is not None else inp - out = self.add_relu(identity_val, out) + # out = self.add_relu(identity_val, out) + # return out - return out + if isinstance(self.add_relu, FloatFunctional): + return self.add_relu.add_relu(identity_val, out) + else: + return self.add_relu(identity_val + out) def initialize(self): _init_conv(self.conv1) @@ -242,7 +250,11 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + # self.add_relu = _AddReLU(out_channels) + if FloatFunctional: + self.self.add_relu = FloatFunctional() + else: + self.self.add_relu = ReLU(num_channels=out_channels, inplace=True) self.initialize() @@ -260,9 +272,13 @@ def forward(self, inp: Tensor): identity_val = self.identity(inp) if self.identity is not None else inp - out = self.add_relu(identity_val, out) + # out = self.add_relu(identity_val, out) + # return out - return out + if isinstance(self.add_relu, FloatFunctional): + return self.add_relu.add_relu(identity_val, out) + else: + return self.add_relu(identity_val + out) def initialize(self): _init_conv(self.conv1) From 8c533eda55f83a19fbda1c67a1268bae3d7137fb Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 8 Apr 2022 19:56:13 -0400 Subject: [PATCH 2/3] Correct typo. --- src/sparseml/pytorch/models/classification/resnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index e4c79278afd..34328c1b8fc 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -187,9 +187,9 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): # self.add_relu = _AddReLU(out_channels) if FloatFunctional: - self.self.add_relu = FloatFunctional() + self.add_relu = FloatFunctional() else: - self.self.add_relu = ReLU(num_channels=out_channels, inplace=True) + self.add_relu = ReLU(num_channels=out_channels, inplace=True) self.initialize() @@ -252,9 +252,9 @@ def __init__( # self.add_relu = _AddReLU(out_channels) if FloatFunctional: - self.self.add_relu = FloatFunctional() + self.add_relu = FloatFunctional() else: - self.self.add_relu = ReLU(num_channels=out_channels, inplace=True) + self.add_relu = ReLU(num_channels=out_channels, inplace=True) self.initialize() From 71d26703c4e31e2fe77fc5b502cbe3bdb64f27ae Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Tue, 12 Apr 2022 09:22:08 -0400 Subject: [PATCH 3/3] Correct number of quantized outputs for future changes. --- src/sparseml/pytorch/models/classification/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 34328c1b8fc..b25894265c2 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -151,7 +151,7 @@ def __init__(self, num_channels): if FloatFunctional: self.functional = FloatFunctional() self.wrap_qat = True - self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 1} + self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0} else: self.functional = ReLU(num_channels=num_channels, inplace=True)