# 1. Data Generator
- Raw Data를 읽어옴
- 여기서 만들어진 데이터는 모델의 입력으로 들어감

In [3]:
import os
import numpy as np
import librosa
from tensorflow.keras.utils import Sequence

In [104]:
class RawForVAEGenerator(Sequence):
    def __init__(self, source, wav_dir, files, sourNum='s1', batch_size=10, shuffle=True):
        self.source = source
        self.wav_dir = wav_dir
        self.files = files
        self.sourNum = sourNum
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()
        
        self.sample_rate = 8000
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.source))
        
        if self.shuffle:
            np.random.shuffle(self.indexes)
    
    def __audioread__(self, path, offset=0.0, duration=None, sample_rate=16000):
        signal = librosa.load(path, sr=self.sample_rate, mono=False, offset=offset, duration=duration)

        return signal[0]
    
    def __padding__(self, data):
        n_batch = len(data)
        max_len = max([d.shape[0] for d in data])
        extrapadding = int(np.ceil(max_len / self.sample_rate) * self.sample_rate)
        pad = np.zeros((n_batch, extrapadding))
        
        for i in range(n_batch):
            pad[i, :data[i].shape[0]] = data[i]
        
        return np.expand_dims(pad, -1)
        
    def __data_generation__(self, source_list):
        wav_list = []
        for name in source_list:
            name = name.strip('\n')
            
            s_wav_name = self.wav_dir + self.files + '/' + self.sourNum + '/' + name
            
            # ------- AUDIO READ -------
            s_wav = (self.__audioread__(s_wav_name,  offset=0.0, duration=None, sample_rate=self.sample_rate))
            # --------------------------
            
            # ------- PADDING -------
#             pad_len = max(len(samples1),len(samples2))
#             pad_s1 = np.concatenate([s1_wav, np.zeros([pad_len - len(s1_wav)])])
            
#             extrapadding = ceil(len(pad_s1) / sample_rate) * sample_rate - len(pad_s1)
#             pad_s1 = np.concatenate([pad_s1, np.zeros([extrapadding - len(pad_s1)])])
#             pad_s2 = np.concatenate([s2_wav, np.zeros([extrapadding - len(s2_wav)])])
            # -----------------------
            
            wav_list.append(s_wav)
        
        return wav_list, wav_list, source_list
            
    
    def __len__(self):
        return int(np.floor(len(self.source) / self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        source_list = [self.source[k] for k in indexes]
        
        if self.files is not 'tt':
            sour, labels, _ = self.__data_generation__(source_list)
            
            # Get Lengths(K value of each batch)
            lengths = np.array([m.shape[0] for m in sour])
            exp = np.expand_dims(lengths, 1)
            exp = np.expand_dims(exp, -1) # [Batch, 1, 1] (length)
            
            # Padding
            sour_pad = self.__padding__(sour) # [Batch, Time_step, Dimension(=1)]
            label_pad = self.__padding__(labels) # [Batch, Time_step, Dimension(=1)]
            
            return sour_pad, np.concatenate([label_pad, exp], axis=1)
        else:
            sour, labels, name = self.__data_generation__(source_list)
            
            # Get Lengths(K value of each batch)
            lengths = np.array([m.shape[0] for m in sour])
            exp = np.expand_dims(lengths, 1)
            exp = np.expand_dims(exp, -1) # [Batch, 1, 1] (length)
            
            # Padding
            sour_pad = self.__padding__(sour) # [Batch, Time_step, Dimension(=1)]
            
            return sour_pad, exp, name

## Data를 어떻게 읽는지에 대한 부분

In [105]:
WAV_DIR = './mycode/wsj0_2mix/use_this/'
LIST_DIR = './mycode/wsj0_2mix/use_this/lists/'

In [106]:
# Directory List file create

wav_dir = WAV_DIR
output_lst = LIST_DIR

for folder in ['tr', 'cv', 'tt']:
    wav_files = os.listdir(wav_dir + folder + '/mix')
    output_lst_files = output_lst + folder + '_wav.lst'
    with open(output_lst_files, 'w') as f:
        for file in wav_files:
            f.write(file + "\n")

print("Generate wav file to .lst done!")

Generate wav file to .lst done!


In [107]:
batch_size = 20

train_dataset = 0
valid_dataset = 0
test_dataset = 0

name_list = []
for files in ['tr', 'cv', 'tt']:
    # --- Lead lst file ---""
    output_lst_files = LIST_DIR + files + '_wav.lst'
    fid = open(output_lst_files, 'r')
    lines = fid.readlines()
    fid.close()
    # ---------------------
    
    if files == 'tr':
        train_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', batch_size)
    elif files == 'cv':
        valid_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', batch_size)
    else:
        test_batch = 1
        test_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', test_batch)

# 2. Building VQ-VAE model with Gumbel Softmax

In [4]:
import threading
from scipy.io.wavfile import write as wav_write
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from tensorflow.keras import backend as Kb
import numpy as np
import pandas as pd
from importlib import reload
import time
from tensorflow.keras.models import Model, Sequential, load_model

In [109]:
def mkdir_p(path):
    """ Creates a path recursively without throwing an error if it already exists
    :param path: path to create
    :return: None
    """
    if not os.path.exists(path):
        os.makedirs(path)

In [110]:
mkdir_p('./CKPT/') # model check point 폴더 만드는 코드
filepath = "./CKPT/CKP_ep_{epoch:d}__loss_{val_loss:.5f}_.h5"

In [111]:
initial_learning_rate = 0.001

# learning rate를 점점 줄이는 부분
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
)

# validation loss에 대해서 좋은 것만 저장됨
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True, mode='min'
)

# early stop 하는 부분인데, validation loss에 대해서 제일 좋은 모델이 저장됨
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', mode='min', verbose=1, patience=50, restore_best_weights=True
)

