In [None]:
# ! pip install git+https://github.com/keunwoochoi/kapre.git

In [1]:
import IPython
import IPython.display as ipd
import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_addons as tfa
from perceiver import Perceiver
from tensorflow import keras
from tqdm import tqdm

from lib.float2d_to_rgb_layer import Float2DToRGB
from lib.melspectrogram_layer import MelSpectrogram
from lib.power_to_db_layer import PowerToDb
from lib.utils import float2d_to_rgb, save_keras_model
from src.config import c
from src.generator import Generator
from src.services import get_msg_provider, get_wave_provider

In [2]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, GlobalAveragePooling2D, Dense, Softmax
from kapre import STFT, Magnitude, MagnitudeToDecibel
from kapre.composed import get_melspectrogram_layer, get_log_frequency_spectrogram_layer

In [3]:
wave_p = get_wave_provider(c)

df = pd.read_pickle("/app/_work/dataset-C.pickle")

g = Generator(
    df=df,
    wave_provider=wave_p,
    batch_size=1,
    rating_as_sw=False,
    rareness_as_sw=False,
)

In [4]:
# 6 channels (!), maybe 1-sec audio signal, for an example.
input_shape = (160000, 1)
sr = 32000

In [14]:
N = 32000

In [15]:
waves = np.zeros((N, 160000), dtype=np.float16)
ys = np.zeros((N, 319), dtype=np.float16)

In [16]:
for i in tqdm(range(N)):
    x, y, sw = g.__getitem__(i)
    waves[i] = x["i_wave"][0]
    ys[i] = y[0]

100%|██████████| 32000/32000 [02:28<00:00, 215.71it/s]


In [10]:
waves = waves[..., np.newaxis]

In [11]:
waves.shape, ys.shape

((1024, 160000, 1, 1), (1024, 319))

In [18]:
get_melspectrogram_layer?

[0;31mSignature:[0m
[0mget_melspectrogram_layer[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput_shape[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_fft[0m[0;34m=[0m[0;36m2048[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwin_length[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhop_length[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwindow_name[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpad_begin[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpad_end[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msample_rate[0m[0;34m=[0m[0;36m22050[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_mels[0m[0;34m=[0m[0;36m128[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmel_f_min[0m[0;34m=[0m[0;36m0.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmel_f_max[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[

In [17]:
model = Sequential()
# A STFT layer
model.add(STFT(n_fft=2048, win_length=2018, hop_length=1024,
               window_name=None, pad_end=False,
               input_data_format='channels_last', output_data_format='channels_last',
               input_shape=input_shape))
model.add(Magnitude())
model.add(MagnitudeToDecibel())  # these three layers can be replaced with get_stft_magnitude_layer()
# Alternatively, you may want to use a melspectrogram layer

melgram_layer = get_melspectrogram_layer()

# or log-frequency layer
# log_stft_layer = get_log_frequency_spectrogram_layer() 

# add more layers as you want
model.add(Conv2D(32, (3, 3), strides=(2, 2)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(GlobalAveragePooling2D())
model.add(Dense(ys.shape[1]))
model.add(Softmax())

model.compile(
    optimizer='binary_crossentropy',
    loss="bce",
    metrics=[
        tfa.metrics.F1Score(
            num_classes=ys.shape[1],
            threshold=0.5,
            average="micro",
        ),
    ],
)

# Compile the model
# model.compile('adam', 'categorical_crossentropy') # if single-label classification

# train it with raw audio sample inputs
# for example, you may have functions that load your data as below.
# x = load_x() # e.g., x.shape = (10000, 6, 44100)
# y = load_y() # e.g., y.shape = (10000, 10) if it's 10-class classification
# then..
model.fit(x=waves, y=ys, batch_size=8, verbose=1, validation_split=0.2, epochs=10)
# Done!

Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 