In [None]:
import tensorflow as tf
from tensorflow.keras import layers

In [None]:
class CConv2D(layers.Layer):
    def __init__(self, out_channels, kernel_size, stride, padding, **kwargs):
        super(CConv2D, self).__init__(**kwargs)
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.re_conv = layers.Conv2D(filters=out_channels, kernel_size=kernel_size, 
                                     strides=stride, padding=padding, use_bias=False)
        self.im_conv = layers.Conv2D(filters=out_channels, kernel_size=kernel_size, 
                                     strides=stride, padding=padding, use_bias=False)

    def build(self, input_shape):
        # Initialize weights
        self.re_conv.build(input_shape)
        self.im_conv.build(input_shape)
        self.re_conv.kernel.assign(tf.keras.initializers.GlorotUniform()(self.re_conv.kernel.shape))
        self.im_conv.kernel.assign(tf.keras.initializers.GlorotUniform()(self.im_conv.kernel.shape))

    def call(self, inputs):
        # Assuming the input is in the format [batch, height, width, channels, 2]
        x_re = inputs[..., 0]
        x_im = inputs[..., 1]

        out_re = self.re_conv(x_re) - self.im_conv(x_im)
        out_im = self.re_conv(x_im) + self.im_conv(x_re)

        out = tf.stack([out_re, out_im], axis=-1)
        return out


In [None]:
class CBatchNorm(layers.Layer):
    def __init__(self, in_channels, **kwargs):
        super(CBatchNorm, self).__init__(**kwargs)
        self.in_channels = in_channels

        # 创建实部和虚部的批量标准化层
        self.re_batch = layers.BatchNormalization()
        self.im_batch = layers.BatchNormalization()

    def call(self, inputs):
        # 假设输入是 [batch, height, width, channels, 2] 格式，最后一个维度是复数的实部和虚部
        x_re = inputs[..., 0]
        x_im = inputs[..., 1]

        # 分别对实部和虚部进行批量标准化
        out_re = self.re_batch(x_re)
        out_im = self.im_batch(x_im)

        # 将处理过的实部和虚部重新堆叠
        out = tf.stack([out_re, out_im], axis=-1)

        return out

In [None]:
class CMaxPool2D(layers.Layer):
    def __init__(self, kernel_size, **kwargs):
        super(CMaxPool2D, self).__init__(**kwargs)
        self.kernel_size = kernel_size

        # 创建实部和虚部的最大池化层
        self.CMax_re = layers.MaxPooling2D(pool_size=kernel_size, **kwargs)
        self.CMax_im = layers.MaxPooling2D(pool_size=kernel_size, **kwargs)

    def call(self, inputs):
        # 假设输入是 [batch, height, width, channels, 2] 格式，最后一个维度是复数的实部和虚部
        x_re = inputs[..., 0]
        x_im = inputs[..., 1]

        # 分别对实部和虚部进行最大池化
        out_re = self.CMax_re(x_re)
        out_im = self.CMax_im(x_im)

        # 将处理过的实部和虚部重新堆叠
        out = tf.stack([out_re, out_im], axis=-1)

        return out

In [None]:
class MagnitudeOperation(layers.Layer):
    def __init__(self, **kwargs):
        super(MagnitudeOperation, self).__init__(**kwargs)

    def call(self, inputs):
        # 假设输入的最后一个维度为2，代表复数的实部和虚部
        real_part = inputs[..., 0]
        imag_part = inputs[..., 1]
        magnitude = tf.sqrt(tf.square(real_part) + tf.square(imag_part))
        return magnitude
