In [None]:
# %%
import torch
from torch.utils.data import Dataset
import soundfile as sf
import os, fnmatch
import numpy as np
import torch.nn as nn
import torch.fft as fft
from sklearn.model_selection import train_test_split
from scipy.signal.windows import hann
from tqdm import tqdm

#Data Normalization
def minMaxNorm(wav, eps=1e-8):
    max = np.max(abs(wav))
    min = np.min(abs(wav))
    wav = (wav - min) / (max - min + eps)
    return wav

class DataConfig():
    def __init__(self, 
                 frameSize = 512, 
                 stride_length = 32,
                 sample_rate = 16000,
                 duration = 3,
                 n_fft = 512,
                 modelBufferFrames = 10,
                 batchSize = 32,
                 shuffle = True,
                 noisyPath = 'dataset/train/',
                 cleanPath = 'dataset/y_train/',
                 dtype = torch.float64,
                 device = 'cpu',
                 learningRate = 0.001,
                ):
        
        self.frameSize = frameSize
        self.stride_length = stride_length
        self.sample_rate = sample_rate
        self.duration = duration
        self.n_fft = n_fft
        self.modelBufferFrames = modelBufferFrames
        self.batchSize = batchSize
        self.shuffle = shuffle
        self.noisyPath = noisyPath
        self.cleanPath = cleanPath
        self.dtype = dtype
        self.device = device
        self.learningRate = learningRate

class CustomDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_data = self.inputs[idx]
        target_data = self.targets[idx]
        return input_data, target_data

