In [None]:
#静态：model==0；
"""
MindSpore implementation of `ResNet` with Mamba integration.
Refer to Deep Residual Learning for Image Recognition.
"""

from typing import List, Optional, Type, Union

import mindspore.common.initializer as init
from mindspore import Tensor, nn, ops

from .helpers import build_model_with_cfg
from .layers.pooling import GlobalAvgPooling
from .registry import register_model

__all__ = [
    "ResNet",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "resnext50_32x4d",
    "resnext101_32x4d",
    "resnext101_64x4d",
    "resnext152_64x4d",
]


def _cfg(url="", **kwargs):
    return {
        "url": url,
        "num_classes": 1000,
        "first_conv": "conv1",
        "classifier": "classifier",
        **kwargs,
    }


default_cfgs = {
    "resnet18": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet18-1e65cd21.ckpt"),
    "resnet34": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet34-f297d27e.ckpt"),
    "resnet50": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet50-e0733ab8.ckpt"),
    "resnet101": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet101-689c5e77.ckpt"),
    "resnet152": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet152-beb689d8.ckpt"),
    "resnext50_32x4d": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext50_32x4d-af8aba16.ckpt"),
    "resnext101_32x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_32x4d-3c1e9c51.ckpt"
    ),
    "resnext101_64x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_64x4d-8929255b.ckpt"
    ),
    "resnext152_64x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext152_64x4d-3aba275c.ckpt"
    ),
}


