# 1. Initialization

## Imports

In [None]:
import math, random

import torch
import torch.nn as nn
from torch.nn import init, functional
from torch.utils.data import DataLoader, Dataset, random_split
import torchaudio
from torchaudio import transforms

import pandas as pd
import numpy as np
from IPython.display import Audio
from matplotlib import pyplot as plt
from os import walk

In [None]:
#To allow plotting pytorch tensors
torch.Tensor.ndim = property(lambda self: len(self.shape))
#use gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Constants

In [None]:
MAX_AUDIO_LENGTH = 5000
SAMPLING_RATE = 2000
N_CHANNELS = 1
SHIFT_PCT = 0.3
MODEL_NAME = 'model_1'

In [None]:
num_epochs = 100

# 2. Audio Preprocessing Utilities

## AudioUtil Class
A class that includes all the methods required to load and preprocess the audio files of the dataset

In [None]:
class AudioUtil():
    @staticmethod
    def open(audio_file):

        #Open an audio file
        # print(f"Opening file : {audio_file}")
        sig, sr = torchaudio.load(audio_file)
        sig.to(device)
        return (sig, sr)
    
    @staticmethod
    def print(aud, channel):

        #Plot the audio signal wave

        sig, sr = aud
        duration = sig.shape[1]
        time = torch.linspace(0, duration/sr, duration)

        print(sig.shape)
        print('Plotting...')

        plt.figure(figsize=(15, 5))
        plt.plot(time, sig[channel - 1])
        plt.title('Audio Plot')
        plt.ylabel(' signal wave')
        plt.xlabel('time (s)')
        plt.show()

    @staticmethod
    def display_spectrogram(spec, label='Audio mel spectrogram'):
        
        #Display the audio mel spectrogram

        print(spec.shape)
        print('Plotting...')
        
        plt.imshow(spec[0])
        plt.title(label)
        plt.ylabel('Frequency (mels)')
        plt.xlabel('Time (ms)')
        plt.colorbar(format='%+2.0f dB')

        plt.show()

    @staticmethod
    def rechannel(aud, new_channel):

        #Convert the audio from mono to stereo or vice versa

        sig, sr = aud

        if(sig.shape[0] == new_channel):
            return aud
        
        # print('Rechanneling to ' + str(new_channel))
        if(new_channel == 1):
            resig = sig[:1, :]
        else:
            resig = torch.cat([sig, sig])
        
        return ((resig, sr))

    @staticmethod
    def resample(aud, newsr):

        #Resample the audio to the newsr frequency

        sig, sr = aud

        if(sr == newsr):
            return
        
        # print('Resampling to ' + str(newsr))

        num_channels = sig.shape[0]
        resig = torchaudio.transforms.Resample(sr, newsr)(sig[:1, :])
        resig.to(device)

        if(num_channels > 1):
            retwo = torchaudio.transforms.Resample(sr, newsr)(sig[1:, :])
            retwo.to(device)
            resig = torch.cat([resig, retwo])

        return((resig, newsr))

    @staticmethod
    def pad_trunc(aud, max_ms):

        #add padding, or truncate the signal to fit the max length
        sig, sr = aud
        num_rows, sig_len = sig.shape
        max_len = sr//1000 * max_ms

        if(sig_len > max_len):
            #Truncate the signal
            # print('Truncating signal to ' + str(max_ms) + ' ms')
            sig = sig[:, :max_len]
        elif(sig_len < max_len):
            #Add padding
            # print('Padding signal to ' + str(max_ms) + ' ms')
            pad_begin_len = random.randint(0, max_len - sig_len)
            pad_end_len = max_len - sig_len - pad_begin_len

            pad_begin = torch.zeros((num_rows, pad_begin_len))
            pad_end = torch.zeros((num_rows, pad_end_len))

            sig = torch.cat((pad_begin, sig, pad_end), 1)
        
        return ((sig, sr))

    @staticmethod
    def time_shift(aud, shift_limit):
        sig, sr = aud
        _, sig_len = sig.shape
        shift_amt = int(random.random() * shift_limit * sig_len)
        return (sig.roll(shift_amt), sr)

    @staticmethod
    def pitch_shift(aud, shift_limit):
        sig, sr = aud
        shift_amt = random.random() * shift_limit
        return (sig * shift_amt, sr)

    @staticmethod
    def get_mel_spectrogram(aud, hop_length):
        sig, sr = aud
        top_db = 80

        mel_transformation = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=1024, hop_length=hop_length, n_mels=64)
        db_transformation = torchaudio.transforms.AmplitudeToDB(top_db=top_db)
        mel_transformation.to(device)
        db_transformation.to(device)
        spec = mel_transformation(sig.to(device))
        spec = db_transformation(spec.to(device))
        return spec

    @staticmethod
    def spectro_augment(spec, max_mask_pct = 0.1, n_freq_masks = 1, n_time_masks=1):
        _, n_mels, n_steps = spec.shape
        mask_value = spec.mean()
        aug_spec = spec

        freq_max_param = max_mask_pct * n_mels
        for _ in range(n_freq_masks):
            aug_spec = transforms.FrequencyMasking(freq_max_param).to(device)(aug_spec, mask_value)
        time_mask_params = max_mask_pct * n_steps
        for _ in range(n_time_masks):
            aug_spec = transforms.TimeMasking(time_mask_params).to(device)(aug_spec, mask_value)

        return aug_spec

    @staticmethod
    def get_MFCC(aud, hop_length=512):
        sig, sr = aud
        mfcc_fn = transforms.MFCC(  sample_rate=sr,
                                    n_mfcc=64,
                                    melkwargs={"n_fft": 1024, "n_mels": 64, "hop_length": hop_length})
        mfcc_fn.to(device)
        return mfcc_fn(sig.to(device))

    @staticmethod
    def plot_MFCC(mfcc, label="MFCC Features"):
        for i in range(mfcc.shape[1]):
            sig = mfcc[0, i]
            plt.plot(range(sig.shape[0]), sig, label=f"MFCC {i}")
        plt.title(label)
        plt.show()

    @staticmethod
    def preprocess_audio(audio_dir):
        aud = AudioUtil.open(audio_dir)
        aud = AudioUtil.rechannel(aud, N_CHANNELS)
        aud = AudioUtil.resample(aud, SAMPLING_RATE)
        aud = AudioUtil.pad_trunc(aud, MAX_AUDIO_LENGTH)
        aud = AudioUtil.time_shift(aud, SHIFT_PCT)
        aud = AudioUtil.pitch_shift(aud, SHIFT_PCT)
        spec = AudioUtil.get_mel_spectrogram(aud, hop_length=512)
        aug_spec = AudioUtil.spectro_augment(spec, n_freq_masks=2, n_time_masks=2)
        mfcc = AudioUtil.get_MFCC(aud=aud)
        # print(type(aud))
        # print(type(spec))
        # print(type(aug_spec))
        # print(type(mfcc))
        return (aud, spec, aug_spec, mfcc)

