diff --git a/clip/model.py b/clip/model.py index e743d2c78..3121dd75d 100644 --- a/clip/model.py +++ b/clip/model.py @@ -16,16 +16,18 @@ def __init__(self, inplanes, planes, stride=1): # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) - self.relu = nn.ReLU(inplace=True) self.downsample = None self.stride = stride @@ -40,8 +42,8 @@ def __init__(self, inplanes, planes, stride=1): def forward(self, x: torch.Tensor): identity = x - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) @@ -49,7 +51,7 @@ def forward(self, x: torch.Tensor): identity = self.downsample(x) out += identity - out = self.relu(out) + out = self.relu3(out) return out @@ -106,12 +108,14 @@ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): # the 3-layer stem self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(2) - self.relu = nn.ReLU(inplace=True) # residual layers self._inplanes = width # this is a *mutable* variable used during construction @@ -134,8 +138,9 @@ def _make_layer(self, planes, blocks, stride=1): def forward(self, x): def stem(x): - for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: - x = self.relu(bn(conv(x))) + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) x = self.avgpool(x) return x