class BasicBlock(nn.Cell):
    """define the basic block of resnet"""
    expansion: int = 1

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d
        assert groups == 1, "BasicBlock only supports groups=1"
        assert base_width == 64, "BasicBlock only supports base_width=64"

        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3,
                               stride=stride, padding=1, pad_mode="pad")
        self.bn1 = norm(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3,
                               stride=1, padding=1, pad_mode="pad")
        self.bn2 = norm(channels)
        self.down_sample = down_sample

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Cell):
    """集成Mamba的Bottleneck"""
    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
        use_mamba: bool = False,  # 新增参数
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d

        width = int(channels * (base_width / 64.0)) * groups

        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1)
        self.bn1 = norm(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, pad_mode="pad", group=groups)
        self.bn2 = norm(width)
        self.conv3 = nn.Conv2d(width, channels * self.expansion,
                               kernel_size=1, stride=1)
        self.bn3 = norm(channels * self.expansion)
        self.relu = nn.ReLU()
        self.down_sample = down_sample
        
        # 选择性添加Mamba
        self.use_mamba = use_mamba
        if use_mamba:
            self.mamba = MambaBlock(dim=channels * self.expansion)

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # 在BN3之后，残差连接之前应用Mamba
        # 这是Mamba能发挥最大价值的位置
        if self.use_mamba:
            out = self.mamba(out)
        
        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class MambaBlock(nn.Cell):
    """精简高效的Mamba块：专注于空间关系建模和训练稳定性"""
    
    def __init__(self, dim: int) -> None:
        super().__init__()
        
        # 内部维度设计
        self.d_inner = max(dim // 2, 64)
        
        # 输入投影
        self.in_proj = nn.Conv2d(dim, self.d_inner, kernel_size=1, 
                                 weight_init=init.HeNormal())
        self.in_norm = nn.BatchNorm2d(self.d_inner, gamma_init='zeros')
        
        # 空间处理 - 使用较大的卷积核
        self.spatial = nn.Conv2d(self.d_inner, self.d_inner, kernel_size=5,
                                stride=1, padding=2, pad_mode="pad")
        self.spatial_norm = nn.BatchNorm2d(self.d_inner)
        
        # 输出投影
        self.out_proj = nn.Conv2d(self.d_inner, dim, kernel_size=1)
        self.out_norm = nn.BatchNorm2d(dim, gamma_init='zeros')
        
        self.relu = nn.ReLU()

        self.scale = 0.1  # 固定的中等强度影响
    
    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.in_proj(x)
        out = self.in_norm(out)
        out = self.relu(out)
        
        out = self.spatial(out)
        out = self.spatial_norm(out)
        out = self.relu(out)
        
        out = self.out_proj(out)
        out = self.out_norm(out)
        
        # 使用固定缩放系数
        out = identity + self.scale * out
        out = self.relu(out)
        
        return out


class MambaBottleneck(nn.Cell):
    """自适应Mamba Bottleneck：根据网络位置动态调整Mamba影响"""
    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
        use_mamba: bool = True,
        layer_index: int = 0,  # 添加层索引参数
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d

        width = int(channels * (base_width / 64.0)) * groups

        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1)
        self.bn1 = norm(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, pad_mode="pad", group=groups)
        self.bn2 = norm(width)
        self.conv3 = nn.Conv2d(width, channels * self.expansion,
                               kernel_size=1, stride=1)
        self.bn3 = norm(channels * self.expansion)
        self.relu = nn.ReLU()
        self.down_sample = down_sample
        
        # 根据层深度选择性添加Mamba
        # 浅层用更保守的配置，深层更激进
        self.use_mamba = use_mamba
        if use_mamba:
            self.mamba = MambaBlock(dim=channels * self.expansion)

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # 应用Mamba处理
        if self.use_mamba:
            out = self.mamba(out)
        
        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Cell):
    """优化的ResNet-Mamba集成架构"""
    
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        in_channels: int = 3,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        use_mamba: bool = True,
        cifar_mode: bool = False,  # 添加CIFAR模式参数
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d
        
        # 关键设计: 只在深层网络使用Mamba，那里更需要长距离关系建模
        self.use_mamba_in_layer = [False, False, True, True] if use_mamba else [False, False, False, False]
        
        self.norm = norm
        self.groups = groups
        self.base_width = base_width
        self.in_channels = 64
        
        # CIFAR模式使用更小的卷积核和去除最大池化
        if cifar_mode:
            self.conv1 = nn.Conv2d(in_channels, self.in_channels, kernel_size=3, stride=1, 
                                  padding=1, pad_mode="pad")
            self.max_pool = nn.Identity()  # 不进行池化
        else:
            self.conv1 = nn.Conv2d(in_channels, self.in_channels, kernel_size=7, stride=2, 
                                  padding=3, pad_mode="pad")
            self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
            
        self.bn1 = norm(self.in_channels)
        self.relu = nn.ReLU()
        
        # 构建网络层
        self.layer1 = self._make_layer(block, 64, layers[0], use_mamba=self.use_mamba_in_layer[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_mamba=self.use_mamba_in_layer[1])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_mamba=self.use_mamba_in_layer[2])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_mamba=self.use_mamba_in_layer[3])
        
        self.pool = GlobalAvgPooling()
        self.num_features = 512 * block.expansion
        self.classifier = nn.Dense(self.num_features, num_classes)
        self._initialize_weights()

    def _initialize_weights(self) -> None:
        """Initialize weights for cells."""
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Dense):
                cell.weight.set_data(
                    init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)
                )
                if cell.bias is not None:
                    cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype))
            elif isinstance(cell, nn.BatchNorm2d):
                cell.gamma.set_data(init.initializer(init.Constant(1), cell.gamma.shape, cell.gamma.dtype))
                cell.beta.set_data(init.initializer(init.Constant(0), cell.beta.shape, cell.beta.dtype))
            elif isinstance(cell, nn.Conv2d):
                cell.weight.set_data(
                    init.initializer(init.HeNormal(), cell.weight.shape, cell.weight.dtype)
                )
                if cell.bias is not None:
                    cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype))

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        channels: int,
        blocks_num: int,
        stride: int = 1,
        use_mamba: bool = False,
    ) -> nn.SequentialCell:
        layers = []
        # 第一个block通常不使用Mamba，因为它处理下采样
        down_sample = None
        if stride != 1 or self.in_channels != channels * block.expansion:
            down_sample = nn.SequentialCell([
                nn.Conv2d(self.in_channels, channels * block.expansion, kernel_size=1, stride=stride),
                self.norm(channels * block.expansion)
            ])
        
        # 第一个block，不使用Mamba
        layers.append(
            block(
                self.in_channels,
                channels,
                stride=stride,
                down_sample=down_sample,
                groups=self.groups,
                base_width=self.base_width,
                norm=self.norm,
                use_mamba=False
            )
        )
        
        self.in_channels = channels * block.expansion
        
        # 对剩余blocks，根据密度决定是否使用Mamba
        for i in range(1, blocks_num):
 
            use_mamba_here = use_mamba and i >= 1
            
            if use_mamba_here:

                layers.append(
                    block(
                        self.in_channels,
                        channels,
                        groups=self.groups,
                        base_width=self.base_width,
                        norm=self.norm,
                        use_mamba=True
                    )
                )
            else:
                layers.append(
                    block(
                        self.in_channels,
                        channels,
                        groups=self.groups,
                        base_width=self.base_width,
                        norm=self.norm,
                        use_mamba=False
                    )
                )
        
        return nn.SequentialCell(layers)

    def forward_features(self, x: Tensor) -> Tensor:
        """Network forward feature extraction."""
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def forward_head(self, x: Tensor) -> Tensor:
        x = self.pool(x)
        x = self.classifier(x)
        return x

    def construct(self, x: Tensor) -> Tensor:
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