In [112]:
class GumbelSoftmax(layers.Layer):
    def __init__(self, temperature=0.5, hard=False, name = 'gumbel_softmax',**kwargs):
        super(GumbelSoftmax, self).__init__(name=name, **kwargs)
        
        self.temperature = temperature
        self.hard = hard
    
    def sample_gumbel(self, shape, eps=1e-20): 
        """Sample from Gumbel(0, 1)"""
        U = tf.random.uniform(shape,minval=0,maxval=1)
        
        return -tf.math.log(-tf.math.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature): 
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        
        return tf.nn.softmax(y / temperature)

    def call(self, inputs):
        y = self.gumbel_softmax_sample(inputs, self.temperature)
        
        if self.hard:
            y_hard = tf.cast(tf.equal(y, tf.math.reduce_max(y, 2, keepdims=True)), y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        
        return y


class Encoder(layers.Layer):
    def __init__(self, latent_dim, name = 'encoder',**kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        
        self.conv1d_1 = layers.Conv1D(filters=32, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_2 = layers.Conv1D(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_3 = layers.Conv1D(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_4 = layers.Conv1D(filters=256, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_5 = layers.Conv1D(filters=512, kernel_size=4, strides=2, activation='relu', padding='same')
        self.logit = layers.Conv1D(filters=latent_dim, kernel_size=1, strides=1, activation=None, padding='valid')
    
    def call(self, inputs):
        x = self.conv1d_1(inputs)
        x = self.conv1d_2(x)
        x = self.conv1d_3(x)
        x = self.conv1d_4(x)
        x = self.conv1d_5(x)
        logit = self.logit(x)
        
        return logit


class Decoder(layers.Layer):
    def __init__(self, latent_dim, name = 'decoder',**kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        
        self.trans_conv1d_1 = layers.Conv1DTranspose(filters=512, kernel_size=1, strides=1, activation='relu', padding='same')
        self.trans_conv1d_2 = layers.Conv1DTranspose(filters=256, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_3 = layers.Conv1DTranspose(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_4 = layers.Conv1DTranspose(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_5 = layers.Conv1DTranspose(filters=32, kernel_size=4, strides=2, activation='relu', padding='same')
        self.logit = layers.Conv1DTranspose(filters=1, kernel_size=4, strides=2, activation=None, padding='same')
    
    def call(self, inputs):
        x = self.trans_conv1d_1(inputs)
        x = self.trans_conv1d_2(x)
        x = self.trans_conv1d_3(x)
        x = self.trans_conv1d_4(x)
        x = self.trans_conv1d_5(x)
        logit = self.logit(x)
        
        return logit

In [113]:
# Custom Metric Si-sdr

class SiSdr(keras.metrics.Metric):
    def __init__(self, name="Si-sdr", **kwargs):
        super(SiSdr, self).__init__(name=name, **kwargs)
        self.sdr = self.add_weight(name="sdr", initializer="zeros")
        self.count = self.add_weight(name="cnt", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        ori_length = tf.shape(y_true)[1]
        
        # Label & Length divide
        labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 1]
        lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]
        
        # Check sequence length
        batch_size = tf.shape(labels)[0]
        label_size = tf.shape(labels)[1]
        pred_size = tf.shape(y_pred)[1]
        feature_size = tf.shape(labels)[-1]
        
        # Change sequence length
        if label_size < pred_size:
            y_pred = tf.slice(y_pred, [0, 0, 0], [-1, label_size, -1])
        elif label_size > pred_size:
            labels = tf.slice(labels, [0, 0, 0], [-1, pred_size, -1])

        # SI-SDR
        target = tf.linalg.matmul(y_pred, labels, transpose_a=True) * labels / tf.expand_dims(tf.experimental.numpy.square(tf.norm(labels, axis=1)), axis=-1)
        noise = y_pred - target
        values = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
        
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.sdr.assign_add(tf.reduce_sum(values))
        self.count.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))

    def result(self):
        return self.sdr / self.count

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.sdr.assign(0.0)
        self.count.assign(0.0)

In [114]:
# Custom loss

# Custom mse
def custom_mse(y_true, y_pred):
    ori_length = tf.shape(y_true)[1]

    # Label & Length divide
    labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 129]
    lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]

    loss = tf.reduce_sum(tf.pow(y_pred - labels, 2), axis=[1, 2])
    loss = tf.reduce_mean(loss)

    return loss


# Custom si-sdr loss
def custom_sisdr_loss(y_true, y_pred):
    ori_length = tf.shape(y_true)[1]

    # Label & Length divide
    labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 1]
    lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]

    target = tf.linalg.matmul(y_pred, labels, transpose_a=True) * labels / tf.expand_dims(tf.experimental.numpy.square(tf.norm(labels, axis=1)), axis=-1)
    noise = y_pred - target
    si_sdr = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
    si_sdr = tf.reduce_mean(si_sdr) * -1

    return si_sdr

In [115]:
class Vq_vae(keras.Model):
    def __init__(self, latent_dim, gumbel_hard=False, name='vqvae', **kwargs):
        super(Vq_vae, self).__init__(name=name, **kwargs)
        
        self.latent_dim = latent_dim
        self.softmax = layers.Softmax(-1)
        
        self.encoder = Encoder(latent_dim)
        self.embeddings = layers.Embedding(latent_dim, latent_dim)
        self.decoder = Decoder(latent_dim)
        self.gumbel = GumbelSoftmax(hard=gumbel_hard)
        
    def call(self, inputs, load=False):
        if load:
            inputs = layers.Input(shape=(None, 1))
        
        
        encode = self.encoder(inputs)
        gumbel = self.gumbel(encode)
        decode = self.decoder(gumbel)
        
        # ------------------ KL loss ------------------
        qy = self.softmax(encode)
        log_qy = tf.math.log(qy + 1e-10)
        log_uniform = qy * (log_qy - tf.math.log(1.0 / self.latent_dim))
        kl_loss = tf.reduce_sum(log_uniform, axis=[1, 2])
        kl_loss = tf.reduce_mean(kl_loss) * 0.2
        # ---------------------------------------------
        
        self.add_loss(kl_loss)
        
        return decode

# 이렇게 GradientTape 를 사용해서 프로그램 해도 됨

In [116]:
tf.random.set_seed(42)

latent_size = 1024
epochs = 600

filePath = "./CKPT/CKP_ep_{0}__loss_{1:.5f}_.h5"
model_path = './CKPT/CKP_ep_576__loss_154.33980_.h5'

loss_fun = custom_mse
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

train_loss = tf.keras.metrics.Mean()
train_kl_loss = tf.keras.metrics.Mean()
valid_loss = tf.keras.metrics.Mean()
sisdr_Metric = SiSdr()
val_sisdr_Metric = SiSdr()

# Model 불러오는 부분이다
vq_vae = Vq_vae(latent_size, gumbel_hard=False)
vq_vae(0, True)
# vq_vae.load_weights(model_path)

<KerasTensor: shape=(None, None, 1) dtype=float32 (created by layer 'decoder')>

In [117]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        # Call model
        results = vq_vae(x)
        
        loss_value = loss_fun(y, results)
        loss_value += sum(vq_vae.losses) # Add KL loss
    
    # Update weights
    grads = tape.gradient(loss_value, vq_vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vq_vae.trainable_weights))
    
    # Update loss and si-sdr
    train_loss.update_state(loss_value)
    sisdr_Metric.update_state(y, results)
    
    train_kl_loss.update_state(sum(vq_vae.losses))
    
    return loss_value

@tf.function
def test_step(x, y):
    # Call model
    val_results = vq_vae(x)
    
    val_loss_value = loss_fun(y, val_results)
    val_loss_value += sum(vq_vae.losses) # Add KL loss
    
    # Update loss and si-sdr
    valid_loss.update_state(val_loss_value)
    val_sisdr_Metric.update_state(y, val_results)
    
    return val_loss_value

In [118]:
previous_loss = float('inf')

with tf.device('/cpu:0'):
    for epoch in range(epochs):
        print("\nStart of epoch %d" % (epoch+1,))
        start_time = time.time()

        # Iterate over the batches of the dataset
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            x_batch_train = tf.cast(x_batch_train, dtype=tf.float32)
            y_batch_train = tf.cast(y_batch_train, dtype=tf.float32)

            loss_value = train_step(x_batch_train, y_batch_train)

            # Log every 1 batches
    #         if step % 1 == 0:
    #             print("Training loss (for one batch) at step %d: %.4f" % (step, train_loss.result()))
    #             print("Training Si-sdr (for one batch) at step %d: %.4f" % (step, sisdr_Metric.result()))
    #             print("Seen so far: %d samples" % ((step + 1) * batch_size))

        # Run a validation loop at the end of each epoch
        for x_batch_val, y_batch_val in valid_dataset:
            x_batch_val = tf.cast(x_batch_val, dtype=tf.float32)
            y_batch_val = tf.cast(y_batch_val, dtype=tf.float32)

            val_loss_value = test_step(x_batch_val, y_batch_val)

        print()
        print('----------------------------------------------------------------------------------')
        print("Time taken >>> %.2fs <<<" % (time.time() - start_time))
        print('epoch: {}, Train_loss: {}, Train_Si-sdr: {}, Train_KL_loss: {} \n\
        Valid_loss: {}, Valid_Si-sdr: {}'.format(
            epoch+1,
            train_loss.result(),
            sisdr_Metric.result(),
            train_kl_loss.result(),
            valid_loss.result(),
            val_sisdr_Metric.result()))
        print('----------------------------------------------------------------------------------')

        # Save Model
        if valid_loss.result() < previous_loss:
            filePath_temp = filePath.format(epoch+1, valid_loss.result())

            vq_vae.save_weights(filePath_temp)
            print('Epoch {}: val_loss improved from {} to {}, saving model to {}'.format(
                epoch+1,
                previous_loss,
                valid_loss.result(),
                filePath_temp))

            previous_loss = valid_loss.result()
        else:
            print('Epoch {}: val_loss did not improve from {}'.format(
                epoch+1,
                previous_loss))
        print()

        # Reset metrics at the end of each epoch
        train_loss.reset_states()
        sisdr_Metric.reset_states()
        valid_loss.reset_states()
        val_sisdr_Metric.reset_states()

        train_kl_loss.reset_states()

        # Data shuffle at the end of each epoch
        train_dataset.on_epoch_end()
        valid_dataset.on_epoch_end()


Start of epoch 1

----------------------------------------------------------------------------------
Time taken >>> 445.02s <<<
epoch: 1, Train_loss: 52.05436325073242, Train_Si-sdr: -53.79879379272461, Train_KL_loss: 3.934980486519635e-05 
        Valid_loss: 46.281185150146484, Valid_Si-sdr: -62.87128829956055
----------------------------------------------------------------------------------
Epoch 1: val_loss improved from inf to 46.281185150146484, saving model to ./CKPT/CKP_ep_1__loss_46.28119_.h5


Start of epoch 2

----------------------------------------------------------------------------------
Time taken >>> 449.51s <<<
epoch: 2, Train_loss: 52.04542541503906, Train_Si-sdr: -57.11431884765625, Train_KL_loss: 4.152093970333226e-05 
        Valid_loss: 46.36128616333008, Valid_Si-sdr: -67.22840118408203
----------------------------------------------------------------------------------
Epoch 2: val_loss did not improve from 46.281185150146484


Start of epoch 3

----------------


----------------------------------------------------------------------------------
Time taken >>> 447.62s <<<
epoch: 16, Train_loss: 52.027645111083984, Train_Si-sdr: -58.0009651184082, Train_KL_loss: 3.2182146242121235e-05 
        Valid_loss: 46.28401184082031, Valid_Si-sdr: -67.84493255615234
----------------------------------------------------------------------------------
Epoch 16: val_loss did not improve from 45.8974609375


Start of epoch 17

----------------------------------------------------------------------------------
Time taken >>> 444.19s <<<
epoch: 17, Train_loss: 51.95405578613281, Train_Si-sdr: -58.16285705566406, Train_KL_loss: 3.1641167879570276e-05 
        Valid_loss: 46.10024642944336, Valid_Si-sdr: -68.47142791748047
----------------------------------------------------------------------------------
Epoch 17: val_loss did not improve from 45.8974609375


Start of epoch 18

----------------------------------------------------------------------------------
Time t


----------------------------------------------------------------------------------
Time taken >>> 447.45s <<<
epoch: 35, Train_loss: 52.01133728027344, Train_Si-sdr: -58.18434143066406, Train_KL_loss: 3.724927228176966e-05 
        Valid_loss: 46.39446258544922, Valid_Si-sdr: -67.89771270751953
----------------------------------------------------------------------------------
Epoch 35: val_loss did not improve from 45.8974609375


Start of epoch 36

----------------------------------------------------------------------------------
Time taken >>> 442.65s <<<
epoch: 36, Train_loss: 52.03670883178711, Train_Si-sdr: -58.275001525878906, Train_KL_loss: 4.359529339126311e-05 
        Valid_loss: 46.25807189941406, Valid_Si-sdr: -67.94987487792969
----------------------------------------------------------------------------------
Epoch 36: val_loss did not improve from 45.8974609375


Start of epoch 37

----------------------------------------------------------------------------------
Time ta


----------------------------------------------------------------------------------
Time taken >>> 440.31s <<<
epoch: 53, Train_loss: 51.979034423828125, Train_Si-sdr: -57.99827194213867, Train_KL_loss: 5.172163582756184e-06 
        Valid_loss: 46.27432632446289, Valid_Si-sdr: -67.8656005859375
----------------------------------------------------------------------------------
Epoch 53: val_loss did not improve from 45.645652770996094


Start of epoch 54

----------------------------------------------------------------------------------
Time taken >>> 442.52s <<<
epoch: 54, Train_loss: 52.04646301269531, Train_Si-sdr: -58.02843475341797, Train_KL_loss: 2.9078248189762235e-05 
        Valid_loss: 46.41225051879883, Valid_Si-sdr: -68.13856506347656
----------------------------------------------------------------------------------
Epoch 54: val_loss did not improve from 45.645652770996094


Start of epoch 55

--------------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 442.28s <<<
epoch: 71, Train_loss: 52.03144454956055, Train_Si-sdr: -58.058204650878906, Train_KL_loss: 3.9938946429174393e-05 
        Valid_loss: 46.344242095947266, Valid_Si-sdr: -67.82205963134766
----------------------------------------------------------------------------------
Epoch 71: val_loss did not improve from 45.645652770996094


Start of epoch 72

----------------------------------------------------------------------------------
Time taken >>> 442.71s <<<
epoch: 72, Train_loss: 52.03883361816406, Train_Si-sdr: -58.088321685791016, Train_KL_loss: 6.711843161610886e-05 
        Valid_loss: 46.604393005371094, Valid_Si-sdr: -68.04227447509766
----------------------------------------------------------------------------------
Epoch 72: val_loss did not improve from 45.645652770996094


Start of epoch 73

----------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 438.82s <<<
epoch: 89, Train_loss: 52.03865432739258, Train_Si-sdr: -57.99946975708008, Train_KL_loss: 5.7402434322284535e-05 
        Valid_loss: 46.45106506347656, Valid_Si-sdr: -68.03071594238281
----------------------------------------------------------------------------------
Epoch 89: val_loss did not improve from 45.54467010498047


Start of epoch 90

----------------------------------------------------------------------------------
Time taken >>> 438.19s <<<
epoch: 90, Train_loss: 52.02333068847656, Train_Si-sdr: -57.986392974853516, Train_KL_loss: 2.6425848773214966e-05 
        Valid_loss: 46.248470306396484, Valid_Si-sdr: -68.1055908203125
----------------------------------------------------------------------------------
Epoch 90: val_loss did not improve from 45.54467010498047


Start of epoch 91

--------------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 446.03s <<<
epoch: 107, Train_loss: 52.03288269042969, Train_Si-sdr: -58.072139739990234, Train_KL_loss: 3.6446443118620664e-05 
        Valid_loss: 46.35171890258789, Valid_Si-sdr: -67.94539642333984
----------------------------------------------------------------------------------
Epoch 107: val_loss did not improve from 45.54467010498047


Start of epoch 108

----------------------------------------------------------------------------------
Time taken >>> 444.65s <<<
epoch: 108, Train_loss: 52.04637145996094, Train_Si-sdr: -58.07156753540039, Train_KL_loss: 0.00010220426338491961 
        Valid_loss: 46.460105895996094, Valid_Si-sdr: -67.9457015991211
----------------------------------------------------------------------------------
Epoch 108: val_loss did not improve from 45.54467010498047


Start of epoch 109

--------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 446.18s <<<
epoch: 125, Train_loss: 52.0529670715332, Train_Si-sdr: -58.0317268371582, Train_KL_loss: 2.2677404558635317e-05 
        Valid_loss: 46.109195709228516, Valid_Si-sdr: -68.10261535644531
----------------------------------------------------------------------------------
Epoch 125: val_loss did not improve from 45.54467010498047


Start of epoch 126

----------------------------------------------------------------------------------
Time taken >>> 451.50s <<<
epoch: 126, Train_loss: 52.00446701049805, Train_Si-sdr: -58.11899185180664, Train_KL_loss: -0.0003559110627975315 
        Valid_loss: 46.531429290771484, Valid_Si-sdr: -67.88389587402344
----------------------------------------------------------------------------------
Epoch 126: val_loss did not improve from 45.54467010498047


Start of epoch 127

---------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 448.32s <<<
epoch: 143, Train_loss: 51.96795654296875, Train_Si-sdr: -58.02313995361328, Train_KL_loss: 8.15163875813596e-05 
        Valid_loss: 46.59821319580078, Valid_Si-sdr: -67.9674301147461
----------------------------------------------------------------------------------
Epoch 143: val_loss did not improve from 45.54467010498047


Start of epoch 144

----------------------------------------------------------------------------------
Time taken >>> 452.87s <<<
epoch: 144, Train_loss: 52.0551643371582, Train_Si-sdr: -58.036136627197266, Train_KL_loss: 6.604811642318964e-05 
        Valid_loss: 46.30998229980469, Valid_Si-sdr: -68.03687286376953
----------------------------------------------------------------------------------
Epoch 144: val_loss did not improve from 45.54467010498047


Start of epoch 145

-------------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 461.02s <<<
epoch: 161, Train_loss: 51.97719192504883, Train_Si-sdr: -58.09435272216797, Train_KL_loss: 3.985613875556737e-05 
        Valid_loss: 46.39007568359375, Valid_Si-sdr: -67.98092651367188
----------------------------------------------------------------------------------
Epoch 161: val_loss did not improve from 45.54467010498047


Start of epoch 162

----------------------------------------------------------------------------------
Time taken >>> 457.88s <<<
epoch: 162, Train_loss: 52.01642990112305, Train_Si-sdr: -58.054019927978516, Train_KL_loss: 4.792742038262077e-05 
        Valid_loss: 46.56180953979492, Valid_Si-sdr: -67.9769287109375
----------------------------------------------------------------------------------
Epoch 162: val_loss did not improve from 45.54467010498047


Start of epoch 163

-----------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 469.75s <<<
epoch: 179, Train_loss: 52.03944396972656, Train_Si-sdr: -57.982112884521484, Train_KL_loss: 5.0406044465489686e-05 
        Valid_loss: 46.689029693603516, Valid_Si-sdr: -68.07566833496094
----------------------------------------------------------------------------------
Epoch 179: val_loss did not improve from 45.54467010498047


Start of epoch 180

----------------------------------------------------------------------------------
Time taken >>> 468.48s <<<
epoch: 180, Train_loss: 52.040069580078125, Train_Si-sdr: -57.995750427246094, Train_KL_loss: 5.1509065087884665e-05 
        Valid_loss: 46.08577346801758, Valid_Si-sdr: -67.79029846191406
----------------------------------------------------------------------------------
Epoch 180: val_loss did not improve from 45.54467010498047


Start of epoch 181

-----------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 449.59s <<<
epoch: 197, Train_loss: 52.07259750366211, Train_Si-sdr: -58.03911209106445, Train_KL_loss: 5.6695775128901005e-05 
        Valid_loss: 46.436824798583984, Valid_Si-sdr: -68.11444854736328
----------------------------------------------------------------------------------
Epoch 197: val_loss did not improve from 45.54467010498047


Start of epoch 198

----------------------------------------------------------------------------------
Time taken >>> 445.70s <<<
epoch: 198, Train_loss: 52.04338455200195, Train_Si-sdr: -58.018062591552734, Train_KL_loss: 3.744601781363599e-05 
        Valid_loss: 46.30893325805664, Valid_Si-sdr: -68.11286163330078
----------------------------------------------------------------------------------
Epoch 198: val_loss did not improve from 45.54467010498047


Start of epoch 199

--------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 449.89s <<<
epoch: 215, Train_loss: 52.0468864440918, Train_Si-sdr: -58.06999588012695, Train_KL_loss: 4.344479020801373e-05 
        Valid_loss: 46.39200210571289, Valid_Si-sdr: -68.00041961669922
----------------------------------------------------------------------------------
Epoch 215: val_loss did not improve from 45.54467010498047


Start of epoch 216

----------------------------------------------------------------------------------
Time taken >>> 452.28s <<<
epoch: 216, Train_loss: 52.01406478881836, Train_Si-sdr: -58.06248474121094, Train_KL_loss: 5.7034692872548476e-05 
        Valid_loss: 46.38555908203125, Valid_Si-sdr: -67.997802734375
----------------------------------------------------------------------------------
Epoch 216: val_loss did not improve from 45.54467010498047


Start of epoch 217

-------------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 449.58s <<<
epoch: 233, Train_loss: 52.04537582397461, Train_Si-sdr: -58.058387756347656, Train_KL_loss: 7.408401143038645e-05 
        Valid_loss: 46.18297576904297, Valid_Si-sdr: -67.99649047851562
----------------------------------------------------------------------------------
Epoch 233: val_loss did not improve from 45.54467010498047


Start of epoch 234

----------------------------------------------------------------------------------
Time taken >>> 457.02s <<<
epoch: 234, Train_loss: 52.04290771484375, Train_Si-sdr: -58.145118713378906, Train_KL_loss: 5.478636012412608e-05 
        Valid_loss: 46.540504455566406, Valid_Si-sdr: -67.967529296875
----------------------------------------------------------------------------------
Epoch 234: val_loss did not improve from 45.54467010498047


Start of epoch 235

----------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 445.95s <<<
epoch: 251, Train_loss: 51.99032974243164, Train_Si-sdr: -57.98505783081055, Train_KL_loss: 6.59824872855097e-05 
        Valid_loss: 46.406044006347656, Valid_Si-sdr: -68.14048767089844
----------------------------------------------------------------------------------
Epoch 251: val_loss did not improve from 45.54467010498047


Start of epoch 252

----------------------------------------------------------------------------------
Time taken >>> 450.45s <<<
epoch: 252, Train_loss: 52.008670806884766, Train_Si-sdr: -58.08256912231445, Train_KL_loss: 5.1738694310188293e-05 
        Valid_loss: 46.367679595947266, Valid_Si-sdr: -67.82405090332031
----------------------------------------------------------------------------------
Epoch 252: val_loss did not improve from 45.54467010498047


Start of epoch 253

--------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 455.48s <<<
epoch: 269, Train_loss: 52.045772552490234, Train_Si-sdr: -58.03475570678711, Train_KL_loss: -7.5348107202444226e-06 
        Valid_loss: 46.183738708496094, Valid_Si-sdr: -68.23616790771484
----------------------------------------------------------------------------------
Epoch 269: val_loss did not improve from 45.54467010498047


Start of epoch 270

----------------------------------------------------------------------------------
Time taken >>> 453.14s <<<
epoch: 270, Train_loss: 52.03837966918945, Train_Si-sdr: -58.08070373535156, Train_KL_loss: 3.9687627577222884e-05 
        Valid_loss: 46.31733703613281, Valid_Si-sdr: -67.89303588867188
----------------------------------------------------------------------------------
Epoch 270: val_loss did not improve from 45.54467010498047


Start of epoch 271

------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 450.29s <<<
epoch: 287, Train_loss: 52.0472297668457, Train_Si-sdr: -58.06001663208008, Train_KL_loss: 2.1503779862541705e-05 
        Valid_loss: 46.03738021850586, Valid_Si-sdr: -68.02552795410156
----------------------------------------------------------------------------------
Epoch 287: val_loss did not improve from 45.54467010498047


Start of epoch 288

----------------------------------------------------------------------------------
Time taken >>> 446.82s <<<
epoch: 288, Train_loss: 52.05072021484375, Train_Si-sdr: -58.01803970336914, Train_KL_loss: 3.049212864425499e-05 
        Valid_loss: 46.17576599121094, Valid_Si-sdr: -68.07018280029297
----------------------------------------------------------------------------------
Epoch 288: val_loss did not improve from 45.54467010498047


Start of epoch 289

-----------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 438.82s <<<
epoch: 305, Train_loss: 52.023563385009766, Train_Si-sdr: -58.000465393066406, Train_KL_loss: 5.5928398069227114e-05 
        Valid_loss: 46.617794036865234, Valid_Si-sdr: -67.734130859375
----------------------------------------------------------------------------------
Epoch 305: val_loss did not improve from 45.54467010498047


Start of epoch 306

----------------------------------------------------------------------------------
Time taken >>> 450.10s <<<
epoch: 306, Train_loss: 52.05722427368164, Train_Si-sdr: -58.07223129272461, Train_KL_loss: 2.237044827779755e-05 
        Valid_loss: 46.52487564086914, Valid_Si-sdr: -68.02803039550781
----------------------------------------------------------------------------------
Epoch 306: val_loss did not improve from 45.54467010498047


Start of epoch 307

---------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 453.31s <<<
epoch: 323, Train_loss: 52.053707122802734, Train_Si-sdr: -58.0736083984375, Train_KL_loss: 8.440464625891764e-06 
        Valid_loss: 46.33705139160156, Valid_Si-sdr: -68.12684631347656
----------------------------------------------------------------------------------
Epoch 323: val_loss did not improve from 45.54467010498047


Start of epoch 324

----------------------------------------------------------------------------------
Time taken >>> 449.83s <<<
epoch: 324, Train_loss: 52.061275482177734, Train_Si-sdr: -58.05030822753906, Train_KL_loss: 0.00010609376477077603 
        Valid_loss: 45.91270065307617, Valid_Si-sdr: -67.98577880859375
----------------------------------------------------------------------------------
Epoch 324: val_loss did not improve from 45.54467010498047


Start of epoch 325

---------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 449.05s <<<
epoch: 341, Train_loss: 52.02431106567383, Train_Si-sdr: -58.063358306884766, Train_KL_loss: 4.7287598135881126e-05 
        Valid_loss: 46.52494430541992, Valid_Si-sdr: -67.83160400390625
----------------------------------------------------------------------------------
Epoch 341: val_loss did not improve from 45.54467010498047


Start of epoch 342

----------------------------------------------------------------------------------
Time taken >>> 448.49s <<<
epoch: 342, Train_loss: 51.9880485534668, Train_Si-sdr: -58.064762115478516, Train_KL_loss: 1.375209558318602e-05 
        Valid_loss: 45.975379943847656, Valid_Si-sdr: -67.8877182006836
----------------------------------------------------------------------------------
Epoch 342: val_loss did not improve from 45.54467010498047


Start of epoch 343

---------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 443.81s <<<
epoch: 359, Train_loss: 52.073394775390625, Train_Si-sdr: -57.98613357543945, Train_KL_loss: 1.4459808880928904e-05 
        Valid_loss: 46.45198440551758, Valid_Si-sdr: -67.84223937988281
----------------------------------------------------------------------------------
Epoch 359: val_loss did not improve from 45.54467010498047


Start of epoch 360

----------------------------------------------------------------------------------
Time taken >>> 448.51s <<<
epoch: 360, Train_loss: 52.07356643676758, Train_Si-sdr: -57.999290466308594, Train_KL_loss: 1.0337548701500054e-05 
        Valid_loss: 46.570556640625, Valid_Si-sdr: -68.09197235107422
----------------------------------------------------------------------------------
Epoch 360: val_loss did not improve from 45.54467010498047


Start of epoch 361

---------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 450.18s <<<
epoch: 377, Train_loss: 52.01921463012695, Train_Si-sdr: -58.04930114746094, Train_KL_loss: 8.630430784251075e-06 
        Valid_loss: 46.09756088256836, Valid_Si-sdr: -67.93277740478516
----------------------------------------------------------------------------------
Epoch 377: val_loss did not improve from 45.54467010498047


Start of epoch 378

----------------------------------------------------------------------------------
Time taken >>> 447.60s <<<
epoch: 378, Train_loss: 52.0498046875, Train_Si-sdr: -58.060733795166016, Train_KL_loss: 3.0194962164387107e-05 
        Valid_loss: 46.41958236694336, Valid_Si-sdr: -67.9539566040039
----------------------------------------------------------------------------------
Epoch 378: val_loss did not improve from 45.54467010498047


Start of epoch 379

--------------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 444.16s <<<
epoch: 395, Train_loss: 52.04122543334961, Train_Si-sdr: -58.03474044799805, Train_KL_loss: 2.666846376087051e-05 
        Valid_loss: 46.478004455566406, Valid_Si-sdr: -68.07744598388672
----------------------------------------------------------------------------------
Epoch 395: val_loss did not improve from 45.54467010498047


Start of epoch 396

----------------------------------------------------------------------------------
Time taken >>> 444.90s <<<
epoch: 396, Train_loss: 52.00115966796875, Train_Si-sdr: -58.01897430419922, Train_KL_loss: 7.268571062013507e-05 
        Valid_loss: 46.45310592651367, Valid_Si-sdr: -67.99752807617188
----------------------------------------------------------------------------------
Epoch 396: val_loss did not improve from 45.54467010498047


Start of epoch 397

----------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 445.90s <<<
epoch: 413, Train_loss: 51.97610092163086, Train_Si-sdr: -58.007137298583984, Train_KL_loss: 5.816709017381072e-05 
        Valid_loss: 46.414894104003906, Valid_Si-sdr: -67.70803833007812
----------------------------------------------------------------------------------
Epoch 413: val_loss did not improve from 45.54467010498047


Start of epoch 414

----------------------------------------------------------------------------------
Time taken >>> 449.99s <<<
epoch: 414, Train_loss: 52.076080322265625, Train_Si-sdr: -58.04161071777344, Train_KL_loss: 5.540974598261528e-05 
        Valid_loss: 46.31205368041992, Valid_Si-sdr: -68.10523223876953
----------------------------------------------------------------------------------
Epoch 414: val_loss did not improve from 45.54467010498047


Start of epoch 415

--------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 448.00s <<<
epoch: 431, Train_loss: 52.067413330078125, Train_Si-sdr: -58.02652359008789, Train_KL_loss: 3.913342152372934e-05 
        Valid_loss: 46.527427673339844, Valid_Si-sdr: -68.06779479980469
----------------------------------------------------------------------------------
Epoch 431: val_loss did not improve from 45.54467010498047


Start of epoch 432

----------------------------------------------------------------------------------
Time taken >>> 449.64s <<<
epoch: 432, Train_loss: 52.02119064331055, Train_Si-sdr: -58.06525421142578, Train_KL_loss: 1.5521878594881855e-05 
        Valid_loss: 46.66925048828125, Valid_Si-sdr: -67.78691101074219
----------------------------------------------------------------------------------
Epoch 432: val_loss did not improve from 45.54467010498047


Start of epoch 433

--------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 451.93s <<<
epoch: 449, Train_loss: 52.057891845703125, Train_Si-sdr: -58.055419921875, Train_KL_loss: 5.4671905672876164e-05 
        Valid_loss: 46.46221160888672, Valid_Si-sdr: -67.97576141357422
----------------------------------------------------------------------------------
Epoch 449: val_loss did not improve from 45.54467010498047


Start of epoch 450

----------------------------------------------------------------------------------
Time taken >>> 445.87s <<<
epoch: 450, Train_loss: 51.98521423339844, Train_Si-sdr: -58.011573791503906, Train_KL_loss: 2.2217625883058645e-05 
        Valid_loss: 46.36651611328125, Valid_Si-sdr: -67.94044494628906
----------------------------------------------------------------------------------
Epoch 450: val_loss did not improve from 45.54467010498047


Start of epoch 451

---------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 447.83s <<<
epoch: 467, Train_loss: 51.91593933105469, Train_Si-sdr: -58.016944885253906, Train_KL_loss: -1.0392021749794367e-06 
        Valid_loss: 46.33205032348633, Valid_Si-sdr: -67.9543685913086
----------------------------------------------------------------------------------
Epoch 467: val_loss did not improve from 45.54467010498047


Start of epoch 468

----------------------------------------------------------------------------------
Time taken >>> 448.94s <<<
epoch: 468, Train_loss: 51.99626159667969, Train_Si-sdr: -58.09265899658203, Train_KL_loss: 3.6551478842739016e-05 
        Valid_loss: 46.25740051269531, Valid_Si-sdr: -67.90290832519531
----------------------------------------------------------------------------------
Epoch 468: val_loss did not improve from 45.54467010498047


Start of epoch 469

--------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 445.78s <<<
epoch: 485, Train_loss: 52.0442008972168, Train_Si-sdr: -58.030521392822266, Train_KL_loss: 5.405176489148289e-05 
        Valid_loss: 46.51171875, Valid_Si-sdr: -68.17035675048828
----------------------------------------------------------------------------------
Epoch 485: val_loss did not improve from 45.54467010498047


Start of epoch 486

----------------------------------------------------------------------------------
Time taken >>> 444.88s <<<
epoch: 486, Train_loss: 52.00860595703125, Train_Si-sdr: -58.04813003540039, Train_KL_loss: 3.699549051816575e-05 
        Valid_loss: 46.35709762573242, Valid_Si-sdr: -67.86611938476562
----------------------------------------------------------------------------------
Epoch 486: val_loss did not improve from 45.54467010498047


Start of epoch 487

----------------------------------------------------------------------------------



----------------------------------------------------------------------------------
Time taken >>> 446.74s <<<
epoch: 503, Train_loss: 52.04266357421875, Train_Si-sdr: -58.00826644897461, Train_KL_loss: 3.452334203757346e-05 
        Valid_loss: 46.17028045654297, Valid_Si-sdr: -68.02401733398438
----------------------------------------------------------------------------------
Epoch 503: val_loss did not improve from 45.54467010498047


Start of epoch 504

----------------------------------------------------------------------------------
Time taken >>> 449.54s <<<
epoch: 504, Train_loss: 52.00831604003906, Train_Si-sdr: -58.09680938720703, Train_KL_loss: 3.657916749943979e-05 
        Valid_loss: 46.530555725097656, Valid_Si-sdr: -67.89299011230469
----------------------------------------------------------------------------------
Epoch 504: val_loss did not improve from 45.54467010498047


Start of epoch 505

----------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 446.78s <<<
epoch: 521, Train_loss: 52.02486801147461, Train_Si-sdr: -58.01689147949219, Train_KL_loss: 6.485896301455796e-05 
        Valid_loss: 46.364051818847656, Valid_Si-sdr: -67.9418716430664
----------------------------------------------------------------------------------
Epoch 521: val_loss did not improve from 45.54467010498047


Start of epoch 522

----------------------------------------------------------------------------------
Time taken >>> 443.36s <<<
epoch: 522, Train_loss: 52.06379699707031, Train_Si-sdr: -57.983760833740234, Train_KL_loss: 3.56966738763731e-05 
        Valid_loss: 46.588096618652344, Valid_Si-sdr: -68.04586791992188
----------------------------------------------------------------------------------
Epoch 522: val_loss did not improve from 45.54467010498047


Start of epoch 523

----------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 443.89s <<<
epoch: 539, Train_loss: 51.9930305480957, Train_Si-sdr: -57.99472427368164, Train_KL_loss: 4.5364187826635316e-05 
        Valid_loss: 46.36756896972656, Valid_Si-sdr: -68.0090560913086
----------------------------------------------------------------------------------
Epoch 539: val_loss did not improve from 45.54467010498047


Start of epoch 540

----------------------------------------------------------------------------------
Time taken >>> 448.19s <<<
epoch: 540, Train_loss: 52.04765319824219, Train_Si-sdr: -58.035362243652344, Train_KL_loss: 6.86115090502426e-05 
        Valid_loss: 46.02019500732422, Valid_Si-sdr: -68.00927734375
----------------------------------------------------------------------------------
Epoch 540: val_loss did not improve from 45.54467010498047


Start of epoch 541

---------------------------------------------------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 449.75s <<<
epoch: 557, Train_loss: 52.054195404052734, Train_Si-sdr: -58.04725646972656, Train_KL_loss: 3.519913298077881e-05 
        Valid_loss: 46.25051498413086, Valid_Si-sdr: -67.97972106933594
----------------------------------------------------------------------------------
Epoch 557: val_loss did not improve from 45.54467010498047


Start of epoch 558

----------------------------------------------------------------------------------
Time taken >>> 453.16s <<<
epoch: 558, Train_loss: 52.03129577636719, Train_Si-sdr: -58.098812103271484, Train_KL_loss: 5.0155376811744645e-05 
        Valid_loss: 46.36841583251953, Valid_Si-sdr: -68.01551818847656
----------------------------------------------------------------------------------
Epoch 558: val_loss did not improve from 45.54467010498047


Start of epoch 559

--------------------------------------------------------------------------

KeyboardInterrupt: 

# 여기는 기존의 .fit() 함수를 사용해서 학습하는 부분임

In [140]:
tf.random.set_seed(42)

latent_size = 1024
epoch = 300

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    model_path = './CKPT/CKP_ep_252__loss_45.39530_.h5'
    
    loss_fun = custom_mse
#     loss_fun = custom_sisdr_loss
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)

    optimizer = keras.optimizers.Adam(learning_rate=1e-4)
    vq_vae.compile(optimizer, loss=loss_fun, metrics=[SiSdr()])
    
    vq_vae(0, True)
    vq_vae.summary()
    
    # 사용 안할 때는 load_model 주석 처리 하자
    vq_vae.load_weights(model_path)
    # ----------------------------------------
    
    tf.executing_eagerly()

history = vq_vae.fit(
    train_dataset,
    epochs=epoch,
    validation_data=valid_dataset,
    shuffle=True,
    callbacks=[checkpoint_cb],
)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_30 (Softmax)         (None, None, 1024)        0         
_________________________________________________________________
encoder (Encoder)            (None, None, 1024)        739232    
_________________________________________________________________
embedding_30 (Embedding)     multiple                  0 (unused)
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           1524641   
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 1024)        0         
Total params: 2,263,873
Trainable params: 2,263,873
Non-trainable params: 0
_________________________________________________________________
Epoch 


