In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import h5py
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchinfo import summary

# Define hyper-parameters
lr = 1.0
max_epochs = 160
batch_size = 150
input_height = 32000
input_width = 1
nb_classes = 2

In [2]:
class VocalDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [3]:
class STFT(nn.Module):  
    def __init__(self):
        super(STFT, self).__init__()

    def forward(self, x):

        x = x.squeeze(dim=2)  # remove channel dimension
        window = torch.hann_window(2048, device=x.device)
        
        # STFT
        stft = torch.stft(x, n_fft=2048, hop_length=512, win_length=2048, window=window, return_complex=True)
        
        # calculate magnitude
        spectrogram = torch.abs(stft)
        
        # cut the spectrogram to 1024 frames and transpose it
        spectrogram = spectrogram[:, :1024, :].permute(0, 2, 1)
        
        # make sure size is (batch_size, 1, 64, 1024)
        spectrogram = spectrogram.unsqueeze(1)  # unsqueeze channel dimension
        
        return spectrogram

In [4]:
class USCLLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, pooling=False, padding=0, activation='relu'):
        super(USCLLayer, self).__init__()
        
        self.pooling = pooling

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size)

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.bn1(x)

        if self.pooling:
            x = self.pool(x)

        return x

In [5]:
class SCNN18(nn.Module):
    def __init__(self, nb_classes):
        super(SCNN18, self).__init__()

        self.STFT = STFT()
        self.uscl_conv1 = USCLLayer(in_channels=1, out_channels=64, kernel_size=(3, 2), stride=(3, 2), pooling=False, padding=0)
        self.uscl_conv2 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv3 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv4 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv5 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv6 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv7 = USCLLayer(in_channels=64, out_channels=64, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv8 = USCLLayer(in_channels=64, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv9 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv10 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv11 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv12 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)
        self.uscl_conv13 = USCLLayer(in_channels=128, out_channels=128, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv14 = USCLLayer(in_channels=128, out_channels=256, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv15 = USCLLayer(in_channels=256, out_channels=256, kernel_size=(1, 2), stride=(1, 1), pooling=True, padding=0)
        self.uscl_conv16 = USCLLayer(in_channels=256, out_channels=256, kernel_size=(1, 2), stride=(1, 1), pooling=False, padding=0)


        self.final_uscl = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 2), stride=(1, 1), padding=0)
        self.final_uscl_relu = nn.ReLU()
        self.final_uscl_bn = nn.BatchNorm2d(256)
        self.final_pool = nn.MaxPool2d((3, 2))

        self.final_conv = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=0)
        self.final_relu = nn.ReLU()
        self.final_bn = nn.BatchNorm2d(256)

        self.dropout = nn.Dropout(0.5)
        self.flatten = nn.Flatten()
        self.out_dense = nn.Linear(1792, nb_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.STFT(x)
        x = self.uscl_conv1(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv2(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv3(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv4(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv5(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv6(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv7(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv8(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv9(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv10(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv11(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv12(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv13(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv14(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv15(x)
        x = F.pad(x, (0, 1, 0, 0))
        x = self.uscl_conv16(x)
        x = F.pad(x, (0, 1, 1, 1))
        x = self.final_uscl(x)

        x = self.final_uscl_relu(x)
        x = self.final_uscl_bn(x)
        x = self.final_pool(x)

        x = self.final_conv(x)
        x = self.final_relu(x)
        x = self.final_bn(x)

        x = self.dropout(x)
        x = self.flatten(x)
        x = self.out_dense(x)
        # x = self.softmax(x)
        return x

In [6]:
def save_log(message, log_file):
    with open(log_file, 'a') as f:
        f.write(message + '\n')

In [7]:
if __name__ == '__main__':
    log_file = 'FMA-C-1_SCNN18_training_log.txt'
    # Load data from HDF5 files
    training_data = 'FMA-C-1-fixed-SCNN-Train.h5'
    with h5py.File('./FMA-C-1-fixed-SCNN-Train.h5', 'r') as train_file:
        train_data = torch.tensor(train_file['X'][:])
        train_labels = torch.tensor(train_file['Y'][:])
        print("train_data", train_data.shape)

    with h5py.File('./FMA-C-1-fixed-SCNN-Test.h5', 'r') as val_file:
        val_data = torch.tensor(val_file['X'][:])
        val_labels = torch.tensor(val_file['Y'][:])

    # Create TensorDataset
    train_dataset = VocalDataset(train_data, train_labels)
    val_dataset = VocalDataset(val_data, val_labels)

    # DataLoader for training dataset
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)

    # DataLoader for validation dataset
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    # Printing dataset lengths
    save_log("Training dataset: " + training_data, log_file)
    save_log("Training dataset length: " + str(len(train_dataset)), log_file)
    save_log("Validation dataset length: " + str(len(val_dataset)), log_file)

    # Check device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    save_log("Using device: " + str(device), log_file)

    # Initialize model, loss function and optimizer
    model = SCNN18(nb_classes).to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        save_log("Let's use " + str(torch.cuda.device_count()) + " GPUs!", log_file)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adadelta(model.parameters(), lr=lr)

    # Train the model
    for epoch in range(max_epochs):
        model.train()
        epoch_train_loss = 0.0
        correct_train = 0
        total_train = 0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            labels = torch.argmax(labels, dim=1)
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        # Calculate average train loss
        epoch_train_loss /= len(train_loader)
        train_accuracy = 100 * correct_train / total_train

        # Validate the model
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                labels = torch.argmax(labels, dim=1)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        # Calculate average validation loss
        val_loss /= len(val_loader)
        val_accuracy = 100 * correct_val / total_val

        # Log epoch results
        epoch_log = (f"Epoch {epoch + 1}:\n"
                     f"Train loss: {epoch_train_loss:.8f}, Train accuracy: {train_accuracy:.2f}%\n"
                     f"Validation loss: {val_loss:.8f}, Validation accuracy: {val_accuracy:.2f}%\n")
        save_log(epoch_log, log_file)

    print("Finished Training")
    save_log("./SCNN18_FMAC-1_train_1.pth", log_file)
    torch.save(model.state_dict(), './SCNN18_FMAC-1_train_1.pth')

train_data torch.Size([12254, 32000, 1])




Finished Training