def _create_resnet(pretrained=False, **kwargs):
    return build_model_with_cfg(ResNet, pretrained, **kwargs)


@register_model
def resnet18(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 18 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet18"]
    model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet34(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 34 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet34"]
    model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet50(
    pretrained: bool = False,
    num_classes: int = 1000,
    in_channels: int = 3,
    cifar_mode: bool = False,  # 添加CIFAR模式参数
    use_mamba: bool = True,    # Mamba开关
    **kwargs
) -> ResNet:
    """Get 50 layers ResNet model.
    
    Args:
        pretrained: Whether to download and load the pre-trained model. Default: False.
        num_classes: The number of classification. Default: 1000.
        in_channels: The input channels. Default: 3.
        cifar_mode: Whether to use CIFAR optimized architecture. Default: False.
        use_mamba: Whether to use Mamba blocks. Default: True.
        
    Returns:
        ResNet network.
    """
    default_cfg = default_cfgs["resnet50"]
    model_args = dict(
        block=Bottleneck,
        layers=[3, 4, 6, 3],
        num_classes=num_classes,
        in_channels=in_channels,
        cifar_mode=cifar_mode,
        use_mamba=use_mamba,
        **kwargs
    )
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet101(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet101"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet152(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 152 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet152"]
    model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext50_32x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 50 layers ResNeXt model with 32 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext50_32x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], groups=32, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext101_32x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNeXt model with 32 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext101_32x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext101_64x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNeXt model with 64 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext101_64x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=64, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext152_64x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    default_cfg = default_cfgs["resnext152_64x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], groups=64, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


In [None]:
#动态：mode==1
"""
MindSpore implementation of `ResNet` with Mamba integration.
Refer to Deep Residual Learning for Image Recognition.
"""

from typing import List, Optional, Type, Union

import mindspore.common.initializer as init
from mindspore import Tensor, nn, ops

from .helpers import build_model_with_cfg
from .layers.pooling import GlobalAvgPooling
from .registry import register_model

__all__ = [
    "ResNet",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "resnext50_32x4d",
    "resnext101_32x4d",
    "resnext101_64x4d",
    "resnext152_64x4d",
]


def _cfg(url="", **kwargs):
    return {
        "url": url,
        "num_classes": 1000,
        "first_conv": "conv1",
        "classifier": "classifier",
        **kwargs,
    }


default_cfgs = {
    "resnet18": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet18-1e65cd21.ckpt"),
    "resnet34": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet34-f297d27e.ckpt"),
    "resnet50": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet50-e0733ab8.ckpt"),
    "resnet101": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet101-689c5e77.ckpt"),
    "resnet152": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet152-beb689d8.ckpt"),
    "resnext50_32x4d": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext50_32x4d-af8aba16.ckpt"),
    "resnext101_32x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_32x4d-3c1e9c51.ckpt"
    ),
    "resnext101_64x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_64x4d-8929255b.ckpt"
    ),
    "resnext152_64x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext152_64x4d-3aba275c.ckpt"
    ),
}