Epoch 00036: val_loss did not improve from 50.88303
Epoch 37/300

Epoch 00037: val_loss did not improve from 50.88303
Epoch 38/300

Epoch 00038: val_loss did not improve from 50.88303
Epoch 39/300

Epoch 00039: val_loss did not improve from 50.88303
Epoch 40/300

Epoch 00040: val_loss did not improve from 50.88303
Epoch 41/300

Epoch 00041: val_loss did not improve from 50.88303
Epoch 42/300

Epoch 00042: val_loss did not improve from 50.88303
Epoch 43/300

Epoch 00043: val_loss did not improve from 50.88303
Epoch 44/300

Epoch 00044: val_loss did not improve from 50.88303
Epoch 45/300

Epoch 00045: val_loss did not improve from 50.88303
Epoch 46/300

Epoch 00046: val_loss did not improve from 50.88303
Epoch 47/300

Epoch 00047: val_loss did not improve from 50.88303
Epoch 48/300

Epoch 00048: val_loss did not improve from 50.88303
Epoch 49/300

Epoch 00049: val_loss did not improve from 50.88303
Epoch 50/300

Epoch 00050: val_loss did not improve from 50.88303
Epoch 51/300

Epoch 000


Epoch 00078: val_loss did not improve from 50.88303
Epoch 79/300

