In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import h5py
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchinfo import summary


# Define hyper-parameters
lr = 0.5
max_epochs = 160
batch_size = 20
input_height = 32000
input_width = 1
nb_classes = 2

In [2]:
class VocalDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [3]:
class STFT(nn.Module):  
    def __init__(self):
        super(STFT, self).__init__()

    def forward(self, x):
        # 確保輸入尺寸為 [batch_size, 1, 32000]
        # if x.dim() != 3 or x.size(2) != 1:
        #     raise ValueError("Input tensor must be of shape [batch_size, 1, time]")

        x = x.squeeze(dim=2)  # 移除不必要的維度
        window = torch.hann_window(2048, device=x.device)
        
        # 執行 STFT 轉換
        stft = torch.stft(x, n_fft=2048, hop_length=512, win_length=2048, window=window, return_complex=True)
        
        # 計算頻譜圖
        spectrogram = torch.abs(stft)
        
        # 修剪頻率維度至 1024，並調整維度順序
        spectrogram = spectrogram[:, :1024, :].permute(0, 2, 1)
        
        # 確保輸出形狀為 (batch_size, 1, 64, 1024)
        spectrogram = spectrogram.unsqueeze(1)  # 增加 channel 維度
        # if spectrogram.size(2) != 63 or spectrogram.size(3) != 1024:
        #     raise ValueError("Output tensor must be of shape [batch_size, 1, 64, 1024]")
        
        return spectrogram

In [4]:
class USCLLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, pooling=False, padding=0, activation='relu'):
        super(USCLLayer, self).__init__()
        
        self.pooling = pooling

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size)

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.bn1(x)

        if self.pooling:
            x = self.pool(x)

        return x

In [5]:
class MultiheadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiheadSelfAttention, self).__init__()
        self.mhsa = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, add_bias_kv=True, batch_first=True)
        
    def forward(self, x):
        residual = x
        x , _= self.mhsa(x, x, x)
        x = x + residual
        return x    

