In [59]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models

# SE Block From Paper

Global Pooling -> FC -> ReLU -> FC -> Sigmoid

### Adaptive Average Pooling (Global Average Pooling):

**Implementation**: self.avg_pool = nn.AdaptiveAvgPool2d(1)

**Paper Reference**: In the SE block, global average pooling is utilized to aggregate information across the spatial dimensions of each channel, producing a compact representation of the global context.

The nn.AdaptiveAvgPool2d(1) layer ensures that the pooling operation adapts to the spatial dimensions of the input feature map.

### Excitation:

**Implementation**: self.fc = nn.Sequential(...)

**Paper Reference**: The feature transformation is achieved through a sequence of fully connected layers.

The first linear layer reduces the number of channels by a reduction ratio (default is 16), introducing a bottleneck structure.

The ReLU activation introduces non-linearity, and the second linear layer restores the original number of channels.

The final sigmoid activation function scales the features to values between 0 and 1, representing the importance or relevance of each channel.

# Integration of SE Block in Forward Pass:

**Implementation**: def forward(self, x): ...

**Paper Reference**: In the forward pass, the input feature map is passed through the adaptive average pooling layer resulting in a compact global representation for each channel.

This representation is then processed through the fully connected layers (self.fc), and the output is reshaped to match the original dimensions.

The element-wise multiplication of the input by the scaling factors obtained from the Sigmoid activation (x \* y) performs the adaptive recalibration, emphasizing informative channels and suppressing less relevant ones.


In [60]:
# Squeeze-and-Excitation (SE) Block
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        # The nn.AdaptiveAvgPool2d layer in PyTorch is essentially a flexible implementation of global average pooling.
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Sequential(
            # Creates Bottleneck reducing number of channels
            nn.Linear(channels, channels // reduction, bias=False),
            # Non-linearity
            nn.ReLU(inplace=True),
            # Restores number of channels
            nn.Linear(channels // reduction, channels, bias=False),
            # Scale [0,1]
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)  # Squeeze
        y = self.fc(y).view(b, c, 1, 1)  # Excite
        return x * y  # Recalibrate

In [61]:
# Extend ResNet with SE Blocks
class SEResNet(models.ResNet):
    def __init__(self, block, layers, num_classes=1000):
        super(SEResNet, self).__init__(block, layers, num_classes=num_classes)

        # Replace the last layer with an SEBlock
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(self.inplanes, num_classes) 
        self.se_block = SEBlock(self.inplanes)  # Add SEBlock

    def forward(self, x):
        x = super(SEResNet, self).forward(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.se_block(x)  # Apply SEBlock
        return x

In [62]:
# Initialize models
resnet18 = models.resnet18(weights="DEFAULT")
seresnet = SEResNet(models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=1000)

print(resnet18.state_dict().keys())
print(seresnet.state_dict().keys())

# Load the state_dict while excluding problematic keys
state_dict = resnet18.state_dict()
state_dict['se_block.fc.0.weight']=seresnet.state_dict()['se_block.fc.0.weight']
state_dict['se_block.fc.2.weight']=seresnet.state_dict()['se_block.fc.2.weight']


# Load the adjusted state_dict into SEResNet
seresnet.load_state_dict(state_dict)

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.weight', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.num_batches_tracked', 'layer2.0.conv1.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.num_batches_tracked', 'layer2.0.conv2.weight', 'layer2.0.bn2.weight', 'layer2.0.bn2.bias', '

<All keys matched successfully>