Epoch 00079: val_loss did not improve from 50.88303
Epoch 80/300

Epoch 00080: val_loss did not improve from 50.88303
Epoch 81/300

Epoch 00081: val_loss did not improve from 50.88303
Epoch 82/300

Epoch 00082: val_loss did not improve from 50.88303
Epoch 83/300

Epoch 00083: val_loss did not improve from 50.88303
Epoch 84/300

Epoch 00084: val_loss did not improve from 50.88303
Epoch 85/300

Epoch 00085: val_loss did not improve from 50.88303
Epoch 86/300

Epoch 00086: val_loss did not improve from 50.88303
Epoch 87/300

Epoch 00087: val_loss did not improve from 50.88303
Epoch 88/300

Epoch 00088: val_loss did not improve from 50.88303
Epoch 89/300

Epoch 00089: val_loss did not improve from 50.88303
Epoch 90/300

Epoch 00090: val_loss did not improve from 50.88303
Epoch 91/300

Epoch 00091: val_loss did not improve from 50.88303
Epoch 92/300

Epoch 00092: val_loss did not improve from 50.88303
Epoch 93/300

Epoch 000


Epoch 00119: val_loss did not improve from 50.88303
Epoch 120/300

Epoch 00120: val_loss did not improve from 50.88303
Epoch 121/300

