In [None]:
import numpy as np
import librosa
import librosa.display
import pandas as pd
import random
import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import os
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score, roc_auc_score

from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from src.model import MobileNetv2Model, ResNet18Model
from sklearn.metrics import confusion_matrix

import torch
import torch.nn as nn

if torch.cuda.is_available():
    device = torch.device('cuda:1')
else:
    device = torch.device('cpu') 
    
%load_ext autoreload
%autoreload 2

In [None]:
import audiomentations
from audiomentations import AddGaussianSNR, TimeStretch, PitchShift, Shift
from sklearn.model_selection import KFold, StratifiedKFold
import collections

In [None]:
def spec_to_image(spec, eps=1e-6):
    mean = spec.mean()
    std = spec.std()
    spec_norm = (spec - mean) / (std + eps)
    spec_min, spec_max = spec_norm.min(), spec_norm.max()
    spec_scaled = 255 * (spec_norm - spec_min) / (spec_max - spec_min)
    spec_scaled = spec_scaled.astype(np.uint8)
        
    return spec_scaled

def get_melspectrogram_db(file_path, aug=False, sr=48000, n_fft=2048, hop_length=256, n_mels=128, fmin=20, fmax=8300, top_db=80):
    wav, _ = librosa.load(file_path, sr=sr)
    
    # # Ensure audio is at least 20 seconds
    if wav.shape[0] < 14 * sr:
        wav = np.pad(wav, int(np.ceil((14 * sr - wav.shape[0]) / 2)), mode='reflect')
    else:
        wav = wav[:14 * sr]
    
    audio_transforms = audiomentations.Compose([
        AddGaussianSNR(min_snr_in_db=5, max_snr_in_db=40.0, p=0.5),
        TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
        PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
        Shift(-0.5, 0.5, p=0.5),
    ])
    
    if aug:
        wav = audio_transforms(samples=wav, sample_rate=sr)
     
    spec = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax)
    spec_db = librosa.power_to_db(spec, top_db=top_db)

    return spec_db

def array_to_tensor(img_array) -> torch.FloatTensor:
    return torch.FloatTensor(img_array)

class ImageCombinedDataset(Dataset):
    def __init__(self, files, labels,  augs=False):
        self.files = files
        self.labels = labels
        self.augs = augs
            
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_a = self.files[idx]
        lbl = self.labels[idx]
        file_b = file_a.replace('pre', 'post')
                
        image_post = spec_to_image(get_melspectrogram_db(file_b, aug=self.augs))[np.newaxis, ...]
        image_pre = spec_to_image(get_melspectrogram_db(file_a, aug=self.augs))[np.newaxis, ...]
        
        combined = np.concatenate([image_pre, image_post])
        
        label = torch.tensor(lbl, dtype=torch.int64)
        return array_to_tensor(combined), label

### Splitting data

In [None]:
random.seed(10)

# Data files
data_path_pre = 'etvoice/extractedMPTs/pre swallow/**'
data_files_pre = [x for x in glob.glob(data_path_pre) if '.wav' in x and 'android' not in x
                 and 'apple' not in x]

data_path_post = 'wetvoice/extractedMPTs/post swallow/**'
data_files_post = [x for x in glob.glob(data_path_post) if '.wav' in x and 'android' not in x
                  and 'apple' not in x]

ids = [x.split('/')[-1].split('_post')[0] for x in data_files_post]

train_files = [x for x in data_files_pre if x.split('/')[-1].split('_pre')[0] in ids]
train_files.remove('wetvoice/extractedMPTs/pre swallow/159_pre_mpt.wav')

# LABELS
df = pd.read_excel('audio file numbers and aspiration values.xlsx')
df.head()
train_labels = [df[df['Audio File Name']==int(i.split('/')[-1].split('_pre')[0])]['Aspiration  '].values[0]
                for i in train_files]
train_labels = [1 if x == 'Yes' else 0 for x in train_labels]

len(train_files),len(train_labels)

## K-fold split

In [None]:
# Define the number of splits
k = 5 
kf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)

train_files = np.array(train_files)  
train_labels = np.array(train_labels)

folds = collections.defaultdict()

for fold, (train_index, test_index) in enumerate(kf.split(train_files, train_labels)):
    folds[fold] = [train_files[train_index], train_files[test_index], 
                   train_labels[train_index], train_labels[test_index]]

In [None]:
folds.keys()

In [None]:
# Fold: 0
X_train, X_testing, y_train, y_testing = folds[0]

X_val, X_test, y_val, y_test = train_test_split(X_testing, y_testing,
                                                    test_size=0.5, random_state=42,  stratify=y_testing)

len(X_train), len(y_train), len(X_val), len(y_val), len(X_test), len(y_test)

In [None]:
train_data = ImageCombinedDataset(X_train, y_train, augs=True)
valid_data = ImageCombinedDataset(X_val, y_val, augs=False)
test_data = ImageCombinedDataset(X_test, y_test, augs=False)

train_loader = DataLoader(train_data, batch_size=12, shuffle=True, num_workers=2) 
valid_loader = DataLoader(valid_data, batch_size=12, shuffle=False, num_workers=2)
test_loader = DataLoader(test_data, batch_size=12, shuffle=False, num_workers=2)

### Model Training

In [None]:
# prop_class = (len(y_train) - np.sum(y_train)) / len(y_train)
# class_weights = [prop_class, 1.]

model = MobileNetv2Model(in_channels=2, num_classes=2)

logger = TensorBoardLogger(save_dir=os.getcwd(), name="logs/mobilenet_augs_fold0")

trainer = pl.Trainer(logger=logger, accelerator='gpu', devices=[0],
                     max_epochs=30, callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5,)],)

model.hparams.lr = 1e-5  
trainer.fit(model, train_loader, valid_loader)

## Model Inference

In [None]:
check = 'logs/mobilenet_augs_fold0/version_0/checkpoints/epoch=18-step=114.ckpt'
model = MobileNetv2Model.load_from_checkpoint(check,  in_channels=2, num_classes=2)

In [None]:
# Training set 
model.eval()

tr_data = ImageCombinedDataset(X_train, y_train, augs=False)
results, labels = [], []    
for i, sample in tqdm(enumerate(tr_data)): 
    img, cl = sample
    labels.append(cl.item())    
    logits = model(img.unsqueeze(0)).squeeze()  
    res = torch.sigmoid(logits) > 0.5
    results.append(res.item())

print(confusion_matrix(labels, results))
print('Precision', precision_score(labels, results, average='macro'))
print('Recall', recall_score(labels, results, average='macro'))
print('Accuracy', accuracy_score(labels, results))
print('F1', f1_score(labels, results, average='macro'))

tn, fp, fn, tp = confusion_matrix(labels, results).ravel()
specificity = tn / (tn+fp)
print(specificity)

In [None]:
print('\n', 'Validation & Test')
v_data = ImageCombinedDataset(X_testing, y_testing, augs=False)

results, labels = [], []    
for i, sample in tqdm(enumerate(v_data)): 
    img, cl = sample
    labels.append(cl.item())
    logits = model(img.unsqueeze(0)).squeeze()  
    res = torch.sigmoid(logits) > 0.5
    results.append(res.item())
    
print(confusion_matrix(labels, results))
# Accuracy, Precision, Recall, F1-score, AUROC,
print('Precision', precision_score(labels, results, average='macro'))
print('Recall', recall_score(labels, results, average='macro'))
print('Accuracy', accuracy_score(labels, results))
print('F1', f1_score(labels, results, average='macro'))

tn, fp, fn, tp = confusion_matrix(labels, results).ravel()
specificity = tn / (tn+fp)
print(specificity)