In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import h5py
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchinfo import summary


# Define hyper-parameters
lr = 0.5
max_epochs = 160
batch_size = 150
input_height = 32000
input_width = 1
nb_classes = 2

In [2]:
class VocalDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [3]:
class STFT(nn.Module):  
    def __init__(self):
        super(STFT, self).__init__()

    def forward(self, x):
        # 確保輸入尺寸為 [batch_size, 1, 32000]
        # if x.dim() != 3 or x.size(2) != 1:
        #     raise ValueError("Input tensor must be of shape [batch_size, 1, time]")

        x = x.squeeze(dim=2)  # 移除不必要的維度
        window = torch.hann_window(2048, device=x.device)
        
        # 執行 STFT 轉換
        stft = torch.stft(x, n_fft=2048, hop_length=507, win_length=2048, window=window, return_complex=True)
        
        # 計算頻譜圖
        spectrogram = torch.abs(stft)
        
        # 修剪頻率維度至 1024，並調整維度順序
        spectrogram = spectrogram[:, :1024, :].permute(0, 2, 1)
        
        # 確保輸出形狀為 (batch_size, 1, 64, 1024)
        spectrogram = spectrogram.unsqueeze(3)  # 增加 channel 維度
        # if spectrogram.size(2) != 63 or spectrogram.size(3) != 1024:
        #     raise ValueError("Output tensor must be of shape [batch_size, 1, 64, 1024]")
        
        return spectrogram

In [4]:
class PatchReverse(nn.Module):
    def __init__(self, patch_size):
        super(PatchReverse, self).__init__()
        self.patch_size = patch_size

    def forward(self, patches):
        B, H, W, C = patches.shape
        # print("Input shape:", B, H, W, C)
        
        ph, pw = self.patch_size
        
        original_c = C // (ph * pw)
        # print("Patches shape:", B, H, W, ph, pw, original_c)

        patches = patches.view(B, H, W, ph, pw, original_c)
    
        patches = patches.permute(0, 1, 3, 2, 4, 5).contiguous()
        patches = patches.view(B, H * ph, W * pw, original_c)

        return patches

