In [8]:
from dataloading import ClarityAudioDataloaderSequenceAudio, ClarityAudioDataloaderSequenceSpectrograms
from pathlib import Path
import sklearn
import tensorflow as tf
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
import librosa
import tqdm
from custom_unet import custom_unet
from pathlib import Path
import os
from IPython.display import clear_output

eps = 1e-9

In [9]:
fs = 8000

spec_frame_size = 512 # odd fft ensures even rfft 
spec_frame_step = spec_frame_size // 4 #int(fs * 5e-3) // 2

channels_in = 6
lookahead_frame_size = int(6.0*fs)
lookahead_frame_step = int(np.floor(5e-3 * fs))
batch_size = 3

In [10]:
data_loader = ClarityAudioDataloaderSequenceSpectrograms(
    spec_frame_step=spec_frame_step,
    spec_frame_size=spec_frame_size,
    frame_step=lookahead_frame_step,
    frame_size=lookahead_frame_size,
    new_sample_rate=fs,
    batch_size=batch_size,
    target_length=6.0*fs,
    verbose=0,
    return_type="abs",
    n_proc=8,
    subset_size_ratio=0.1
)

x_spec, y_spec = data_loader[0]
frames = x_spec.shape[1]
bins = x_spec.shape[2]
channels = x_spec.shape[3]
print(x_spec.shape)


(360, 376, 257, 6)


In [11]:
# enq = tf.keras.utils.OrderedEnqueuer(data_loader, use_multiprocessing=True)
# enq.start(workers=1)
# gen = enq.get()
# for i in range(30):
#     print(i)
#     next(gen)
#     print(f"outer {i}")
# enq.stop()
# print('test with', workers, 'workers took', datetime.datetime.now() - start)
# print("results:", results)


In [12]:
model = custom_unet(
    input_shape=(frames, bins, 6),
    input_type="mag",
    output_channels=1,
    activation="relu",
    use_batch_norm=True,
    upsample_mode="deconv",  # 'deconv' or 'simple'
    dropout=0.3,
    dropout_change_per_layer=0.0,
    dropout_type="spatial",
    use_dropout_on_upsampling=False,
    use_attention=True,
    filters=16,
    num_layers=4,
    mag_activation="relu",
    phase_activation=None,
)

model.summary()


Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 376, 257, 6) 0                                            
__________________________________________________________________________________________________
tf.compat.v1.pad (TFOpLambda)   (None, 512, 512, 6)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 16) 864         tf.compat.v1.pad[0][0]           
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 16) 64          conv2d[0][0]                     
______________________________________________________________________________________________

In [16]:
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.MeanAbsoluteError(
    reduction="auto", name="mean_absolute_error"
)
def chunk_list(input_list, chunk_size):
    """Convert list into a list of lists."""
    return [
        input_list[chunk : chunk + chunk_size]
        for chunk in range(0, len(input_list), chunk_size)
    ]

checkpoint_filepath = Path('/home/kenders/greenhdd/clarity_challenge/pk_speech_enhancement/models')
if not checkpoint_filepath.exists():
    os.makedirs(checkpoint_filepath)

epochs = 20
sub_batch_size = 3

enq = tf.keras.utils.OrderedEnqueuer(data_loader, use_multiprocessing=False)
enq.start(workers=1)
gen = enq.get()

progbar = tf.keras.utils.Progbar(len(data_loader))

for epoch in range(epochs):
    print(f"epoch {epoch} \n")
    # model.save(checkpoint_filepath / f"{epoch}_start")
    # Iterate over the batches of the dataset.
    for batch_n in range(len(data_loader)):
        x_spec, y_spec = next(gen)
        # print(f"epoch {epoch} batch {batch_n} of {len(data_loader)}")
        # the batches are large and need to split into smaller chunks for feeding into the network
        x_spec, y_spec = sklearn.utils.shuffle(x_spec, y_spec)
        x_spec_chunked, y_spec_chunked = chunk_list(x_spec, sub_batch_size), chunk_list(y_spec, sub_batch_size)
        
        for sub_batch, (x_batch_train, y_batch_train) in enumerate(zip(x_spec_chunked, y_spec_chunked)):
            with tf.device('/gpu:0'):
                with tf.GradientTape() as tape:
                    logits = model(x_batch_train, training=True)  # Logits for this minibatch
                    # Compute the loss value for this minibatch.
                    loss_value = loss_fn(y_batch_train, logits)

                # Use the gradient tape to automatically retrieve
                # the gradients of the trainable variables with respect to the loss.
                grads = tape.gradient(loss_value, model.trainable_weights)

                # Run one step of gradient descent by updating
                # the value of the variables to minimize the loss.
                opt.apply_gradients(zip(grads, model.trainable_weights))
        progbar.add(1, values=[("loss", float(loss_value))])
    model.save(checkpoint_filepath / f"{epoch}_{float(loss_value):.4}")



epoch 0 batch 3 of 6000
   1/6000 [..............................] - ETA: 48:23:24 - loss: 0.0402

KeyboardInterrupt: 

In [15]:
model.save(checkpoint_filepath / f"{epoch}_{float(loss_value):.4}")

INFO:tensorflow:Assets written to: /home/kenders/greenhdd/clarity_challenge/pk_speech_enhancement/models/0_0.0848/assets


INFO:copy_assets_to_destination_dir: Assets written to: /home/kenders/greenhdd/clarity_challenge/pk_speech_enhancement/models/0_0.0848/assets
