In [1]:
__all__ = ["gen_separate_unet"]

from keras import Model
from keras.layers import InputLayer, Conv2D, Dropout, BatchNormalization
from keras.layers import Concatenate, Conv2DTranspose, LeakyReLU, ReLU
from keras.layers import Multiply

def gen_unet(input_shape=(960, 832, 2), encode=5) : 
    input_layer = InputLayer(input_shape=input_shape).input
    
    last_layer = input_layer
    concate_list = []
    filter_num = 16 * input_shape[2]                # set initial number of filters as 16
    
    for _ in range(encode) : 
        conv_encoder = Conv2D(filters=filter_num, kernel_size=5, strides=2, padding="same")(last_layer)
        batch_norm = BatchNormalization()(conv_encoder)
        activ_layer = LeakyReLU(alpha=0.2)(batch_norm)
        
        concate_list.insert(0, conv_encoder)
        last_layer = activ_layer
        filter_num *= 2
    
    conv_layer = Conv2D(filters=filter_num, kernel_size=5, strides=2, padding="same")(last_layer)
    activ_layer = LeakyReLU(alpha=0.2)(conv_layer)
    last_layer = activ_layer

    count = 1
    for concate_layer in concate_list : 
        filter_num /= 2

        conv_trans = Conv2DTranspose(filters=filter_num, kernel_size=5, strides=2, padding="same")(last_layer)
        merge_layer = Concatenate(axis=3)([conv_trans, concate_layer])
        conv_decoder = Conv2D(filters=filter_num, kernel_size=5, strides=1, padding="same")(merge_layer)
        batch_norm = BatchNormalization()(conv_decoder)
        activ_layer = ReLU()(batch_norm)

        if count <= 3 : 
            last_layer = Dropout(rate=0.5)(activ_layer)
        
        else : 
            last_layer = activ_layer
        
        count += 1

    conv_layer = Conv2DTranspose(filters=8 * input_shape[2], kernel_size=5, strides=2, padding="same")(last_layer)
    conv_layer = Conv2DTranspose(filters=input_shape[2], kernel_size=3, strides=1, padding="same", activation="sigmoid")(conv_layer)
    output_layer = Multiply()([input_layer, conv_layer])
    
    return Model(inputs=[input_layer], outputs=[output_layer])

In [2]:
import librosa, warnings
import numpy as np

from sys import getsizeof
from glob import glob1

from keras.utils import Sequence
from numpy import ndarray

from utils import complex_to_polar