In [5]:
class PatchPartition(nn.Module):
    def __init__(self, patch_size):
        super(PatchPartition, self).__init__()
        self.patch_size = patch_size

    def forward(self, x):
        B, H, W, C = x.shape
        ph, pw = self.patch_size
        assert H % ph == 0 and W % pw == 0, 'Invalid patch size'
        x = x.contiguous().view(B, H // ph, ph, W // pw, pw, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(B, H // ph, W // pw, ph * pw * C)
        return x

In [6]:
class WindowMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_h, window_w):
        super(WindowMultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_h = window_h
        self.window_w = window_w
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True, add_bias_kv=True)

    def forward(self, x):
        B, H, W, C = x.shape
        
        # 計算窗口數量
        num_windows_h = H // self.window_h
        num_windows_w = W // self.window_w

        # 將圖像劃分為多個窗口
        x = x.view(B, num_windows_h, self.window_h, num_windows_w, self.window_w, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(B * num_windows_h * num_windows_w, self.window_h * self.window_w, C)
        
        # 對每個窗口應用多頭注意力
        attn_output, _ = self.multihead_attn(x, x, x)
        
        # 將窗口重構回原始尺寸
        attn_output = x.view(B, num_windows_h, num_windows_w, self.window_h, self.window_w, C)
        attn_output = attn_output.permute(0, 1, 3, 2, 4, 5).contiguous()
        attn_output = attn_output.view(B, H, W, C)
        
        return attn_output

In [7]:
class ShiftedWindowMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, height, width, shifted_window_h, shifted_window_w):
        super(ShiftedWindowMultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.shifted_window_h = shifted_window_h
        self.shifted_window_w = shifted_window_w
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.height = height
        self.width = width

    def splitwindow(self, x):
        H, W, C = x.shape
        B = 1
        x = x.contiguous().view(B, -1, self.shifted_window_h, self.shifted_window_w, C)
        
        def concatenate_slices(image_list, dim=0):
            return torch.cat(image_list, dim=dim)

        output = []
        output.append(x[:, 0])  # 保留第1個切片
        output.append(concatenate_slices([x[:, 1], x[:, 2]], dim=2))  # 連接第2和第3個切片
        output.append(x[:, 3])  # 保留第4個切片
        output.append(concatenate_slices([x[:, 4], x[:, 8]], dim=1))  # 連接第5和第9個切片
        output.append(concatenate_slices([
            concatenate_slices([x[:, 5], x[:, 6]], dim=2),
            concatenate_slices([x[:, 9], x[:, 10]], dim=2)
        ], dim=1))  # 連接第6、7、10和11個切片
        output.append(concatenate_slices([x[:, 7], x[:, 11]], dim=1))  # 連接第8和12個切片
        output.append(x[:, 12])  # 保留第13個切片
        output.append(concatenate_slices([x[:, 13], x[:, 14]], dim=2))  # 連接第14和15個切片
        output.append(x[:, 15])  # 保留第16個切片
        processed_output = []
        for img in output:
            _, H, W, C = img.shape
            
            reshaped_img = img.view(1, H * W, C)
            attn_output, _ = self.multihead_attn(reshaped_img, reshaped_img, reshaped_img)
            processed_output.append(attn_output.view(1, H, W, C))

        final_output = []
        final_output.append(processed_output[0])
        split2_3 = torch.split(processed_output[1], self.shifted_window_w, dim=2)
        final_output.extend(split2_3)
        final_output.append(processed_output[2])
        split5_9 = torch.split(processed_output[3], self.shifted_window_h, dim=1)
        final_output.append(split5_9[0])
        split6_7_10_11 = torch.split(processed_output[4], self.shifted_window_h, dim=1)
        split6_7 = split6_7_10_11[0]
        split6_7 = torch.split(split6_7, self.shifted_window_w, dim=2)
        split_6 = split6_7[0]
        split_7 = split6_7[1]
        split10_11 = split6_7_10_11[1]
        split10_11 = torch.split(split10_11, self.shifted_window_w, dim=2)
        split_10 = split10_11[0]
        split_11 = split10_11[1]
        final_output.append(split_6)
        final_output.append(split_7)
        split8_12 = torch.split(processed_output[5], self.shifted_window_h, dim=1)
        final_output.append(split8_12[0])
        final_output.append(split5_9[1])
        final_output.append(split_10)
        final_output.append(split_11)
        final_output.append(split8_12[1])
        final_output.append(processed_output[6])
        split14_15 = torch.split(processed_output[7], self.shifted_window_w, dim=2)
        final_output.extend(split14_15)
        final_output.append(processed_output[8])

        list_tensor = torch.zeros((16, B, self.shifted_window_h * self.shifted_window_w, C), device=x.device)
        for i, list_item in enumerate(final_output):
            reshaped_item = list_item.contiguous().view(B, self.shifted_window_h * self.shifted_window_w, C)
            list_tensor[i] = reshaped_item
        
        list_tensor = list_tensor.view(B, H * 4, W * 4, C)
        return list_tensor
     
    def forward(self, x):
        tensor_list = []
        for shift_item in x:
            shift_item = self.splitwindow(shift_item)
            tensor_list.append(shift_item)
            
        stacked_tensor = torch.stack(tensor_list).squeeze(1)
        
        return stacked_tensor

In [8]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size):
        super(PatchMerging, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        return self.conv(x)

In [9]:
class WidthMerging(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(WidthMerging, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 2), stride=(1, 2))

    def forward(self, x):
        x = x.float()
        return self.conv(x)


In [10]:
class Window_MSA(nn.Module):
    def __init__(self, embed_dim):
        super(Window_MSA, self).__init__()
        self.stft = STFT()
        self.patch_partition1 = PatchPartition((4, 32))
        self.window_multihead_attn1 = WindowMultiheadAttention(embed_dim=128, num_heads=8, window_h=8, window_w=16)
        self.shifted_window_multihead_attn1 = ShiftedWindowMultiheadAttention(embed_dim=128, num_heads=8, height=16, width=32, shifted_window_h=4, shifted_window_w=8)
        self.patch_reverse1 = PatchReverse((4, 32))
        self.patch_merging_half1 = WidthMerging(in_channels=1, out_channels=2)
        self.patch_partition2 = PatchPartition((4, 32))
        self.window_multihead_attn2 = WindowMultiheadAttention(embed_dim=256, num_heads=8, window_h=8, window_w=8)
        self.shifted_window_multihead_attn2 = ShiftedWindowMultiheadAttention(embed_dim=256, num_heads=8, height=16, width=16, shifted_window_h=4, shifted_window_w=4)
        self.patch_reverse2 = PatchReverse((4, 32))
        self.patch_merging_half2 = WidthMerging(in_channels=2, out_channels=4)
        self.patch_partition3 = PatchPartition((4, 32))
        self.window_multihead_attn3 = WindowMultiheadAttention(embed_dim=512, num_heads=8, window_h=8, window_w=4)
        self.shifted_window_multihead_attn3 = ShiftedWindowMultiheadAttention(embed_dim=512, num_heads=8, height=16, width=8, shifted_window_h=4, shifted_window_w=2)
        self.patch_reverse3 = PatchReverse((4, 32))
        self.patch_merging_half3 = WidthMerging(in_channels=4, out_channels=8)
        self.patch_partition4 = PatchPartition((4, 32))
        self.window_multihead_attn4 = WindowMultiheadAttention(embed_dim=1024, num_heads=8, window_h=8, window_w=2)
        self.shifted_window_multihead_attn4 = ShiftedWindowMultiheadAttention(embed_dim=1024, num_heads=8, height=16, width=4, shifted_window_h=4, shifted_window_w=1)
        self.patch_reverse4 = PatchReverse((4, 32))
        self.patch_merging_half4 = WidthMerging(in_channels=8,out_channels=16)
        self.patch_partition5 = PatchPartition((4, 4))
        self.window_multihead_attn5 = WindowMultiheadAttention(embed_dim=256, num_heads=8, window_h=8, window_w=8)
        self.shifted_window_multihead_attn5 = ShiftedWindowMultiheadAttention(embed_dim=256, num_heads=8, height=16, width=16, shifted_window_h=4, shifted_window_w=4)
        self.patch_merging5 = PatchMerging(in_channels=256,out_channels=512,patch_size=(4,4))
        self.patch_partition6 = PatchPartition((4, 4))
        self.window_multihead_attn6 = WindowMultiheadAttention(embed_dim=512, num_heads=8, window_h=2, window_w=2)
        self.shifted_window_multihead_attn6 = ShiftedWindowMultiheadAttention(embed_dim=512, num_heads=8, height=8, width=8, shifted_window_h=1, shifted_window_w=1)
        self.flatten = nn.Flatten()
        self.out_dense = nn.Linear(8192, nb_classes)
        self.softmax = nn.Softmax(dim=1)
        

    def forward(self, x):
        x = self.stft(x)
        x = self.patch_partition1(x)
        x = self.window_multihead_attn1(x)
        x = self.shifted_window_multihead_attn1(x)
        x = self.patch_reverse1(x)
        x = x.permute(0, 3, 1, 2)
        x = self.patch_merging_half1(x)
        x = x.permute(0, 2, 3, 1)
        x = self.patch_partition2(x)
        x = self.window_multihead_attn2(x)
        x = self.shifted_window_multihead_attn2(x)
        x = self.patch_reverse2(x)
        x = x.permute(0, 3, 1, 2)
        x = self.patch_merging_half2(x)
        x = x.permute(0, 2, 3, 1)
        x = self.patch_partition3(x)
        x = self.window_multihead_attn3(x)
        x = self.shifted_window_multihead_attn3(x)
        x = self.patch_reverse3(x)
        x = x.permute(0, 3, 1, 2)
        x = self.patch_merging_half3(x)
        x = x.permute(0, 2, 3, 1)
        x = self.patch_partition4(x)
        x = self.window_multihead_attn4(x)
        x = self.shifted_window_multihead_attn4(x)
        x = self.patch_reverse4(x)
        x = x.permute(0, 3, 1, 2)
        x = self.patch_merging_half4(x)
        x = x.permute(0, 2, 3, 1)
        x = self.patch_partition5(x)
        x = self.window_multihead_attn5(x)
        x = self.shifted_window_multihead_attn5(x)
        x = x.permute(0, 3, 1, 2)
        x = self.patch_merging5(x)
        x = x.permute(0, 2, 3, 1)
        x = self.window_multihead_attn6(x)
        x = self.shifted_window_multihead_attn6(x)
        x = self.flatten(x)
        x = self.out_dense(x)

        return x


In [11]:
def save_log(message, log_file):
    with open(log_file, 'a') as f:
        f.write(message + '\n')

In [12]:
if __name__ == '__main__':
    log_file = 'W-MHSA_training_log_0726.txt'
    # Load data from HDF5 files
    # with h5py.File('./FMA-C-1-fixed-SCNN-Train.h5', 'r') as train_file:
    training_data = 'FMA-C-1-fixed-SCNN-Train.h5 + SCNN-Jamendo-train.h5'
    # Load data from FMA-C-1-fixed-SCNN-Train.h5
    with h5py.File('./FMA-C-1-fixed-SCNN-Train.h5', 'r') as fma_file:
        fma_train_data = torch.tensor(fma_file['X'][:])
        fma_train_labels = torch.tensor(fma_file['Y'][:])

    # Load data from SCNN-Jamendo-train.h5
    with h5py.File('./SCNN-Jamendo-train.h5', 'r') as jamendo_file:
        jamendo_train_data = torch.tensor(jamendo_file['X'][:])
        jamendo_train_labels = torch.tensor(jamendo_file['Y'][:])

    # Combine the training data and labels from both datasets
    train_data = torch.cat((fma_train_data, jamendo_train_data), dim=0)
    train_labels = torch.cat((fma_train_labels, jamendo_train_labels), dim=0)
    print("Combined train_data shape:", train_data.shape)

    with h5py.File('./FMA-C-1-fixed-SCNN-Test.h5', 'r') as val_file:
        val_data = torch.tensor(val_file['X'][:])
        val_labels = torch.tensor(val_file['Y'][:])

    # Create TensorDataset
    train_dataset = VocalDataset(train_data, train_labels)
    val_dataset = VocalDataset(val_data, val_labels)

    # DataLoader for training dataset
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)

    # DataLoader for validation dataset
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    # Printing dataset lengths
    save_log("Training dataset: " + training_data, log_file)
    save_log("Training dataset length: " + str(len(train_dataset)), log_file)
    save_log("Validation dataset length: " + str(len(val_dataset)), log_file)

    # Check device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    save_log("Using device: " + str(device), log_file)

    # Initialize model, loss function and optimizer
    model = Window_MSA(nb_classes).to(device)
    # if torch.cuda.device_count() > 1:
    #     model = nn.DataParallel(model)
    #     save_log("Let's use " + str(torch.cuda.device_count()) + " GPUs!", log_file)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adadelta(model.parameters(), lr=lr)
    # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Train the model
    for epoch in range(max_epochs):
        model.train()
        epoch_train_loss = 0.0
        correct_train = 0
        total_train = 0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            labels = torch.argmax(labels, dim=1)
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            print(f"[{epoch + 1}, {i + 1}] loss: {epoch_train_loss}")


        epoch_train_loss /= len(train_loader)
        train_accuracy = 100 * correct_train / total_train


        # Validate the model
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                labels = torch.argmax(labels, dim=1)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct_val / total_val

        epoch_log = (f"Epoch {epoch + 1}:\n"
                     f"Train loss: {epoch_train_loss}, Train accuracy: {train_accuracy:.2f}%\n"
                     f"Validation loss: {val_loss:.8f}, Validation accuracy: {val_accuracy:.2f}%\n")
        save_log(epoch_log, log_file)
        if (epoch + 1) % 5 == 0:
            save_path = f'./FMAC1_Jamendo_SWMHSA_weights_epoch_{epoch + 1}.pth'
            save_log(f"Saving model at epoch {epoch + 1}: {save_path}", log_file)
            torch.save(model.state_dict(), save_path)

    print("Finished Training")
    save_log("./FMAC1_Jamendo_SWMHSA_weights_1.pth", log_file)
    torch.save(model.state_dict(), './FMAC1_Jamendo_SWMHSA_weights_1.pth')

