In [1]:
# 0. The core logic of the WPD_PWCN model
# 1. The input signal is segmented into 1000 data points, corresponding to a temporal resolution of 10 ms, to ensure consistent time scales across all samples.
# 2. Each signal is subjected to a three-level wavelet packet decomposition using the db4 wavelet basis, yielding eight sub-bands that capture distinct frequency-domain characteristics.
# 3. Each sub-band is convolved with 16 learnable Morlet wavelet kernels, resulting in the extraction of 128 low-level features in the first convolutional layer, which possess explicit physical interpretability.
# 4. An attention mechanism is embedded within the first wavelet convolution module to adaptively enhance the feature response of critical frequency components, thereby improving the overall representational capacity.
# 5. The extracted features are subsequently fed into a deep convolutional neural network for further high-level feature abstraction and pattern discrimination, enabling accurate classification of signal categories.

In [None]:
import torch
import numpy as np
import scipy.io
import pandas as pd
import copy
import sys
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from PIL import Image
import torchvision
import imageio
import time
from torchvision import transforms, models, datasets
from torchvision import transforms
import os
import random
from torchinfo import summary
import os   
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [None]:
# The location of the dataset
root = "C:\\Users\\asus\\Desktop\\JMP4\\expdata\\db4" 

In [None]:
class MyDataSet(Dataset):

    def __init__(self, txts_path1: list, txts_class: list, transform=None):
        self.txts_path1 = txts_path1
        self.txts_class = txts_class
        self.transform = transform

    def __len__(self):
        return len(self.txts_path1)

    def __getitem__(self, item):
        txt = scipy.io.loadmat(self.txts_path1[item])
        txt = txt['feature']
        txt = np.expand_dims(txt, 1)
        txt = torch.from_numpy(txt.astype(np.float32))

        txt1 = txt[0,:]
        txt2 = txt[1,:]
        txt3 = txt[2,:]
        txt4 = txt[3,:]
        txt5 = txt[4,:]
        txt6 = txt[5,:]
        txt7 = txt[6,:]
        txt8 = txt[7,:]
        
        label = self.txts_class[item]

        return txt1,txt2,txt3,txt4,txt5,txt6,txt7,txt8, label

In [None]:
import os
import random

manualSeed = 666 
random.seed(manualSeed)

# Get the list of category folders
item_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
item_class.sort()

# Categories are mapped to the index
class_indices = dict((k, v) for v, k in enumerate(item_class))

all_images_path = []
all_images_label = []
every_class_num = []

for cla in item_class:
    cla_path = os.path.join(root, cla)     
    images = [os.path.join(cla_path, i) for i in os.listdir(cla_path)]
    image_class = class_indices[cla]

    all_images_path.extend(images)
    all_images_label.extend([image_class] * len(images))
    every_class_num.append(len(images))

print("{} images were found in the dataset.".format(sum(every_class_num)))

In [None]:
full_dataset = MyDataSet(txts_path1=all_images_path, txts_class=all_images_label)

