In [1]:
import librosa, warnings
import numpy as np

from sys import getsizeof
from glob import glob1
from keras.utils import Sequence

from numpy import ndarray

from utils import complex_to_polar, polar_to_complex

class Train_Generator(Sequence) : 
    def __init__(self, src_path, bulk_num=5, sample_dur=5, max_cache_size=10000, n_fft=1918, win_length=1024, shuffle=False) :
        
        def sort_via_dur(input_list) : 
            input_list = sorted(input_list)
            input_list = [[path, librosa.get_duration(path=path)] for path in input_list]
            input_list = sorted(input_list, key=lambda x : x[1], reverse=True)
            return [var[0] for var in input_list]
        
        self._input_list = [src_path + name for name in glob1(dirname=src_path, pattern="merge*")]
        self._output_list1 = [src_path + name for name in glob1(dirname=src_path, pattern="voice*")]
        self._output_list2 = [src_path + name for name in glob1(dirname=src_path, pattern="music*")]
        
        num_input = len(self._input_list)
        num_output1 = len(self._output_list1)
        num_output2 = len(self._output_list2)
        if not ((num_input == num_output1) and (num_output1 == num_output2)) : 
            raise AssertionError("The number of source sample must be same. : [{}, {}, {}]".format(num_input, num_output1, num_output2))
        
        self._input_list = sort_via_dur(self._input_list)
        self._output_list1 = sort_via_dur(self._output_list1)
        self._output_list2 = sort_via_dur(self._output_list2)

        self._bulk_num = bulk_num
        self._sample_dur = sample_dur
        self._max_cache_size = max_cache_size
        self._n_fft = n_fft
        self._win_length = win_length

        self._src_index = 0
        self._duration_list = np.zeros_like(self._input_list, dtype=np.float32)
        self._src_index_list = [i for i in range(len(self._input_list))]
        if shuffle : np.random.shuffle(self.src_index_list)

        self._max_dur = self.__estimate_max_len(self._src_index)
        self.sample_x, self.sample_y1, self.sample_y2 = self.__get_sample()

        self.__cache_warning()
    
    def __get_sample(self) : 
        X, (Y1, Y2) = self.__load_data(self._src_index, update=False)
        return X, Y1, Y2

    def __src_index_list(self, src_index) : 
        return [i % len(self._input_list) for i in range(src_index, src_index + self._bulk_num)]
    
    def __resource_validation(self, X : ndarray, Y1 : ndarray, Y2 : ndarray) : 
        assert (X.shape == Y1.shape) and (Y1.shape == Y2.shape) \
            , ValueError("The shape of IO is different : [{} / {}, {}]".format(X.shape, Y1.shape, Y2.shape))

    def __load_single_data(self, src_index, path_list, update=True) -> ndarray : 
        path = path_list[src_index]
        sample_rate = librosa.get_samplerate(path)
        source_duration = librosa.get_duration(path=path) - 0.1

        if update : 
            if self._duration_list[src_index] + self._sample_dur >= source_duration : 
                source_offset = source_duration - self._sample_dur
                self._duration_list[src_index] = 0
            else : 
                source_offset = self._duration_list[src_index]
                self._duration_list[src_index] += self._sample_dur
        else : 
            source_offset = self._duration_list[src_index]
        
        data_source = librosa.load(path=path, sr=sample_rate, offset=source_offset, duration=self._sample_dur)[0]
        D = librosa.stft(data_source, n_fft=self._n_fft, win_length=self._win_length)
        del data_source
        return D
    
    def __load_data(self, src_index, update=True) : 
        src_index_list = self.__src_index_list(src_index)
        
        input_list = []
        output1_list = []
        output2_list = []

        for index in src_index_list : 
            X = self.__load_single_data(index, self._input_list, update)
            Y1 = self.__load_single_data(index, self._output_list1, update)
            Y2 = self.__load_single_data(index, self._output_list2, update)

            quotient_x = len(X[0]) // 64
            quotient_y1 = len(Y1[0]) // 64
            quotient_y2 = len(Y2[0]) // 64
            X = X[:,:quotient_x * 64]
            Y1 = Y1[:,:quotient_y1 * 64]
            Y2 = Y2[:,:quotient_y2 * 64]
            self.__resource_validation(X, Y1, Y2)

            input_list.append(X)
            output1_list.append(Y1)
            output2_list.append(Y2)

        data_input = complex_to_polar(np.array(input_list))
        data_output1 = complex_to_polar(np.array(output1_list))
        data_output2 = complex_to_polar(np.array(output2_list))

        del input_list, output1_list, output2_list

        return data_input, (data_output1, data_output2)

    def __estimate_max_len(self, src_index) : 
        src_index_list = self.__src_index_list(src_index)
        max = 0
        for index in src_index_list : 
            input_path = self._input_list[index]
            dur = librosa.get_duration(path=input_path)
            if max < dur : max = dur
        return max

    def __estimate_cache(self) : 
        current_cache = 0
        for sample in [self.sample_x, self.sample_y1, self.sample_y2] : 
            current_cache += getsizeof(sample)
        return current_cache

    def __cache_warning(self) : 
        if self._max_cache_size <= self.__estimate_cache() : 
            with warnings.catch_warnings():
                warnings.simplefilter("always")
                warnings.warn("Loaded data exceeded max_cache_size.", ResourceWarning)

    def __len__(self) :
        return int(self._max_dur // self._sample_dur)
    
    def __getitem__(self, index) :
        return self.__load_data(self._src_index)
    
    def on_epoch_end(self) :
        del self.sample_x, self.sample_y1, self.sample_y2

        self._src_index += self._bulk_num
        self._src_index %= len(self._input_list)
        self.sample_x, self.sample_y1, self.sample_y2 = self.__get_sample()
        self.__cache_warning()
    
    @property
    def cache_size(self) : 
        return self.__estimate_cache()

    @property
    def input_shape(self) : 
        return self.sample_x.shape[1:]

In [2]:
from utils import gen_dataset

gen_dataset(target_dir="./test_sample/", train_test_split=0.7)

Processing... [-] : [002/006]

[src/libmpg123/id3.c:process_comment():584] error: No comment text / valid description?
[src/libmpg123/id3.c:process_comment():584] error: No comment text / valid description?


Processing... [-] : [006/006]	  Done


In [2]:
from sys import getsizeof

def calc_mem_usage(x) : 
    mem_use = x if type(x) == int else getsizeof(x)
    unit_list = ["B", "KB", "MB", "GB", "TB"]
    count = 0
    while mem_use >= 1000 : 
        mem_use /= 1024
        count += 1
    
    if count >= len(unit_list) : 
        raise AssertionError("Memory usage is out of 1024 TB")

    else : 
        print("{:,.3f} {}ytes".format(mem_use, unit_list[count]))

In [3]:
train_path = "./test_sample/train_data/"

train_generator = Train_Generator(train_path, bulk_num=3)
print(train_generator.input_shape)
print(train_generator.sample_x.shape)

(960, 832, 2)
(3, 960, 832, 2)




In [4]:
calc_mem_usage(train_generator.cache_size)

54.844 MBytes


In [5]:
from utils import gen_separate_unet

test_model = gen_separate_unet(input_shape=train_generator.input_shape)

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



In [6]:
test_model.compile(optimizer="adam", loss="mae")
test_model.fit(x=train_generator, epochs=3)

Epoch 1/3


2023-07-20 18:00:16.260352: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


0
1


Epoch 2/3




1
0


Epoch 3/3




0
1






<keras.callbacks.History at 0x29ed08850>