class BasicBlock(nn.Cell):
    """define the basic block of resnet"""
    expansion: int = 1

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d
        assert groups == 1, "BasicBlock only supports groups=1"
        assert base_width == 64, "BasicBlock only supports base_width=64"

        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3,
                               stride=stride, padding=1, pad_mode="pad")
        self.bn1 = norm(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3,
                               stride=1, padding=1, pad_mode="pad")
        self.bn2 = norm(channels)
        self.down_sample = down_sample

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Cell):
    """集成Mamba的Bottleneck"""
    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
        use_mamba: bool = False,  # 新增参数
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d

        width = int(channels * (base_width / 64.0)) * groups

        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1)
        self.bn1 = norm(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, pad_mode="pad", group=groups)
        self.bn2 = norm(width)
        self.conv3 = nn.Conv2d(width, channels * self.expansion,
                               kernel_size=1, stride=1)
        self.bn3 = norm(channels * self.expansion)
        self.relu = nn.ReLU()
        self.down_sample = down_sample
        
        # 选择性添加Mamba
        self.use_mamba = use_mamba
        if use_mamba:
            self.mamba = MambaBlock(dim=channels * self.expansion)

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # 在BN3之后，残差连接之前应用Mamba
        # 这是Mamba能发挥最大价值的位置
        if self.use_mamba:
            out = self.mamba(out)
        
        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class MambaBlock(nn.Cell):
    """精简高效的Mamba块：专注于空间关系建模和训练稳定性"""
    
    def __init__(self, dim: int) -> None:
        super().__init__()
        
        # CIFAR数据集尺寸较小，我们可以使用更大的内部维度
        self.d_inner = max(dim // 2, 64)  # 增大内部维度以捕获更多信息
        
        # 输入投影
        self.in_proj = nn.Conv2d(dim, self.d_inner, kernel_size=1, 
                                 weight_init=init.HeNormal())
        self.in_norm = nn.BatchNorm2d(self.d_inner, gamma_init='zeros')
        
        # 空间处理 - 使用更大的卷积核以捕获更广泛的空间关系
        self.spatial = nn.Conv2d(self.d_inner, self.d_inner, kernel_size=5,  # 增大卷积核
                                stride=1, padding=2, pad_mode="pad")  # 调整padding
        self.spatial_norm = nn.BatchNorm2d(self.d_inner)
        
        # 输出投影
        self.out_proj = nn.Conv2d(self.d_inner, dim, kernel_size=1)
        self.out_norm = nn.BatchNorm2d(dim, gamma_init='zeros')
        
        self.relu = nn.ReLU()
        
        # 保持自适应影响机制
        self.epoch_counter = 0
        self.warmup_epochs = 5
        self.base_scale = 0.3
    
    def construct(self, x: Tensor) -> Tensor:
        identity = x
        
        # 渐进式影响控制
        if self.training:
            self.epoch_counter = min(self.epoch_counter + 1, self.warmup_epochs)
        
        # 计算自适应缩放系数
        current_scale = self.base_scale * min(1.0, self.epoch_counter / self.warmup_epochs)
        
        # 简化处理路径
        out = self.in_proj(x)
        out = self.in_norm(out)
        out = self.relu(out)
        
        # 关键的空间处理
        out = self.spatial(out)
        out = self.spatial_norm(out)
        out = self.relu(out)
        
        out = self.out_proj(out)
        out = self.out_norm(out)
        
        # 自适应残差连接
        out = identity + current_scale * out
        out = self.relu(out)
        
        return out


class MambaBottleneck(nn.Cell):
    """自适应Mamba Bottleneck：根据网络位置动态调整Mamba影响"""
    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
        use_mamba: bool = True,
        layer_index: int = 0,  # 添加层索引参数
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d

        width = int(channels * (base_width / 64.0)) * groups

        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1)
        self.bn1 = norm(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, pad_mode="pad", group=groups)
        self.bn2 = norm(width)
        self.conv3 = nn.Conv2d(width, channels * self.expansion,
                               kernel_size=1, stride=1)
        self.bn3 = norm(channels * self.expansion)
        self.relu = nn.ReLU()
        self.down_sample = down_sample
        
        # 根据层深度选择性添加Mamba
        # 浅层用更保守的配置，深层更激进
        self.use_mamba = use_mamba
        if use_mamba:
            self.mamba = MambaBlock(dim=channels * self.expansion)

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # 应用Mamba处理
        if self.use_mamba:
            out = self.mamba(out)
        
        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Cell):
    """优化的ResNet-Mamba集成架构"""
    
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        in_channels: int = 3,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        use_mamba: bool = True,
        cifar_mode: bool = False,  # 添加CIFAR模式参数
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d
        
        # 关键设计: 只在深层网络使用Mamba，那里更需要长距离关系建模
        self.use_mamba_in_layer = [False, False, True, True] if use_mamba else [False, False, False, False]
        
        self.norm = norm
        self.groups = groups
        self.base_width = base_width
        self.in_channels = 64
        
        # CIFAR模式使用更小的卷积核和去除最大池化
        if cifar_mode:
            self.conv1 = nn.Conv2d(in_channels, self.in_channels, kernel_size=3, stride=1, 
                                  padding=1, pad_mode="pad")
            self.max_pool = nn.Identity()  # 不进行池化
        else:
            self.conv1 = nn.Conv2d(in_channels, self.in_channels, kernel_size=7, stride=2, 
                                  padding=3, pad_mode="pad")
            self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
            
        self.bn1 = norm(self.in_channels)
        self.relu = nn.ReLU()
        
        # 构建网络层
        self.layer1 = self._make_layer(block, 64, layers[0], use_mamba=self.use_mamba_in_layer[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_mamba=self.use_mamba_in_layer[1])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_mamba=self.use_mamba_in_layer[2])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_mamba=self.use_mamba_in_layer[3])
        
        self.pool = GlobalAvgPooling()
        self.num_features = 512 * block.expansion
        self.classifier = nn.Dense(self.num_features, num_classes)
        self._initialize_weights()

    def _initialize_weights(self) -> None:
        """Initialize weights for cells."""
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Dense):
                cell.weight.set_data(
                    init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)
                )
                if cell.bias is not None:
                    cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype))
            elif isinstance(cell, nn.BatchNorm2d):
                cell.gamma.set_data(init.initializer(init.Constant(1), cell.gamma.shape, cell.gamma.dtype))
                cell.beta.set_data(init.initializer(init.Constant(0), cell.beta.shape, cell.beta.dtype))
            elif isinstance(cell, nn.Conv2d):
                cell.weight.set_data(
                    init.initializer(init.HeNormal(), cell.weight.shape, cell.weight.dtype)
                )
                if cell.bias is not None:
                    cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype))

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        channels: int,
        blocks_num: int,
        stride: int = 1,
        use_mamba: bool = False,
    ) -> nn.SequentialCell:
        layers = []
        # 第一个block通常不使用Mamba，因为它处理下采样
        down_sample = None
        if stride != 1 or self.in_channels != channels * block.expansion:
            down_sample = nn.SequentialCell([
                nn.Conv2d(self.in_channels, channels * block.expansion, kernel_size=1, stride=stride),
                self.norm(channels * block.expansion)
            ])
        
        # 第一个block，不使用Mamba
        layers.append(
            block(
                self.in_channels,
                channels,
                stride=stride,
                down_sample=down_sample,
                groups=self.groups,
                base_width=self.base_width,
                norm=self.norm,
                use_mamba=False
            )
        )
        
        self.in_channels = channels * block.expansion
        
        # 对剩余blocks，根据密度决定是否使用Mamba
        for i in range(1, blocks_num):
            use_mamba_here = use_mamba and i >= 1
            
            if use_mamba_here:

                layers.append(
                    block(
                        self.in_channels,
                        channels,
                        groups=self.groups,
                        base_width=self.base_width,
                        norm=self.norm,
                        use_mamba=True
                    )
                )
            else:
                layers.append(
                    block(
                        self.in_channels,
                        channels,
                        groups=self.groups,
                        base_width=self.base_width,
                        norm=self.norm,
                        use_mamba=False
                    )
                )
        
        return nn.SequentialCell(layers)

    def forward_features(self, x: Tensor) -> Tensor:
        """Network forward feature extraction."""
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def forward_head(self, x: Tensor) -> Tensor:
        x = self.pool(x)
        x = self.classifier(x)
        return x

    def construct(self, x: Tensor) -> Tensor:
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


