diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index cd8b979c3ad..b25894265c2 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -185,7 +185,11 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1): else None ) - self.add_relu = _AddReLU(out_channels) + # self.add_relu = _AddReLU(out_channels) + if FloatFunctional: + self.add_relu = FloatFunctional() + else: + self.add_relu = ReLU(num_channels=out_channels, inplace=True) self.initialize() @@ -198,9 +202,13 @@ def forward(self, inp: Tensor): out = self.bn2(out) identity_val = self.identity(inp) if self.identity is not None else inp - out = self.add_relu(identity_val, out) + # out = self.add_relu(identity_val, out) + # return out - return out + if isinstance(self.add_relu, FloatFunctional): + return self.add_relu.add_relu(identity_val, out) + else: + return self.add_relu(identity_val + out) def initialize(self): _init_conv(self.conv1) @@ -242,7 +250,11 @@ def __init__( else None ) - self.add_relu = _AddReLU(out_channels) + # self.add_relu = _AddReLU(out_channels) + if FloatFunctional: + self.add_relu = FloatFunctional() + else: + self.add_relu = ReLU(num_channels=out_channels, inplace=True) self.initialize() @@ -260,9 +272,13 @@ def forward(self, inp: Tensor): identity_val = self.identity(inp) if self.identity is not None else inp - out = self.add_relu(identity_val, out) + # out = self.add_relu(identity_val, out) + # return out - return out + if isinstance(self.add_relu, FloatFunctional): + return self.add_relu.add_relu(identity_val, out) + else: + return self.add_relu(identity_val + out) def initialize(self): _init_conv(self.conv1)