Epoch 00121: val_loss did not improve from 50.88303
Epoch 122/300

Epoch 00122: val_loss did not improve from 50.88303
Epoch 123/300

Epoch 00123: val_loss did not improve from 50.88303
Epoch 124/300

Epoch 00124: val_loss did not improve from 50.88303
Epoch 125/300

Epoch 00125: val_loss did not improve from 50.88303
Epoch 126/300

Epoch 00126: val_loss did not improve from 50.88303
Epoch 127/300

Epoch 00127: val_loss did not improve from 50.88303
Epoch 128/300

Epoch 00128: val_loss did not improve from 50.88303
Epoch 129/300

Epoch 00129: val_loss did not improve from 50.88303
Epoch 130/300

Epoch 00130: val_loss did not improve from 50.88303
Epoch 131/300

Epoch 00131: val_loss did not improve from 50.88303
Epoch 132/300

Epoch 00132: val_loss did not improve from 50.88303
Epoch 133/300

Epoch 00133: val_loss did not improve from 50.88303
Epoch 134


Epoch 00202: val_loss did not improve from 50.88303
Epoch 203/300

Epoch 00203: val_loss did not improve from 50.88303
Epoch 204/300

Epoch 00204: val_loss did not improve from 50.88303
Epoch 205/300

Epoch 00205: val_loss did not improve from 50.88303
Epoch 206/300

