In [5]:
import torch
import torch.nn as nn

class CustomBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-10):
        super(CustomBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0, keepdim=True)
            var = x.var(dim=0, keepdim=True, unbiased=False)
            self.running_mean = (1 - self.eps) * self.running_mean + self.eps * mean
            self.running_var = (1 - self.eps) * self.running_var + self.eps * var
            out = (x - mean) / (var.sqrt() + self.eps)
        else:
            out = (x - self.running_mean) / (self.running_var.sqrt() + self.eps)
        return out

In [6]:
# 创建一个随机输入张量
input_tensor = torch.randn(64, 624, 1)

# 创建一个CustomBatchNorm实例
norm_layer = CustomBatchNorm(num_features=624)

# 将输入张量传递给norm_layer
output_tensor = norm_layer(input_tensor)

# 打印输出张量的形状
print(output_tensor.shape)

torch.Size([64, 624, 1])


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class SEAttention(nn.Module):

    def __init__(self, channel=512,reduction=4):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, l, c = x.size()
        x = rearrange(x, 'b l c -> b c l')
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c,1)
        res = x * y
        return rearrange(res, 'b c l -> b l c')

x = torch.randn(2, 128, 16)
att = SEAttention(channel=16)
y = att(x)

# other wavelet

## sinc

In [None]:
import math

def sinc(band, t_right):
    y_right = torch.sin(2 * math.pi * band * t_right) / ((2 * math.pi * band * t_right) + 1e-6)
    y_left = torch.flip(y_right, [0])
    y = torch.cat([y_left, torch.ones(1).to(t_right.device), y_right])
    return y

def Mexh(p):
    # p = 0.04 * p  # 将时间转化为在[-5,5]这个区间内
    y = (1 - torch.pow(p, 2)) * torch.exp(-torch.pow(p, 2) / 2)

    return y

def Laplace(p):
    A = 0.08
    ep = 0.03
    tal = 0.1
    f = 50
    w = 2 * pi * f
    q = torch.tensor(1 - pow(ep, 2))
    y = A * torch.exp((-ep / (torch.sqrt(q))) * (w * (p - tal))) * (-torch.sin(w * (p - tal)))
    return y

class SincConv_multiple_channel(nn.Module):
    def __init__(self, out_channels, kernel_size, in_channels=1):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        if kernel_size % 2 == 0:
            self.kernel_size += 1

        self.a_ = nn.Parameter(torch.linspace(1, 10, out_channels)).view(-1, 1)
        self.b_ = nn.Parameter(torch.linspace(0, 10, out_channels)).view(-1, 1)

    def forward(self, waveforms):
        half_kernel = self.kernel_size // 2
        time_disc = torch.linspace(-half_kernel, half_kernel, steps=self.kernel_size).to(waveforms.device)
        self.a_ = self.a_.to(waveforms.device)
        self.b_ = self.b_.to(waveforms.device)
        
        filters = []
        for i in range(self.out_channels):
            band = self.a_[i]
            t_right = time_disc - self.b_[i]
            filter = sinc(band, t_right)
            filters.append(filter)

        filters = torch.stack(filters)
        self.filters = filters.view(self.out_channels, 1, -1)

        output = []
        for i in range(self.in_channels):
            output.append(F.conv1d(waveforms[:, i:i+1], self.filters, stride=1, padding=half_kernel, dilation=1, bias=None, groups=1))
        return torch.cat(output, dim=1)


class Morlet_multiple_channel(nn.Module):

    def __init__(self, out_channels, kernel_size, in_channels=1):

        super(Morlet_multiple_channel, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size - 1

        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1

        self.a_ = nn.Parameter(torch.linspace(1, 10, out_channels)).view(-1, 1)

        self.b_ = nn.Parameter(torch.linspace(0, 10, out_channels)).view(-1, 1)

    def forward(self, waveforms):

        time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1,
                                         steps=int((self.kernel_size / 2)))

        time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1,
                                        steps=int((self.kernel_size / 2)))

        p1 = time_disc_right - self.b_ / self.a_
        p2 = time_disc_left - self.b_ / self.a_

        Morlet_right = Morlet(p1).to(waveforms.device)
        Morlet_left = Morlet(p2).to(waveforms.device)

        Morlet_filter = torch.cat([Morlet_left, Morlet_right], dim=1)  # 40x1x250

        self.filters = (Morlet_filter).view(self.out_channels, 1, self.kernel_size).to(waveforms.device)# .cuda()

        output = []
        for i in range(self.in_channels):
            output.append(F.conv1d(waveforms[:, i:i+1], self.filters, stride=1, padding=1, dilation=1, bias=None, groups=1))
        return torch.cat(output, dim=1)
    
