In [1]:
import os

In [2]:
path_to_dataset = "D:\COUGHVID\public_dataset"
files = os.listdir(path_to_dataset)
print(f"There are {len(files)} files")
extensions = set([os.path.splitext(f)[1] for f in files])
print(f"Files have extensions {extensions}")

There are 55101 files
Files have extensions {'.json', '.webm', '.csv', '.ogg'}


In [3]:
files_by_type = {}
for ext in extensions:
    files_by_type[ext] = [f for f in files if f.endswith(ext)]
    print(f"There are {len(files_by_type[ext])} files with extension {ext}")

There are 27550 files with extension .json
There are 25985 files with extension .webm
There are 1 files with extension .csv
There are 1565 files with extension .ogg


In [4]:
import librosa
import IPython

In [5]:
import numpy as np

In [6]:
def segment_cough(x,fs, cough_padding=0.2,min_cough_len=0.2, th_l_multiplier = 0.1, th_h_multiplier = 2):
    """Preprocess the data by segmenting each file into individual coughs using a hysteresis comparator on the signal power
    
    Inputs:
    *x (np.array): cough signal
    *fs (float): sampling frequency in Hz
    *cough_padding (float): number of seconds added to the beginning and end of each detected cough to make sure coughs are not cut short
    *min_cough_length (float): length of the minimum possible segment that can be considered a cough
    *th_l_multiplier (float): multiplier of the RMS energy used as a lower threshold of the hysteresis comparator
    *th_h_multiplier (float): multiplier of the RMS energy used as a high threshold of the hysteresis comparator
    
    Outputs:
    *coughSegments (np.array of np.arrays): a list of cough signal arrays corresponding to each cough
    cough_mask (np.array): an array of booleans that are True at the indices where a cough is in progress"""
                
    cough_mask = np.array([False]*len(x))
    

    #Define hysteresis thresholds
    rms = np.sqrt(np.mean(np.square(x)))
    seg_th_l = th_l_multiplier * rms
    seg_th_h =  th_h_multiplier*rms

    #Segment coughs
    coughSegments = []
    padding = round(fs*cough_padding)
    min_cough_samples = round(fs*min_cough_len)
    cough_start = 0
    cough_end = 0
    cough_in_progress = False
    tolerance = round(0.01*fs)
    below_th_counter = 0
    
    for i, sample in enumerate(x**2):
        if cough_in_progress:
            if sample<seg_th_l:
                below_th_counter += 1
                if below_th_counter > tolerance:
                    cough_end = i+padding if (i+padding < len(x)) else len(x)-1
                    cough_in_progress = False
                    if (cough_end+1-cough_start-2*padding>min_cough_samples):
                        coughSegments.append(x[cough_start:cough_end+1])
                        cough_mask[cough_start:cough_end+1] = True
            elif i == (len(x)-1):
                cough_end=i
                cough_in_progress = False
                if (cough_end+1-cough_start-2*padding>min_cough_samples):
                    coughSegments.append(x[cough_start:cough_end+1])
            else:
                below_th_counter = 0
        else:
            if sample>seg_th_h:
                cough_start = i-padding if (i-padding >=0) else 0
                cough_in_progress = True
    
    return coughSegments, cough_mask

In [42]:
import time

In [68]:
def test_segmentation_time():
    start = time.time()

    outfile = TemporaryFile()
    count = 1
    total_mask = []
    for file in files_by_type['.webm']:

        if count == 100:
            break
        audio_file = os.path.join(path_to_dataset, file)
        x, sr = librosa.load(audio_file)
        cough_segment, cough_mask = segment_cough(x, sr, cough_padding=0.2,min_cough_len=0.2, th_l_multiplier = 0.1, th_h_multiplier = 2)
        total_mask.append(cough_mask)
        count += 1
       
    np.save('segmentation.npy', total_mask, allow_pickle = True)
    end = time.time()
    total = end - start
    return total

In [69]:
test_segmentation_time()

72.87383246421814

In [86]:
start = time.time()
data = np.load('segmentation.npy', allow_pickle = True)
end = time.time()
print(end - start)

0.012370586395263672