Epoch 00206: val_loss did not improve from 50.88303
Epoch 207/300

Epoch 00207: val_loss did not improve from 50.88303
Epoch 208/300

Epoch 00208: val_loss did not improve from 50.88303
Epoch 209/300

Epoch 00209: val_loss did not improve from 50.88303
Epoch 210/300

Epoch 00210: val_loss did not improve from 50.88303
Epoch 211/300

Epoch 00211: val_loss did not improve from 50.88303
Epoch 212/300

Epoch 00212: val_loss did not improve from 50.88303
Epoch 213/300

Epoch 00213: val_loss did not improve from 50.88303
Epoch 214/300

Epoch 00214: val_loss did not improve from 50.88303
Epoch 215/300

Epoch 00215: val_loss did not improve from 50.88303
Epoch 216/300

Epoch 00216: val_loss did not improve from 50.88303
Epoch 217


Epoch 00243: val_loss did not improve from 50.88303
Epoch 244/300

Epoch 00244: val_loss did not improve from 50.88303
Epoch 245/300

Epoch 00245: val_loss did not improve from 50.88303
Epoch 246/300

Epoch 00246: val_loss did not improve from 50.88303
Epoch 247/300

Epoch 00247: val_loss did not improve from 50.88303
Epoch 248/300

Epoch 00248: val_loss did not improve from 50.88303
Epoch 249/300

