In [23]:
import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
# import audiomentations
# from audiomentations import Compose, AddGaussianNoise, PitchShift
# import torch_audiomentations
# from torch_audiomentations import Compose, AddGaussianNoise, PitchShift
import torchaudio

In [2]:
# Adapted from https://github.com/musikalkemist/pytorchforaudio
class CoughDataset(Dataset):

    def __init__(self,
                 annotations_df,
                 audio_dir,
                 transformation,
                 target_sample_rate,
                 num_samples,
                 device,
                 augment=False,
                ):
        self.annotations = annotations_df
        self.audio_dir = audio_dir
        self.device = device
        self.transformation = transformation.to(self.device)
        self.target_sample_rate = target_sample_rate
        self.num_samples = num_samples
        self.label_dict = {'healthy':0, 'symptomatic':1, 'COVID-19':2}
        
        self.do_augment = augment
#         self.augmentations = Compose(
#                 [
#                     AddGaussianNoise(min_amplitude=0.01, max_amplitude=0.05, p=0.5),
#                     PitchShift(min_semitones=-8, max_semitones=8, p=0.5)
#                 ]
#         )

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        audio_sample_path = self._get_audio_sample_path(index)
        label = self.label_dict[self._get_audio_sample_label(index)]
        signal, sr = torchaudio.load(audio_sample_path)
        
#         if self.do_augment:
#             signal = torch.from_numpy(self.augmentations(signal.numpy(), sr))
        
        signal = signal.to(self.device)
        signal = self._resample_if_necessary(signal, sr)
        signal = self._mix_down_if_necessary(signal)
        signal = self._cut_if_necessary(signal)
        signal = self._right_pad_if_necessary(signal)
        signal = self.transformation(signal)
        
        return signal, label

    def _cut_if_necessary(self, signal):
        if signal.shape[1] > self.num_samples:
            signal = signal[:, :self.num_samples]
        return signal

    def _right_pad_if_necessary(self, signal):
        length_signal = signal.shape[1]
        if length_signal < self.num_samples:
            num_missing_samples = self.num_samples - length_signal
            last_dim_padding = (0, num_missing_samples)
            signal = torch.nn.functional.pad(signal, last_dim_padding)
        return signal

    def _resample_if_necessary(self, signal, sr):
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
            signal = resampler(signal)
        return signal

    def _mix_down_if_necessary(self, signal):
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        return signal

    def _get_audio_sample_path(self, index):
        path = os.path.join(self.audio_dir, self.annotations.iloc[index, 0])+".wav"
        return path

    def _get_audio_sample_label(self, index):
        return self.annotations.iloc[index, 9]


In [3]:
AUDIO_DIR = "../valid_data/"
SAMPLE_RATE = 16000
NUM_SAMPLES = SAMPLE_RATE*10

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
# print(f"Using device {device}")

mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=1024,
    hop_length=512,
    n_mels=128
)

train_df = pd.read_parquet(os.path.join(AUDIO_DIR, "train.parquet.gzip"))
val_df = pd.read_parquet(os.path.join(AUDIO_DIR, "train.parquet.gzip"))
test_df = pd.read_parquet(os.path.join(AUDIO_DIR, "train.parquet.gzip"))


dataset = CoughDataset(train_df,
                        AUDIO_DIR,
                        mel_spectrogram,
                        SAMPLE_RATE,
                        NUM_SAMPLES,
                        device,
                        augment=False
                      )
# print(f"There are {len(usd)} samples in the dataset.")
# signal, label = usd[0]

In [4]:
dataset[0][0].shape

torch.Size([1, 128, 313])

In [25]:
class CNNNetwork(torch.nn.Module):

    def __init__(self):
        super().__init__()
        # 4 conv blocks / flatten / linear / softmax
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=5,
                stride=2,
                padding=0
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.BatchNorm2d(8)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=8,
                out_channels=12,
                kernel_size=3,
                stride=1,
                padding=0
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.BatchNorm2d(12)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=12,
                out_channels=12,
                kernel_size=3,
                stride=1,
                padding=0
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.BatchNorm2d(12)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(
                in_channels=12,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=0
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.BatchNorm2d(16)
        )
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(224, 3)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_data):
        #nomralization
        std = input_data.std()
        input_data -= input_data.mean()
        input_data /= std
        
        x = self.conv1(input_data)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        logits = self.linear(x)
        predictions = self.softmax(logits)
        return predictions
    



In [46]:
BATCH_SIZE =8
EPOCHS =5
MODEL_FOLDER = '../models/'

def create_data_loader(train_data, batch_size):
    train_dataloader = DataLoader(train_data, batch_size=batch_size)
    return train_dataloader

def count_correct(logits, y_true):
    y_pred = torch.argmax(logits, axis = 1)
    return torch.sum(y_pred==y_true)

