In [None]:
from __future__ import print_function, division
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys, time, random, glob, os
import numpy as np

from IPython.display import Audio, clear_output
#from pyAudioAnalysis import audioBasicIO, audioFeatureExtraction

from scipy.io import wavfile
from scipy import signal as sig
from pydub import AudioSegment
from pydub import effects
import sklearn
from sklearn import neighbors, datasets, metrics, linear_model
import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.model_selection import learning_curve
from numpy.random.mtrand import permutation
import random
import pandas as pd 
import librosa
import re

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [None]:
import torch
import numpy as np
from pydub import AudioSegment
import glob
from silero_vad import (load_silero_vad,
                    read_audio,
                    get_speech_timestamps,
                    save_audio,
                    VADIterator,
                    collect_chunks)

silero_model = load_silero_vad(onnx=False)

def use_silero_vad(file_path, Fs=16000, retun_seconds=True):
    wav = read_audio(file_path, sampling_rate=16000)
    speech_timestamps = get_speech_timestamps(wav, silero_model, threshold=0.5, sampling_rate=Fs, return_seconds=retun_seconds)
    # print(speech_timestamps)

    return speech_timestamps


def add_duration(duration_list, window_size, time_in_seconds=True):
    total_duration = 0 
    unpaked_duration = [] 

    for data in duration_list:
        if isinstance(data, dict):
            start = data['start']
            end   = data['end']

            total_duration += end-start
            unpaked_duration.append([start, end])

        elif isinstance(data, AudioSegment):
            total_duration += data.duration_seconds
            print(data.duration_seconds)
        else:
            assert TypeError, f'Type: {type(data)}'
            break

        if time_in_seconds==False: # it's in ms convert to s 
            total_duration /= 1000 

    return total_duration, unpaked_duration

def get_window_timestamp_wo_df(file_duration, window_size=4, num_window=10, min_step=0.5, debug=False):
    step_size=(file_duration-window_size)/(num_window - 1) 
    valid_clips = []
    start = 0
    end = window_size
    for i in range(num_window):
        if start + window_size > file_duration: #, f"{i} {start} {window_size} {file_duration}"
            start = file_duration - window_size
        valid_clips.append(int(start * 10 ** 3) / 10 ** 3)
        start += step_size
    assert len(valid_clips) == num_window, f"Expected {num_window} valid clips, but got {len(valid_clips)}"
    
    return valid_clips


def get_window_timestamp(defect_list, file_duration, window_size=4, num_window=10, min_step=0.1, debug=False):
    
    windows = []
    MaxPossible = []
    valid_clips = []
    remaining = num_window 

    defect_list = np.array(defect_list)

    if len(defect_list) == 0 : 
        return get_window_timestamp_wo_df(file_duration, window_size, num_window, min_step, debug)

    i = 0
    while i < len(defect_list[:, 0]) - 1:
        defect_list[i, 1] = defect_list[i + 1, 1] if defect_list[i + 1, 0] <= defect_list[i, 1] else defect_list[i, 1]
        defect_list = np.delete(defect_list, i + 1, 0) if defect_list[i + 1, 0] <= defect_list[i, 1] else defect_list
        if i+1 < len(defect_list) and defect_list[i + 1, 0] > defect_list[i, 1]:
            i += 1 
        else:
            i = 0


    for i in range (len(defect_list[:,0])):
        start = 0 if  i == 0 else defect_list[i-1,1]
        windows.extend([defect_list[i,0]-start, start, defect_list[i,0]])
        windows.extend([file_duration-defect_list[-1,1],  defect_list[-1,1], file_duration]) if i  == len(defect_list[:,0])-1 else []
    windows = np.reshape(np.array(windows), (-1,3))
    windows = windows[(windows)[:,0].argsort()]

    new = np.squeeze(windows[np.argwhere(windows[:,0]>=window_size)],  axis=1)
    if debug:
        print('window:', windows)
        print('new window: ', new)

    MaxPossible = (((new[:,0] - window_size) / min_step + 1)).flatten()
    
    # print(np.sum(MaxPossible) , num_window)
    if np.sum(MaxPossible) < num_window : 
        # skip the file
        print(f'Possible windows: {np.sum(MaxPossible)} less than expected windows: {num_window}')
        return None

    proportional = [num_window*i/(np.sum(MaxPossible)) for i in MaxPossible]
    #validclips 
    assert  len(new[:,0]) == len(MaxPossible) == len(proportional)

    for x in range(len(new[:,0])):
        w = int(proportional[x] + 1) if int(proportional[x] + 1) < remaining else remaining 
        remaining -= w 
        new_step = (new[x,0]-window_size)/ (w - 1) if w > 1 else 0  # Avoid division by zero
        start = new[x,1]
        for i in range(w):
            valid_clips.append(int(start * 10 ** 3) / 10 ** 3)
            # valid_clips.append(round(start, 3))
            start = new[x,2] - window_size if i == w - 1 else start + new_step
    assert  len(valid_clips) == num_window , f"Expected {num_window} valid clips, but got {len(valid_clips)}"
    

    return valid_clips