## Importing dataset

In [None]:
d = {'relative_path' : [], 'classID': [], 'file_name': []}
class_weights = [0] * 4
def fillPaths(path, classID):
    for (dirpath, dirnames, filenames) in walk(path):
        relative_path = map(lambda self: dirpath + '/' + self, filenames)
        d['relative_path'].extend(relative_path)
        temp = [classID] * len(filenames)
        d['classID'].extend(temp)
        d['file_name'].extend(filenames)
        print(f"{path} [{len(filenames)}]")
        class_weights[classID] = 1/len(filenames)
        break

fillPaths('Dataset/Atraining_extrahls', 3)
fillPaths('Dataset/Atraining_murmur', 2)
fillPaths('Dataset/Atraining_normal', 1)
fillPaths('Dataset/Atraining_artifact', 0)

df = pd.DataFrame(data=d)
df = df[df.file_name != '.DS_Store']

## Sound Dataset class
Used to return a randomly preprocessed audio each time by overriding the __getitem__ function

In [None]:
class SoundDS(Dataset):
    # constructor
    def __init__(self, df):
        super().__init__()
        self.df = df
    
    # get length of ds
    def __len__(self):
        return len(self.df)
    
    # get i'th item in dataset
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            d = []
            start = 0 if idx.start == None else idx.start
            stop = len(self.df) if idx.stop == None else idx.stop
            for i in range(start,stop):
                try:
                    audio_file = self.df.iloc[i].relative_path
                    class_id = self.df.iloc[i].classID
                    _, _, aug_spec, mfcc = AudioUtil.preprocess_audio(audio_file)
                    data = torch.cat(aug_spec, mfcc, dim=0)
                    d.append([data, class_id])
                except:
                    print(f"An exception occurred at {i}, file: {df.iloc[i].relative_path}")

            return d

        audio_file = self.df.loc[idx, 'relative_path']
        class_id = self.df.loc[idx, 'classID']

        _, _, aug_spec, mfcc = AudioUtil.preprocess_audio(audio_file)
        data = torch.cat([aug_spec, mfcc], dim=0)
        return data, class_id

