In [5]:
import torch
import torch.nn as nn

class CustomBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-10):
        super(CustomBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0, keepdim=True)
            var = x.var(dim=0, keepdim=True, unbiased=False)
            self.running_mean = (1 - self.eps) * self.running_mean + self.eps * mean
            self.running_var = (1 - self.eps) * self.running_var + self.eps * var
            out = (x - mean) / (var.sqrt() + self.eps)
        else:
            out = (x - self.running_mean) / (self.running_var.sqrt() + self.eps)
        return out

In [6]:
# 创建一个随机输入张量
input_tensor = torch.randn(64, 624, 1)

# 创建一个CustomBatchNorm实例
norm_layer = CustomBatchNorm(num_features=624)

# 将输入张量传递给norm_layer
output_tensor = norm_layer(input_tensor)

# 打印输出张量的形状
print(output_tensor.shape)

torch.Size([64, 624, 1])


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class SEAttention(nn.Module):

    def __init__(self, channel=512,reduction=4):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, l, c = x.size()
        x = rearrange(x, 'b l c -> b c l')
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c,1)
        res = x * y
        return rearrange(res, 'b c l -> b l c')

x = torch.randn(2, 128, 16)
att = SEAttention(channel=16)
y = att(x)

# other wavelet

## sinc

In [None]:
import math

def sinc(band, t_right):
    y_right = torch.sin(2 * math.pi * band * t_right) / ((2 * math.pi * band * t_right) + 1e-6)
    y_left = torch.flip(y_right, [0])
    y = torch.cat([y_left, torch.ones(1).to(t_right.device), y_right])
    return y

def Mexh(p):
    # p = 0.04 * p  # 将时间转化为在[-5,5]这个区间内
    y = (1 - torch.pow(p, 2)) * torch.exp(-torch.pow(p, 2) / 2)

    return y

def Laplace(p):
    A = 0.08
    ep = 0.03
    tal = 0.1
    f = 50
    w = 2 * pi * f
    q = torch.tensor(1 - pow(ep, 2))
    y = A * torch.exp((-ep / (torch.sqrt(q))) * (w * (p - tal))) * (-torch.sin(w * (p - tal)))
    return y

class SincConv_multiple_channel(nn.Module):
    def __init__(self, out_channels, kernel_size, in_channels=1):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        if kernel_size % 2 == 0:
            self.kernel_size += 1

        self.a_ = nn.Parameter(torch.linspace(1, 10, out_channels)).view(-1, 1)
        self.b_ = nn.Parameter(torch.linspace(0, 10, out_channels)).view(-1, 1)

    def forward(self, waveforms):
        half_kernel = self.kernel_size // 2
        time_disc = torch.linspace(-half_kernel, half_kernel, steps=self.kernel_size).to(waveforms.device)
        self.a_ = self.a_.to(waveforms.device)
        self.b_ = self.b_.to(waveforms.device)
        
        filters = []
        for i in range(self.out_channels):
            band = self.a_[i]
            t_right = time_disc - self.b_[i]
            filter = sinc(band, t_right)
            filters.append(filter)

        filters = torch.stack(filters)
        self.filters = filters.view(self.out_channels, 1, -1)

        output = []
        for i in range(self.in_channels):
            output.append(F.conv1d(waveforms[:, i:i+1], self.filters, stride=1, padding=half_kernel, dilation=1, bias=None, groups=1))
        return torch.cat(output, dim=1)


class Morlet_multiple_channel(nn.Module):

    def __init__(self, out_channels, kernel_size, in_channels=1):

        super(Morlet_multiple_channel, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size - 1

        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1

        self.a_ = nn.Parameter(torch.linspace(1, 10, out_channels)).view(-1, 1)

        self.b_ = nn.Parameter(torch.linspace(0, 10, out_channels)).view(-1, 1)

    def forward(self, waveforms):

        time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1,
                                         steps=int((self.kernel_size / 2)))

        time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1,
                                        steps=int((self.kernel_size / 2)))

        p1 = time_disc_right - self.b_ / self.a_
        p2 = time_disc_left - self.b_ / self.a_

        Morlet_right = Morlet(p1).to(waveforms.device)
        Morlet_left = Morlet(p2).to(waveforms.device)

        Morlet_filter = torch.cat([Morlet_left, Morlet_right], dim=1)  # 40x1x250

        self.filters = (Morlet_filter).view(self.out_channels, 1, self.kernel_size).to(waveforms.device)# .cuda()

        output = []
        for i in range(self.in_channels):
            output.append(F.conv1d(waveforms[:, i:i+1], self.filters, stride=1, padding=1, dilation=1, bias=None, groups=1))
        return torch.cat(output, dim=1)
    