def get_objList(file, win_timestamp,  window_size=4, Fs=16000):

    dataObj = AudioSegment.from_file(file)
    dataObj  = dataObj.set_frame_rate(Fs)

    objList = []
    for ts in win_timestamp:
        start = round(ts * 1000) 
        end = start + (window_size * 1000)

        # print(ts, start, end)

        obj = dataObj[start:end]
        objList.append(obj)

    return objList

In [None]:
##################
# NOTE Run this cell if you have labels in a csv file - (UCSF R2D2_validation data)
##################
import torch
import numpy as np
from pydub import AudioSegment
import glob
import os
import sys
sys.path.append('.')
from utils.pre_generator_noderiv import get_click_objects
from utils.pre_generator_noderiv import create_dirs
from utils.refactored_common import gen_mel_feature, yaml_load
import pandas as pd
import random
import re

DEBUG=False
param = yaml_load('./old_files/gen_config.yaml')

tmp_path = './data/new_R2D2_Data/R2D2_Train_Data'   # audium base

csv_file = os.path.join('./data/new_R2D2_Data/R2D2 lung sounds metadata_TRAIN_2025.05.08_v3.csv')
patient_data=pd.read_csv(csv_file)
patient_data.head()

normal_patient_ids = patient_data[patient_data['Sputumxpertreferencestandard'] == "TB Negative"]['StudyID'].tolist()
tb_patient_ids = patient_data[patient_data['Sputumxpertreferencestandard'] == "TB Positive"]['StudyID'].tolist()
patient_country = {p:c for p, c in zip(patient_data['StudyID'], patient_data['Country'])}

print('normal patients:', len(normal_patient_ids), 'tb patients:',len(tb_patient_ids))

#  (60% + 10% + 10% + 20% 
# 50% + + 10% + 10% + 30% test
# (55% + 10% + 10% + 25% split
# data_split = [0.3, 0.4, 0.5]
# (60% + 10% + 10% + 20% test)
# 50%+ 10% + 10% + 30%)  + ignore (100% + 0% + 0%) + 150_new_patient (70% + 10% + 10% + 10%)
# (30%+ 10% + 10% + 50%)
# (55% + 15% + 10% + 20%)  
# 0:55, 55:70, 70:80, 80:100
# train: 0.65, val_train: 0.70, val_test: 0.7-0.8 test: 0.80-1.0
# train: 0.65, (val_train: 0.70, val_test: 0.7-0.8) test: 0.80-1.0
# data_split = [0.80, 0.81, 0.90]

test_mode = True
train_test_normal = []
train_test_tb = []
train_normal = []
train_tb = []
patient_id_pattern = r'^R2D2\d{5}$'