train_loader = torch.utils.data.DataLoader(
    full_dataset,
    batch_size=32,
    shuffle=True,  
    num_workers=0
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

for step, (txt1,txt2,txt3,txt4,txt5,txt6,txt7,txt8, label) in enumerate(train_loader):
    if step > 0:
        break
sample = txt1[0]
print(sample.shape)
sample = torch.squeeze(sample,dim=0)

sample = sample.numpy()
plt.plot(sample)

In [None]:
import torch
import torch.nn as nn
from math import pi
import torch.nn.functional as F

In [None]:
def Morlet(p):
    C = pow(pi, 0.25)
    y = C * torch.exp(-torch.pow(p, 2) / 2) * torch.cos(2 * pi * p)
    return y

class Morlet_fast(nn.Module):

    def __init__(self, out_channels, kernel_size, in_channels=1):

        super(Morlet_fast, self).__init__()

        if in_channels != 1:

            msg = "MexhConv only support one input channel (here, in_channels = {%i})" % (in_channels)
            raise ValueError(msg)

        self.out_channels = out_channels
        self.kernel_size = kernel_size - 1

        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1

        self.a_ = nn.Parameter(torch.linspace(1, 10, out_channels)).view(-1, 1)

        self.b_ = nn.Parameter(torch.linspace(0, 10, out_channels)).view(-1, 1)

    def forward(self, waveforms):

        time_disc_right = torch.linspace(0, (self.kernel_size / 2) - 1,
                                         steps=int((self.kernel_size / 2)))

        time_disc_left = torch.linspace(-(self.kernel_size / 2) + 1, -1,
                                        steps=int((self.kernel_size / 2)))

        p1 = time_disc_right.cuda() - self.b_.cuda() / self.a_.cuda()
        p2 = time_disc_left.cuda() - self.b_.cuda() / self.a_.cuda()

        Morlet_right = Morlet(p1)
        Morlet_left = Morlet(p2)

        Morlet_filter = torch.cat([Morlet_left, Morlet_right], dim=1)  # 40x1x250

        self.filters = (Morlet_filter).view(self.out_channels, 1, self.kernel_size).cuda()
        output = F.conv1d(waveforms, self.filters, stride=1, padding=1, dilation=1, bias=None, groups=1)

        return output

In [None]:
class ChannelAttention1D(nn.Module):

    def __init__(self, in_channels, ratio=4):
        super(ChannelAttention1D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)

        self.fc = nn.Sequential(
            nn.Conv1d(in_channels, in_channels // ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels // ratio, in_channels, 1, bias=False)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        out = self.sigmoid(out)
        return out,x,out * x  

class SpatialAttention1D(nn.Module):

    def __init__(self, kernel_size=7):
        super(SpatialAttention1D, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv1d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv1(out))
        return out,x,out * x

class CBAM1D(nn.Module):

    def __init__(self, in_channels, ratio=4, kernel_size=3):
        super(CBAM1D, self).__init__()
        self.channelattention = ChannelAttention1D(in_channels, ratio=ratio)
        self.spatialattention = SpatialAttention1D(kernel_size=kernel_size)

    def forward(self, x):
        weight1,x1,out1 = self.channelattention(x)
        weight2,x2,out2 = self.spatialattention(out1)
        return out2
    
class WPD_PWCN(nn.Module):
    def __init__(self, in_channel=1, out_channel=5):
        super(WPD_PWCN, self).__init__()
        self.conv1 = nn.ModuleList([
            nn.Sequential(
                Morlet_fast(16, 32),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            ),
            CBAM1D(in_channels=16)  # ADD CBAM1D 
        ])
        
        self.conv2 = nn.ModuleList([
            nn.Sequential(
                Morlet_fast(16, 32),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            ),
            CBAM1D(in_channels=16)  # ADD CBAM1D
        ])
        
        self.conv3 = nn.ModuleList([
            nn.Sequential(
                Morlet_fast(16, 32),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            ),
            CBAM1D(in_channels=16)  # ADD CBAM1D
        ])
        
        self.conv4 = nn.ModuleList([
            nn.Sequential(
                Morlet_fast(16, 32),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            ),
            CBAM1D(in_channels=16)  # ADD CBAM1D
        ])
                    
        self.conv5 = nn.ModuleList([
            nn.Sequential(
                Morlet_fast(16, 32),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            ),
            CBAM1D(in_channels=16)  # ADD CBAM1D
        ])
                        
        self.conv6 = nn.ModuleList([
            nn.Sequential(
                Morlet_fast(16, 32),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            ),
            CBAM1D(in_channels=16)  # ADD CBAM1D
        ])
                            
        self.conv7 = nn.ModuleList([
            nn.Sequential(
                Morlet_fast(16, 32),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            ),
            CBAM1D(in_channels=16)  # ADD CBAM1D
        ])
                                
        self.conv8 = nn.ModuleList([
            nn.Sequential(
                Morlet_fast(16, 32),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            ),
            CBAM1D(in_channels=16)  # ADD CBAM1D
        ])
        
        self.conv9 = nn.Sequential(
            nn.Conv1d(128, 128, 16),
            nn.BatchNorm1d(128),
            nn.ReLU(),    
            nn.AdaptiveMaxPool1d(25)  # adaptive change the outputsize
        )
        
        self.fc1 = nn.Sequential(
            nn.Linear(128 * 25, 256),
            nn.ReLU(),
            nn.Dropout(p=0.2)
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(p=0.2)
        )
        
        self.fc3 = nn.Linear(64, out_channel)

    def forward(self, x1, x2, x3, x4, x5, x6, x7, x8):
        for module in self.conv1:
            x1 = module(x1)
        for module in self.conv2:
            x2 = module(x2)
        for module in self.conv3:
            x3 = module(x3)
        for module in self.conv4:
            x4 = module(x4)
        for module in self.conv5:
            x5 = module(x5)
        for module in self.conv6:
            x6 = module(x6)
        for module in self.conv7:
            x7 = module(x7)
        for module in self.conv8:
            x8 = module(x8)

        x = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1)
        x = self.conv9(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        return x

In [None]:
model = WPD_PWCN()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
model.to(device)

In [None]:
summary(model, [(1,1,1256),(1,1,1256),(1,1,1256),(1,1,1256),(1,1,1256),(1,1,1256),(1,1,1256),(1,1,1256)])

In [None]:
# five-fold cross-validation
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from sklearn.model_selection import KFold


def weights_init(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, Morlet_fast):
        if isinstance(m, Morlet_fast):
            nn.init.xavier_uniform_(m.a_)
            nn.init.xavier_uniform_(m.b_)
        else:
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

def train_model(model, train_loader, train_rate, criterion, optimizer, num_epochs, num_folds=5):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    train_loss_all = []
    train_acc_all = []
    val_loss_all = []
    val_acc_all = []
    since = time.time()

    kf = KFold(n_splits=num_folds, shuffle=True)

    accuracy_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    for fold, (train_index, val_index) in enumerate(kf.split(train_loader.dataset), 1):
        print(f'Fold {fold}/{num_folds}')
        train_dataset = torch.utils.data.Subset(train_loader.dataset, train_index)
        val_dataset = torch.utils.data.Subset(train_loader.dataset, val_index)

        train_loader_fold = torch.utils.data.DataLoader(train_dataset, batch_size=train_loader.batch_size, shuffle=True)
        val_loader_fold = torch.utils.data.DataLoader(val_dataset, batch_size=train_loader.batch_size, shuffle=False)

        model.apply(weights_init)

        for epoch in range(num_epochs):
            print(f'Epoch {epoch + 1}/{num_epochs}')
            print('-' * 10)

            # training
            model.train()
            train_loss = 0.0
            train_corrects = 0
            train_num = 0
            for step, (b_x1, b_x2, b_x3, b_x4, b_x5, b_x6, b_x7, b_x8, b_y) in enumerate(train_loader_fold):
                b_x1 = b_x1.cuda()
                b_x2 = b_x2.cuda()
                b_x3 = b_x3.cuda()
                b_x4 = b_x4.cuda()
                b_x5 = b_x5.cuda()
                b_x6 = b_x6.cuda()
                b_x7 = b_x7.cuda()
                b_x8 = b_x8.cuda()
                b_y = b_y.cuda()

                optimizer.zero_grad()
                output = model(b_x1, b_x2, b_x3, b_x4, b_x5, b_x6, b_x7, b_x8)
                loss = criterion(output, b_y)
                loss.backward()
                optimizer.step()

                pre_lab = torch.argmax(output, 1)
                train_loss += loss.item() * b_x1.size(0)
                train_corrects += torch.sum(pre_lab == b_y.data)
                train_num += b_x1.size(0)

            train_loss_all.append(train_loss / train_num)
            train_acc_all.append(train_corrects.double().item() / train_num)

            # val
            model.eval()
            val_loss = 0.0
            val_corrects = 0
            val_num = 0
            with torch.no_grad():
                for b_x1, b_x2, b_x3, b_x4, b_x5, b_x6, b_x7, b_x8, b_y in val_loader_fold:
                    b_x1 = b_x1.cuda()
                    b_x2 = b_x2.cuda()
                    b_x3 = b_x3.cuda()
                    b_x4 = b_x4.cuda()
                    b_x5 = b_x5.cuda()
                    b_x6 = b_x6.cuda()
                    b_x7 = b_x7.cuda()
                    b_x8 = b_x8.cuda()
                    b_y = b_y.cuda()

                    output = model(b_x1, b_x2, b_x3, b_x4, b_x5, b_x6, b_x7, b_x8)
                    loss = criterion(output, b_y)

                    pre_lab = torch.argmax(output, 1)
                    val_loss += loss.item() * b_x1.size(0)
                    val_corrects += torch.sum(pre_lab == b_y.data)
                    val_num += b_x1.size(0)

                val_loss_all.append(val_loss / val_num)
                val_acc_all.append(val_corrects.double().item() / val_num)

            print(f'Train Loss: {train_loss_all[-1]:.4f} | Train Acc: {train_acc_all[-1]:.4f}')
            print(f'Val Loss: {val_loss_all[-1]:.4f} | Val Acc: {val_acc_all[-1]:.4f}')

            if val_acc_all[-1] > best_acc:
                best_acc = val_acc_all[-1]
                best_model_wts = copy.deepcopy(model.state_dict())

        # testing
        labels_for_test = []
        preds_for_test = []
        model.eval()
        with torch.no_grad():
            for b_x1, b_x2, b_x3, b_x4, b_x5, b_x6, b_x7, b_x8, b_y in val_loader_fold:
                b_x1 = b_x1.cuda()
                b_x2 = b_x2.cuda()
                b_x3 = b_x3.cuda()
                b_x4 = b_x4.cuda()
                b_x5 = b_x5.cuda()
                b_x6 = b_x6.cuda()
                b_x7 = b_x7.cuda()
                b_x8 = b_x8.cuda()
                b_y = b_y.cuda()

                output = model(b_x1, b_x2, b_x3, b_x4, b_x5, b_x6, b_x7, b_x8)
                pre_lab = torch.argmax(output, 1)
                labels_for_test.extend(b_y.cpu().numpy())
                preds_for_test.extend(pre_lab.cpu().numpy())

        accuracy = accuracy_score(labels_for_test, preds_for_test)
        precision = precision_score(labels_for_test, preds_for_test, average='macro')
        recall = recall_score(labels_for_test, preds_for_test, average='macro')
        f1 = f1_score(labels_for_test, preds_for_test, average='macro')

        accuracy_scores.append(accuracy)
        precision_scores.append(precision)
        recall_scores.append(recall)
        f1_scores.append(f1)

        print(f"Fold {fold} metrics:")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  Precision: {precision:.4f}")
        print(f"  Recall: {recall:.4f}")
        print(f"  F1-score: {f1:.4f}")

        time_use = time.time() - since
        print(f'Train and validation complete in {time_use // 60:.0f}m {time_use % 60:.0f}s\n')

    model.load_state_dict(best_model_wts)

    train_process = pd.DataFrame({
        'epoch': range(len(train_loss_all)),
        'train_loss_all': train_loss_all,
        'val_loss_all': val_loss_all,
        'train_acc_all': train_acc_all,
        'val_acc_all': val_acc_all
    })

    print("Final Results:")
    print(f"Accuracy: Mean = {np.mean(accuracy_scores):.4f}, Std = {np.std(accuracy_scores):.4f}")
    print(f"Precision: Mean = {np.mean(precision_scores):.4f}, Std = {np.std(precision_scores):.4f}")
    print(f"Recall: Mean = {np.mean(recall_scores):.4f}, Std = {np.std(recall_scores):.4f}")
    print(f"F1-score: Mean = {np.mean(f1_scores):.4f}, Std = {np.std(f1_scores):.4f}")

    return model, train_process

In [None]:
# Start training the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

model, train_process = train_model(
    model=model,
    train_loader=train_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=50,
    num_folds=5
)

In [None]:
# Visualization
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1,2,1)
plt.plot(train_process.epoch,train_process.train_loss_all, "ro-",label="Train Loss")
plt.plot(train_process.epoch,train_process.val_loss_all, "bs-",label="Val Loss")
plt.legend()
plt.xlabel('epoch')
plt.ylabel('loss')

plt.subplot(1,2,2)
plt.plot(train_process.epoch,train_process.train_acc_all, "ro-",label="Train acc")
plt.plot(train_process.epoch,train_process.val_acc_all, "bs-",label="Val acc")
plt.legend()
plt.xlabel('epoch')
plt.ylabel('acc')

plt.show()

In [None]:
torch.save(model.state_dict(), 'WPD_PWCN.pth')