class CustomDataloaderCreator():
    def __init__(self, 
                 noisy_files_list, 
                 clean_files_list,
                 test_noisy_files_list,
                 test_clean_files_list,
                 val_noisy_files_list,
                 val_clean_files_list,
                 dataconfig
                 ):
        
        self.debugFlag = True

        #Training Dataset
        self.noisy_files_list = noisy_files_list
        self.clean_files_list = clean_files_list
        self.trainNoisyDataset = []
        self.trainCleanDataset = []
        self.trainModelInputBuffers = []
        self.trainPhaseInfo = []
        self.targets = []

        #Test Dataset
        self.test_noisy_files_list = test_noisy_files_list
        self.test_clean_files_list = test_clean_files_list
        self.testNoisyDataset = []
        self.testCleanDataset = []
        
        #Validation Dataset
        self.val_noisy_files_list = val_noisy_files_list
        self.val_clean_files_list = val_clean_files_list
        self.validationNoisyDataset = []
        self.validationCleanDataset = []
        self.val_modelInputBuffers = []
        self.val_phaseInfo = []
        self.val_targets = []
        
        self.dataconfig = dataconfig

        if(dataconfig.dtype == torch.float64) :
            self.dtype = np.float64
            self.complexDtype = np.complex128
        
        elif(dataconfig.dtype == torch.float32) :
            self.dtype = np.float32
            self.complexDtype = np.complex64
        
    #Creates the specified duration audio clips from the noisy and clean files
    def createAudioClips(self):
        print("CustomDataLoader.createAudioClips()")
        speechSampleSize = self.dataconfig.duration * self.dataconfig.sample_rate

        listIter = [self.noisy_files_list, self.val_noisy_files_list]
        datasetIter = [(self.trainNoisyDataset, self.trainCleanDataset), (self.validationNoisyDataset, self.validationCleanDataset)]
        NOISY = 0
        CLEAN = 1
        TEST = 2    #Not used yet

        #Create the training and then validation dataset
        for index,currlist in enumerate(listIter):
            for idx,filename in enumerate(currlist):
                if idx == 500:
                    break

                noisySpeech,_ = sf.read(os.path.join(self.dataconfig.noisyPath, filename))
                cleanSpeech,_ = sf.read(os.path.join(self.dataconfig.cleanPath, filename))

                #Normalize
                noisySpeech = minMaxNorm(noisySpeech)
                cleanSpeech = minMaxNorm(cleanSpeech)

                numSubSamples = int(len(noisySpeech)/speechSampleSize)
                for i in range(numSubSamples):
                    datasetIter[index][NOISY].append(noisySpeech[i*speechSampleSize:(i+1)*speechSampleSize])
                    datasetIter[index][CLEAN].append(cleanSpeech[i*speechSampleSize:(i+1)*speechSampleSize])
        
    #This function creates train and validation inputs and targets
    # Input : Frequency Domain 10 frame buffer (size 10*framesize)
    # Target : Time Domain 2 ms clean speech (size strideLength)
    def createModelBufferInputs(self):
        print("CustomDataLoader.createModelBufferInputs()")
       
        fft_freq_bins = int(self.dataconfig.n_fft/2) + 1

        datasetIter = [(self.trainNoisyDataset, self.trainCleanDataset), (self.validationNoisyDataset, self.validationCleanDataset)]
        modelBufferFramesIter = [self.trainModelInputBuffers, self.val_modelInputBuffers]
        targetsIter = [self.targets, self.val_targets]
        
        NOISY = 0
        CLEAN = 1
       
        #Create the training and validation inputs and targets
        for index,data in enumerate(datasetIter):
            currNoisyDataset = data[NOISY]
            corrCleanDataset = data[CLEAN]
            print(f'xFrames (expectedFrames) per audio clip = {len(currNoisyDataset[0])//self.dataconfig.stride_length}')
            for idx, currNoisySample in enumerate(tqdm(currNoisyDataset)):
                modelInputBuffer = np.zeros((self.dataconfig.modelBufferFrames,fft_freq_bins)).astype(self.complexDtype)
                inbuffer = np.zeros((self.dataconfig.frameSize)).astype(self.dtype)

                for i in range(0, len(currNoisySample),self.dataconfig.stride_length):      

                    if(i+self.dataconfig.stride_length > len(currNoisySample)-1):
                        break

                    #inbuffer is moved : [__s1__+++++++++++++__s2__] -> [+++++++++++++__s2__]
                    inbuffer[:-self.dataconfig.stride_length] = inbuffer[self.dataconfig.stride_length:] 
                    #inbuffer is filled with new data: [+++++++++++++__s2__] -> [+++++++++++++----]
                    inbuffer[-self.dataconfig.stride_length:] = currNoisySample[i : i + self.dataconfig.stride_length]

                    #Start up time
                    if i < self.dataconfig.frameSize:
                        continue

                    buffer_array = np.array(inbuffer)
                    windowed_buffer = buffer_array * hann(len(buffer_array), sym=False)

                    # Taking the real-valued FFT
                    frame = np.fft.rfft(windowed_buffer)    

                    # if(self.debugFlag):
                    #     print(f'frame.shape = {frame.shape}')
            
                    # Shift the modelInputBuffer
                    modelInputBuffer[:-1, :] = modelInputBuffer[1:, :]

                    # Fill the last row of modelInputBuffer with the new spectrogram values
                    modelInputBuffer[-1, :] = frame
                    modelBufferFramesIter[index].append(np.array(modelInputBuffer))
                    targetsIter[index].append(np.array(corrCleanDataset[idx][i:i+self.dataconfig.stride_length]))
            
            # #Shuffle up the dataset if required
            # if self.dataconfig.shuffle:
            #     self.indices = np.random.permutation(len(modelBufferFramesIter[index]))
            # else:
            #     self.indices = np.arange(len(modelBufferFramesIter[index]))
            # print("CustomDataLoader.createModelBufferInputs(): modelinputbuffers size = ", len(self.trainModelInputBuffers))
            # print(f'gotten frames per modelinputbuffer = {len(self.trainModelInputBuffers)//len(self.trainNoisyDataset)}')
            # print("CustomDataLoader.createModelBufferInputs(): targets size = ", len(self.targets))

    #This function creates train and validation inputs and targets
    # Input : Frequency Domain 10 frame buffer (size 10*framesize)
    # Target : Frequency Domain 1 frame buffer (size framesize)
    def createModelBufferInputs2(self):
        print("CustomDataLoader.createModelBufferInputs2()")
       
        fft_freq_bins = int(self.dataconfig.n_fft/2) + 1

        datasetIter = [(self.trainNoisyDataset, self.trainCleanDataset), (self.validationNoisyDataset, self.validationCleanDataset)]
        modelBufferFramesIter = [self.trainModelInputBuffers, self.val_modelInputBuffers]
        targetsIter = [self.targets, self.val_targets]
        
        NOISY = 0
        CLEAN = 1
       
        #Create the training and validation inputs and targets
        for index,data in enumerate(datasetIter):
            currNoisyDataset = data[NOISY]
            corrCleanDataset = data[CLEAN]
            print(f'xFrames (expectedFrames) per audio clip = {len(currNoisyDataset[0])//self.dataconfig.stride_length}')
            for idx, currNoisySample in enumerate(tqdm(currNoisyDataset)):
                modelInputBuffer = np.zeros((self.dataconfig.modelBufferFrames,fft_freq_bins)).astype(self.complexDtype)
                inbuffer = np.zeros((self.dataconfig.frameSize)).astype(self.dtype)
                inbufferClean = np.zeros((self.dataconfig.frameSize)).astype(self.dtype)
            
                for i in range(0, len(currNoisySample),self.dataconfig.stride_length):      
                
                    if(i+self.dataconfig.stride_length > len(currNoisySample)-1):
                        break

                    #inbuffer is moved : [__s1__+++++++++++++__s2__] -> [+++++++++++++__s2__]
                    inbuffer[:-self.dataconfig.stride_length] = inbuffer[self.dataconfig.stride_length:] 
                    #inbuffer is filled with new data: [+++++++++++++__s2__] -> [+++++++++++++----]
                    inbuffer[-self.dataconfig.stride_length:] = currNoisySample[i : i + self.dataconfig.stride_length]

                    inbufferClean[:-self.dataconfig.stride_length] = inbufferClean[self.dataconfig.stride_length:] 
                    inbufferClean[-self.dataconfig.stride_length:] = corrCleanDataset[idx][i : i + self.dataconfig.stride_length]
                    
                    #Start up time
                    if i < self.dataconfig.frameSize:
                        continue

                    # ModelInput Creation
                    buffer_array = np.array(inbuffer)
                    windowed_buffer = buffer_array * hann(len(buffer_array), sym=False)
                    frame = np.fft.rfft(windowed_buffer)    
                    modelInputBuffer[:-1, :] = modelInputBuffer[1:, :]
                    modelInputBuffer[-1, :] = frame
                    modelBufferFramesIter[index].append(np.array(modelInputBuffer))

                    # Target Creation
                    clean_buffer_array = np.array(inbufferClean)
                    clean_windowed_buffer = buffer_array * hann(len(clean_buffer_array), sym=False)
                    clean_frame = np.fft.rfft(clean_windowed_buffer)
                    clean_iffted_segment = np.fft.irfft(clean_frame)
                    targetsIter[index].append(clean_iffted_segment)

    def createModelBufferInputs3(self):
        print("CustomDataLoader.createModelBufferInputs3()")
       
        fft_freq_bins = int(self.dataconfig.n_fft/2) + 1

        datasetIter = [(self.trainNoisyDataset, self.trainCleanDataset), (self.validationNoisyDataset, self.validationCleanDataset)]
        modelBufferFramesIter = [self.trainModelInputBuffers, self.val_modelInputBuffers]
        phaseInfoIter = [self.trainPhaseInfo, self.val_phaseInfo]
        targetsIter = [self.targets, self.val_targets]
        
        NOISY = 0
        CLEAN = 1
       
        #Create the training and validation inputs and targets
        for index,data in enumerate(datasetIter):
            currNoisyDataset = data[NOISY]
            corrCleanDataset = data[CLEAN]
            print(f'xFrames (expectedFrames) per audio clip = {len(currNoisyDataset[0])//self.dataconfig.stride_length}')
            for idx, currNoisySample in enumerate(tqdm(currNoisyDataset)):
                modelInputBuffer = np.zeros((self.dataconfig.modelBufferFrames,fft_freq_bins)).astype(self.complexDtype)
                inbuffer = np.zeros((self.dataconfig.frameSize)).astype(self.dtype)
                inbufferClean = np.zeros((self.dataconfig.frameSize)).astype(self.dtype)
            
                for i in range(0, len(currNoisySample),self.dataconfig.stride_length):      
                
                    if(i+self.dataconfig.stride_length > len(currNoisySample)-1):
                        break

                    #inbuffer is moved : [__s1__+++++++++++++__s2__] -> [+++++++++++++__s2__]
                    inbuffer[:-self.dataconfig.stride_length] = inbuffer[self.dataconfig.stride_length:] 
                    #inbuffer is filled with new data: [+++++++++++++__s2__] -> [+++++++++++++----]
                    inbuffer[-self.dataconfig.stride_length:] = currNoisySample[i : i + self.dataconfig.stride_length]

                    inbufferClean[:-self.dataconfig.stride_length] = inbufferClean[self.dataconfig.stride_length:] 
                    inbufferClean[-self.dataconfig.stride_length:] = corrCleanDataset[idx][i : i + self.dataconfig.stride_length]
                    
                    #Start up time
                    if i < self.dataconfig.frameSize:
                        continue

                    # ModelInput Creation
                    buffer_array = np.array(inbuffer)
                    windowed_buffer = buffer_array * hann(len(buffer_array), sym=False)
                    frame = np.fft.rfft(windowed_buffer)    

                    # Phase Info
                    phaseInfo = np.angle(frame)

                    modelInputBuffer[:-1, :] = modelInputBuffer[1:, :]
                    modelInputBuffer[-1, :] = frame
                    modelBufferFramesIter[index].append(np.array(modelInputBuffer))
                    phaseInfoIter[index].append(np.array(phaseInfo))

                    # Target Creation
                    clean_buffer_array = np.array(inbufferClean)
                    clean_windowed_buffer = buffer_array * hann(len(clean_buffer_array), sym=False)
                    clean_frame = np.fft.rfft(clean_windowed_buffer)
                    targetsIter[index].append(np.abs(clean_frame))

    # This function creates the test dataset 
    def createTestDataset(self):
        print("CustomDataLoader.createTestDataset()")
        speechSampleSize = self.dataconfig.duration * self.dataconfig.sample_rate
        for index,filename in enumerate(self.test_noisy_files_list):
            if index == 100:
                break

            noisySpeech,_ = sf.read(os.path.join(self.dataconfig.noisyPath, filename))
            cleanSpeech,_ = sf.read(os.path.join(self.dataconfig.cleanPath, filename))

            #Normalize
            noisySpeech = minMaxNorm(noisySpeech)
            cleanSpeech = minMaxNorm(cleanSpeech)

            numSubSamples = int(len(noisySpeech)/speechSampleSize)
            for i in range(numSubSamples):
                self.testNoisyDataset.append(noisySpeech[i*speechSampleSize:(i+1)*speechSampleSize])
                self.testCleanDataset.append(cleanSpeech[i*speechSampleSize:(i+1)*speechSampleSize])

        print("Test Noisy Dataset Size: ", len(self.testNoisyDataset))
        print("Test Clean Dataset Size: ", len(self.testCleanDataset))

    #Call this function to prepare the dataloader
    def prepare(self):
        self.createAudioClips()
        # self.createModelBufferInputs()
        self.createModelBufferInputs2()
        self.printMembers()
            
    def getTrainDataloader(self):
        trainingDataset = CustomDataset(
            np.array(self.trainModelInputBuffers).astype(self.complexDtype),
            np.array(self.targets).astype(self.dtype)
            )

        return torch.utils.data.DataLoader(trainingDataset,
            batch_size = self.dataconfig.batchSize,
            shuffle = self.dataconfig.shuffle,
            generator = torch.Generator(device= self.dataconfig.device)
            )

    def getTrainDataloader2(self):
        trainingDataset = CustomDataset(
            [np.array(self.trainModelInputBuffers).astype(self.complexDtype), np.array(self.trainPhaseInfo).astype(self.dtype)],
            np.array(self.targets).astype(self.dtype)
            )

        return torch.utils.data.DataLoader(trainingDataset,
            batch_size = self.dataconfig.batchSize,
            shuffle = self.dataconfig.shuffle,
            generator = torch.Generator(device= self.dataconfig.device)
            )

    def getValidationDataloader(self):
        validationDataset = CustomDataset(
            np.array(self.val_modelInputBuffers).astype(self.complexDtype),
            np.array(self.val_targets).astype(self.dtype)
            )

        return torch.utils.data.DataLoader(validationDataset,
                        batch_size = self.dataconfig.batchSize, 
                        shuffle = self.dataconfig.shuffle,
                        generator = torch.Generator(device=self.dataconfig.device)
                        )
    def getValidationDataloader2(self):
        validationDataset = CustomDataset(
            [np.array(self.val_modelInputBuffers).astype(self.complexDtype), np.array(self.val_phaseInfo).astype(self.dtype)],
            np.array(self.val_targets).astype(self.dtype)
            )

        return torch.utils.data.DataLoader(validationDataset,
                        batch_size = self.dataconfig.batchSize, 
                        shuffle = self.dataconfig.shuffle,
                        generator = torch.Generator(device=self.dataconfig.device)
                        )

    def printMembers(self):
        print('--------------------------DISPLAY---------------------------------------------')
        print(f'noisy_files_list.shape = {np.array(self.noisy_files_list).shape}')
        print(f'clean_files_list.shape = {np.array(self.clean_files_list).shape}')
        print(f'trainNoisyDataset.shape = {len(self.trainNoisyDataset)}')
        print(f'trainCleanDataset.shape = {len(self.trainCleanDataset)}')
        print(f'trainModelInputBuffers.shape = {len(self.trainModelInputBuffers)},{len(self.trainModelInputBuffers[0])}')
        print(f'targets.shape = {len(self.targets)}')

        print('-----------------------------------------------------------------------')
        print(f'test_noisy_files_list.shape = {np.array(self.test_noisy_files_list).shape}')
        print(f'test_clean_files_list.shape = {np.array(self.test_clean_files_list).shape}')
        print(f'testNoisyDataset.shape = {len(self.testNoisyDataset)}')
        print(f'testCleanDataset.shape = {len(self.testCleanDataset)}')
        print('-----------------------------------------------------------------------')
        print(f'val_noisy_files_list.shape = {np.array(self.val_noisy_files_list).shape}')
        print(f'val_clean_files_list.shape = {np.array(self.val_clean_files_list).shape}')
        print(f'validationNoisyDataset.shape = {len(self.validationNoisyDataset)}')
        print(f'validationCleanDataset.shape = {len(self.validationCleanDataset)}')
        print(f'val_modelInputBuffers.shape = {len(self.val_modelInputBuffers)},{len(self.val_modelInputBuffers[0])}')
        print(f'val_targets.shape = {len(self.val_targets)}')