class TrainGenerator(Sequence) : 
    def __init__(
            self, src_path : str, bulk_num : int=5, 
            sample_dur : float=5, max_cache_size : float=2, restrict_cache=False,
            n_fft : int=1918, win_length : int=1024, 
            sample_rate=None, shuffle=True
    ) :
        if src_path[-1] != "/" : src_path += "/"

        input_list = [src_path + name for name in glob1(dirname=src_path, pattern="merge*")]
        output_list = [src_path + name for name in glob1(dirname=src_path, pattern="voice*")]
        
        assert len(input_list) == len(output_list)\
            , AssertionError("The number of source sample must be same. : [{}, | {}]".format(
            len(input_list), len(output_list)
            ))
        assert len(input_list), AssertionError("In src_path, no match with pattern [merge*]")
        assert len(output_list), AssertionError("In src_path, no match with pattern [voice*]")
        
        def sort_via_dur(input_list) : 
            var_list = sorted(input_list)
            var_list = [[path, librosa.get_duration(path=path)] for path in var_list]
            var_list = sorted(var_list, key=lambda x : x[1], reverse=True)
            return [var[0] for var in var_list]
        
        self._input_path = sort_via_dur(input_list)
        self._output_path = sort_via_dur(output_list)

        self._bulk_num = bulk_num
        self._sample_rate = sample_rate
        self._sample_dur = sample_dur
        
        self._offset_list = np.zeros_like(self._input_path, dtype=np.float32)
        self._max_dur_list = [librosa.get_duration(path=path) for path in self._input_path]

        self._index_list = [i for i in range(len(self._input_path))]
        if shuffle : np.random.shuffle(self._index_list)
        
        self._n_fft = n_fft
        self._win_len = win_length

        rate = librosa.get_samplerate(path=self._input_path[0]) if not sample_rate else sample_rate
        sample_source = librosa.load(path=self._input_path[0], sr=rate, duration=sample_dur)[0]
        self._sample_arr = librosa.stft(sample_source, n_fft=n_fft, win_length=win_length)
        quotient = len(self._sample_arr[0]) // 64
        self._sample_arr = complex_to_polar(self._sample_arr[:,:quotient * 64])

        self._src_index = 0

        self._max_cache_size = max_cache_size * (1024**3)
        self._restrict_cache = restrict_cache

        self.__cache_warning()
    
    def __cache_warning(self) : 
        if self._max_cache_size <= self.__estimate_cache() : 
            if self._restrict_cache : 
                raise MemoryError("Loaded data exceeded max_cache_size.")
            else : 
                with warnings.catch_warnings():
                    warnings.simplefilter("always")
                    warnings.warn("Loaded data exceeded max_cache_size.", ResourceWarning)
    
    def __estimate_cache(self) : 
        return getsizeof(self._sample_arr) * self._bulk_num * 2

    def __resource_validation(self, arr : ndarray) : 
        shape1, shape2 = arr.shape, self._sample_arr.shape
        assert shape1 == shape2, AssertionError("The shape of IO is different : [{} | {}]".format(shape1, shape2))
    
    def __get_src_list(self, src_index) : 
        return [i % len(self._input_path) for i in range(src_index, src_index + self._bulk_num)]

    def __load_data(self, src_index, update=True) : 
        x_list = []
        y_list = []
        for index in self.__get_src_list(src_index) : 
            x_list.append(self.__load_single_data(index, self._input_path))
            y_list.append(self.__load_single_data(index, self._output_path))
            if update : self._offset_list[index] += self._sample_dur
        
        return np.array(x_list), np.array(y_list)

    def __load_single_data(self, src_index, path_list) : 
        path = path_list[src_index]
        offset = self._offset_list[src_index]
        sample_dur = self._sample_dur
        max_dur = self._max_dur_list[src_index] - 0.1
        
        if offset + sample_dur >= max_dur : 
            self._offset_list[src_index] = 0
            return self.__load_single_data(src_index, path_list)
        
        else : 
            sample_rate = self._sample_rate if self._sample_rate else librosa.get_samplerate(path) 
            source = librosa.load(path, sr=sample_rate, offset=offset, duration=sample_dur)[0]
            D = librosa.stft(source, n_fft=self._n_fft, win_length=self._win_len)
            quotient = len(D[0]) // 64
            del source
            D = complex_to_polar(D[:,:quotient * 64])
            self.__resource_validation(D)

            return D

    def __len__(self) :
        current_index = self.__get_src_list(self._src_index)
        max = 0
        for index in current_index : 
            duration = librosa.get_duration(path=self._input_path[index])
            if duration > max : max = duration
        
        return int(max // self._sample_dur + 1)
    
    def __getitem__(self, index) :
        return self.__load_data(self._src_index, update=True)
    
    def on_epoch_end(self) :
        self._src_index += self._bulk_num
        self._src_index %= len(self._input_path)

        return super().on_epoch_end()

    @property
    def input_shape(self) : 
        return self._sample_arr.shape


In [3]:
path = "./test_sample/train_data/"
temp_gen = TrainGenerator(path, bulk_num=2)

In [4]:
temp_model = gen_unet(temp_gen.input_shape)
temp_model.summary()

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 960, 832, 2  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 480, 416, 32  1632        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 480, 416, 32  128        ['conv2d[0][0]']                 
 alization)  

In [5]:
temp_model.compile(optimizer="adam", loss="mae")
temp_model.fit(temp_gen, epochs=3)

[0. 0. 0. 0. 0.]
Epoch 1/3


2023-07-21 21:34:37.904340: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


[5. 5. 0. 0. 0.]
[10. 10.  0.  0.  0.]
 1/82 [..............................] - ETA: 2:27 - loss: 0.7597[15. 15.  0.  0.  0.]
 2/82 [..............................] - ETA: 47s - loss: 0.7194 [20. 20.  0.  0.  0.]
 3/82 [>.............................] - ETA: 47s - loss: 0.7193[25. 25.  0.  0.  0.]
 4/82 [>.............................] - ETA: 46s - loss: 0.7084[30. 30.  0.  0.  0.]
 5/82 [>.............................] - ETA: 45s - loss: 0.7063[35. 35.  0.  0.  0.]
 6/82 [=>............................] - ETA: 45s - loss: 0.7122[40. 40.  0.  0.  0.]
 7/82 [=>............................] - ETA: 44s - loss: 0.7078[45. 45.  0.  0.  0.]
 8/82 [=>............................] - ETA: 43s - loss: 0.7034[50. 50.  0.  0.  0.]
 9/82 [==>...........................] - ETA: 43s - loss: 0.7010[55. 55.  0.  0.  0.]
10/82 [==>...........................] - ETA: 42s - loss: 0.6968[60. 60.  0.  0.  0.]
11/82 [===>..........................] - ETA: 42s - loss: 0.6936[65. 65.  0.  0.  0.]
12/82 [===>..

<keras.callbacks.History at 0x294ec4be0>

In [6]:
path = "./temp_checkpoint/"
temp_model.save(path + "checkpoint.h5")