Fs = 16000
expected_files = 20
min_expected_files = 16
expected_duration = 16
duration_thresh = 4
defect_dict = {} # speech [{start, end}], peaks [{start, end}]
filtered_file_dict = {}
DEBUG = False
repeat_files = True ## option 1 
window_size = 4 
num_window = 9
min_step = 0.5

# features to gen 
gen_mel = False
gen_stft = True

total_patients_found = [] # just for debug
pateint_label_notFound = []

# is good == 0 -> use for train only 
# is good == 1 -> ok use in train and val
# is good == 2 -> empty
for subdir, dirs, files in os.walk(tmp_path):
    # print (subdir, len(files)) #, files)
    is_good = 1
    tmp_files = []
    dirname = os.path.basename(subdir)
    tmp_files.extend(glob.glob(os.path.join(subdir, "*.wav")))
    print('subdir', subdir, len(tmp_files))
    if len(tmp_files) == 0:
        is_good = 2
        continue 
    
    # continue
    # if (dirname != 'R2D204281'): continue 
    # print(subdir)
    # if ('R2D204227' not in subdir): continue        
    
    filtered_files = []
    # tmp_files = []
    for file in tmp_files:
        try :
            filename = os.path.basename(file)
            dataObj = AudioSegment.from_file(file)
            assert dataObj.frame_rate == Fs
        except:
            print ("File read error:", filename)
            continue

        if dataObj.duration_seconds < expected_duration or dataObj.duration_seconds > 21:
            print(f'Duration error, file: {filename}, duration: {dataObj.duration_seconds}')

        else: ## good file > proceed checks
            speech_duration = use_silero_vad(file, Fs)
            if DEBUG: print('speech', speech_duration)
            total_speech_duration, speech_duration_ = add_duration(speech_duration, window_size)
            if DEBUG: print(f'File : {filename}, speech duration: {total_speech_duration}s')

            if total_speech_duration > duration_thresh:
                print(f'speech duration is more skipping file: {file}, duration: {total_speech_duration}')
                continue 

            # # click_objList = get_click_objects([file], duration=100, Fs=Fs, disable_tqdm=True)
            # click_objList = get_click_objects([file], duration=0.1, Fs=Fs, dB_threshold=-5, disable_tqdm=True, DEBUG=False)
            # # click_objList[0][2] = []
            # click_duration_list = click_objList[0][2] if len(click_objList[0])>=3 else click_objList[0][1]
            # print('click', click_duration_list)        
            # total_peak_duration, unpaked_peak_duraiton = add_duration(click_duration_list, window_size)
            # if DEBUG:print(f'File : {filename}, peak duration: {total_peak_duration}s, unpaked peak_duration: ', unpaked_peak_duraiton)
            # if total_peak_duration > duration_thresh: 
            #     print(f'peak duration is more skipping file: {file}')
            #     continue

            # if DEBUG: print(f'File : {filename}, total duration: {(total_speech_duration + total_peak_duration)}s')
            # if (total_speech_duration + total_peak_duration) > duration_thresh:
            #     print(f'defect duration is more skipping file: {file}')
            #     continue
            
            defect_ = speech_duration_ #+ unpaked_peak_duraiton
            win_timestamp = get_window_timestamp(defect_, dataObj.duration_seconds, window_size, num_window, min_step, debug=DEBUG)
            if DEBUG: print('win_timestamp', file, win_timestamp)

            if win_timestamp != None and len(win_timestamp) == num_window:
                defect_dict[filename] = win_timestamp
            else:
                print(f'File: {filename}, expected win: {num_window}, windows got: {win_timestamp}')
                continue 
 
            filtered_files.append(file)
                
    if len(filtered_files) > expected_files:
        filtered_files = filtered_files[0:expected_files]
    elif min_expected_files <= len(filtered_files) < expected_files:  ## option 1 
        # is_good = 0  # FIXME gowtham 
        # print ("num files issue 2", subdir, len(filtered_files))
        if repeat_files:
            if DEBUG: print('???', len(filtered_files))
            filtered_files_ = []
            filtered_files_.extend(filtered_files)
            files_need = expected_files - len(filtered_files)
            filtered_files_.extend(random.choices(filtered_files, k=files_need))
            filtered_files = filtered_files_

    else:
        print(f'ERROR: file count criteria not matching. filtered files: {len(filtered_files)}')
    #     continue
    # break
    print('subdir3', subdir)
    subdir_parts = subdir.split('/')
    patient_id = None
    for part in subdir_parts[::-1]:
        if re.match(patient_id_pattern, part):
            patient_id = part
            # assert patient_id not in total_patients_found, f'{patient_id}, {total_patients_found}'

            if patient_id not in total_patients_found and (patient_id in normal_patient_ids or patient_id in tb_patient_ids):
                total_patients_found.append([patient_id, len(tmp_files)])
            break
    else:
        print('Error in patient id finding ', subdir_parts, subdir)
    assert patient_id is not None, f'{patient_id}'
    
    if patient_id not in filtered_file_dict:
        filtered_file_dict[patient_id] = filtered_files
    else:
        if DEBUG: print('adding into existing patient data')
        filtered_file_dict[patient_id].extend(filtered_files)

    if DEBUG: print('!!!!', patient_id, len(filtered_files))

    if len(filtered_files) == 0 or patient_id is None:
        pass 
    elif len(filtered_files) == expected_files:
        if patient_id in normal_patient_ids:
            train_test_normal.append(patient_id)
        elif patient_id in tb_patient_ids :
            train_test_tb.append(patient_id)
        else:
            print('label not found for pid', patient_id)
            pateint_label_notFound.append(patient_id)
    else:
        if patient_id in normal_patient_ids:
            train_normal.append(patient_id)
        elif patient_id in tb_patient_ids :
            train_tb.append(patient_id)
        else:
            print('label not found for pid', patient_id)
            pateint_label_notFound.append(patient_id)