Epoch 00249: val_loss did not improve from 50.88303
Epoch 250/300

Epoch 00250: val_loss did not improve from 50.88303
Epoch 251/300

Epoch 00251: val_loss did not improve from 50.88303
Epoch 252/300

Epoch 00252: val_loss did not improve from 50.88303
Epoch 253/300

Epoch 00253: val_loss did not improve from 50.88303
Epoch 254/300

Epoch 00254: val_loss did not improve from 50.88303
Epoch 255/300

Epoch 00255: val_loss did not improve from 50.88303
Epoch 256/300

Epoch 00256: val_loss did not improve from 50.88303
Epoch 257/300

Epoch 00257: val_loss did not improve from 50.88303
Epoch 258


Epoch 00285: val_loss did not improve from 50.88303
Epoch 286/300

Epoch 00286: val_loss did not improve from 50.88303
Epoch 287/300

Epoch 00287: val_loss did not improve from 50.88303
Epoch 288/300

Epoch 00288: val_loss did not improve from 50.88303
Epoch 289/300

Epoch 00289: val_loss did not improve from 50.88303
Epoch 290/300

Epoch 00290: val_loss did not improve from 50.88303
Epoch 291/300

Epoch 00291: val_loss did not improve from 50.88303
Epoch 292/300

Epoch 00292: val_loss did not improve from 50.88303
Epoch 293/300

Epoch 00293: val_loss did not improve from 50.88303
Epoch 294/300

Epoch 00294: val_loss did not improve from 50.88303
Epoch 295/300

Epoch 00295: val_loss did not improve from 50.88303
Epoch 296/300

Epoch 00296: val_loss did not improve from 50.88303
Epoch 297/300

Epoch 00297: val_loss did not improve from 50.88303
Epoch 298/300

Epoch 00298: val_loss did not improve from 50.88303
Epoch 299/300

Epoch 00299: val_loss did not improve from 50.88303
Epoch 300

## 2.2. Encoder 부르는 방법, Decoder에 값 넣는 방법

In [50]:
latent_size = 512
epoch = 200
BATCH_SIZE = 2

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    model_path = './CKPT/CKP_ep_283__loss_141.77045_.h5'
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)
    vq_vae(0, True)
    vq_vae.summary()
    
    vq_vae.load_weights(model_path)
    
    # 이렇게 하면, transforer의 input으로 들어가는 one-hot 형식의 값을 얻을 수 있음
    for inputs, label in train_dataset:
        encode = vq_vae.encoder(inputs).numpy()
        encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
    
    # 이렇게 하면, transformer의 output을 vq-vae의 decoder 입력으로 넣을 수 있음
    for inputs, label in train_dataset:
        encode = vq_vae.encoder(inputs).numpy()
        encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
        
        # 이렇게 이전 layer의 출렫을 넣으면 됨
        decode = vq_vae.decoder(encode_onehot).numpy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_25 (Softmax)         (None, None, 512)         0         
_________________________________________________________________
encoder (Encoder)            (None, None, 512)         517248    
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           516737    
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 512)         0         
Total params: 1,033,985
Trainable params: 1,033,985
Non-trainable params: 0
_________________________________________________________________


# 3. Test Model

In [110]:
def mkdir_p(path):
    """ Creates a path recursively without throwing an error if it already exists
    :param path: path to create
    :return: None
    """
    if not os.path.exists(path):
        os.makedirs(path)

In [111]:
mkdir_p('./test_wav/') # Result wav 폴더 만드는 코드

In [112]:
def audiowrite(data, path, samplerate=16000, normalize=False, threaded=True):
    """ Write the audio data ``data`` to the wav file ``path``
    The file can be written in a threaded mode. In this case, the writing
    process will be started at a separate thread. Consequently, the file will
    not be written when this function exits.
    :param data: A numpy array with the audio data
    :param path: The wav file the data should be written to
    :param samplerate: Samplerate of the audio data
    :param normalize: Normalize the audio first so that the values are within
        the range of [INTMIN, INTMAX]. E.g. no clipping occurs
    :param threaded: If true, the write process will be started as a separate
        thread
    :return: The number of clipped samples
    """
    data = data.copy()
    int16_max = np.iinfo(np.int16).max
    int16_min = np.iinfo(np.int16).min

    if normalize:
        if not data.dtype.kind == 'f':
            data = data.astype(np.float)
        data /= np.max(np.abs(data))

    if data.dtype.kind == 'f':
        data *= int16_max

    sample_to_clip = np.sum(data > int16_max)
    if sample_to_clip > 0:
        print('Warning, clipping {} samples'.format(sample_to_clip))
    data = np.clip(data, int16_min, int16_max)
    data = data.astype(np.int16)

    if threaded:
        threading.Thread(target=wav_write, args=(path, samplerate, data)).start()
    else:
        wav_write(path, samplerate, data)

    return sample_to_clip

In [115]:
with tf.device('/cpu:0'):
    latent_size = 1024
    sample_rate = 8000
    model_path = './CKPT/CKP_ep_299__loss_70.91695_.h5'
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=True)
    vq_vae(0, True)
    vq_vae.summary()
    vq_vae.load_weights(model_path)

    for batch in test_dataset:
        input_batch, length_batch, name = batch

        result = vq_vae.predict(input_batch)
        
        wav_name = './test_wav/' + name[0][:-5] + '_s1.wav'
        audiowrite(result[0], wav_name, sample_rate, True, True)

Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_20 (Softmax)         (None, None, 1024)        0         
_________________________________________________________________
encoder (Encoder)            (None, None, 1024)        739232    
_________________________________________________________________
embedding_20 (Embedding)     multiple                  0 (unused)
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           1524641   
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 1024)        0         
Total params: 2,263,873
Trainable params: 2,263,873
Non-trainable params: 0
_________________________________________________________________


# 여기 밑에는 연습장임

In [75]:
import numpy as np
from keras.models import Sequential
from keras.layers import Embedding

model = Sequential()
model.add(layers.Conv1D(filters=64, kernel_size=3, activation='relu', padding='same', input_shape=(None, 1)))
model.add(layers.Conv1D(filters=3, kernel_size=3, padding='same'))

input_array = np.random.randn(2, 3, 1)
input_array2 = np.random.randn(2, 9, 1)
with tf.device('/cpu:0'):
    model.compile('rmsprop', 'mse')

    output_array = model.predict(input_array)
    output_array2 = model.predict(input_array2)

In [76]:
tf.reduce_sum(output_array, axis=[1, 2])

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.26833484, 0.08281735], dtype=float32)>

In [43]:
print(output_array)
print(output_array.shape)
print(output_array2.shape)