In [2]:
import numpy as np
import torch

# 假设的 select_validation_samples 函数
def select_validation_samples(data_all, label_all, num_samples):
    unique_labels = np.unique(label_all)
    indices_to_keep = []

    for label in unique_labels:            
        indices = np.where(label_all == label)[0]
        if len(indices) > num_samples:
            chosen_indices = indices[:num_samples]
        else:
            chosen_indices = indices
        indices_to_keep.extend(chosen_indices)

    return data_all[indices_to_keep], label_all[indices_to_keep]

# 生成模拟数据和标签
np.random.seed(0)  # 为了可重复性
data_all = torch.randn(100, 10)  # 假设有100个样本，每个样本10个特征
label_all = np.random.randint(0, 5, size=(100,))  # 假设有5个类别

# 调用函数
selected_data, selected_labels = select_validation_samples(data_all, label_all, 10)

# 打印结果
print("Selected data shape:", selected_data.shape)
print("Selected labels shape:", selected_labels.shape)
print("Unique labels in selected set:", np.unique(selected_labels))

# 验证每个类别的样本数是否正确
for label in np.unique(label_all):
    print(f"Number of samples for label {label}: {np.sum(selected_labels == label)}")

Selected data shape: torch.Size([50, 10])
Selected labels shape: (50,)
Unique labels in selected set: [0 1 2 3 4]
Number of samples for label 0: 10
Number of samples for label 1: 10
Number of samples for label 2: 10
Number of samples for label 3: 10
Number of samples for label 4: 10


In [95]:
import torch.nn as nn
import torch
from einops import rearrange

import numpy as np
import torch
from torch.nn import init