print (f'train test normal: {len(train_test_normal)}, train normal: {len(train_normal)}, train test tb: {len(train_test_tb)}, train tb: {len(train_tb)}')


In [None]:
outdir = "./data/dia_split2_2/"

data_split = [0.50, 0.60, 0.70]


if test_mode: 
    print ("Number used only for training", len(train_normal), len(train_tb))
    
    # train_split_normal = []
    # val_split_normal = []
    # test_split_normal = train_test_normal

    # train_split_tb = []
    # val_split_tb = []
    # test_split_tb = train_test_tb
    
    # train_normal = []
    # train_tb = []

    random.shuffle(train_test_normal)
    random.shuffle(train_test_tb)
    l1 = len(train_test_normal)

    train_split_normal = train_test_normal[0:int(data_split[0]*l1)]
    val_train_split_normal = train_test_normal[int(data_split[0]*l1): int(data_split[1]*l1)]
    val_test_split_normal = train_test_normal[int(data_split[1]*l1): int(data_split[2]*l1)]
    test_split_normal = train_test_normal[int(data_split[2]*l1): ]

    l1 = len(train_test_tb)
    train_split_tb = train_test_tb[0:int(data_split[0]*l1)]
    val_train_split_tb = train_test_tb[int(data_split[0]*l1): int(data_split[1]*l1)]
    val_test_split_tb = train_test_tb[int(data_split[1]*l1): int(data_split[2]*l1)]
    test_split_tb = train_test_tb[int(data_split[2]*l1): ]


else :
    assert False
    # gentype split 
    random.shuffle(train_test_normal)
    random.shuffle(train_test_tb)
    l1 = len(train_test_normal)

    train_split_normal = train_test_normal[0:int(data_split[0]*l1)]
    val_split_normal = train_test_normal[int(data_split[0]*l1): int(data_split[1]*l1)]
    test_split_normal = train_test_normal[int(data_split[1]*l1): ]

    l1 = len(train_test_tb)
    train_split_tb = train_test_tb[0:int(data_split[0]*l1)]
    val_split_tb = train_test_tb[int(data_split[0]*l1): int(data_split[1]*l1)]
    test_split_tb = train_test_tb[int(data_split[1]*l1): ]

    #  [0.0, 1.537, 3.075, 7.175, 8.645, 10.116, 11.587, 13.058, 14.529, 15.999]
