In [8]:

import librosa, os
import numpy as np


from IPython.display import Audio

from sys import getsizeof
from glob import glob1
from numpy import ndarray
from keras.utils import Sequence

from utils import complex_to_polar, polar_to_complex

class PredGenerator(Sequence) : 
    def __init__(
            self, model_input_shape : tuple, src_path : str, pred_dir : str, pattern : str=".mp3",
            max_cache_size : float=2, restrict_cache=False,
            n_fft : int=1918, win_length : int=1024, 
    ) :
        if pattern[0] != "*" : pattern = "*" + pattern

        input_name_list = glob1(dirname=src_path, pattern=pattern)
        assert len(input_name_list), AssertionError("In src_path, no match with pattern [{}]".format(pattern))

        def sort_and_estimate(name_list) :
            # path, name, duration
            var_list = [[src_path + name, name, librosa.get_duration(path=src_path + name)] for name in name_list]
            var_list = sorted(var_list, key=lambda x : x[2], reverse=True)
            path_list = [var[0] for var in var_list]
            name_list = [var[1] for var in var_list]
            dur_list = [var[2] for var in var_list]
            return path_list, name_list, dur_list

        self._input_path_list, self._name_list, self._input_dur_list = sort_and_estimate(input_name_list)
        self._pred_dir = pred_dir if pred_dir[-1] == "/" else pred_dir + "/"

        if model_input_shape[0] == None : model_input_shape = model_input_shape[1:]
        
        self._sample_arr = np.zeros(model_input_shape, dtype=np.float32)
        self._sample_src = librosa.istft(np.zeros(model_input_shape[:-1], dtype=np.float32), n_fft=n_fft, win_length=win_length)

        def estimate_tot_dur(path_list, sample_src) : 
            sample_dur_list = []
            tot_dur_list = []
            length_list = []
            for path in path_list : 
                sample_rate = librosa.get_samplerate(path)
                sample_dur = librosa.get_duration(y=sample_src, sr=sample_rate)
                tot_dur = librosa.get_duration(path=path)
                quotient = int(tot_dur // sample_dur) + 1

                sample_dur_list.append(sample_dur)
                tot_dur_list.append(quotient * sample_dur)
                length_list.append(quotient)
            return sample_dur_list, tot_dur_list, length_list

        self._sample_dur_list, self._tot_dur_list, self._tot_len_list = estimate_tot_dur(self._input_path_list, self._sample_src)
        self._offset_list = np.zeros_like(self._sample_dur_list, dtype=np.float32)
        self._tot_len = np.sum(self._tot_len_list, dtype=int)
        self._src_index = 0

        self._n_fft = n_fft
        self._win_length = win_length

        self._max_cache_size = max_cache_size * (1024**3)
        self._restrict_cache = restrict_cache

        self._prevent_update_step = 3
        self._count = 0
        self.X = self._load_data_via_index(update=False)
        self._output = np.zeros_like(self.X, dtype=np.float32)

    def _estimate_cache(arr_batch) : 
        pass
        # return getsizeof() *

    def refresh(self) : 
        del self.X, self._output
        self._count = 0
        self.X = self._load_data_via_index(update=False)
        self._output = np.zeros_like(self.X, dtype=np.float32)
    
    def _resource_validation(self, arr : ndarray) : 
        if arr.shape != self._sample_arr.shape : 
            temp_arr = np.zeros_like(self._sample_arr)
            temp_arr[:,:len(arr[0])] = arr
            return temp_arr
        else : return arr

    def __len__(self):
        return self._tot_len
    
    def __getitem__(self, index) :
        return self._load_data_via_index()
    
    def _load_data(self, src_index, update=True) : 
        path = self._input_path_list[src_index]
        sample_rate = librosa.get_samplerate(path)
        sample_dur = self._sample_dur_list[src_index]
        src_offset = self._offset_list[src_index]
        
        source = librosa.load(path, sr=sample_rate, offset=src_offset, duration=sample_dur)[0]
        D = librosa.stft(source, n_fft=self._n_fft, win_length=self._win_length)
        D = complex_to_polar(D)
        D = self._resource_validation(D)
        D = np.array([D])
        
        del source

        if update and (self._count < self._prevent_update_step) : self._count += 1
        elif update : 
            self._offset_list[src_index] += sample_dur
            self._output = np.hstack((self._output, D))

        return D
    
    def _load_data_via_index(self, update=True) :
        index = self._src_index
        sample_dur = self._sample_dur_list[index]
        src_offset = self._offset_list[index]

        return self._load_data(self._src_index, update)
    
    def on_epoch_end(self) :
        index = self._src_index
        sample_rate = librosa.get_samplerate(self._input_path_list[index])
        pred_dir = self._pred_dir
        music_name = self._name_list[index] if music_name[0] != "/" else self._name_list[index][1:]

        output = polar_to_complex(self._output)
        output = librosa.istft(output, n_fft=self._n_fft, win_length=self._win_length)
        output = Audio(output, rate=sample_rate)

        if os.path.isfile(pred_dir + "Pred_" + music_name[:-4] + ".wav") : 
            raise AssertionError("File already exists in [{}]".format(pred_dir))

        with open(file=pred_dir + "Pred_" + music_name[:-4] + ".wav", mode="wb") as f : 
            f.write(output)

        self._src_index += 1
        self._src_index %= len(self._input_path_list)

        self.refresh()
    
    
    pass



In [3]:
from keras.models import load_model

model_path = "./U_Net_checkpoint/checkpoint.h5"
temp_model = load_model(model_path)
print(temp_model.input_shape, temp_model.output_shape)

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

(None, 960, 832, 2) [(None, 960, 832, 2), (None, 960, 832, 2)]


In [9]:
temp_path1 = "../Data/music/music_only/sample/"
temp_path2 = "../Data/music/music_only/pred_sample/"
temp_gen = PredGenerator(temp_model.input_shape, src_path=temp_path1, pred_dir=temp_path2)

0.0 405.21142857142854


In [10]:
print(temp_gen._input_path_list)
print(temp_gen._name_list)


['../Data/music/music_only/sample/sb_adriftamonginfinitestars.mp3']
['sb_adriftamonginfinitestars.mp3']


In [15]:
temp_gen._name_list[0][:-4]

'sb_adriftamonginfinitestars'

In [11]:
print(temp_gen.__len__())
print(temp_gen._tot_len_list)

84
[84]


In [12]:
from keras import Model

Model.predict_generator
temp_ouput = temp_model.predict(x=temp_gen)

0.0 405.21142857142854


2023-07-21 18:50:20.567769: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


0.0 405.21142857142854
0.0 405.21142857142854
 1/84 [..............................] - ETA: 54s0.0 405.21142857142854
 3/84 [>.............................] - ETA: 9s 4.8239455 405.21142857142854
 4/84 [>.............................] - ETA: 11s9.647891 405.21142857142854
 5/84 [>.............................] - ETA: 12s14.471837 405.21142857142854
 6/84 [=>............................] - ETA: 13s19.295782 405.21142857142854
 7/84 [=>............................] - ETA: 13s24.119728 405.21142857142854
 8/84 [=>............................] - ETA: 13s28.943674 405.21142857142854
 9/84 [==>...........................] - ETA: 13s33.76762 405.21142857142854
10/84 [==>...........................] - ETA: 13s38.591564 405.21142857142854
11/84 [==>...........................] - ETA: 13s43.41551 405.21142857142854
12/84 [===>..........................] - ETA: 13s48.239452 405.21142857142854
13/84 [===>..........................] - ETA: 13s53.063396 405.21142857142854
14/84 [====>...............

KeyboardInterrupt: 

In [None]:
len(temp_ouput)

2

In [None]:
print(temp_ouput[0].shape)
print(temp_ouput[1].shape)

(84, 960, 832, 2)
(84, 960, 832, 2)


In [None]:
from keras import Model

print(Model.predict.__doc__)

Generates output predictions for the input samples.

        Computation is done in batches. This method is designed for batch
        processing of large numbers of inputs. It is not intended for use inside
        of loops that iterate over your data and process small numbers of inputs
        at a time.

        For small numbers of inputs that fit in one batch,
        directly use `__call__()` for faster execution, e.g.,
        `model(x)`, or `model(x, training=False)` if you have layers such as
        `tf.keras.layers.BatchNormalization` that behave differently during
        inference. You may pair the individual model call with a `tf.function`
        for additional performance inside your inner loop.
        If you need access to numpy array values instead of tensors after your
        model call, you can use `tensor.numpy()` to get the numpy array value of
        an eager tensor.

        Also, note the fact that test loss is not affected by
        regularization layers l

In [None]:
from keras import losses

for something in dir(losses) : 
    print(something)

BCE
BinaryCrossentropy
BinaryFocalCrossentropy
CategoricalCrossentropy
CategoricalHinge
CosineSimilarity
Hinge
Huber
KLD
KLDivergence
LABEL_DTYPES_FOR_LOSSES
LogCosh
Loss
LossFunctionWrapper
MAE
MAPE
MSE
MSLE
MeanAbsoluteError
MeanAbsolutePercentageError
MeanSquaredError
MeanSquaredLogarithmicError
Poisson
SparseCategoricalCrossentropy
SquaredHinge
__builtins__
__cached__
__doc__
__file__
__loader__
__name__
__package__
__spec__
_maybe_convert_labels
_ragged_tensor_apply_loss
_ragged_tensor_binary_crossentropy
_ragged_tensor_binary_focal_crossentropy
_ragged_tensor_categorical_crossentropy
_ragged_tensor_mae
_ragged_tensor_mape
_ragged_tensor_mse
_ragged_tensor_msle
_ragged_tensor_sparse_categorical_crossentropy
abc
backend
bce
binary_crossentropy
binary_focal_crossentropy
categorical_crossentropy
categorical_hinge
cosine_similarity
deserialize
deserialize_keras_object
dispatch
doc_controls
functools
get
hinge
huber
huber_loss
is_categorical_crossentropy
keras_export
kl_divergence
kld
