In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DSPM(nn.Module):
    def __init__(self, in_channels, out_channels, dilation_rates=[1, 3, 6]):
        super(DSPM, self).__init__()
        reduced_channels = out_channels // 1
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, reduced_channels, kernel_size=3, padding=d, dilation=d, bias=False)
            for d in dilation_rates
        ])
        
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.global_conv = nn.Conv2d(in_channels, reduced_channels, kernel_size=1, bias=False)
        
        self.final_conv = nn.Conv2d(reduced_channels * 4, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        conv_outs = [conv(x) for conv in self.convs]
        pool_out = self.global_pool(x)
        pool_out = self.global_conv(pool_out)  
        pool_out = F.interpolate(pool_out, size=x.shape[2:], mode="bilinear", align_corners=False)
        
        x = torch.cat([pool_out ] + conv_outs, dim=1)
        x = self.final_conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


In [None]:
from torchvision.models import resnet50

class ResNetWithDSPM(nn.Module):
    def __init__(self, num_classes=100):
        super(ResNetWithDSPM, self).__init__()
        
        base_model = resnet50(weights="ResNet50_Weights.DEFAULT")
        self.layer0 = nn.Sequential(base_model.conv1, base_model.bn1, base_model.relu, base_model.maxpool)
        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.layer3 = base_model.layer3
        self.layer4 = base_model.layer4  # ResNet 最後一層
        
        self.dspm = DSPM(in_channels=2048, out_channels=1024)
        
        self.classifier = nn.Conv2d(1024, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)  # ResNet 最後一層
        print("dspp:",x.shape)
        x = self.dspm(x)  # 加入 DSPM
        print("dspp:",x.shape)
        x = self.classifier(x)  # 最終分類
        # x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)  # 放大回輸入尺寸
        
        return x


In [None]:
model = ResNetWithDSPM(num_classes=100)
input_tensor = torch.randn(1, 3, 512, 512)
output = model(input_tensor)
# print(output.shape)  # (1, 21, 512, 512)