print(train_normal)

In [None]:
p = np.array(total_patients_found)
_, idx = np.unique(p[:, 0], return_index=True)
flattern = p[np.sort(idx)]
print(flattern.shape)

p = flattern[::]

print(p.shape)
p = np.array(p[p[:, 1] != 0])
print(p.shape)

pid = p[:, :1].flatten().tolist()
files = p[:, 1:].flatten().tolist()


In [None]:

# print('total patients', len(set(total_patients_found)))
print('total test normal :', len(train_test_normal))
print('total test tb     :', len(train_test_tb))
print('data split   :', data_split)

print('train normal :', len(train_normal))
print('train tb     :', len(train_tb))

print(f'total patients: {len(train_test_normal) + len(train_test_tb) + len(train_normal) + len(train_tb)}')

print(len(train_split_normal), len(val_train_split_normal), len(val_test_split_normal), len(test_split_normal))
print(len(train_split_tb), len(val_train_split_tb), len(val_test_split_tb), len(test_split_tb))


In [None]:
##################
# NOTE Run this cell if you have labels in a csv file - (UCSF R2D2_validation data)
##################

from tqdm.notebook import tqdm

# target_dBFS = -28
Fs = 8000
n_fft = 512      # 512 -> 0.4ms, 8000 -> 64 or 128
hop_length = 160 #160 #512 #1200  # 160 # [ 240 -> for impulse ]
win_length = 512
window = 'hann'
n_mels = 128
    
def process_file(f, duration=20) :
    if isinstance(f, str):
        dataObj = AudioSegment.from_file(f)
    elif isinstance(f, AudioSegment):
        dataObj = f
    else:
        assert False, f'Unknown data type {type(f)}'            
    
    # FIXME: Change the sampling rate -- check frame rate API in pydub
    dataObj = dataObj.set_frame_rate(Fs)
    
    # print(dataObj.duration_seconds, duration)
    if dataObj.duration_seconds != duration :
        print ("Duration issue", f, dataObj.duration_seconds)
        return None

    # change_in_dBFS = target_dBFS - dataObj.dBFS
    # dataObj = dataObj.apply_gain(change_in_dBFS)
    data = dataObj.get_array_of_samples()
    data = np.array(data)/32768
    # print('data:', data.shape, data[0])

    if False:
        t = gen_mel_feature_local(data, Fs, window)
        print (t.shape)
        return  t
    else:
        return gen_mel_feature(param, data, Fs, n_fft, hop_length, win_length, n_mels, window=window, enable_librosa=True, gen_stft=True) # librosa = True


# tuples_ = [
#         #    (train_split_normal, "train", "good"), 
#         #    (train_normal, "train", "good"), 
#         #    (test_split_normal, "test", "good"), 
#         #    (val_split_normal, "val", "good"), 
#         #    (train_split_tb, "train", "bad"), 
#         #    (train_tb, "train", "bad"), 
#         #    (test_split_tb, "test", "bad"), 
#         #    (val_split_tb, "val", "bad")  
#             (train_test_normal, "test", "good"), 
#             (train_test_tb, "val", "bad"), 
#            ]

# if set(train_split_normal).issuperset(set(train_normal)) == False:
#     train_split_normal.extend(train_normal)
# if set(train_split_tb).issuperset(set(train_tb)) == False:
#     train_split_tb.extend(train_tb)

train_split_normal.extend(train_normal)
train_split_tb.extend(train_tb)

print('normal', len(train_split_normal))
print('tb', len(train_split_tb))