def _create_resnet(pretrained=False, **kwargs):
    return build_model_with_cfg(ResNet, pretrained, **kwargs)


@register_model
def resnet18(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 18 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet18"]
    model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet34(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 34 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet34"]
    model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet50(
    pretrained: bool = False,
    num_classes: int = 1000,
    in_channels: int = 3,
    cifar_mode: bool = False,  # 添加CIFAR模式参数
    use_mamba: bool = True,    # Mamba开关
    **kwargs
) -> ResNet:
    """Get 50 layers ResNet model.
    
    Args:
        pretrained: Whether to download and load the pre-trained model. Default: False.
        num_classes: The number of classification. Default: 1000.
        in_channels: The input channels. Default: 3.
        cifar_mode: Whether to use CIFAR optimized architecture. Default: False.
        use_mamba: Whether to use Mamba blocks. Default: True.
        
    Returns:
        ResNet network.
    """
    default_cfg = default_cfgs["resnet50"]
    model_args = dict(
        block=Bottleneck,
        layers=[3, 4, 6, 3],
        num_classes=num_classes,
        in_channels=in_channels,
        cifar_mode=cifar_mode,
        use_mamba=use_mamba,
        **kwargs
    )
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet101(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet101"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet152(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 152 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet152"]
    model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext50_32x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 50 layers ResNeXt model with 32 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext50_32x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], groups=32, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext101_32x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNeXt model with 32 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext101_32x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext101_64x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNeXt model with 64 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext101_64x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=64, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext152_64x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    default_cfg = default_cfgs["resnext152_64x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], groups=64, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


In [None]:
#更新减少损伤
"""
MindSpore implementation of `ResNet` with Mamba integration.
Refer to Deep Residual Learning for Image Recognition.
"""

from typing import List, Optional, Type, Union

import mindspore.common.initializer as init
from mindspore import Tensor, nn, ops

from .helpers import build_model_with_cfg
from .layers.pooling import GlobalAvgPooling
from .registry import register_model

__all__ = [
    "ResNet",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "resnext50_32x4d",
    "resnext101_32x4d",
    "resnext101_64x4d",
    "resnext152_64x4d",
]


def _cfg(url="", **kwargs):
    return {
        "url": url,
        "num_classes": 1000,
        "first_conv": "conv1",
        "classifier": "classifier",
        **kwargs,
    }


default_cfgs = {
    "resnet18": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet18-1e65cd21.ckpt"),
    "resnet34": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet34-f297d27e.ckpt"),
    "resnet50": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet50-e0733ab8.ckpt"),
    "resnet101": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet101-689c5e77.ckpt"),
    "resnet152": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet152-beb689d8.ckpt"),
    "resnext50_32x4d": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext50_32x4d-af8aba16.ckpt"),
    "resnext101_32x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_32x4d-3c1e9c51.ckpt"
    ),
    "resnext101_64x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_64x4d-8929255b.ckpt"
    ),
    "resnext152_64x4d": _cfg(
        url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext152_64x4d-3aba275c.ckpt"
    ),
}