## Random split
Split the dataset into an 80:20 ratio between training and validation sets

In [None]:
df = df.sample(frac=1).reset_index(drop=True) #shuffle data
myds = SoundDS(df)

num_items = len(myds)
num_train = round(num_items * 0.8)
num_val = num_items - num_train
train_ds, val_ds = random_split(myds, [num_train, num_val])

## Data loaders
Create training and validation data loaders

In [None]:
sample_weights = [0] * len(train_ds)
for idx in range(len(train_ds)):
    data = train_ds[idx]
    
    label = data[1].item()
    class_weight = class_weights[label]
    sample_weights[idx] = class_weight
sampler = torch.utils.data.WeightedRandomSampler(sample_weights,
                                                num_samples=len(sample_weights),
                                                replacement=True)

In [None]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size = 16, sampler=sampler)
# train_dl = torch.utils.data.DataLoader(train_ds, batch_size = 16, shuffle = True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size = 16, shuffle = False)

# 3. Audio Classification Model

## Creating a model instance
- Put it on the gpu if available
- Import a model if you want to

In [None]:
from packages.CNN_1 import AudioClassifier
myModel = AudioClassifier(input_dim=2*N_CHANNELS, output_dim=4)

## Creating save path

In [None]:
IMPORT = False  #Import model state_dict if available
TEMP = False    #Import temp model state_dict if available

from pathlib import Path
MODEL_PATH = Path('models')
MODEL_PATH.mkdir(parents=True, exist_ok=True)

MODEL_NAME_PTH = MODEL_NAME + '.pth'
MODEL_TEMP_NAME_PTH = MODEL_NAME + '_TEMP.pth'
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME_PTH
MODEL_TEMP_SAVE_PATH = MODEL_PATH / MODEL_TEMP_NAME_PTH

if IMPORT:
    try:
        myModel.load_state_dict(torch.load(f=MODEL_TEMP_SAVE_PATH if TEMP else MODEL_SAVE_PATH))
    except:
        print(MODEL_TEMP_SAVE_PATH if TEMP else MODEL_SAVE_PATH + 'not found.') 

In [None]:
myModel = myModel.to(device)
# Check that it is on cuda
next(myModel.parameters()).device

# 4. Training Loop

## Training function