[[[-0.2349534  -0.01047383 -0.0306325 ]
  [-0.10266599 -0.07345683 -0.0540017 ]
  [-0.09509992 -0.09991434  0.05194581]]

 [[-0.1935267  -0.18672779  0.12265931]
  [-0.18398881 -0.0192      0.01416024]
  [-0.01810659  0.06382702 -0.02847608]]]
(2, 3, 3)
(2, 9, 3)


In [74]:
# dist = tfp.distributions.Categorical(logits=output_array)
a = np.array([[5.0, 2.0, 1.0], [1.0, 2.0, 3.0]])
dist = tf.compat.v1.distributions.Categorical(logits=a)
sampled = dist.sample()
sampled

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 2])>

In [69]:
one_hot = tf.math.argmax(tf.cast(tf.equal(output_array, tf.math.reduce_max(output_array, 2, keepdims=True)), output_array.dtype), axis=-1)
print(one_hot.shape)
print(one_hot)
# layers.Embedding(512, 512)(one_hot)

(2, 10)
tf.Tensor(
[[228 249 283 206  20 206 435  32 270  30]
 [428 206  20 244 357 289 324 249 498 134]], shape=(2, 10), dtype=int64)


In [70]:
one_hot = tf.math.argmax(output_array, axis=-1)
print(one_hot.shape)
print(one_hot)

(2, 10)
tf.Tensor(
[[228 249 283 206  20 206 435  32 270  30]
 [428 206  20 244 357 289 324 249 498 134]], shape=(2, 10), dtype=int64)


In [50]:
qy = tf.nn.softmax(output_array)
log_qy = tf.math.log(qy + 1e-10)
log_uniform = qy * (log_qy - tf.math.log(1.0 / 512))
kl_loss = tf.reduce_sum(log_uniform, axis=[1, 2])
kl_loss = tf.reduce_mean(kl_loss)

print(kl_loss)

tf.Tensor(0.012074914, shape=(), dtype=float32)


In [30]:
batch_size = tf.shape(output_array)[0]
array1_size = tf.shape(output_array)[1]
array2_size = tf.shape(output_array2)[1]
feature_size = tf.shape(output_array)[-1]

if array1_size < array2_size:
#     append_size = array1_size - array2_size
#     append_zeros = tf.zeros([batch_size, append_size, feature_size])
#     append_zeros = tf.Variable(initial_value=tf.zeros((batch_size, append_size, feature_size)))
#     output_array2 = tf.concat([output_array2, append_zeros], axis=1)
    output_array2 = tf.slice(output_array2, [0, 0, 0], [-1, array1_size, -1])
elif array1_size > array2_size:
#     append_size = array2_size - array1_size
#     append_zeros = tf.zeros([batch_size, append_size, feature_size])
#     append_zeros = tf.Variable(initial_value=tf.zeros((batch_size, append_size, feature_size)))
#     output_array = tf.concat([output_array, append_zeros], axis=1)
    output_array = tf.slice(output_array, [0, 0, 0], [-1, array2_size, -1])

print(output_array.shape)
print(output_array2.shape)
# output_array0 = output_array[1]
# output_array20 = output_array2[1]
# target = np.sum(output_array20 * output_array0) * output_array0 / np.square(np.linalg.norm(output_array0, ord=2))
# noise = output_array20 - target
# npnp = 10 * np.log10(np.square(np.linalg.norm(target, ord=2)) / np.square(np.linalg.norm(noise, ord=2)))
# print(npnp)

target = tf.linalg.matmul(output_array2, output_array, transpose_a=True) * output_array / tf.expand_dims(tf.experimental.numpy.square(tf.norm(output_array, axis=1)), axis=-1)
noise = output_array2 - target
si_sdr = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
si_sdr = tf.reduce_mean(si_sdr)
print(si_sdr)

(2, 9, 1)
(2, 9, 1)
tf.Tensor(2.8309882, shape=(), dtype=float32)


In [20]:
tf.cast(tf.equal(output_array, tf.math.reduce_max(output_array, 2, keepdims=True)), output_array.dtype)

<tf.Tensor: shape=(2, 9, 4), dtype=float32, numpy=
array([[[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.]],

       [[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]], dtype=float32)>

In [21]:
output_array

<tf.Tensor: shape=(2, 9, 4), dtype=float32, numpy=
array([[[-0.03009652, -0.03612775, -0.06680483, -0.03670201],
        [-0.04768711, -0.12344762, -0.03924457, -0.11762322],
        [ 0.01808495, -0.16106637, -0.19467078, -0.15282159],
        [-0.0986427 , -0.08625205, -0.12661007, -0.16366175],
        [-0.09758376, -0.08886974, -0.0433558 , -0.19985165],
        [-0.06933096, -0.03154394, -0.13725929, -0.20143284],
        [ 0.03375649,  0.00182091, -0.01022564, -0.35924646],
        [-0.01645333, -0.10466891, -0.13975918, -0.12066491],
        [-0.13588801, -0.08173112, -0.00253745, -0.28615874]],

       [[ 0.04865369, -0.02880372, -0.06414615, -0.07730438],
        [-0.08225074, -0.03192509, -0.06216412, -0.08035193],
        [-0.09515338,  0.04221668,  0.14230826, -0.23082384],
        [-0.00094383,  0.05597762, -0.09290768, -0.08630683],
        [-0.09894791, -0.04727853, -0.01004983, -0.30325216],
        [ 0.01705559, -0.16948727, -0.08829505, -0.16453639],
        [-0.07230

In [67]:
output_softmax = tf.nn.softmax(output_array)
output_softmax

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0.24628553, 0.282701  , 0.2385324 , 0.23248109],
        [0.2298986 , 0.23856457, 0.25392184, 0.27761498],
        [0.2101039 , 0.2444843 , 0.26824066, 0.27717113],
        [0.22202027, 0.29728216, 0.2501223 , 0.23057525],
        [0.24084595, 0.2724257 , 0.25000373, 0.23672463]],

       [[0.23988546, 0.27957046, 0.23915865, 0.24138539],
        [0.24975686, 0.27135593, 0.24479878, 0.23408844],
        [0.24278015, 0.26340333, 0.24202275, 0.25179377],
        [0.23139507, 0.262904  , 0.25970972, 0.2459912 ],
        [0.2601803 , 0.25627998, 0.24524413, 0.23829558]]], dtype=float32)>

In [73]:
output_reshape = tf.reshape(output_softmax, [-1, 4])
output_reshape.shape

TensorShape([10, 4])

In [83]:
tf.reshape(tf.nn.softmax(output_array), [-1, 5, 4])

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0.24628553, 0.282701  , 0.2385324 , 0.23248109],
        [0.2298986 , 0.23856457, 0.25392184, 0.27761498],
        [0.2101039 , 0.2444843 , 0.26824066, 0.27717113],
        [0.22202027, 0.29728216, 0.2501223 , 0.23057525],
        [0.24084595, 0.2724257 , 0.25000373, 0.23672463]],

       [[0.23988546, 0.27957046, 0.23915865, 0.24138539],
        [0.24975686, 0.27135593, 0.24479878, 0.23408844],
        [0.24278015, 0.26340333, 0.24202275, 0.25179377],
        [0.23139507, 0.262904  , 0.25970972, 0.2459912 ],
        [0.2601803 , 0.25627998, 0.24524413, 0.23829558]]], dtype=float32)>

In [76]:
# tf.cast(tf.equal(y, tf.reduce_max(y,1,keep_dims=True)), y.dtype)
output_hard = tf.cast(tf.equal(output_reshape, tf.math.reduce_max(output_reshape, 1, keepdims=True)), output_softmax.dtype)
output_hard

<tf.Tensor: shape=(10, 4), dtype=float32, numpy=
array([[0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.]], dtype=float32)>

In [81]:
tf.reshape(output_hard, [-1, 5, 4])

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.]],

       [[0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.]]], dtype=float32)>

In [322]:
tf.transpose(output_array, perm=[0, 2, 1]).shape

TensorShape([2, 4, 5])

In [316]:
layers.Softmax(output_array)

<tensorflow.python.keras.layers.advanced_activations.Softmax at 0x26f14e934c8>

In [317]:
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).shape

(3, 3)

In [65]:
np.log(10)

2.302585092994046