class BasicBlock(nn.Cell):
    """define the basic block of resnet"""
    expansion: int = 1

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d
        assert groups == 1, "BasicBlock only supports groups=1"
        assert base_width == 64, "BasicBlock only supports base_width=64"

        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3,
                               stride=stride, padding=1, pad_mode="pad")
        self.bn1 = norm(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3,
                               stride=1, padding=1, pad_mode="pad")
        self.bn2 = norm(channels)
        self.down_sample = down_sample

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Cell):
    """集成Mamba的Bottleneck"""
    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
        use_mamba: bool = False,  # 新增参数
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d

        width = int(channels * (base_width / 64.0)) * groups

        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1)
        self.bn1 = norm(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, pad_mode="pad", group=groups)
        self.bn2 = norm(width)
        self.conv3 = nn.Conv2d(width, channels * self.expansion,
                               kernel_size=1, stride=1)
        self.bn3 = norm(channels * self.expansion)
        self.relu = nn.ReLU()
        self.down_sample = down_sample
        
        # 选择性添加Mamba
        self.use_mamba = use_mamba
        if use_mamba:
            self.mamba = MambaBlock(dim=channels * self.expansion)

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # 在BN3之后，残差连接之前应用Mamba
        # 这是Mamba能发挥最大价值的位置
        if self.use_mamba:
            out = self.mamba(out)
        
        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class MambaBlock(nn.Cell):
    """高效轻量的Mamba块：专注于不干扰原始网络性能"""
    
    def __init__(self, dim: int) -> None:
        super().__init__()
        
        # 缩小内部维度，减少参数量和计算量
        self.d_inner = max(dim // 4, 32)
        
        # 输入投影
        self.in_proj = nn.Conv2d(dim, self.d_inner, kernel_size=1, 
                                weight_init=init.HeNormal())
        self.in_norm = nn.BatchNorm2d(self.d_inner, gamma_init='zeros')
        
        # 使用小卷积核，更适合CIFAR数据集
        self.spatial = nn.Conv2d(self.d_inner, self.d_inner, kernel_size=3,
                                stride=1, padding=1, pad_mode="pad")
        self.spatial_norm = nn.BatchNorm2d(self.d_inner)
        
        # 输出投影
        self.out_proj = nn.Conv2d(self.d_inner, dim, kernel_size=1)
        self.out_norm = nn.BatchNorm2d(dim, gamma_init='zeros')
        
        self.relu = nn.ReLU()
        
        # 大幅降低缩放系数，使Mamba的影响更轻微
        self.scale = 0.01  # 降低到0.01，减少对原始特征的干扰
    
    def construct(self, x: Tensor) -> Tensor:
        identity = x
        
        # 标准处理路径
        out = self.in_proj(x)
        out = self.in_norm(out)
        out = self.relu(out)
        
        out = self.spatial(out)
        out = self.spatial_norm(out)
        out = self.relu(out)
        
        out = self.out_proj(out)
        out = self.out_norm(out)
        
        # 使用极小的缩放系数
        out = identity + self.scale * out
        out = self.relu(out)
        
        return out


class MambaBottleneck(nn.Cell):
    """自适应Mamba Bottleneck：根据网络位置动态调整Mamba影响"""
    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        channels: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        down_sample: Optional[nn.Cell] = None,
        use_mamba: bool = True,
        layer_index: int = 0,  # 添加层索引参数
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d

        width = int(channels * (base_width / 64.0)) * groups

        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1)
        self.bn1 = norm(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, pad_mode="pad", group=groups)
        self.bn2 = norm(width)
        self.conv3 = nn.Conv2d(width, channels * self.expansion,
                               kernel_size=1, stride=1)
        self.bn3 = norm(channels * self.expansion)
        self.relu = nn.ReLU()
        self.down_sample = down_sample
        
        # 根据层深度选择性添加Mamba
        # 浅层用更保守的配置，深层更激进
        self.use_mamba = use_mamba
        if use_mamba:
            self.mamba = MambaBlock(dim=channels * self.expansion)

    def construct(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        
        # 应用Mamba处理
        if self.use_mamba:
            out = self.mamba(out)
        
        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Cell):
    """优化的ResNet-Mamba集成架构"""
    
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        in_channels: int = 3,
        groups: int = 1,
        base_width: int = 64,
        norm: Optional[nn.Cell] = None,
        use_mamba: bool = True,
        cifar_mode: bool = False,
    ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.BatchNorm2d
        
        # 只在最深的layer4使用Mamba
        self.use_mamba_in_layer = [False, False, False, True] if use_mamba else [False, False, False, False]
        
        self.norm = norm
        self.groups = groups
        self.base_width = base_width
        self.in_channels = 64
        
        # CIFAR模式使用更小的卷积核和去除最大池化
        if cifar_mode:
            self.conv1 = nn.Conv2d(in_channels, self.in_channels, kernel_size=3, stride=1, 
                                  padding=1, pad_mode="pad")
            self.max_pool = nn.Identity()  # 不进行池化
        else:
            self.conv1 = nn.Conv2d(in_channels, self.in_channels, kernel_size=7, stride=2, 
                                  padding=3, pad_mode="pad")
            self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
            
        self.bn1 = norm(self.in_channels)
        self.relu = nn.ReLU()
        
        # 构建网络层
        self.layer1 = self._make_layer(block, 64, layers[0], use_mamba=self.use_mamba_in_layer[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_mamba=self.use_mamba_in_layer[1])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_mamba=self.use_mamba_in_layer[2])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_mamba=self.use_mamba_in_layer[3])
        
        self.pool = GlobalAvgPooling()
        self.num_features = 512 * block.expansion
        self.classifier = nn.Dense(self.num_features, num_classes)
        self._initialize_weights()

    def _initialize_weights(self) -> None:
        """Initialize weights for cells."""
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Dense):
                cell.weight.set_data(
                    init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)
                )
                if cell.bias is not None:
                    cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype))
            elif isinstance(cell, nn.BatchNorm2d):
                cell.gamma.set_data(init.initializer(init.Constant(1), cell.gamma.shape, cell.gamma.dtype))
                cell.beta.set_data(init.initializer(init.Constant(0), cell.beta.shape, cell.beta.dtype))
            elif isinstance(cell, nn.Conv2d):
                cell.weight.set_data(
                    init.initializer(init.HeNormal(), cell.weight.shape, cell.weight.dtype)
                )
                if cell.bias is not None:
                    cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype))

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        channels: int,
        blocks_num: int,
        stride: int = 1,
        use_mamba: bool = False,
    ) -> nn.SequentialCell:
        layers = []
        # 第一个block通常不使用Mamba，因为它处理下采样
        down_sample = None
        if stride != 1 or self.in_channels != channels * block.expansion:
            down_sample = nn.SequentialCell([
                nn.Conv2d(self.in_channels, channels * block.expansion, kernel_size=1, stride=stride),
                self.norm(channels * block.expansion)
            ])
        
        # 第一个block，不使用Mamba
        layers.append(
            block(
                self.in_channels,
                channels,
                stride=stride,
                down_sample=down_sample,
                groups=self.groups,
                base_width=self.base_width,
                norm=self.norm,
                use_mamba=False
            )
        )
        
        self.in_channels = channels * block.expansion
        
        # 对剩余blocks，根据密度决定是否使用Mamba
        for i in range(1, blocks_num):
            use_mamba_here = use_mamba and i >= 1
            
            if use_mamba_here:
                layers.append(
                    block(
                        self.in_channels,
                        channels,
                        groups=self.groups,
                        base_width=self.base_width,
                        norm=self.norm,
                        use_mamba=True
                    )
                )
            else:
                layers.append(
                    block(
                        self.in_channels,
                        channels,
                        groups=self.groups,
                        base_width=self.base_width,
                        norm=self.norm,
                        use_mamba=False
                    )
                )
        
        return nn.SequentialCell(layers)

    def forward_features(self, x: Tensor) -> Tensor:
        """Network forward feature extraction."""
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def forward_head(self, x: Tensor) -> Tensor:
        x = self.pool(x)
        x = self.classifier(x)
        return x

    def construct(self, x: Tensor) -> Tensor:
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