def train_single_epoch(model, train_data_loader, val_data_loader, loss_fn, optimiser, device):
    total_loss = 0.0
    correct_pred = 0.0
    total_pred = 0
    for x_batch, y_batch in tqdm(train_data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # calculate loss
        y_pred = model(x_batch)
        loss = loss_fn(y_pred, y_batch)
        
        correct_pred += count_correct(y_pred, y_batch)
        total_pred += y_batch.shape[0]

        # backpropagate error and update weights
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        total_loss += loss.item()
        
    print(f"Training loss: {total_loss}, Training accuracy : {correct_pred/total_pred}")
    
    total_loss = 0.0
    correct_pred = 0.0
    total_pred = 0
    for x_batch, y_batch in val_data_loader:
        with torch.no_grad():
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            y_pred = model(x_batch)
            loss = loss_fn(y_pred, y_batch)
            total_loss += loss.item() 
            
        correct_pred += count_correct(y_pred, y_batch)
        total_pred += y_batch.shape[0]
        
    print(f"Validataion loss: {total_loss}, Validation accuracy : {correct_pred/total_pred}")

    
def train(model, train_data_loader, val_data_loader, loss_fn, optimiser, device, epochs):
    for i in range(epochs):
        print(f"Epoch {i+1}")
        train_single_epoch(model, train_data_loader, val_data_loader, loss_fn, optimiser, device)
        
        path = os.path.join(MODEL_FOLDER, f"epoch_{i}.pth")
        torch.save(cnn.state_dict(), path)
        print(f"Saved at {path}")
        print("---------------------------")
    print("Finished training")
    print("---------------------------")
    
    
def evaluate(model, eval_data_loader, loss_fn, device):
    print("Evaluating model")
    total_loss = 0.0
    correct_pred = 0.0
    total_pred = 0
    for x_batch, y_batch in tqdm(eval_data_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        # calculate loss
        y_pred = model(x_batch)
        loss = loss_fn(y_pred, y_batch)
        
        correct_pred += count_correct(y_pred, y_batch)
        total_pred += y_batch.shape[0]

        # backpropagate error and update weights
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        total_loss += loss.item()
        
    print(f"Evaluation loss: {total_loss}, Evaluation accuracy : {correct_pred/total_pred}")
    print("---------------------------")

In [47]:
train_data = CoughDataset(train_df,
                        AUDIO_DIR,
                        mel_spectrogram,
                        SAMPLE_RATE,
                        NUM_SAMPLES,
                        device)

val_data = CoughDataset(val_df,
                        AUDIO_DIR,
                        mel_spectrogram,
                        SAMPLE_RATE,
                        NUM_SAMPLES,
                        device)

test_data = CoughDataset(test_df,
                        AUDIO_DIR,
                        mel_spectrogram,
                        SAMPLE_RATE,
                        NUM_SAMPLES,
                        device)

train_dataloader = create_data_loader(train_data, BATCH_SIZE)
val_dataloader = create_data_loader(val_data, BATCH_SIZE)
test_dataloader = create_data_loader(val_data, BATCH_SIZE)

# construct model and assign it to device
model = CNNNetwork().to(device)

# initialise loss funtion + optimiser
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(cnn.parameters())

# train model
train(model, train_dataloader, val_dataloader, loss_fn, optimiser, device, EPOCHS)
evaluate(model, test_dataloader, loss_fn, device)

# save model


Epoch 1


  0%|                                                                                                                                                                          | 0/1371 [00:00<?, ?it/s]


Training loss: 1.0370057821273804, Training accuracy : 0.5
Validataion loss: 1.0370057821273804, Validation accuracy : 0.5
Saved at ../models/epoch_0.pth
---------------------------
Epoch 2


  0%|                                                                                                                                                                          | 0/1371 [00:00<?, ?it/s]


Training loss: 1.0370057821273804, Training accuracy : 0.5
Validataion loss: 1.0370057821273804, Validation accuracy : 0.5
Saved at ../models/epoch_1.pth
---------------------------
Epoch 3


  0%|                                                                                                                                                                          | 0/1371 [00:00<?, ?it/s]


Training loss: 1.0370057821273804, Training accuracy : 0.5
Validataion loss: 1.0370057821273804, Validation accuracy : 0.5
Saved at ../models/epoch_2.pth
---------------------------
Epoch 4


  0%|                                                                                                                                                                          | 0/1371 [00:00<?, ?it/s]


Training loss: 1.0370057821273804, Training accuracy : 0.5
Validataion loss: 1.0370057821273804, Validation accuracy : 0.5
Saved at ../models/epoch_3.pth
---------------------------
Epoch 5


  0%|                                                                                                                                                                          | 0/1371 [00:00<?, ?it/s]


Training loss: 1.0370057821273804, Training accuracy : 0.5
Validataion loss: 1.0370057821273804, Validation accuracy : 0.5
Saved at ../models/epoch_4.pth
---------------------------
Finished training
Evaluating model


  0%|                                                                                                                                                                          | 0/1371 [00:00<?, ?it/s]

Evaluation loss: 1.0370057821273804, Evaluation accuracy : 0.5
---------------------------





In [9]:
CNNNetwork()

CNNNetwork(
  (conv1): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv2): Sequential(
    (0): Conv2d(8, 12, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): Sequential(
    (0): Conv2d(12, 12, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv4): Sequential(
    (0): Conv2d(12, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ce