Combined train_data shape: torch.Size([25611, 32000, 1])
[1, 1] loss: 0.6933262944221497
[1, 2] loss: 1.3643097877502441
[1, 3] loss: 2.0380415320396423
[1, 4] loss: 2.7198262214660645
[1, 5] loss: 3.390055239200592
[1, 6] loss: 4.054217457771301
[1, 7] loss: 4.739806830883026
[1, 8] loss: 5.414485812187195
[1, 9] loss: 6.1338934898376465
[1, 10] loss: 6.82484757900238
[1, 11] loss: 7.510760068893433
[1, 12] loss: 8.204543054103851
[1, 13] loss: 8.878549933433533
[1, 14] loss: 9.554711699485779
[1, 15] loss: 10.242636322975159
[1, 16] loss: 10.918583452701569
[1, 17] loss: 11.608520805835724
[1, 18] loss: 12.28189867734909
[1, 19] loss: 12.93231201171875
[1, 20] loss: 13.632288992404938
[1, 21] loss: 14.31502115726471
[1, 22] loss: 14.997853517532349
[1, 23] loss: 15.682101488113403
[1, 24] loss: 16.34697663784027
[1, 25] loss: 17.043636739253998
[1, 26] loss: 17.71511596441269
[1, 27] loss: 18.400352239608765
[1, 28] loss: 19.084711253643036
[1, 29] loss: 19.782024562358856
[1, 30] lo