In [1]:
import os

from torch.utils.data import Dataset
import torchaudio
import torch


# This is a custom dataset class.
class CustomDataset(Dataset):
    def __init__(self, path_dataset, is_train=True,
                 transform=False, max_length=32000, sr=4000):
        self.samples = []
        self.labels_map = {}
        self.is_train = is_train
        self.transform = transform
        self.max_length = max_length
        self.sr = sr
        self.read(path_dataset)

    def read(self, path_dataset):
        for idx_class, class_name in enumerate(os.listdir(path_dataset)):
            path_class = os.path.join(path_dataset, class_name)
            self.labels_map[class_name] = idx_class
            if self.is_train:
                path_class = os.path.join(path_class, "train")
            else:
                path_class = os.path.join(path_class, "test")

            for idx, file_name in enumerate(os.listdir(path_class)):
                path_file = os.path.join(path_class, file_name)
                waveform, sr = torchaudio.load(path_file)
                metadata = torchaudio.info(path_file)                
                if self.transform:
                    transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)
                    waveform = transform(waveform)
                    waveform = self.padding(waveform, self.max_length)
                    print(waveform.shape)
                    sr = self.sr

                self.samples.append((waveform, idx_class, sr))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        return self.samples[index][0], self.samples[index][1], self.samples[index][2]

    def padding(self, waveform, max_len):
        # Pad the waveform
        length_waveform = waveform.shape[1]
        if length_waveform < max_len:
            waveform = torch.cat((waveform, torch.zeros((1, max_len - length_waveform))), dim=1)
        return waveform

    def getLabelsMap(self):
        return self.labels_map

    def getLabelCount(self):
        return len(self.labels_map)




In [2]:


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import sys

import matplotlib.pyplot as plt
import IPython.display as ipd

from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader


class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)


In [3]:
import custom_dataset
from model import M5
import torch
import torchaudio
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import sounddevice as sd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader


def plot_waveform(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c + 1}")
    figure.suptitle("waveform")
    plt.show(block=True)


def pad_sequence(batch):
    # Make all tensor in a batch the same length by padding with zeros
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
    return batch.permute(0, 2, 1)

def collate_fn(batch):
    tensors, targets = [], []
    for waveform, idx_class, sr in batch:
        tensors += [waveform]
        targets += [torch.tensor(idx_class)]

        # Group the list of tensors into a batched tensor
    tensors = pad_sequence(tensors)
    targets = torch.stack(targets)

    return tensors, targets


def number_of_correct(pred, target):
    return pred.squeeze().eq(target).sum().item()

def padding(waveform, max_len):
    # Pad the waveform
    length_waveform = waveform.shape[1]
    if length_waveform < max_len:
        waveform = torch.cat((waveform, torch.zeros((1, max_len - length_waveform))), dim=1)
    return waveform


def get_likely_index(tensor):
    return tensor.argmax(dim=-1)

In [6]:
path_dataset = "/home/goktug/projects/Ceres/dataset/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: ", device)

sr_target = 8000
max_length = 32000
batch_size = 2
transform = True

custom_dataset_train = custom_dataset.CustomDataset(path_dataset, True, transform,
                                                    max_length, sr_target)
train_loader = torch.utils.data.DataLoader(custom_dataset_train, batch_size=batch_size,
                                           shuffle=True, num_workers=1, collate_fn=collate_fn,
                                           pin_memory=True)

custom_dataset_test = custom_dataset.CustomDataset(path_dataset, False, transform, max_length, sr_target)
test_loader = torch.utils.data.DataLoader(custom_dataset_test, batch_size=batch_size,
                                          shuffle=True, num_workers=1, collate_fn=collate_fn,
                                          pin_memory=True)

device:  cuda


In [7]:
# Iterate over batches train:
sample_waveform = None
for idx, (waveforms, targets) in enumerate(train_loader):
    # print("Batch index: ", idx)
    # print("Waveforms shape: ", len(waveforms))
    # print("Targets: ", targets)
    sample_waveform = waveforms[0]
    break
    # plot_waveform(waveforms[0], sr_target)

print("Number of samples in train dataset: ", len(custom_dataset_train))
print("Number of samples in test dataset: ", len(custom_dataset_test))
print("Sample rate: ", sr_target)
print("Wavelength: ", len(sample_waveform[0]))
print(custom_dataset_train.labels_map)
print(custom_dataset_test.labels_map)

Number of samples in train dataset:  3066
Number of samples in test dataset:  771
Sample rate:  8000
Wavelength:  16000
{'kapa': 0, 'ac': 1}
{'kapa': 0, 'ac': 1}


In [6]:
model = M5(n_input=sample_waveform.shape[0],
           n_output=custom_dataset_train.getLabelCount())
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

In [7]:

losses_train_list = []
for epoch in range(12):

    loss_epoch = 0

    # Train
    model.train()
    for batch_idx, (waveforms, targets) in enumerate(train_loader):
        sample_count = len(waveforms)
        waveforms = waveforms.to(device)
        targets = targets.to(device)
        output_ac = model(waveforms)
        loss = F.nll_loss(output_ac.squeeze(), targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #print("Batch: ", batch_idx, " Sample count: ", sample_count, " Loss: ", loss.item())
        loss_epoch += loss.item()

    loss_epoch /= len(train_loader)
    losses_train_list.append(loss_epoch)
    print("\nEpoch: ", epoch, " Loss: ", loss_epoch)
    
    
    # Test
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (waveforms, targets) in enumerate(test_loader):
            waveforms = waveforms.to(device)
            targets = targets.to(device)
            output_ac = model(waveforms)
            pred = get_likely_index(output_ac)
            correct += number_of_correct(pred, targets)

    print("Test accuracy: ", correct / len(custom_dataset_test))
    print("-----------------------------------------------------")



Epoch:  0  Loss:  0.21246020459727974
Test accuracy:  0.9559014267185474
-----------------------------------------------------


KeyboardInterrupt: 

In [None]:
path_ac = "/home/goktug/Downloads/ac.wav"
path_kapa = "/home/goktug/Downloads/kapa.wav"
waveform_ac, sr_ac = torchaudio.load(path_ac)
waveform_ac = waveform_ac[0]
print("Loaded audio with sample rate: ", sr_ac)

waveform_ac = waveform_ac.reshape(1, -1)
transform = torchaudio.transforms.Resample(orig_freq=sr_ac, new_freq=sr_target)
waveform_ac = transform(waveform_ac)
waveform_ac = padding(waveform_ac, max_length)
waveform_ac.shape