In [1]:
import torch.nn as nn
from jaxtyping import Float, Array

In [None]:
class SimpleCNNSmall(nn.Module):
    __version__ = '0.1.0'

    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(2, 8, kernel_size=3, stride=1, padding=1, padding_mode="circular")
        self.bn1 = nn.BatchNorm2d(8)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(8, 8, kernel_size=3, stride=1, padding=1, padding_mode="circular")
        self.bn2 = nn.BatchNorm2d(8)
        self.act2 = nn.ReLU()
        self.conv3 = nn.Conv2d(8, 2, kernel_size=3, stride=1, padding=1, padding_mode="circular")
        # 可选：self.bn3 = nn.BatchNorm2d(2)

    def forward(self, x: Float[Array, "batch 2 w h"]) -> Float[Array, "batch 2 w h"]:
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.conv3(x)
        # 如果用于分类，建议不加 sigmoid
        # x = torch.sigmoid(x)
        return x

In [None]:
# p4_cnn_small.py
from __future__ import annotations
import torch
import torch.nn as nn
from jaxtyping import Float, Array
from e2cnn import gspaces, nn as enn


class SimpleCNNSmallP4(nn.Module):
    """P4-旋转等价版 SimpleCNNSmall
    输入: (B, 2, H, W)  输出: (B, 2, H, W)
    用法与原模型完全一致，直接替换类名即可。
    """
    __version__ = '0.1.0-p4'

    # ---------- 内部工具 ----------
    def __init__(self):
        super().__init__()

        # 1. 定义对称群：4 个离散旋转 [0°,90°,180°,270°]
        r2_act = gspaces.Rot2dOnR2(N=4)
        
        in_type = enn.FieldType(r2_act, 2 * [r2_act.trivial_repr])
        hid_type = enn.FieldType(r2_act, 8 * [r2_act.regular_repr])   # 8*4=32 实际通道
        out_type = enn.FieldType(r2_act, 2 * [r2_act.trivial_repr])   # 输出回到 2 通道

        # 4. 等变卷积 + BN + ReLU
        self.conv1 = enn.R2Conv(in_type, hid_type, kernel_size=3,
                                stride=1, padding=1, padding_mode="circular", bias=False)
        self.bn1   = enn.InnerBatchNorm(hid_type)
        self.act1  = enn.ReLU(hid_type, inplace=True)
        self.conv2 = enn.R2Conv(hid_type, hid_type, kernel_size=3,
                                stride=1, padding=1, padding_mode="circular", bias=False)
        self.bn2   = enn.InnerBatchNorm(hid_type)
        self.act2  = enn.ReLU(hid_type, inplace=True)
        self.conv3 = enn.R2Conv(hid_type, out_type, kernel_size=3,
                                stride=1, padding=1, padding_mode="circular", bias=True)

        # 5. 保存 in/out_type 供 forward 用
        self.in_type  = in_type
        self.out_type = out_type

    # ---------- 标准 forward ----------
    def forward(self, x: Float[Array, "batch 2 w h"]) -> Float[Array, "batch 2 w h"]:
        x: enn.GeometricTensor = enn.GeometricTensor(x, self.in_type)

        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.conv3(x)

        return x.tensor

    # ---------- 可选：导出纯 PyTorch 模型，推理无 e2cnn 依赖 ----------
    def export(self) -> nn.Module:
        """返回普通 nn.Module，等变性固化，推理更快。"""
        return torch.jit.trace(self, torch.randn(1, 2, 100, 100))


# ------------------- 简单测试 -------------------
if __name__ == "__main__":
    net = SimpleCNNSmallP4()
    x = torch.randn(5, 2, 64, 64)
    y = net(x)
    print(y.shape)  # 期望: torch.Size([5, 2, 64, 64])

torch.Size([5, 2, 64, 64])


  full_mask[mask] = norms.to(torch.uint8)


In [None]:
net = SimpleCNNSmallP4()
x = torch.randn(5, 2, 64, 64)
y = net(x)
print(y.shape)  

torch.Size([5, 2, 64, 64])


In [17]:
from torchinfo import summary

summary(net.cpu().export(), input_size=(1, 2, 100, 100))

Layer (type:depth-idx)                             Output Shape              Param #
SimpleCNNSmallP4                                   --                        --
├─R2Conv: 1-1                                      --                        96
│    └─BlocksBasisExpansion: 2-1                   --                        --
│    │    └─SingleBlockBasisExpansion: 3-1         --                        --
├─InnerBatchNorm: 1-2                              --                        --
│    └─BatchNorm3d: 2-2                            --                        16
├─ReLU: 1-3                                        --                        --
├─R2Conv: 1-4                                      --                        1,536
│    └─BlocksBasisExpansion: 2-3                   --                        --
│    │    └─SingleBlockBasisExpansion: 3-2         --                        --
├─InnerBatchNorm: 1-5                              --                        --
│    └─BatchNorm3d: 2-4         

  assert tensor.shape[1] == type.size, \
  assert weights.shape[0] == self.dimension()
  assert len(weights.shape) == 2 and weights.shape[1] == self.dimension()


SimpleCNNSmallP4(
  original_name=SimpleCNNSmallP4
  (conv1): R2Conv(
    original_name=R2Conv
    (_basisexpansion): BlocksBasisExpansion(
      original_name=BlocksBasisExpansion
      (block_expansion_('irrep_0', 'regular')): SingleBlockBasisExpansion(original_name=SingleBlockBasisExpansion)
    )
  )
  (bn1): InnerBatchNorm(
    original_name=InnerBatchNorm
    (batch_norm_[4]): BatchNorm3d(original_name=BatchNorm3d)
  )
  (act1): ReLU(original_name=ReLU)
  (conv2): R2Conv(
    original_name=R2Conv
    (_basisexpansion): BlocksBasisExpansion(
      original_name=BlocksBasisExpansion
      (block_expansion_('regular', 'regular')): SingleBlockBasisExpansion(original_name=SingleBlockBasisExpansion)
    )
  )
  (bn2): InnerBatchNorm(
    original_name=InnerBatchNorm
    (batch_norm_[4]): BatchNorm3d(original_name=BatchNorm3d)
  )
  (act2): ReLU(original_name=ReLU)
  (conv3): R2Conv(
    original_name=R2Conv
    (_basisexpansion): BlocksBasisExpansion(
      original_name=BlocksBasisExp

In [None]:
from pprint import pprint

for i in net.parameters():
    print()

<generator object Module.parameters at 0x0000025B1F0777D0>