In [None]:
def training(model, train_dl, num_epochs, test_dl):
    # Loss function, Optimizer and Scheduler
    loss_fn = nn.CrossEntropyLoss() #weight=torch.tensor([1.21, 1.56, 1.42, 2.63])
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=int(len(train_dl)), epochs=num_epochs, anneal_strategy='linear')

    # Repeat for each epoch
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_prediction = 0
        total_prediction = 0
        # Repeat for each batch in the training set
        for i, data in enumerate(train_dl):
            # Get the input features and target labels and put them on the GPU
            inputs, labels = data[0].to(device), data[1].to(device)
            # Normalize the inputs
            inputs_m, inputs_s = inputs.mean(), inputs.std()
            inputs = (inputs - inputs_m) / inputs_s

            optimizer.zero_grad()

            # forward + backward + optimize
            model.train()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            #keep stats for loss and accuracy
            running_loss += loss.item()

            # get the predicted class with the highest score
            _, prediction = torch.max(outputs, 1)
            # count of predictions that matched the target label
            correct_prediction += (prediction == labels).sum().item()
            total_prediction += prediction.shape[0]

            # test preds
            model.eval()
            test_correct_prediction = 0
            test_total_prediction = 0
            with torch.inference_mode():
                for j, test_data in enumerate(test_dl):
                    test_inputs, test_labels = test_data[0].to(device), test_data[1].to(device)
                    # Normalize the inputs
                    test_inputs_m, test_inputs_s = test_inputs.mean(), test_inputs.std()
                    test_inputs = (test_inputs - test_inputs_m) / test_inputs_s
                    test_outputs = model(test_inputs)
                    _, test_prediction = torch.max(test_outputs, 1)
                    # count of predictions that matched the target label
                    test_correct_prediction += (test_prediction == test_labels).sum().item()
                    test_total_prediction += test_prediction.shape[0]
        
        # Print stats at the end of the epoch
        num_batches = len(train_dl)
        avg_loss = running_loss / num_batches
        acc = correct_prediction / total_prediction
        test_acc = test_correct_prediction / test_total_prediction
        if epoch % 10 == 0:
            print(f'Epoch: {epoch}, Loss: {avg_loss:.2f}, Accuracy: {acc:.2f}, Test Accuracy: {test_acc:.2f}')
        if epoch % 1000 == 0:
            print(f"Saving model to {MODEL_TEMP_SAVE_PATH}")
            torch.save(obj=myModel.state_dict(), f=MODEL_TEMP_SAVE_PATH)
    
    print('Finished Training')

## Start training

In [None]:
training(myModel, train_dl, num_epochs, test_dl=val_dl)

Save the model after the training is complete

In [None]:
print(f"Saving model to {MODEL_SAVE_PATH}")
torch.save(obj=myModel.state_dict(), f=MODEL_SAVE_PATH)

# 5. Inference
Run inference on the trained model with the validation set

In [None]:
def inference(model, val_dl):
    correct_prediction = 0
    total_prediction = 0

    with torch.inference_mode():
        for data in val_dl:
            inputs, labels = data[0].to(device), data[1].to(device)

            inputs_m, inputs_s = inputs.mean(), inputs.std()
            inputs = (inputs - inputs_m) / inputs_s

            outputs = model(inputs)

            _, prediction = torch.max(outputs, 1)

            correct_prediction += (prediction == labels).sum().item()
            total_prediction += prediction.shape[0]
        acc = correct_prediction / total_prediction
        print(f'Accuracy: {acc:.2f}, Total items: {total_prediction}')

In [None]:
inference(myModel, val_dl)

# 6. Testing

In [None]:
def testing(model, data):
    model.eval()
    with torch.inference_mode():
        input = torch.tensor(data[0].unsqueeze(dim=0)).to(device)
        label = torch.tensor(data[1]).to(device)
        inputs_m, inputs_s = input.mean(), input.std()
        input = (input - inputs_m) / inputs_s
        output = model(input)
        _, prediction = torch.max(output,1)
        print(f"{prediction[0].item()} <--> {label}")

# print(df[(df.index > num_train) & (df.index < num_train + len(val_ds))])
print(f"predicted <--> actual")
for i in range(len(val_ds)):
    testing(myModel, val_ds[i])