def _create_resnet(pretrained=False, **kwargs):
    return build_model_with_cfg(ResNet, pretrained, **kwargs)


@register_model
def resnet18(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 18 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet18"]
    model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet34(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 34 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet34"]
    model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet50(
    pretrained: bool = False,
    num_classes: int = 1000,
    in_channels: int = 3,
    cifar_mode: bool = False,  # 添加CIFAR模式参数
    use_mamba: bool = True,    # Mamba开关
    **kwargs
) -> ResNet:
    """Get 50 layers ResNet model.
    
    Args:
        pretrained: Whether to download and load the pre-trained model. Default: False.
        num_classes: The number of classification. Default: 1000.
        in_channels: The input channels. Default: 3.
        cifar_mode: Whether to use CIFAR optimized architecture. Default: False.
        use_mamba: Whether to use Mamba blocks. Default: True.
        
    Returns:
        ResNet network.
    """
    default_cfg = default_cfgs["resnet50"]
    model_args = dict(
        block=Bottleneck,
        layers=[3, 4, 6, 3],
        num_classes=num_classes,
        in_channels=in_channels,
        cifar_mode=cifar_mode,
        use_mamba=use_mamba,
        **kwargs
    )
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet101(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet101"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnet152(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 152 layers ResNet model.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnet152"]
    model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], num_classes=num_classes, in_channels=in_channels,
                      **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext50_32x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 50 layers ResNeXt model with 32 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext50_32x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], groups=32, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext101_32x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNeXt model with 32 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext101_32x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, base_width=4, num_classes=num_classes,
                      in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext101_64x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 101 layers ResNeXt model with 64 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext101_64x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=64, base_width=4, num_classes=num_classes,
                       in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))


@register_model
def resnext152_64x4d(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
    """Get 152 layers ResNeXt model with 64 groups of GPConv.
    Refer to the base class `models.ResNet` for more details.
    """
    default_cfg = default_cfgs["resnext152_64x4d"]
    model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], groups=64, base_width=4, num_classes=num_classes,
                       in_channels=in_channels, **kwargs)
    return _create_resnet(pretrained, **dict(default_cfg=default_cfg, **model_args))