In [8]:
class SCNN18_MHSA_1layers(nn.Module):
    def __init__(self, nb_classes):
        super(SCNN18_MHSA_1layers, self).__init__()

        self.STFT = STFT()
        self.uscl_conv1 = USCLLayer(in_channels=1, out_channels=64, kernel_size=(3, 2), stride=(3, 2), pooling=False, padding=0)
        self.uscl_conv2 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        # self.mhsa1 = MultiheadSelfAttention(embed_dim=2048, num_heads=256)
        self.uscl_conv3 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv4 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        # self.mhsa2 = MultiheadSelfAttention(embed_dim=1024, num_heads=128)
        self.uscl_conv5 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv6 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv7 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        # self.mhsa3 = MultiheadSelfAttention(embed_dim=256, num_heads=32)
        self.uscl_conv8 = USCLLayer(in_channels=64, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv9 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        # self.mhsa4 = MultiheadSelfAttention(embed_dim=512, num_heads=64)
        self.uscl_conv10 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv11 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        # self.mhsa5 = MultiheadSelfAttention(embed_dim=128, num_heads=16)
        self.uscl_conv12 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv13 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        # self.mhsa6 = MultiheadSelfAttention(embed_dim=128, num_heads=16)
        self.uscl_conv14 = USCLLayer(in_channels=128, out_channels=256, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv15 = USCLLayer(in_channels=256, out_channels=256, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv16 = USCLLayer(in_channels=256, out_channels=256, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        # self.mhsa7 = MultiheadSelfAttention(embed_dim=256, num_heads=32)

        self.final_uscl = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 2), stride=(1, 1), padding=0)
        self.final_uscl_relu = nn.ReLU()
        self.final_uscl_bn = nn.BatchNorm2d(256)
        self.final_pool = nn.MaxPool2d((3, 2))

        self.final_conv = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=0)
        self.final_relu = nn.ReLU()
        self.final_bn = nn.BatchNorm2d(256)
        self.mhsa8 = MultiheadSelfAttention(embed_dim=256, num_heads=32)

        self.dropout = nn.Dropout(0.5)
        self.flatten = nn.Flatten()
        self.out_dense = nn.Linear(1792, nb_classes)

    def forward(self, x):
        x = self.STFT(x)
        x = self.uscl_conv1(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv2(x)
        # x = torch.reshape(x, (x.shape[0], 336, 2048))
        # x = self.mhsa1(x)
        # x = torch.reshape(x, (x.shape[0], 64, 21, 512))
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv3(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv4(x)
        # x = torch.reshape(x, (x.shape[0], 336, 1024))
        # x = self.mhsa2(x)
        # x = torch.reshape(x, (x.shape[0], 64, 21, 256))
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv5(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv6(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv7(x)
        # x = torch.reshape(x, (x.shape[0], 336, 256))
        # x = self.mhsa3(x)
        # x = torch.reshape(x, (x.shape[0], 64, 21, 64))
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv8(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv9(x)
        # x = torch.reshape(x, (x.shape[0], 336, 512))
        # x = self.mhsa4(x)
        # x = torch.reshape(x, (x.shape[0], 128, 21, 64))
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv10(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv11(x)
        # x = torch.reshape(x, (x.shape[0], 336, 128))
        # x = self.mhsa5(x)
        # x = torch.reshape(x, (x.shape[0], 128, 21, 16))
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv12(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv13(x)
        # x = torch.reshape(x, (x.shape[0], 168, 128))
        # x = self.mhsa6(x)
        # x = torch.reshape(x, (x.shape[0], 128, 21, 8))
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv14(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv15(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv16(x)
        # x = torch.reshape(x, (x.shape[0], 42, 256))
        # x = self.mhsa7(x)
        # x = torch.reshape(x, (x.shape[0], 256, 21, 2))
        x = F.pad(x, (0, 1, 1, 1))
        x = self.final_uscl(x)

        x = self.final_uscl_relu(x)
        x = self.final_uscl_bn(x)
        x = self.final_pool(x)

        x = self.final_conv(x)
        x = self.final_relu(x)
        x = self.final_bn(x)
        x = torch.reshape(x, (x.shape[0], 7, 256))
        x = self.mhsa8(x)
        x = torch.reshape(x, (x.shape[0], 256, 7, 1))

        x = self.dropout(x)
        x = self.flatten(x)
        x = self.out_dense(x)
        return x

In [9]:
model = SCNN18_MHSA_1layers(nb_classes)
summary(model, input_size=(1, 1, 32000))

Layer (type:depth-idx)                   Output Shape              Param #
SCNN18_MHSA_1layers                      [1, 2]                    --
├─STFT: 1-1                              [1, 1, 63, 1024]          --
├─USCLLayer: 1-2                         [1, 64, 21, 512]          --
│    └─Conv2d: 2-1                       [1, 64, 21, 512]          448
│    └─ReLU: 2-2                         [1, 64, 21, 512]          --
│    └─BatchNorm2d: 2-3                  [1, 64, 21, 512]          128
├─USCLLayer: 1-3                         [1, 64, 21, 512]          --
│    └─Conv2d: 2-4                       [1, 64, 21, 512]          8,256
│    └─ReLU: 2-5                         [1, 64, 21, 512]          --
│    └─BatchNorm2d: 2-6                  [1, 64, 21, 512]          128
├─USCLLayer: 1-4                         [1, 64, 21, 512]          --
│    └─Conv2d: 2-7                       [1, 64, 21, 512]          8,256
│    └─ReLU: 2-8                         [1, 64, 21, 512]          --
│    └

In [7]:
def save_log(message, log_file):
    with open(log_file, 'a') as f:
        f.write(message + '\n')

In [8]:
if __name__ == '__main__':
    log_file = 'Jamendo_SCNN18_MHSA_training_log3.txt'
    # Load data from HDF5 files
    # with h5py.File('./FMA-C-1-fixed-SCNN-Train.h5', 'r') as train_file:
    training_data = 'FMA-C-1-fixed-SCNN-Train.h5'
    with h5py.File('./FMA-C-1-fixed-SCNN-Train.h5', 'r') as train_file:
        train_data = torch.tensor(train_file['X'][:])
        train_labels = torch.tensor(train_file['Y'][:])
        print("train_data",train_data.shape)

    with h5py.File('./FMA-C-1-fixed-SCNN-Test.h5', 'r') as val_file:
        val_data = torch.tensor(val_file['X'][:])
        val_labels = torch.tensor(val_file['Y'][:])

    # Create TensorDataset
    train_dataset = VocalDataset(train_data, train_labels)
    val_dataset = VocalDataset(val_data, val_labels)

    # DataLoader for training dataset
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)

    # DataLoader for validation dataset
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    # Printing dataset lengths
    save_log("Training dataset: " + training_data, log_file)
    save_log("Training dataset length: " + str(len(train_dataset)), log_file)
    save_log("Validation dataset length: " + str(len(val_dataset)), log_file)

    # Check device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    save_log("Using device: " + str(device), log_file)

    # Initialize model, loss function and optimizer
    model = SCNN18_MHSA_1layers(nb_classes).to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        save_log("Let's use " + str(torch.cuda.device_count()) + " GPUs!", log_file)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adadelta(model.parameters(), lr=lr)
    # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Train the model
    for epoch in range(max_epochs):
        model.train()
        epoch_train_loss = 0.0
        correct_train = 0
        total_train = 0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            labels = torch.argmax(labels, dim=1)
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            print(f"[{epoch + 1}, {i + 1}] loss: {epoch_train_loss}")


        train_accuracy = 100 * correct_train / total_train

        # Validate the model
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                labels = torch.argmax(labels, dim=1)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_accuracy = 100 * correct_val / total_val

        epoch_log = (f"Epoch {epoch + 1}:\n"
                     f"Train loss: {epoch_train_loss}, Train accuracy: {train_accuracy:.2f}%\n"
                     f"Validation loss: {val_loss / len(val_loader):.8f}, Validation accuracy: {val_accuracy:.2f}%\n")
        save_log(epoch_log, log_file)

    print("Finished Training")
    save_log("./SCNN18_FMAC1_MHSA_1layers_weights_1.pth", log_file)
    torch.save(model.state_dict(), './SCNN18_FMAC1_MHSA_1layers_weights_1.pth')

train_data torch.Size([12254, 32000, 1])




[1, 1] loss: 0.8302696943283081
[1, 2] loss: 1.6352658867835999
[1, 3] loss: 2.335068166255951
[1, 4] loss: 3.3290732502937317
[1, 5] loss: 4.274637877941132
[1, 6] loss: 5.200613975524902
[1, 7] loss: 5.97530210018158
[1, 8] loss: 6.679151117801666
[1, 9] loss: 7.372725188732147
[1, 10] loss: 8.357544004917145
[1, 11] loss: 9.313083052635193
[1, 12] loss: 10.11479926109314
[1, 13] loss: 10.5976123213768
[1, 14] loss: 11.246461689472198
[1, 15] loss: 12.027625262737274
[1, 16] loss: 12.634738385677338
[1, 17] loss: 13.906274616718292
[1, 18] loss: 14.960134088993073
[1, 19] loss: 15.79520457983017
[1, 20] loss: 16.77294659614563
[1, 21] loss: 18.021053910255432
[1, 22] loss: 19.37508499622345
[1, 23] loss: 20.04096382856369
[1, 24] loss: 21.194097459316254
[1, 25] loss: 21.982347071170807
[1, 26] loss: 23.1859033703804
[1, 27] loss: 23.984370052814484
[1, 28] loss: 24.812599778175354
[1, 29] loss: 26.089016318321228
[1, 30] loss: 26.90662384033203
[1, 31] loss: 27.638180434703827
[1, 3