In [None]:
noisyPath = '/home/ubuntu/OticonStuff/dataset/train'
cleanPath = '/home/ubuntu/OticonStuff/dataset/y_train'
noisy_files_list = fnmatch.filter(os.listdir(noisyPath), '*.wav')
clean_files_list = fnmatch.filter(os.listdir(cleanPath), '*.wav')


#Split into train and temp ( 70-15-15 split for now)
X_train, X_temp, y_train, y_temp = train_test_split(noisy_files_list, clean_files_list, test_size=0.3, random_state=42)

#Splitting the temp into validation and test
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# print(f'shape of numpy X_train: {np.array(X_train).shape}')
# # print(f'shape of numpy y_train: {np.array(y_train).shape}')
# print(f'shape of numpy X_test: {np.array(X_test).shape}')
# # print(f'shape of numpy y_test: {np.array(y_test).shape}')

# print(f'shape of numpy X_val: {np.array(X_val).shape}')
# # print(f'shape of numpy y_test: {np.array(y_test).shape}')


#All defaults in dataconfig
dataConfig = DataConfig(
    dtype = torch.float32,
)

dataloaderCreator = CustomDataloaderCreator(X_train, y_train,X_test,y_test,X_val,y_val,dataconfig=dataConfig)
dataloaderCreator.prepare()

dataloader = dataloaderCreator.getTrainDataloader()
validationDataloader = dataloaderCreator.getValidationDataloader()

In [None]:
print(f'dataloader length = {len(dataloader)}')
print(f'validationDataloader length = {len(validationDataloader)}')

# Iterate through the dataloader to access individual batches
for batch in dataloader:
    # Access the shape of the entire batch
    modelInputs, targets = batch
    randomSelectedTrainingPoint = np.random.randint(0,targets.shape[0])

    # Access the shape of the first data point in the batch
    print(f'First data point shape = {batch[0].shape}')
    print(f'First target shape = {batch[1].shape}')
    break