class ChannelAttention(nn.Module):
    def __init__(self,channel,reduction=16):
        super().__init__()
        self.maxpool=nn.AdaptiveMaxPool1d(1)
        self.avgpool=nn.AdaptiveAvgPool1d(1)
        self.se=nn.Sequential(
            nn.Conv1d(channel,channel//reduction,1,bias=False),
            nn.ReLU(),
            nn.Conv1d(channel//reduction,channel,1,bias=False)
        )
        self.sigmoid=nn.Sigmoid()
        self.softmax=nn.Softmax(dim=1)
    
    def forward(self, x) :
        max_result=self.maxpool(x)
        avg_result=self.avgpool(x)
        max_out=self.se(max_result)
        avg_out=self.se(avg_result)
        output=self.softmax(max_out+avg_out)
        return output



input=torch.randn(50,512,6) # B,C,L
kernel_size= 7
cbam = ChannelAttention(channel=512,reduction=16)
output=cbam(input)
print(output.shape)

    
       
# batch_size, seq_length, channel = 10, 512, 6
# x = torch.randn(batch_size, seq_length, channel)  # 创建一个随机输入张量

# attention_layer = SPAttention(channel=channel, reduction=8)  # 实例化注意力层
# output = attention_layer(x)  # 前向传播

# print(f"Input shape: {x.shape}")
# print(f"Output shape: {output.shape}")

# # 检查输出形状是否正确
# assert output.shape == x.shape, f"Expected output shape {x.shape}, but got {output.shape}"

# print("Test passed successfully.")

torch.Size([50, 512, 1])


In [101]:
import torch
import torch.nn as nn
from einops import rearrange

class ChannelAttention(nn.Module):
    def __init__(self, channel, reduction=16, topk=10):
        super().__init__()
        # self.maxpool=nn.AdaptiveMaxPool1d(1) # B,C,L -> B,C,1
        self.varpool = lambda x: (((x - torch.mean(x, dim=-1, keepdim=True)) ** 2).mean(dim=-1, keepdim=True))
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.se1 = nn.Sequential(
            nn.Conv1d(channel, channel // reduction, 1, bias=False), # = linear
            nn.ReLU(),
            nn.Conv1d(channel // reduction, channel, 1, bias=False)
        )
        self.se2 = nn.Sequential(
            nn.Conv1d(channel, channel // reduction, 1, bias=False), # = linear
            nn.ReLU(),
            nn.Conv1d(channel // reduction, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)
        self.topk = topk

    def sparser(self, x, topk):
        # Get the topk values and their indices
        topk_values, topk_indices = torch.topk(x, topk, dim=1)
        # Create a mask for the topk values
        mask = torch.zeros_like(x, dtype=torch.bool)
        mask.scatter_(1, topk_indices, True)
        # Set the values not in the topk to negative infinity
        x[~mask] = float('-inf')
        return x

    def forward(self, x):
        x = rearrange(x, 'b l c -> b c l')
        var_result = self.varpool(x)
        avg_result = self.avgpool(x)
        var_out = self.se1(var_result)
        avg_out = self.se2(avg_result)
        res = var_out + avg_out
        res = self.sparser(res, self.topk)
        # output=self.sigmoid(max_out+avg_out)
        output = self.softmax(res)
        return output

# Test the module with a random input
x = torch.randn(100, 4096, 16)  # Example input tensor # B,L,C
channel_attention = ChannelAttention(channel=16)
output = channel_attention(x)
print(output)


tensor([[[0.0000],
         [0.0000],
         [0.1000],
         ...,
         [0.1000],
         [0.1000],
         [0.1000]],

        [[0.0995],
         [0.0000],
         [0.1005],
         ...,
         [0.0994],
         [0.0000],
         [0.0000]],

        [[0.0993],
         [0.0000],
         [0.1006],
         ...,
         [0.0992],
         [0.0000],
         [0.0000]],

        ...,

        [[0.0000],
         [0.0000],
         [0.1000],
         ...,
         [0.1000],
         [0.1000],
         [0.1000]],

        [[0.0999],
         [0.0000],
         [0.1001],
         ...,
         [0.0999],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.1000],
         ...,
         [0.1000],
         [0.1000],
         [0.1000]]], grad_fn=<SoftmaxBackward0>)


In [56]:
import torch
import torch.nn.functional as F

def random_shuffle_channel(tensor, channel_index):
    # 提取要打乱的通道
    channel = tensor[:, channel_index, :].clone()
    
    # 随机打乱通道
    perm = torch.randperm(channel.size(0))
    shuffled_channel = channel[perm]
    
    # 将打乱后的通道重新放回tensor
    new_tensor = tensor.clone()
    new_tensor[:, channel_index, :] = shuffled_channel
    
    return new_tensor

def cosine_similarity(x, y):
    return F.cosine_similarity(x, y, dim=-1)

# 示例用法
# 假设输入tensor的形状为(batch_size, num_channels, height, width)
tensor = torch.randn(1, 2, 3)

# 随机打乱C通道（假设C通道索引为2）
shuffled_tensor = random_shuffle_channel(tensor, 1)

# 计算原始tensor与打乱后tensor之间的余弦相似度
similarity = cosine_similarity(tensor.view(tensor.size(0), -1), shuffled_tensor.view(shuffled_tensor.size(0), -1))

print(similarity)

print("Original Tensor:\n", tensor)
print("Shuffled Tensor:\n", shuffled_tensor)
print("Cosine Similarity:\n", similarity)


tensor([1.])
Original Tensor:
 tensor([[[ 0.4005, -0.7963,  0.5716],
         [ 0.1521, -0.8838,  1.3007]]])
Shuffled Tensor:
 tensor([[[ 0.4005, -0.7963,  0.5716],
         [ 0.1521, -0.8838,  1.3007]]])
Cosine Similarity:
 tensor([1.])


In [94]:
import torch
import torch.nn.functional as F

def random_shuffle_channels(tensor):
    # 获取C通道的数量
    C = tensor.size(1)
    
    # 生成随机的C通道索引
    perm = torch.randperm(C)
    
    # 打乱C通道
    shuffled_tensor = tensor[:, perm]
    
    return shuffled_tensor

def cosine_similarity(x, y):
    return F.cosine_similarity(x, y, dim=1)

# 示例输入
B, C = 5, 3  # B是样本数量，C是通道数量
tensor = torch.randn(B, C,1)

# 打乱C通道
shuffled_tensor = random_shuffle_channels(tensor)

# 计算余弦相似度
# 这里假设我们想计算tensor和打乱后的shuffled_tensor之间的相似度
cos_sim = cosine_similarity(tensor, shuffled_tensor)

print("Original Tensor:\n", tensor)
print("Shuffled Tensor:\n", shuffled_tensor)
print("Cosine Similarity:\n", cos_sim)


Original Tensor:
 tensor([[[-1.5382],
         [-0.6274],
         [ 0.2217]],

        [[-1.2906],
         [-1.5065],
         [ 0.0462]],

        [[ 0.8653],
         [-0.1549],
         [ 0.8502]],

        [[-0.9844],
         [ 0.1447],
         [ 0.4886]],

        [[-0.7618],
         [-1.5920],
         [-0.9459]]])
Shuffled Tensor:
 tensor([[[-0.6274],
         [-1.5382],
         [ 0.2217]],

        [[-1.5065],
         [-1.2906],
         [ 0.0462]],

        [[-0.1549],
         [ 0.8653],
         [ 0.8502]],

        [[ 0.1447],
         [-0.9844],
         [ 0.4886]],

        [[-1.5920],
         [-0.7618],
         [-0.9459]]])
Cosine Similarity:
 tensor([[ 0.7046],
        [ 0.9882],
        [ 0.3040],
        [-0.0376],
        [ 0.8281]])


# test DEN opearator

In [6]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import sympy
from Signal_processing import SignalProcessingBase
from einops import rearrange

KERNEL_SIZE = 49
FRE = 10
DEVICE = 'cuda'
STRIDE = 1
T = torch.linspace(-KERNEL_SIZE / 2, KERNEL_SIZE / 2, KERNEL_SIZE).view(1, 1, KERNEL_SIZE).to(DEVICE)

def Morlet(t):
    C = pow(math.pi, 0.25)
    f = FRE
    w = 2 * math.pi * f    
    y = C * torch.exp(-torch.pow(t, 2) / 2) * torch.cos(w * t)
    return y

def Laplace(t):
    a = 0.08
    ep = 0.03
    tal = 0.1
    f = FRE
    w = 2 * math.pi * f
    q = torch.tensor(1 - pow(ep, 2))
    y = a * torch.exp((-ep / (torch.sqrt(q))) * (w * (t - tal))) * (-torch.sin(w * (t - tal)))
    return y

class convlutional_operator(nn.Module):
    def __init__(self, kernel_op='conv_sin', dim=1, stride=STRIDE, kernel_size=KERNEL_SIZE, device='cuda', in_channels=1):
        super().__init__()
        self.affline = nn.InstanceNorm1d(num_features=dim, affine=True).to(device)
        op_dic = {'conv_sin': torch.sin,
                  'conv_sin2': lambda x: torch.sin(x ** 2),
                  'conv_exp': torch.exp,
                  'conv_exp2': lambda x: torch.exp(x ** 2),
                  'Morlet': Morlet,
                  'Laplace': Laplace}
        self.op = op_dic[kernel_op]
        self.stride = stride
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.t = torch.linspace(-math.pi / 2, math.pi / 2, kernel_size).view(1, 1, kernel_size).to(device)

    def forward(self, x):
        x = rearrange(x, 'b l c -> b c l')
        self.aff_t = self.affline(self.t)
        self.weight = self.op(self.aff_t).view(1, 1, -1).repeat(self.in_channels, 1, 1)
        conv = F.conv1d(x, self.weight, stride=self.stride, padding=(self.kernel_size - 1) // 2, dilation=1, groups=self.in_channels)
        conv = rearrange(conv, 'b c l -> b l c')
        return conv

class signal_filter_(nn.Module):
    def __init__(self, kernel_op='order1_MA', dim=1, stride=STRIDE, kernel_size=KERNEL_SIZE, device='cuda', in_channels=1):
        super().__init__()
        self.affline = nn.InstanceNorm1d(num_features=dim, affine=True).to(device)
        op_dic = {'order1_MA': torch.tensor([0.5, 0, 0.5]),
                  'order2_MA': torch.tensor([1 / 3, 1 / 3, 1 / 3]),
                  'order1_DF': torch.tensor([-1.0, 0, 1.0]),
                  'order2_DF': torch.tensor([-1.0, 2.0, -1.0])}
        self.weight = op_dic[kernel_op].view(1, 1, -1).to(device).repeat(in_channels, 1, 1)
        self.stride = stride
        self.kernel_size = 3
        self.in_channels = in_channels

    def forward(self, x):
        x = rearrange(x, 'b l c -> b c l')
        conv = F.conv1d(x, self.weight, stride=self.stride, padding=(self.kernel_size - 1) // 2, dilation=1, groups=self.in_channels)
        conv = rearrange(conv, 'b c l -> b l c')
        return conv

class MorletFilter(SignalProcessingBase):
    def __init__(self, args):
        super(MorletFilter, self).__init__(args)
        self.name = "Morlet"
        self.convolution_operator = convlutional_operator('Morlet', in_channels=args.scale, device=args.device)

    def forward(self, x):
        x_transformed = self.convolution_operator(x)
        return x_transformed

class LaplaceFilter(SignalProcessingBase):
    def __init__(self, args):
        super(LaplaceFilter, self).__init__(args)
        self.name = "Laplace"
        self.convolution_operator = convlutional_operator('Laplace', in_channels=args.scale, device=args.device)

    def forward(self, x):
        x_transformed = self.convolution_operator(x)
        return x_transformed

class Order1MAFilter(SignalProcessingBase):
    def __init__(self, args):
        super(Order1MAFilter, self).__init__(args)
        self.name = "order1_MA"
        self.filter_operator = signal_filter_('order1_MA', in_channels=args.scale, device=args.device)

    def forward(self, x):
        x_transformed = self.filter_operator(x)
        return x_transformed

class Order2MAFilter(SignalProcessingBase):
    def __init__(self, args):
        super(Order2MAFilter, self).__init__(args)
        self.name = "order2_MA"
        self.filter_operator = signal_filter_('order2_MA', in_channels=args.scale, device=args.device)

    def forward(self, x):
        x_transformed = self.filter_operator(x)
        return x_transformed

class Order1DFFilter(SignalProcessingBase):
    def __init__(self, args):
        super(Order1DFFilter, self).__init__(args)
        self.name = "order1_DF"
        self.filter_operator = signal_filter_('order1_DF', in_channels=args.scale, device=args.device)

    def forward(self, x):
        x_transformed = self.filter_operator(x)
        return x_transformed

class Order2DFFilter(SignalProcessingBase):
    def __init__(self, args):
        super(Order2DFFilter, self).__init__(args)
        self.name = "order2_DF"
        self.filter_operator = signal_filter_('order2_DF', in_channels=args.scale, device=args.device)

    def forward(self, x):
        x_transformed = self.filter_operator(x)
        return x_transformed

class LogOperation(SignalProcessingBase):
    def __init__(self, args):
        super(LogOperation, self).__init__(args)
        self.name = "log"

    def forward(self, x):
        return torch.log(x)
class SquOperation(SignalProcessingBase):
    def __init__(self, args):
        super(SquOperation, self).__init__(args)
        self.name = "squ"

    def forward(self, x):
        return x ** 2

class SinOperation(SignalProcessingBase):
    def __init__(self, args):
        super(SinOperation, self).__init__(args)
        self.name = "sin"
        self.fre = FRE # TODO learbable

    def forward(self, x):
        return torch.sin(self.fre * x)




class SignalProcessingBase(torch.nn.Module):
    def __init__(self, args):
        super(SignalProcessingBase, self).__init__()
        self.args = args
        self.in_dim = args.in_dim
        self.out_dim = args.out_dim
        self.in_channels = args.in_channels
        self.out_channels = args.out_channels
        self.device = args.device
        self.to(self.device)

    def forward(self, x):
        raise NotImplementedError("This method should be implemented by subclass.")
    
    def test_forward(self):
        test_input = torch.randn(2, self.in_dim, self.in_channels).to(self.device)
        output = self.forward(test_input)
        assert output.shape == (2, self.out_dim, self.out_channels), f"\
        input shape is {test_input.shape}, \n\
        Output shape is {output.shape}, \n\
        expected {(2, self.out_dim, self.out_channels)}"

class SignalProcessingModuleDict(torch.nn.ModuleDict):
    def __init__(self, module_dict):
        super(SignalProcessingModuleDict, self).__init__(module_dict)

    def forward(self, x, key):
        if key in self:
            return self[key](x)
        else:
            raise KeyError(f"No signal processing module found for key: {key}")
        
    def test_forward(self):
        for key in self.keys():
            self[key].test_forward()
class Args:
    def __init__(self, in_dim, out_dim, in_channels, out_channels, device, scale):
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.device = device
        self.scale = scale
        self.f_c_mu = 0
        self.f_c_sigma = 1
        self.f_b_mu = 0
        self.f_b_sigma = 1

# 设置测试参数
args = Args(in_dim=49, out_dim=49, in_channels=3, out_channels=3, device='cuda', scale=3)

# 创建过滤器实例
morlet_filter = MorletFilter(args)
laplace_filter = LaplaceFilter(args)
order1_ma_filter = Order1MAFilter(args)
order2_ma_filter = Order2MAFilter(args)
order1_df_filter = Order1DFFilter(args)
order2_df_filter = Order2DFFilter(args)

# 创建模块字典
module_dict = {
    "Morlet": morlet_filter,
    "Laplace": laplace_filter,
    "order1_MA": order1_ma_filter,
    "order2_MA": order2_ma_filter,
    "order1_DF": order1_df_filter,
    "order2_DF": order2_df_filter
}

spm_dict = SignalProcessingModuleDict(module_dict)

# 测试所有模块
spm_dict.test_forward()

print("All tests passed.")


# 创建操作实例
log_op = LogOperation(args)
squ_op = SquOperation(args)
sin_op = SinOperation(args)

# 测试所有操作
log_op.test_forward()
squ_op.test_forward()
sin_op.test_forward()

print("All tests passed.")


All tests passed.
All tests passed.


In [7]:
import torch

# 定义基类和各个操作类
class SignalProcessingBase2Arity(torch.nn.Module):
    def __init__(self, args):
        super(SignalProcessingBase2Arity, self).__init__()
        self.args = args
        self.in_dim = args.in_dim
        self.out_dim = args.out_dim
        self.in_channels = args.in_channels
        self.out_channels = args.out_channels
        self.device = args.device
        self.to(self.device)

    def split_input(self, x):
        # 拆分输入信号
        half_channels = self.in_channels // 2
        x1 = x[:, :, :half_channels]
        x2 = x[:, :, half_channels:]
        return x1, x2

    def forward(self, x):
        x1, x2 = self.split_input(x)
        return self.operation(x1, x2)

    def operation(self, x1, x2):
        raise NotImplementedError("This method should be implemented by subclass.")
    
    def test_forward(self):
        test_input = torch.randn(2, self.in_dim, self.in_channels).to(self.device)
        output = self.forward(test_input)
        assert output.shape == (2, self.out_dim, self.out_channels), f"\
        input shape is {test_input.shape}, \n\
        Output shape is {output.shape}, \n\
        expected {(2, self.out_dim, self.out_channels)}"

class AddOperation(SignalProcessingBase2Arity):
    def __init__(self, args):
        super(AddOperation, self).__init__(args)
        self.name = "add"

    def operation(self, x1, x2):
        return x1 + x2

class MulOperation(SignalProcessingBase2Arity):
    def __init__(self, args):
        super(MulOperation, self).__init__(args)
        self.name = "mul"

    def operation(self, x1, x2):
        return x1 * x2

class DivOperation(SignalProcessingBase2Arity):
    def __init__(self, args):
        super(DivOperation, self).__init__(args)
        self.name = "div"

    def operation(self, x1, x2):
        return x1 / x2

# 设置测试参数
class Args:
    def __init__(self, in_dim, out_dim, in_channels, out_channels, device, scale):
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.device = device
        self.scale = scale

args = Args(in_dim=49, out_dim=49, in_channels=4, out_channels=2, device='cuda', scale=3)

# 创建操作实例
add_op = AddOperation(args)
mul_op = MulOperation(args)
div_op = DivOperation(args)

# 测试 split_input 方法
def test_split_input(operation):
    test_input = torch.randn(2, args.in_dim, args.in_channels).to(args.device)
    x1, x2 = operation.split_input(test_input)
    assert x1.shape == (2, args.in_dim, args.in_channels // 2), f"Expected shape: {(2, args.in_dim, args.in_channels // 2)}, but got {x1.shape}"
    assert x2.shape == (2, args.in_dim, args.in_channels // 2), f"Expected shape: {(2, args.in_dim, args.in_channels // 2)}, but got {x2.shape}"
    print(f"split_input test passed for {operation.name}")

# 测试 forward 方法
def test_forward(operation):
    operation.test_forward()
    print(f"forward test passed for {operation.name}")

# 运行测试
for op in [add_op, mul_op, div_op]:
    test_split_input(op)
    test_forward(op)

print("All tests passed.")


split_input test passed for add
forward test passed for add
split_input test passed for mul
forward test passed for mul
split_input test passed for div
forward test passed for div
All tests passed.


# 测试逻辑

In [1]:
import torch
import torch.nn as nn

ONE = torch.Tensor([1]).cuda()
ZERO = torch.Tensor([0]).cuda()

class LogicInferenceBase(nn.Module):
    def __init__(self, args):
        super(LogicInferenceBase, self).__init__()
        self.args = args
        self.in_dim = args.in_dim
        self.out_dim = args.out_dim
        self.in_channels = args.in_channels
        self.out_channels = args.out_channels
        self.device = args.device
        self.to(self.device)
        
    @staticmethod
    def generalized_softmax(x, y, alpha=20):
        numerator = x * torch.exp(alpha * x) + y * torch.exp(alpha * y)
        denominator = torch.exp(alpha * x) + torch.exp(alpha * y)
        return numerator / denominator 

    @staticmethod
    def generalized_softmin(x, y, alpha=20):
        return -LogicInferenceBase.generalized_softmax(-x, -y, alpha=alpha)

    @staticmethod
    def implication(x, y):
        return LogicInferenceBase.generalized_softmin(ONE, ONE - x + y)

    @staticmethod
    def equivalence(x, y):
        return ONE - torch.abs(x - y)

    @staticmethod
    def negation(x):
        return ONE - x

    @staticmethod
    def weak_conjunction(x, y):
        return LogicInferenceBase.generalized_softmin(x, y)

    @staticmethod
    def weak_disjunction(x, y):
        return LogicInferenceBase.generalized_softmax(x, y)

    @staticmethod
    def strong_conjunction(x, y):
        return LogicInferenceBase.generalized_softmax(ZERO, x + y - 1)

    @staticmethod
    def strong_disjunction(x, y):
        return LogicInferenceBase.generalized_softmin(ONE, x + y)


In [11]:
import torch
import torch.nn as nn

ONE = torch.Tensor([1]).cuda()
ZERO = torch.Tensor([0]).cuda()

class LogicInferenceBase(nn.Module):
    def __init__(self, args):
        super(LogicInferenceBase, self).__init__()
        self.args = args
        self.in_channels = args.in_channels
        self.out_channels = args.out_channels
        self.device = args.device
        self.to(self.device)
        
    @staticmethod
    def generalized_softmax(x, y, alpha=20):
        numerator = x * torch.exp(alpha * x) + y * torch.exp(alpha * y)
        denominator = torch.exp(alpha * x) + torch.exp(alpha * y)
        return numerator / denominator 

    @staticmethod
    def generalized_softmin(x, y, alpha=20):
        return -LogicInferenceBase.generalized_softmax(-x, -y, alpha=alpha)

    @staticmethod
    def implication(x, y):
        return LogicInferenceBase.generalized_softmin(ONE, ONE - x + y)

    @staticmethod
    def equivalence(x, y):
        return ONE - torch.abs(x - y)

    @staticmethod
    def negation(x):
        return ONE - x

    @staticmethod
    def weak_conjunction(x, y):
        return LogicInferenceBase.generalized_softmin(x, y)

    @staticmethod
    def weak_disjunction(x, y):
        return LogicInferenceBase.generalized_softmax(x, y)

    @staticmethod
    def strong_conjunction(x, y):
        return LogicInferenceBase.generalized_softmax(ZERO, x + y - 1)

    @staticmethod
    def strong_disjunction(x, y):
        return LogicInferenceBase.generalized_softmin(ONE, x + y)
    def test_forward(self):
        test_input = torch.randn(2, self.in_channels).to(self.device)
        output = self.forward(test_input)
        assert output.shape == (2, self.out_channels), f"\
        input shape is {test_input.shape}, \n\
        Output shape is {output.shape}, \n\
        expected {(2, self.out_channels)}"
        
    def forward(self, x):
        raise NotImplementedError("This method should be implemented by subclass.")

    

        

In [12]:
class LogicInferenceBase2Arity(LogicInferenceBase):
    def __init__(self, args):
        super(LogicInferenceBase2Arity, self).__init__(args)
        
    def split_input(self, x):
        # 拆分输入信号
        half_channels = self.in_channels // 2
        x1 = x[:, :half_channels]
        x2 = x[:, half_channels:]
        return x1, x2
    def repeat_input(self, x):
        return torch.cat([x, x], dim=-1)
    def forward(self, x):
        x1, x2 = self.split_input(x)
        x = self.operation(x1, x2)
        x = self.repeat_input(x)
        return 

    def operation(self, x1, x2):
        raise NotImplementedError("This method should be implemented by subclass.")
    
    def test_forward(self):
        test_input = torch.randn(2, self.in_channels).to(self.device)
        output = self.forward(test_input)
        assert output.shape == (2, self.out_channels), f"\
        input shape is {test_input.shape}, \n\
        Output shape is {output.shape}, \n\
        expected {(2, self.out_channels)}"


In [13]:
class ImplicationOperation(LogicInferenceBase2Arity):
    def __init__(self, args):
        super(ImplicationOperation, self).__init__(args)
        self.name = "implication"

    def operation(self, x1, x2):
        return LogicInferenceBase.implication(x1, x2)

class EquivalenceOperation(LogicInferenceBase2Arity):
    def __init__(self, args):
        super(EquivalenceOperation, self).__init__(args)
        self.name = "equivalence"

    def operation(self, x1, x2):
        return LogicInferenceBase.equivalence(x1, x2)

class NegationOperation(LogicInferenceBase):
    def __init__(self, args):
        super(NegationOperation, self).__init__(args)
        self.name = "negation"

    def forward(self, x):
        # 对于 negation，只使用 x1
        return LogicInferenceBase.negation(x)

class WeakConjunctionOperation(LogicInferenceBase2Arity):
    def __init__(self, args):
        super(WeakConjunctionOperation, self).__init__(args)
        self.name = "weak_conjunction"

    def operation(self, x1, x2):
        return LogicInferenceBase.weak_conjunction(x1, x2)

class WeakDisjunctionOperation(LogicInferenceBase2Arity):
    def __init__(self, args):
        super(WeakDisjunctionOperation, self).__init__(args)
        self.name = "weak_disjunction"

    def operation(self, x1, x2):
        return LogicInferenceBase.weak_disjunction(x1, x2)

class StrongConjunctionOperation(LogicInferenceBase2Arity):
    def __init__(self, args):
        super(StrongConjunctionOperation, self).__init__(args)
        self.name = "strong_conjunction"

    def operation(self, x1, x2):
        return LogicInferenceBase.strong_conjunction(x1, x2)

class StrongDisjunctionOperation(LogicInferenceBase2Arity):
    def __init__(self, args):
        super(StrongDisjunctionOperation, self).__init__(args)
        self.name = "strong_disjunction"

    def operation(self, x1, x2):
        return LogicInferenceBase.strong_disjunction(x1, x2)


In [14]:
class Args:
    def __init__(self, in_channels, out_channels, device):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.device = device

args = Args(in_channels=4, out_channels=4, device='cuda')

# 创建操作实例
implication_op = ImplicationOperation(args)
equivalence_op = EquivalenceOperation(args)
negation_op = NegationOperation(args)
weak_conjunction_op = WeakConjunctionOperation(args)
weak_disjunction_op = WeakDisjunctionOperation(args)
strong_conjunction_op = StrongConjunctionOperation(args)
strong_disjunction_op = StrongDisjunctionOperation(args)

# 测试 split_input 方法
def test_split_input(operation):
    test_input = torch.randn(2, args.in_channels).to(args.device)
    x1, x2 = operation.split_input(test_input)
    assert x1.shape == (2, args.in_channels // 2), f"Expected shape: {(2, args.in_channels // 2)}, but got {x1.shape}"
    assert x2.shape == (2, args.in_channels // 2), f"Expected shape: {(2, args.in_channels // 2)}, but got {x2.shape}"
    print(f"split_input test passed for {operation.name}")

# 测试 forward 方法
def test_forward(operation):
    operation.test_forward()
    print(f"forward test passed for {operation.name}")

# 运行测试
for op in [implication_op, equivalence_op, negation_op, weak_conjunction_op, weak_disjunction_op, strong_conjunction_op, strong_disjunction_op]:
    if op is not negation_op:  # negation_op 为一元操作，不需要 split_input 方法
        test_split_input(op)
    test_forward(op)

print("All tests passed.")


split_input test passed for implication
forward test passed for implication
split_input test passed for equivalence
forward test passed for equivalence
forward test passed for negation
split_input test passed for weak_conjunction
forward test passed for weak_conjunction
split_input test passed for weak_disjunction
forward test passed for weak_disjunction
split_input test passed for strong_conjunction
forward test passed for strong_conjunction
split_input test passed for strong_disjunction
forward test passed for strong_disjunction
All tests passed.