val_split_normal = val_train_split_normal.copy()
val_split_normal.extend(val_test_split_normal)
val_split_tb = val_train_split_tb.copy()
val_split_tb.extend(val_test_split_tb)

tuples_ = [[train_split_normal[:int(len(train_split_normal)*1)], test_split_normal[:int(len(test_split_normal)*1)], val_split_normal[:int(len(val_split_normal)*1)]],
           [train_split_tb, test_split_tb, val_split_tb]]

gen_types = ['train', 'test', 'val']
classList = ['good', 'bad']

pid_seen = []
for idx, data_path_list in enumerate(tuples_):
    #print('!!!', idx, len(data_path_list))
    print(data_path_list)
    for genId, gen_path_ist in enumerate(data_path_list):
        genType = gen_types[genId]
        classname = classList[idx]
        print(outdir, genType, classname)
        # break
        # create_dirs(outdir, classname)
        #     for t in ["train", "val_train", "val_test", "test"] :
        path_ = os.path.join(outdir, genType, classname)
        if not os.path.exists(path_) :
            os.makedirs(path_)

        
        print(genId, gen_types[genId], len(gen_path_ist))
        for pid in tqdm(gen_path_ist):
            # files = glob.glob(os.path.join(path, '*.wav'))
            # dirname = os.path.basename(path)
            files = filtered_file_dict.get(pid, [])
            # print(files)
            #print('!!!', len(files))
            pid_seen.append(pid)
            for file in files:
                filename = os.path.basename(file)
                file_win_ts = defect_dict.get(filename, [])
                file_win_ts.sort()
                win_objlist = get_objList(file, file_win_ts, window_size, Fs)
                # print(len(win_objlist))

                for i, obj in enumerate(win_objlist):
                    x = file.split('/')
                    # country = '' #x[-4]
                    # pid = None
                    # for part in x[::-1]:
                    #     if re.match(patient_id_pattern, part):
                    #         pid = part
                    #         break
                    # if pid is None:
                    #     print('Error not able fo find patient id', x)
                    #     continue

                    country = patient_country[pid]
                    # assert country is None, f'country not found for {pid}'

                    # add start ts in filename 
                    ts = file_win_ts[i]
                    baseName =  pid + "_" + os.path.splitext(os.path.basename(file))[0] + f'_{i}' #_{ts}'
                    #dst = os.path.join(outdir, genType, classname, baseName + '.flac')
                    mel_dst = os.path.join(outdir, genType, classname, "_" + baseName + '.pt')
                    # FIXME: Create pt and change path

                    # print(dst, mel_dst)
                    #obj.export(dst, format="flac")
                    print(mel_dst)
                    #print(mel_feature)
                    mel_feature = process_file(obj, duration=window_size)
                    #print(mel_dst, mel_feature.shape, mel_feature[0][0])

                    # print(mel_dst)
                    # with open(mel_dst, 'wb+') as fh:
                    #     np.savez_compressed(fh, f=mel_feature)

                    # FIXME: torch.save()       
                    torch.save(torch.tensor(mel_feature), mel_dst)

    # break

       

In [None]:
len(set(pid_seen)), len(pid_seen)

In [None]:
pid_seen2 = []
for idx, data_path_list in enumerate(tuples_):
    #print('!!!', idx, len(data_path_list))
    for genId, gen_path_ist in enumerate(data_path_list):
        genType = gen_types[genId]
        classname = classList[idx]

        # create_dirs(outdir, classname)

        print(genId, gen_types[genId], len(gen_path_ist))
        for pid in tqdm(gen_path_ist):
            # files = glob.glob(os.path.join(path, '*.wav'))
            # dirname = os.path.basename(path)
            files = filtered_file_dict.get(pid, [])
            # print(files)
            #print('!!!', len(files))
            pid_seen2.append(pid)

len(pid_seen2)

In [None]:
for pid in filtered_file_dict:
    if len(filtered_file_dict[pid]) == 0:
        print(pid)