# 1. Data Generator
- Raw Data를 읽어옴
- 여기서 만들어진 데이터는 모델의 입력으로 들어감

In [1]:
import os
import numpy as np
import librosa
from tensorflow.keras.utils import Sequence

In [2]:
class RawForVAEGenerator(Sequence):
    def __init__(self, source, wav_dir, files, sourNum='s1', batch_size=10, shuffle=True):
        self.source = source
        self.wav_dir = wav_dir
        self.files = files
        self.sourNum = sourNum
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()
        
        self.sample_rate = 8000
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.source))
        
        if self.shuffle:
            np.random.shuffle(self.indexes)
    
    def __audioread__(self, path, offset=0.0, duration=None, sample_rate=16000):
        signal = librosa.load(path, sr=self.sample_rate, mono=False, offset=offset, duration=duration)

        return signal[0]
    
    def __padding__(self, data):
        n_batch = len(data)
        max_len = max([d.shape[0] for d in data])
        extrapadding = int(np.ceil(max_len / self.sample_rate) * self.sample_rate)
        pad = np.zeros((n_batch, extrapadding))
        
        for i in range(n_batch):
            pad[i, :data[i].shape[0]] = data[i]
        
        return np.expand_dims(pad, -1)
        
    def __data_generation__(self, source_list):
        wav_list = []
        for name in source_list:
            name = name.strip('\n')
            
            s_wav_name = self.wav_dir + self.files + '/' + self.sourNum + '/' + name
            
            # ------- AUDIO READ -------
            s_wav = (self.__audioread__(s_wav_name,  offset=0.0, duration=None, sample_rate=self.sample_rate))
            # --------------------------
            
            # ------- PADDING -------
#             pad_len = max(len(samples1),len(samples2))
#             pad_s1 = np.concatenate([s1_wav, np.zeros([pad_len - len(s1_wav)])])
            
#             extrapadding = ceil(len(pad_s1) / sample_rate) * sample_rate - len(pad_s1)
#             pad_s1 = np.concatenate([pad_s1, np.zeros([extrapadding - len(pad_s1)])])
#             pad_s2 = np.concatenate([s2_wav, np.zeros([extrapadding - len(s2_wav)])])
            # -----------------------
            
            wav_list.append(s_wav)
        
        return wav_list, wav_list, source_list
            
    
    def __len__(self):
        return int(np.floor(len(self.source) / self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        source_list = [self.source[k] for k in indexes]
        
        if self.files is not 'tt':
            sour, labels, _ = self.__data_generation__(source_list)
            
            # Get Lengths(K value of each batch)
            lengths = np.array([m.shape[0] for m in sour])
            exp = np.expand_dims(lengths, 1)
            exp = np.expand_dims(exp, -1) # [Batch, 1, 1] (length)
            
            # Padding
            sour_pad = self.__padding__(sour) # [Batch, Time_step, Dimension(=1)]
            label_pad = self.__padding__(labels) # [Batch, Time_step, Dimension(=1)]
            
            return sour_pad, np.concatenate([label_pad, exp], axis=1)
        else:
            sour, labels, name = self.__data_generation__(source_list)
            
            # Get Lengths(K value of each batch)
            lengths = np.array([m.shape[0] for m in sour])
            exp = np.expand_dims(lengths, 1)
            exp = np.expand_dims(exp, -1) # [Batch, 1, 1] (length)
            
            # Padding
            sour_pad = self.__padding__(sour) # [Batch, Time_step, Dimension(=1)]
            
            return sour_pad, exp, name

## Data를 어떻게 읽는지에 대한 부분

In [3]:
WAV_DIR = './mycode/wsj0_2mix/use_this/'
LIST_DIR = './mycode/wsj0_2mix/use_this/lists/'

In [4]:
# Directory List file create

wav_dir = WAV_DIR
output_lst = LIST_DIR

for folder in ['tr', 'cv', 'tt']:
    wav_files = os.listdir(wav_dir + folder + '/mix')
    output_lst_files = output_lst + folder + '_wav.lst'
    with open(output_lst_files, 'w') as f:
        for file in wav_files:
            f.write(file + "\n")

print("Generate wav file to .lst done!")

Generate wav file to .lst done!


In [5]:
batch_size = 1

train_dataset = 0
valid_dataset = 0
test_dataset = 0

name_list = []
for files in ['tr', 'cv', 'tt']:
    # --- Lead lst file ---""
    output_lst_files = LIST_DIR + files + '_wav.lst'
    fid = open(output_lst_files, 'r')
    lines = fid.readlines()
    fid.close()
    # ---------------------
    
    if files == 'tr':
        train_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', batch_size)
    elif files == 'cv':
        valid_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', batch_size)
    else:
        test_batch = 1
        test_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', test_batch)

# 2. Building VQ-VAE model with Gumbel Softmax

In [6]:
import threading
from scipy.io.wavfile import write as wav_write
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from tensorflow.keras import backend as Kb
import numpy as np
import pandas as pd
from importlib import reload
import time
from tensorflow.keras.models import Model, Sequential, load_model

In [7]:
def mkdir_p(path):
    """ Creates a path recursively without throwing an error if it already exists
    :param path: path to create
    :return: None
    """
    if not os.path.exists(path):
        os.makedirs(path)

In [8]:
mkdir_p('./CKPT/') # model check point 폴더 만드는 코드
filepath = "./CKPT/CKP_ep_{epoch:d}__loss_{val_loss:.5f}_.h5"

In [9]:
initial_learning_rate = 0.001

# learning rate를 점점 줄이는 부분
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
)

# validation loss에 대해서 좋은 것만 저장됨
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True, mode='min'
)

# early stop 하는 부분인데, validation loss에 대해서 제일 좋은 모델이 저장됨
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', mode='min', verbose=1, patience=50, restore_best_weights=True
)

In [147]:
class GumbelSoftmax(layers.Layer):
    def __init__(self, temperature=0.5, hard=False, name = 'gumbel_softmax',**kwargs):
        super(GumbelSoftmax, self).__init__(name=name, **kwargs)
        
        self.temperature = temperature
        self.hard = hard
    
    def sample_gumbel(self, shape, eps=1e-20): 
        """Sample from Gumbel(0, 1)"""
        U = tf.random.uniform(shape,minval=0,maxval=1)
        
        return -tf.math.log(-tf.math.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature): 
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        
        return tf.nn.softmax(y / temperature)

    def call(self, inputs):
        y = self.gumbel_softmax_sample(inputs, self.temperature)
        
        if self.hard:
            y_hard = tf.cast(tf.equal(y, tf.math.reduce_max(y, 2, keepdims=True)), y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        
        return y


class Encoder(layers.Layer):
    def __init__(self, latent_dim, name = 'encoder',**kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        
        self.conv1d_1 = layers.Conv1D(filters=32, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_2 = layers.Conv1D(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_3 = layers.Conv1D(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_4 = layers.Conv1D(filters=256, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_5 = layers.Conv1D(filters=512, kernel_size=4, strides=2, activation='relu', padding='same')
        self.logit = layers.Conv1D(filters=latent_dim, kernel_size=1, strides=1, activation=None, padding='valid')
    
    def call(self, inputs):
        x = self.conv1d_1(inputs)
        x = self.conv1d_2(x)
        x = self.conv1d_3(x)
        x = self.conv1d_4(x)
        x = self.conv1d_5(x)
        logit = self.logit(x)
        
        return logit


class Decoder(layers.Layer):
    def __init__(self, latent_dim, name = 'decoder',**kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        
        self.trans_conv1d_1 = layers.Conv1DTranspose(filters=512, kernel_size=1, strides=1, activation='relu', padding='same')
        self.trans_conv1d_2 = layers.Conv1DTranspose(filters=256, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_3 = layers.Conv1DTranspose(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_4 = layers.Conv1DTranspose(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_5 = layers.Conv1DTranspose(filters=32, kernel_size=4, strides=2, activation='relu', padding='same')
        self.logit = layers.Conv1DTranspose(filters=1, kernel_size=4, strides=2, activation=None, padding='same')
    
    def call(self, inputs):
        x = self.trans_conv1d_1(inputs)
        x = self.trans_conv1d_2(x)
        x = self.trans_conv1d_3(x)
        x = self.trans_conv1d_4(x)
        x = self.trans_conv1d_5(x)
        logit = self.logit(x)
        
        return logit

In [148]:
# Custom Metric Si-sdr

class SiSdr(keras.metrics.Metric):
    def __init__(self, name="Si-sdr", **kwargs):
        super(SiSdr, self).__init__(name=name, **kwargs)
        self.sdr = self.add_weight(name="sdr", initializer="zeros")
        self.count = self.add_weight(name="cnt", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        ori_length = tf.shape(y_true)[1]
        
        # Label & Length divide
        labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 1]
        lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]
        
        # Check sequence length
        batch_size = tf.shape(labels)[0]
        label_size = tf.shape(labels)[1]
        pred_size = tf.shape(y_pred)[1]
        feature_size = tf.shape(labels)[-1]
        
        # Change sequence length
        if label_size < pred_size:
            y_pred = tf.slice(y_pred, [0, 0, 0], [-1, label_size, -1])
        elif label_size > pred_size:
            labels = tf.slice(labels, [0, 0, 0], [-1, pred_size, -1])

        # SI-SDR
        target = tf.linalg.matmul(y_pred, labels, transpose_a=True) * labels / tf.expand_dims(tf.experimental.numpy.square(tf.norm(labels, axis=1)), axis=-1)
        noise = y_pred - target
        values = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
        
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.sdr.assign_add(tf.reduce_sum(values))
        self.count.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))

    def result(self):
        return self.sdr / self.count

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.sdr.assign(0.0)
        self.count.assign(0.0)

In [149]:
# Custom loss

# Custom mse
def custom_mse(y_true, y_pred):
    ori_length = tf.shape(y_true)[1]

    # Label & Length divide
    labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 129]
    lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]

    loss = tf.reduce_sum(tf.pow(y_pred - labels, 2), axis=[1, 2])
    loss = tf.reduce_mean(loss)

    return loss


# Custom si-sdr loss
def custom_sisdr_loss(y_true, y_pred):
    ori_length = tf.shape(y_true)[1]

    # Label & Length divide
    labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 1]
    lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]

    target = tf.linalg.matmul(y_pred, labels, transpose_a=True) * labels / tf.expand_dims(tf.experimental.numpy.square(tf.norm(labels, axis=1)), axis=-1)
    noise = y_pred - target
    si_sdr = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
    si_sdr = tf.reduce_mean(si_sdr) * -1

    return si_sdr

In [166]:
class Vq_vae(keras.Model):
    def __init__(self, latent_dim, gumbel_hard=False, name='vqvae', **kwargs):
        super(Vq_vae, self).__init__(name=name, **kwargs)
        
        self.latent_dim = latent_dim
        self.softmax = layers.Softmax(-1)
        
        self.encoder = Encoder(latent_dim)
        self.embeddings = layers.Embedding(latent_dim, latent_dim)
        self.decoder = Decoder(latent_dim)
        self.gumbel = GumbelSoftmax(hard=gumbel_hard)
        
    def call(self, inputs, load=False):
        if load:
            inputs = layers.Input(shape=(None, 1))
        
        
        encode = self.encoder(inputs)
        gumbel = self.gumbel(encode)
        decode = self.decoder(gumbel)
        
        # ------------------ KL loss ------------------
        qy = self.softmax(encode)
        log_qy = tf.math.log(qy + 1e-10)
        log_uniform = qy * (log_qy - tf.math.log(1.0 / self.latent_dim))
        kl_loss = tf.reduce_sum(log_uniform, axis=[1, 2])
        kl_loss = tf.reduce_mean(kl_loss) * 0.1
        # ---------------------------------------------
        
        self.add_loss(kl_loss)
        
        return decode

# 이렇게 GradientTape 를 사용해서 프로그램 해도 됨

In [173]:
tf.random.set_seed(42)

latent_size = 1024
epochs = 600

filePath = "./CKPT/CKP_ep_{0}__loss_{1:.5f}_.h5"
model_path = './CKPT/CKP_ep_576__loss_168.75122_.h5'

loss_fun = custom_mse
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

train_loss = tf.keras.metrics.Mean()
train_kl_loss = tf.keras.metrics.Mean()
valid_loss = tf.keras.metrics.Mean()
sisdr_Metric = SiSdr()
val_sisdr_Metric = SiSdr()

# Model 불러오는 부분이다
vq_vae = Vq_vae(latent_size, gumbel_hard=False)
vq_vae(0, True)
vq_vae.load_weights(model_path)

In [174]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        # Call model
        results = vq_vae(x)
        
        loss_value = loss_fun(y, results)
        loss_value += sum(vq_vae.losses) # Add KL loss
    
    # Update weights
    grads = tape.gradient(loss_value, vq_vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vq_vae.trainable_weights))
    
    # Update loss and si-sdr
    train_loss.update_state(loss_value)
    sisdr_Metric.update_state(y, results)
    
    train_kl_loss.update_state(sum(vq_vae.losses))
    
    return loss_value

@tf.function
def test_step(x, y):
    # Call model
    val_results = vq_vae(x)
    
    val_loss_value = loss_fun(y, val_results)
    val_loss_value += sum(vq_vae.losses) # Add KL loss
    
    # Update loss and si-sdr
    valid_loss.update_state(val_loss_value)
    val_sisdr_Metric.update_state(y, val_results)
    
    return val_loss_value

In [175]:
previous_loss = float('inf')

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch+1,))
    start_time = time.time()

    # Iterate over the batches of the dataset
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        x_batch_train = tf.cast(x_batch_train, dtype=tf.float32)
        y_batch_train = tf.cast(y_batch_train, dtype=tf.float32)
        
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 1 batches
#         if step % 1 == 0:
#             print("Training loss (for one batch) at step %d: %.4f" % (step, train_loss.result()))
#             print("Training Si-sdr (for one batch) at step %d: %.4f" % (step, sisdr_Metric.result()))
#             print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Run a validation loop at the end of each epoch
    for x_batch_val, y_batch_val in valid_dataset:
        x_batch_val = tf.cast(x_batch_val, dtype=tf.float32)
        y_batch_val = tf.cast(y_batch_val, dtype=tf.float32)
        
        val_loss_value = test_step(x_batch_val, y_batch_val)
    
    print()
    print('----------------------------------------------------------------------------------')
    print("Time taken >>> %.2fs <<<" % (time.time() - start_time))
    print('epoch: {}, Train_loss: {}, Train_Si-sdr: {}, Train_KL_loss: {} \n\
    Valid_loss: {}, Valid_Si-sdr: {}'.format(
        epoch+1,
        train_loss.result(),
        sisdr_Metric.result(),
        train_kl_loss.result(),
        valid_loss.result(),
        val_sisdr_Metric.result()))
    print('----------------------------------------------------------------------------------')
    
    # Save Model
    if valid_loss.result() < previous_loss:
        filePath_temp = filePath.format(epoch+1, valid_loss.result())
        
        vq_vae.save_weights(filePath_temp)
        print('Epoch {}: val_loss improved from {} to {}, saving model to {}'.format(
            epoch+1,
            previous_loss,
            valid_loss.result(),
            filePath_temp))
        
        previous_loss = valid_loss.result()
    else:
        print('Epoch {}: val_loss did not improve from {}'.format(
            epoch+1,
            previous_loss))
    print()
    
    # Reset metrics at the end of each epoch
    train_loss.reset_states()
    sisdr_Metric.reset_states()
    valid_loss.reset_states()
    val_sisdr_Metric.reset_states()
    
    train_kl_loss.reset_states()
    
    # Data shuffle at the end of each epoch
    train_dataset.on_epoch_end()
    valid_dataset.on_epoch_end()


Start of epoch 1

----------------------------------------------------------------------------------
Time taken >>> 3.19s <<<
epoch: 1, Train_loss: 170.64132690429688, Train_Si-sdr: 2.217930793762207, Train_KL_loss: 58.307884216308594 
    Valid_loss: 165.9063720703125, Valid_Si-sdr: 2.36375093460083
----------------------------------------------------------------------------------
Epoch 1: val_loss improved from inf to 165.9063720703125, saving model to ./CKPT/CKP_ep_1__loss_165.90637_.h5


Start of epoch 2

----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 2, Train_loss: 181.36134338378906, Train_Si-sdr: 1.606610655784607, Train_KL_loss: 58.40456008911133 
    Valid_loss: 176.42575073242188, Valid_Si-sdr: 2.0854380130767822
----------------------------------------------------------------------------------
Epoch 2: val_loss did not improve from 165.9063720703125


Start of epoch 3

--------------------------------------


----------------------------------------------------------------------------------
Time taken >>> 1.02s <<<
epoch: 20, Train_loss: 177.59649658203125, Train_Si-sdr: 2.1253292560577393, Train_KL_loss: 63.61944580078125 
    Valid_loss: 174.72531127929688, Valid_Si-sdr: 2.309779644012451
----------------------------------------------------------------------------------
Epoch 20: val_loss did not improve from 165.9063720703125


Start of epoch 21

----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 21, Train_loss: 171.7360382080078, Train_Si-sdr: 2.530550956726074, Train_KL_loss: 63.045631408691406 
    Valid_loss: 171.48529052734375, Valid_Si-sdr: 2.3708930015563965
----------------------------------------------------------------------------------
Epoch 21: val_loss did not improve from 165.9063720703125


Start of epoch 22

----------------------------------------------------------------------------------
Time taken >>> 1.0


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 39, Train_loss: 176.87759399414062, Train_Si-sdr: 2.1519222259521484, Train_KL_loss: 63.352088928222656 
    Valid_loss: 173.50804138183594, Valid_Si-sdr: 2.463007926940918
----------------------------------------------------------------------------------
Epoch 39: val_loss did not improve from 165.9063720703125


Start of epoch 40

----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 40, Train_loss: 174.2301025390625, Train_Si-sdr: 2.3745625019073486, Train_KL_loss: 63.930946350097656 
    Valid_loss: 170.62049865722656, Valid_Si-sdr: 2.5669026374816895
----------------------------------------------------------------------------------
Epoch 40: val_loss did not improve from 165.9063720703125


Start of epoch 41

----------------------------------------------------------------------------------
Time taken >>> 0


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 58, Train_loss: 176.04168701171875, Train_Si-sdr: 1.9879190921783447, Train_KL_loss: 58.84613037109375 
    Valid_loss: 172.21522521972656, Valid_Si-sdr: 2.259948968887329
----------------------------------------------------------------------------------
Epoch 58: val_loss did not improve from 165.9063720703125


Start of epoch 59

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 59, Train_loss: 172.94662475585938, Train_Si-sdr: 2.2867884635925293, Train_KL_loss: 60.925140380859375 
    Valid_loss: 174.18374633789062, Valid_Si-sdr: 2.3049275875091553
----------------------------------------------------------------------------------
Epoch 59: val_loss did not improve from 165.9063720703125


Start of epoch 60

----------------------------------------------------------------------------------
Time taken >>> 1


----------------------------------------------------------------------------------
Time taken >>> 1.04s <<<
epoch: 77, Train_loss: 173.3389892578125, Train_Si-sdr: 2.182582378387451, Train_KL_loss: 60.04447937011719 
    Valid_loss: 169.20140075683594, Valid_Si-sdr: 2.48701810836792
----------------------------------------------------------------------------------
Epoch 77: val_loss did not improve from 165.9063720703125


Start of epoch 78

----------------------------------------------------------------------------------
Time taken >>> 1.02s <<<
epoch: 78, Train_loss: 171.6219024658203, Train_Si-sdr: 2.372267723083496, Train_KL_loss: 61.03981018066406 
    Valid_loss: 168.11337280273438, Valid_Si-sdr: 2.6053948402404785
----------------------------------------------------------------------------------
Epoch 78: val_loss did not improve from 165.9063720703125


Start of epoch 79

----------------------------------------------------------------------------------
Time taken >>> 1.00s <


----------------------------------------------------------------------------------
Time taken >>> 1.02s <<<
epoch: 96, Train_loss: 167.67250061035156, Train_Si-sdr: 2.488866090774536, Train_KL_loss: 60.04230880737305 
    Valid_loss: 164.0338592529297, Valid_Si-sdr: 2.662656784057617
----------------------------------------------------------------------------------
Epoch 96: val_loss improved from 165.9063720703125 to 164.0338592529297, saving model to ./CKPT/CKP_ep_96__loss_164.03386_.h5


Start of epoch 97

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 97, Train_loss: 174.62635803222656, Train_Si-sdr: 2.130462169647217, Train_KL_loss: 59.314109802246094 
    Valid_loss: 170.8115997314453, Valid_Si-sdr: 2.4208602905273438
----------------------------------------------------------------------------------
Epoch 97: val_loss did not improve from 164.0338592529297


Start of epoch 98

-----------------------------------


----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 114, Train_loss: 174.51708984375, Train_Si-sdr: 2.171738386154175, Train_KL_loss: 61.12372970581055 
    Valid_loss: 170.81834411621094, Valid_Si-sdr: 2.426072597503662
----------------------------------------------------------------------------------
Epoch 114: val_loss did not improve from 163.11868286132812


Start of epoch 115

----------------------------------------------------------------------------------
Time taken >>> 1.02s <<<
epoch: 115, Train_loss: 171.55503845214844, Train_Si-sdr: 2.4900174140930176, Train_KL_loss: 62.71922302246094 
    Valid_loss: 171.54185485839844, Valid_Si-sdr: 2.5347208976745605
----------------------------------------------------------------------------------
Epoch 115: val_loss did not improve from 163.11868286132812


Start of epoch 116

----------------------------------------------------------------------------------
Time taken >>


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 133, Train_loss: 171.5380096435547, Train_Si-sdr: 2.4521946907043457, Train_KL_loss: 62.31769561767578 
    Valid_loss: 169.73385620117188, Valid_Si-sdr: 2.584282636642456
----------------------------------------------------------------------------------
Epoch 133: val_loss did not improve from 163.11868286132812


Start of epoch 134

----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 134, Train_loss: 168.92770385742188, Train_Si-sdr: 2.64862060546875, Train_KL_loss: 63.35304260253906 
    Valid_loss: 167.595458984375, Valid_Si-sdr: 2.791741371154785
----------------------------------------------------------------------------------
Epoch 134: val_loss did not improve from 163.11868286132812


Start of epoch 135

----------------------------------------------------------------------------------
Time taken >>> 


----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 152, Train_loss: 170.06150817871094, Train_Si-sdr: 2.5116617679595947, Train_KL_loss: 62.030418395996094 
    Valid_loss: 166.62408447265625, Valid_Si-sdr: 2.727428913116455
----------------------------------------------------------------------------------
Epoch 152: val_loss did not improve from 160.7294921875


Start of epoch 153

----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 153, Train_loss: 170.2481231689453, Train_Si-sdr: 2.5282649993896484, Train_KL_loss: 61.94981002807617 
    Valid_loss: 166.4190673828125, Valid_Si-sdr: 2.7114713191986084
----------------------------------------------------------------------------------
Epoch 153: val_loss did not improve from 160.7294921875


Start of epoch 154

----------------------------------------------------------------------------------
Time taken >>> 0.9


----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 171, Train_loss: 171.95394897460938, Train_Si-sdr: 2.174494743347168, Train_KL_loss: 59.50470733642578 
    Valid_loss: 169.2172088623047, Valid_Si-sdr: 2.426795721054077
----------------------------------------------------------------------------------
Epoch 171: val_loss did not improve from 160.7294921875


Start of epoch 172

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 172, Train_loss: 169.3689727783203, Train_Si-sdr: 2.467247724533081, Train_KL_loss: 60.533058166503906 
    Valid_loss: 165.77134704589844, Valid_Si-sdr: 2.7495439052581787
----------------------------------------------------------------------------------
Epoch 172: val_loss did not improve from 160.7294921875


Start of epoch 173

----------------------------------------------------------------------------------
Time taken >>> 1.04s


----------------------------------------------------------------------------------
Time taken >>> 1.03s <<<
epoch: 190, Train_loss: 170.2555694580078, Train_Si-sdr: 2.543661117553711, Train_KL_loss: 62.21354293823242 
    Valid_loss: 164.88021850585938, Valid_Si-sdr: 2.981992721557617
----------------------------------------------------------------------------------
Epoch 190: val_loss did not improve from 160.7294921875


Start of epoch 191

----------------------------------------------------------------------------------
Time taken >>> 1.05s <<<
epoch: 191, Train_loss: 167.01937866210938, Train_Si-sdr: 2.8199305534362793, Train_KL_loss: 64.55899047851562 
    Valid_loss: 163.8179473876953, Valid_Si-sdr: 3.002420425415039
----------------------------------------------------------------------------------
Epoch 191: val_loss did not improve from 160.7294921875


Start of epoch 192

----------------------------------------------------------------------------------
Time taken >>> 1.03s 


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 209, Train_loss: 169.46463012695312, Train_Si-sdr: 2.3611600399017334, Train_KL_loss: 59.47051239013672 
    Valid_loss: 164.95384216308594, Valid_Si-sdr: 2.6290762424468994
----------------------------------------------------------------------------------
Epoch 209: val_loss did not improve from 160.7294921875


Start of epoch 210

----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 210, Train_loss: 170.27740478515625, Train_Si-sdr: 2.4332408905029297, Train_KL_loss: 60.65489959716797 
    Valid_loss: 165.56826782226562, Valid_Si-sdr: 2.8502955436706543
----------------------------------------------------------------------------------
Epoch 210: val_loss did not improve from 160.7294921875


Start of epoch 211

----------------------------------------------------------------------------------
Time taken >>> 1


----------------------------------------------------------------------------------
Time taken >>> 1.03s <<<
epoch: 228, Train_loss: 167.26698303222656, Train_Si-sdr: 2.6598761081695557, Train_KL_loss: 61.512001037597656 
    Valid_loss: 166.05352783203125, Valid_Si-sdr: 2.6429715156555176
----------------------------------------------------------------------------------
Epoch 228: val_loss did not improve from 160.7294921875


Start of epoch 229

----------------------------------------------------------------------------------
Time taken >>> 1.07s <<<
epoch: 229, Train_loss: 165.28738403320312, Train_Si-sdr: 2.7347912788391113, Train_KL_loss: 60.874366760253906 
    Valid_loss: 163.8655242919922, Valid_Si-sdr: 2.6780693531036377
----------------------------------------------------------------------------------
Epoch 229: val_loss did not improve from 160.7294921875


Start of epoch 230

----------------------------------------------------------------------------------
Time taken >>> 


----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 247, Train_loss: 168.7781219482422, Train_Si-sdr: 2.436429977416992, Train_KL_loss: 59.372398376464844 
    Valid_loss: 164.86781311035156, Valid_Si-sdr: 2.6134145259857178
----------------------------------------------------------------------------------
Epoch 247: val_loss did not improve from 160.7294921875


Start of epoch 248

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 248, Train_loss: 168.11805725097656, Train_Si-sdr: 2.501309871673584, Train_KL_loss: 58.9129638671875 
    Valid_loss: 163.57534790039062, Valid_Si-sdr: 2.831010341644287
----------------------------------------------------------------------------------
Epoch 248: val_loss did not improve from 160.7294921875


Start of epoch 249

----------------------------------------------------------------------------------
Time taken >>> 0.99s


----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 266, Train_loss: 170.51092529296875, Train_Si-sdr: 2.486591339111328, Train_KL_loss: 61.5166130065918 
    Valid_loss: 170.2000274658203, Valid_Si-sdr: 2.3843820095062256
----------------------------------------------------------------------------------
Epoch 266: val_loss did not improve from 160.7294921875


Start of epoch 267

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 267, Train_loss: 168.15719604492188, Train_Si-sdr: 2.550713062286377, Train_KL_loss: 60.916221618652344 
    Valid_loss: 166.55929565429688, Valid_Si-sdr: 2.6125235557556152
----------------------------------------------------------------------------------
Epoch 267: val_loss did not improve from 160.7294921875


Start of epoch 268

----------------------------------------------------------------------------------
Time taken >>> 0.99


----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 285, Train_loss: 167.3802490234375, Train_Si-sdr: 2.667372465133667, Train_KL_loss: 61.963951110839844 
    Valid_loss: 163.4745330810547, Valid_Si-sdr: 2.8566551208496094
----------------------------------------------------------------------------------
Epoch 285: val_loss did not improve from 160.7294921875


Start of epoch 286

----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 286, Train_loss: 163.47586059570312, Train_Si-sdr: 2.8094539642333984, Train_KL_loss: 59.96678924560547 
    Valid_loss: 162.84295654296875, Valid_Si-sdr: 2.7466564178466797
----------------------------------------------------------------------------------
Epoch 286: val_loss did not improve from 160.7294921875


Start of epoch 287

----------------------------------------------------------------------------------
Time taken >>> 1.0


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 304, Train_loss: 168.80551147460938, Train_Si-sdr: 2.6281487941741943, Train_KL_loss: 62.307350158691406 
    Valid_loss: 164.12892150878906, Valid_Si-sdr: 2.865877151489258
----------------------------------------------------------------------------------
Epoch 304: val_loss did not improve from 160.09945678710938


Start of epoch 305

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 305, Train_loss: 168.35597229003906, Train_Si-sdr: 2.571929931640625, Train_KL_loss: 61.595359802246094 
    Valid_loss: 162.8459014892578, Valid_Si-sdr: 2.863967180252075
----------------------------------------------------------------------------------
Epoch 305: val_loss did not improve from 160.09945678710938


Start of epoch 306

----------------------------------------------------------------------------------
Time taken


----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 323, Train_loss: 168.0724334716797, Train_Si-sdr: 2.324481248855591, Train_KL_loss: 57.42721939086914 
    Valid_loss: 168.34384155273438, Valid_Si-sdr: 2.341130495071411
----------------------------------------------------------------------------------
Epoch 323: val_loss did not improve from 160.09945678710938


Start of epoch 324

----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 324, Train_loss: 165.2529296875, Train_Si-sdr: 2.6230268478393555, Train_KL_loss: 59.10908508300781 
    Valid_loss: 161.25030517578125, Valid_Si-sdr: 3.012364387512207
----------------------------------------------------------------------------------
Epoch 324: val_loss did not improve from 160.09945678710938


Start of epoch 325

----------------------------------------------------------------------------------
Time taken >>> 1


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 341, Train_loss: 163.26934814453125, Train_Si-sdr: 2.9741177558898926, Train_KL_loss: 62.79109191894531 
    Valid_loss: 161.34912109375, Valid_Si-sdr: 3.0404253005981445
----------------------------------------------------------------------------------
Epoch 341: val_loss did not improve from 159.79026794433594


Start of epoch 342

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 342, Train_loss: 167.80404663085938, Train_Si-sdr: 2.630074977874756, Train_KL_loss: 60.85508728027344 
    Valid_loss: 164.41485595703125, Valid_Si-sdr: 2.6846776008605957
----------------------------------------------------------------------------------
Epoch 342: val_loss did not improve from 159.79026794433594


Start of epoch 343

----------------------------------------------------------------------------------
Time taken >


----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 360, Train_loss: 162.1719970703125, Train_Si-sdr: 2.8649909496307373, Train_KL_loss: 59.59733963012695 
    Valid_loss: 163.75228881835938, Valid_Si-sdr: 2.7369894981384277
----------------------------------------------------------------------------------
Epoch 360: val_loss did not improve from 159.79026794433594


Start of epoch 361

----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 361, Train_loss: 165.56675720214844, Train_Si-sdr: 2.7657809257507324, Train_KL_loss: 60.727901458740234 
    Valid_loss: 164.7001495361328, Valid_Si-sdr: 2.749007225036621
----------------------------------------------------------------------------------
Epoch 361: val_loss did not improve from 159.79026794433594


Start of epoch 362

----------------------------------------------------------------------------------
Time taken


----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 379, Train_loss: 165.82696533203125, Train_Si-sdr: 2.803722858428955, Train_KL_loss: 62.98537063598633 
    Valid_loss: 162.31866455078125, Valid_Si-sdr: 3.0736303329467773
----------------------------------------------------------------------------------
Epoch 379: val_loss did not improve from 159.6747283935547


Start of epoch 380

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 380, Train_loss: 165.47927856445312, Train_Si-sdr: 2.86286997795105, Train_KL_loss: 63.58484649658203 
    Valid_loss: 160.99928283691406, Valid_Si-sdr: 3.1211376190185547
----------------------------------------------------------------------------------
Epoch 380: val_loss did not improve from 159.6747283935547


Start of epoch 381

----------------------------------------------------------------------------------
Time taken >>


----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 397, Train_loss: 165.3421630859375, Train_Si-sdr: 2.7681522369384766, Train_KL_loss: 62.063629150390625 
    Valid_loss: 161.0601806640625, Valid_Si-sdr: 2.9639394283294678
----------------------------------------------------------------------------------
Epoch 397: val_loss did not improve from 158.20651245117188


Start of epoch 398

----------------------------------------------------------------------------------
Time taken >>> 1.02s <<<
epoch: 398, Train_loss: 164.12747192382812, Train_Si-sdr: 2.773609161376953, Train_KL_loss: 60.21014404296875 
    Valid_loss: 160.16500854492188, Valid_Si-sdr: 2.9571797847747803
----------------------------------------------------------------------------------
Epoch 398: val_loss did not improve from 158.20651245117188


Start of epoch 399

----------------------------------------------------------------------------------
Time taken


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 416, Train_loss: 166.70925903320312, Train_Si-sdr: 2.7579257488250732, Train_KL_loss: 62.05764389038086 
    Valid_loss: 163.27261352539062, Valid_Si-sdr: 2.81951904296875
----------------------------------------------------------------------------------
Epoch 416: val_loss did not improve from 158.20651245117188


Start of epoch 417

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 417, Train_loss: 167.9879150390625, Train_Si-sdr: 2.5337281227111816, Train_KL_loss: 60.88188934326172 
    Valid_loss: 163.67434692382812, Valid_Si-sdr: 2.938814640045166
----------------------------------------------------------------------------------
Epoch 417: val_loss did not improve from 158.20651245117188


Start of epoch 418

----------------------------------------------------------------------------------
Time taken >


----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 435, Train_loss: 164.6424560546875, Train_Si-sdr: 2.721987247467041, Train_KL_loss: 60.36070251464844 
    Valid_loss: 161.1997528076172, Valid_Si-sdr: 3.012190580368042
----------------------------------------------------------------------------------
Epoch 435: val_loss did not improve from 158.20651245117188


Start of epoch 436

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 436, Train_loss: 165.1339111328125, Train_Si-sdr: 2.8263564109802246, Train_KL_loss: 62.04298782348633 
    Valid_loss: 160.61410522460938, Valid_Si-sdr: 3.207573890686035
----------------------------------------------------------------------------------
Epoch 436: val_loss did not improve from 158.20651245117188


Start of epoch 437

----------------------------------------------------------------------------------
Time taken >>>


----------------------------------------------------------------------------------
Time taken >>> 1.02s <<<
epoch: 454, Train_loss: 166.94520568847656, Train_Si-sdr: 2.6031863689422607, Train_KL_loss: 59.16876983642578 
    Valid_loss: 162.97608947753906, Valid_Si-sdr: 2.8616507053375244
----------------------------------------------------------------------------------
Epoch 454: val_loss did not improve from 158.20651245117188


Start of epoch 455

----------------------------------------------------------------------------------
Time taken >>> 1.03s <<<
epoch: 455, Train_loss: 166.80172729492188, Train_Si-sdr: 2.6023969650268555, Train_KL_loss: 60.736793518066406 
    Valid_loss: 162.90084838867188, Valid_Si-sdr: 2.9248194694519043
----------------------------------------------------------------------------------
Epoch 455: val_loss did not improve from 158.20651245117188


Start of epoch 456

----------------------------------------------------------------------------------
Time ta


----------------------------------------------------------------------------------
Time taken >>> 1.12s <<<
epoch: 472, Train_loss: 166.95220947265625, Train_Si-sdr: 2.681650161743164, Train_KL_loss: 60.73919677734375 
    Valid_loss: 162.82717895507812, Valid_Si-sdr: 3.0104362964630127
----------------------------------------------------------------------------------
Epoch 472: val_loss did not improve from 157.59002685546875


Start of epoch 473

----------------------------------------------------------------------------------
Time taken >>> 1.03s <<<
epoch: 473, Train_loss: 164.19375610351562, Train_Si-sdr: 2.918610095977783, Train_KL_loss: 62.565345764160156 
    Valid_loss: 159.91470336914062, Valid_Si-sdr: 3.258347988128662
----------------------------------------------------------------------------------
Epoch 473: val_loss did not improve from 157.59002685546875


Start of epoch 474

----------------------------------------------------------------------------------
Time taken


----------------------------------------------------------------------------------
Time taken >>> 1.05s <<<
epoch: 490, Train_loss: 161.2119903564453, Train_Si-sdr: 2.9891462326049805, Train_KL_loss: 60.38112258911133 
    Valid_loss: 157.60443115234375, Valid_Si-sdr: 3.1865479946136475
----------------------------------------------------------------------------------
Epoch 490: val_loss did not improve from 157.28408813476562


Start of epoch 491

----------------------------------------------------------------------------------
Time taken >>> 1.02s <<<
epoch: 491, Train_loss: 162.59881591796875, Train_Si-sdr: 2.8196637630462646, Train_KL_loss: 59.40234375 
    Valid_loss: 158.87149047851562, Valid_Si-sdr: 3.018937110900879
----------------------------------------------------------------------------------
Epoch 491: val_loss did not improve from 157.28408813476562


Start of epoch 492

----------------------------------------------------------------------------------
Time taken >>> 1


----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 508, Train_loss: 162.1533203125, Train_Si-sdr: 3.046971321105957, Train_KL_loss: 62.03819274902344 
    Valid_loss: 159.0111846923828, Valid_Si-sdr: 3.200991630554199
----------------------------------------------------------------------------------
Epoch 508: val_loss did not improve from 156.6813201904297


Start of epoch 509

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 509, Train_loss: 165.22494506835938, Train_Si-sdr: 2.730090618133545, Train_KL_loss: 60.16706848144531 
    Valid_loss: 161.47344970703125, Valid_Si-sdr: 2.8400681018829346
----------------------------------------------------------------------------------
Epoch 509: val_loss did not improve from 156.6813201904297


Start of epoch 510

----------------------------------------------------------------------------------
Time taken >>> 0.9


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 527, Train_loss: 162.94056701660156, Train_Si-sdr: 2.966662883758545, Train_KL_loss: 61.42108154296875 
    Valid_loss: 160.01065063476562, Valid_Si-sdr: 2.962556838989258
----------------------------------------------------------------------------------
Epoch 527: val_loss did not improve from 156.6080780029297


Start of epoch 528

----------------------------------------------------------------------------------
Time taken >>> 0.99s <<<
epoch: 528, Train_loss: 165.39743041992188, Train_Si-sdr: 2.6423685550689697, Train_KL_loss: 58.92369842529297 
    Valid_loss: 161.3029022216797, Valid_Si-sdr: 2.8261635303497314
----------------------------------------------------------------------------------
Epoch 528: val_loss did not improve from 156.6080780029297


Start of epoch 529

----------------------------------------------------------------------------------
Time taken >>


----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 545, Train_loss: 157.48068237304688, Train_Si-sdr: 3.1532680988311768, Train_KL_loss: 59.5720100402832 
    Valid_loss: 154.5758056640625, Valid_Si-sdr: 3.2737977504730225
----------------------------------------------------------------------------------
Epoch 545: val_loss improved from 154.94232177734375 to 154.5758056640625, saving model to ./CKPT/CKP_ep_545__loss_154.57581_.h5


Start of epoch 546

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 546, Train_loss: 161.62152099609375, Train_Si-sdr: 2.8255813121795654, Train_KL_loss: 58.09868621826172 
    Valid_loss: 159.3902130126953, Valid_Si-sdr: 2.908630132675171
----------------------------------------------------------------------------------
Epoch 546: val_loss did not improve from 154.5758056640625


Start of epoch 547

---------------------------


----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 564, Train_loss: 164.87786865234375, Train_Si-sdr: 2.535874128341675, Train_KL_loss: 57.082908630371094 
    Valid_loss: 160.48934936523438, Valid_Si-sdr: 2.8276891708374023
----------------------------------------------------------------------------------
Epoch 564: val_loss did not improve from 154.5758056640625


Start of epoch 565

----------------------------------------------------------------------------------
Time taken >>> 1.00s <<<
epoch: 565, Train_loss: 161.5125732421875, Train_Si-sdr: 2.7732274532318115, Train_KL_loss: 57.678611755371094 
    Valid_loss: 157.64572143554688, Valid_Si-sdr: 3.1005406379699707
----------------------------------------------------------------------------------
Epoch 565: val_loss did not improve from 154.5758056640625


Start of epoch 566

----------------------------------------------------------------------------------
Time taken


----------------------------------------------------------------------------------
Time taken >>> 1.02s <<<
epoch: 582, Train_loss: 163.84475708007812, Train_Si-sdr: 2.948310375213623, Train_KL_loss: 61.48612976074219 
    Valid_loss: 159.52378845214844, Valid_Si-sdr: 3.2044599056243896
----------------------------------------------------------------------------------
Epoch 582: val_loss did not improve from 154.3397979736328


Start of epoch 583

----------------------------------------------------------------------------------
Time taken >>> 1.01s <<<
epoch: 583, Train_loss: 164.42433166503906, Train_Si-sdr: 2.775447368621826, Train_KL_loss: 61.317874908447266 
    Valid_loss: 161.92909240722656, Valid_Si-sdr: 2.911527156829834
----------------------------------------------------------------------------------
Epoch 583: val_loss did not improve from 154.3397979736328


Start of epoch 584

----------------------------------------------------------------------------------
Time taken >

# 여기는 기존의 .fit() 함수를 사용해서 학습하는 부분임

In [140]:
tf.random.set_seed(42)

latent_size = 1024
epoch = 300

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    model_path = './CKPT/CKP_ep_252__loss_45.39530_.h5'
    
    loss_fun = custom_mse
#     loss_fun = custom_sisdr_loss
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)

    optimizer = keras.optimizers.Adam(learning_rate=1e-4)
    vq_vae.compile(optimizer, loss=loss_fun, metrics=[SiSdr()])
    
    vq_vae(0, True)
    vq_vae.summary()
    
    # 사용 안할 때는 load_model 주석 처리 하자
    vq_vae.load_weights(model_path)
    # ----------------------------------------
    
    tf.executing_eagerly()

history = vq_vae.fit(
    train_dataset,
    epochs=epoch,
    validation_data=valid_dataset,
    shuffle=True,
    callbacks=[checkpoint_cb],
)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_30 (Softmax)         (None, None, 1024)        0         
_________________________________________________________________
encoder (Encoder)            (None, None, 1024)        739232    
_________________________________________________________________
embedding_30 (Embedding)     multiple                  0 (unused)
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           1524641   
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 1024)        0         
Total params: 2,263,873
Trainable params: 2,263,873
Non-trainable params: 0
_________________________________________________________________
Epoch 


Epoch 00036: val_loss did not improve from 50.88303
Epoch 37/300

Epoch 00037: val_loss did not improve from 50.88303
Epoch 38/300

Epoch 00038: val_loss did not improve from 50.88303
Epoch 39/300

Epoch 00039: val_loss did not improve from 50.88303
Epoch 40/300

Epoch 00040: val_loss did not improve from 50.88303
Epoch 41/300

Epoch 00041: val_loss did not improve from 50.88303
Epoch 42/300

Epoch 00042: val_loss did not improve from 50.88303
Epoch 43/300

Epoch 00043: val_loss did not improve from 50.88303
Epoch 44/300

Epoch 00044: val_loss did not improve from 50.88303
Epoch 45/300

Epoch 00045: val_loss did not improve from 50.88303
Epoch 46/300

Epoch 00046: val_loss did not improve from 50.88303
Epoch 47/300

Epoch 00047: val_loss did not improve from 50.88303
Epoch 48/300

Epoch 00048: val_loss did not improve from 50.88303
Epoch 49/300

Epoch 00049: val_loss did not improve from 50.88303
Epoch 50/300

Epoch 00050: val_loss did not improve from 50.88303
Epoch 51/300

Epoch 000


Epoch 00078: val_loss did not improve from 50.88303
Epoch 79/300

Epoch 00079: val_loss did not improve from 50.88303
Epoch 80/300

Epoch 00080: val_loss did not improve from 50.88303
Epoch 81/300

Epoch 00081: val_loss did not improve from 50.88303
Epoch 82/300

Epoch 00082: val_loss did not improve from 50.88303
Epoch 83/300

Epoch 00083: val_loss did not improve from 50.88303
Epoch 84/300

Epoch 00084: val_loss did not improve from 50.88303
Epoch 85/300

Epoch 00085: val_loss did not improve from 50.88303
Epoch 86/300

Epoch 00086: val_loss did not improve from 50.88303
Epoch 87/300

Epoch 00087: val_loss did not improve from 50.88303
Epoch 88/300

Epoch 00088: val_loss did not improve from 50.88303
Epoch 89/300

Epoch 00089: val_loss did not improve from 50.88303
Epoch 90/300

Epoch 00090: val_loss did not improve from 50.88303
Epoch 91/300

Epoch 00091: val_loss did not improve from 50.88303
Epoch 92/300

Epoch 00092: val_loss did not improve from 50.88303
Epoch 93/300

Epoch 000


Epoch 00119: val_loss did not improve from 50.88303
Epoch 120/300

Epoch 00120: val_loss did not improve from 50.88303
Epoch 121/300

Epoch 00121: val_loss did not improve from 50.88303
Epoch 122/300

Epoch 00122: val_loss did not improve from 50.88303
Epoch 123/300

Epoch 00123: val_loss did not improve from 50.88303
Epoch 124/300

Epoch 00124: val_loss did not improve from 50.88303
Epoch 125/300

Epoch 00125: val_loss did not improve from 50.88303
Epoch 126/300

Epoch 00126: val_loss did not improve from 50.88303
Epoch 127/300

Epoch 00127: val_loss did not improve from 50.88303
Epoch 128/300

Epoch 00128: val_loss did not improve from 50.88303
Epoch 129/300

Epoch 00129: val_loss did not improve from 50.88303
Epoch 130/300

Epoch 00130: val_loss did not improve from 50.88303
Epoch 131/300

Epoch 00131: val_loss did not improve from 50.88303
Epoch 132/300

Epoch 00132: val_loss did not improve from 50.88303
Epoch 133/300

Epoch 00133: val_loss did not improve from 50.88303
Epoch 134


Epoch 00161: val_loss did not improve from 50.88303
Epoch 162/300

Epoch 00162: val_loss did not improve from 50.88303
Epoch 163/300

Epoch 00163: val_loss did not improve from 50.88303
Epoch 164/300

Epoch 00164: val_loss did not improve from 50.88303
Epoch 165/300

Epoch 00165: val_loss did not improve from 50.88303
Epoch 166/300

Epoch 00166: val_loss did not improve from 50.88303
Epoch 167/300

Epoch 00167: val_loss did not improve from 50.88303
Epoch 168/300

Epoch 00168: val_loss did not improve from 50.88303
Epoch 169/300

Epoch 00169: val_loss did not improve from 50.88303
Epoch 170/300

Epoch 00170: val_loss did not improve from 50.88303
Epoch 171/300

Epoch 00171: val_loss did not improve from 50.88303
Epoch 172/300

Epoch 00172: val_loss did not improve from 50.88303
Epoch 173/300

Epoch 00173: val_loss did not improve from 50.88303
Epoch 174/300

Epoch 00174: val_loss did not improve from 50.88303
Epoch 175/300

Epoch 00175: val_loss did not improve from 50.88303
Epoch 176


Epoch 00202: val_loss did not improve from 50.88303
Epoch 203/300

Epoch 00203: val_loss did not improve from 50.88303
Epoch 204/300

Epoch 00204: val_loss did not improve from 50.88303
Epoch 205/300

Epoch 00205: val_loss did not improve from 50.88303
Epoch 206/300

Epoch 00206: val_loss did not improve from 50.88303
Epoch 207/300

Epoch 00207: val_loss did not improve from 50.88303
Epoch 208/300

Epoch 00208: val_loss did not improve from 50.88303
Epoch 209/300

Epoch 00209: val_loss did not improve from 50.88303
Epoch 210/300

Epoch 00210: val_loss did not improve from 50.88303
Epoch 211/300

Epoch 00211: val_loss did not improve from 50.88303
Epoch 212/300

Epoch 00212: val_loss did not improve from 50.88303
Epoch 213/300

Epoch 00213: val_loss did not improve from 50.88303
Epoch 214/300

Epoch 00214: val_loss did not improve from 50.88303
Epoch 215/300

Epoch 00215: val_loss did not improve from 50.88303
Epoch 216/300

Epoch 00216: val_loss did not improve from 50.88303
Epoch 217


Epoch 00243: val_loss did not improve from 50.88303
Epoch 244/300

Epoch 00244: val_loss did not improve from 50.88303
Epoch 245/300

Epoch 00245: val_loss did not improve from 50.88303
Epoch 246/300

Epoch 00246: val_loss did not improve from 50.88303
Epoch 247/300

Epoch 00247: val_loss did not improve from 50.88303
Epoch 248/300

Epoch 00248: val_loss did not improve from 50.88303
Epoch 249/300

Epoch 00249: val_loss did not improve from 50.88303
Epoch 250/300

Epoch 00250: val_loss did not improve from 50.88303
Epoch 251/300

Epoch 00251: val_loss did not improve from 50.88303
Epoch 252/300

Epoch 00252: val_loss did not improve from 50.88303
Epoch 253/300

Epoch 00253: val_loss did not improve from 50.88303
Epoch 254/300

Epoch 00254: val_loss did not improve from 50.88303
Epoch 255/300

Epoch 00255: val_loss did not improve from 50.88303
Epoch 256/300

Epoch 00256: val_loss did not improve from 50.88303
Epoch 257/300

Epoch 00257: val_loss did not improve from 50.88303
Epoch 258


Epoch 00285: val_loss did not improve from 50.88303
Epoch 286/300

Epoch 00286: val_loss did not improve from 50.88303
Epoch 287/300

Epoch 00287: val_loss did not improve from 50.88303
Epoch 288/300

Epoch 00288: val_loss did not improve from 50.88303
Epoch 289/300

Epoch 00289: val_loss did not improve from 50.88303
Epoch 290/300

Epoch 00290: val_loss did not improve from 50.88303
Epoch 291/300

Epoch 00291: val_loss did not improve from 50.88303
Epoch 292/300

Epoch 00292: val_loss did not improve from 50.88303
Epoch 293/300

Epoch 00293: val_loss did not improve from 50.88303
Epoch 294/300

Epoch 00294: val_loss did not improve from 50.88303
Epoch 295/300

Epoch 00295: val_loss did not improve from 50.88303
Epoch 296/300

Epoch 00296: val_loss did not improve from 50.88303
Epoch 297/300

Epoch 00297: val_loss did not improve from 50.88303
Epoch 298/300

Epoch 00298: val_loss did not improve from 50.88303
Epoch 299/300

Epoch 00299: val_loss did not improve from 50.88303
Epoch 300

## 2.2. Encoder 부르는 방법, Decoder에 값 넣는 방법

In [50]:
latent_size = 512
epoch = 200
BATCH_SIZE = 2

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    model_path = './CKPT/CKP_ep_283__loss_141.77045_.h5'
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)
    vq_vae(0, True)
    vq_vae.summary()
    
    vq_vae.load_weights(model_path)
    
    # 이렇게 하면, transforer의 input으로 들어가는 one-hot 형식의 값을 얻을 수 있음
    for inputs, label in train_dataset:
        encode = vq_vae.encoder(inputs).numpy()
        encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
    
    # 이렇게 하면, transformer의 output을 vq-vae의 decoder 입력으로 넣을 수 있음
    for inputs, label in train_dataset:
        encode = vq_vae.encoder(inputs).numpy()
        encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
        
        # 이렇게 이전 layer의 출렫을 넣으면 됨
        decode = vq_vae.decoder(encode_onehot).numpy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_25 (Softmax)         (None, None, 512)         0         
_________________________________________________________________
encoder (Encoder)            (None, None, 512)         517248    
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           516737    
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 512)         0         
Total params: 1,033,985
Trainable params: 1,033,985
Non-trainable params: 0
_________________________________________________________________


# 3. Test Model

In [110]:
def mkdir_p(path):
    """ Creates a path recursively without throwing an error if it already exists
    :param path: path to create
    :return: None
    """
    if not os.path.exists(path):
        os.makedirs(path)

In [111]:
mkdir_p('./test_wav/') # Result wav 폴더 만드는 코드

In [112]:
def audiowrite(data, path, samplerate=16000, normalize=False, threaded=True):
    """ Write the audio data ``data`` to the wav file ``path``
    The file can be written in a threaded mode. In this case, the writing
    process will be started at a separate thread. Consequently, the file will
    not be written when this function exits.
    :param data: A numpy array with the audio data
    :param path: The wav file the data should be written to
    :param samplerate: Samplerate of the audio data
    :param normalize: Normalize the audio first so that the values are within
        the range of [INTMIN, INTMAX]. E.g. no clipping occurs
    :param threaded: If true, the write process will be started as a separate
        thread
    :return: The number of clipped samples
    """
    data = data.copy()
    int16_max = np.iinfo(np.int16).max
    int16_min = np.iinfo(np.int16).min

    if normalize:
        if not data.dtype.kind == 'f':
            data = data.astype(np.float)
        data /= np.max(np.abs(data))

    if data.dtype.kind == 'f':
        data *= int16_max

    sample_to_clip = np.sum(data > int16_max)
    if sample_to_clip > 0:
        print('Warning, clipping {} samples'.format(sample_to_clip))
    data = np.clip(data, int16_min, int16_max)
    data = data.astype(np.int16)

    if threaded:
        threading.Thread(target=wav_write, args=(path, samplerate, data)).start()
    else:
        wav_write(path, samplerate, data)

    return sample_to_clip

In [115]:
with tf.device('/cpu:0'):
    latent_size = 1024
    sample_rate = 8000
    model_path = './CKPT/CKP_ep_299__loss_70.91695_.h5'
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=True)
    vq_vae(0, True)
    vq_vae.summary()
    vq_vae.load_weights(model_path)

    for batch in test_dataset:
        input_batch, length_batch, name = batch

        result = vq_vae.predict(input_batch)
        
        wav_name = './test_wav/' + name[0][:-5] + '_s1.wav'
        audiowrite(result[0], wav_name, sample_rate, True, True)

Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_20 (Softmax)         (None, None, 1024)        0         
_________________________________________________________________
encoder (Encoder)            (None, None, 1024)        739232    
_________________________________________________________________
embedding_20 (Embedding)     multiple                  0 (unused)
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           1524641   
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 1024)        0         
Total params: 2,263,873
Trainable params: 2,263,873
Non-trainable params: 0
_________________________________________________________________


# 여기 밑에는 연습장임

In [63]:
import numpy as np
from keras.models import Sequential
from keras.layers import Embedding

model = Sequential()
model.add(layers.Conv1D(filters=64, kernel_size=3, activation='relu', padding='same', input_shape=(None, 1)))
model.add(layers.Conv1D(filters=512, kernel_size=3, padding='same'))

input_array = np.random.randn(2, 10, 1)
input_array2 = np.random.randn(2, 9, 1)
with tf.device('/cpu:0'):
    model.compile('rmsprop', 'mse')

    output_array = model.predict(input_array)
    output_array2 = model.predict(input_array2)

In [64]:
tf.reduce_sum(output_array, axis=[1, 2])

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-3.2364433, -5.300486 ], dtype=float32)>

In [65]:
print(output_array)
print(output_array.shape)
print(output_array2.shape)

[[[ 0.00213541  0.01945671 -0.01550578 ... -0.00601448  0.01089236
    0.03509172]
  [-0.00979825 -0.01161936 -0.03782888 ... -0.01964907  0.02604836
   -0.01746842]
  [-0.01680236 -0.04395067  0.00300831 ...  0.02191263  0.02651824
   -0.0194194 ]
  ...
  [ 0.03252641 -0.03450746 -0.01116706 ...  0.04157236  0.00767434
    0.00222839]
  [ 0.03689077 -0.00653183  0.0138604  ...  0.03399033  0.01157293
   -0.01727439]
  [ 0.00956812  0.00808328 -0.00985544 ...  0.02565406  0.01499317
    0.0055492 ]]

 [[ 0.00062903 -0.02250458 -0.01209254 ... -0.00480657  0.02518445
    0.00395943]
  [ 0.00067916 -0.03764667 -0.00435403 ... -0.00190964  0.0198026
   -0.01638822]
  [ 0.03713017 -0.02867446 -0.01009638 ...  0.03275302  0.02717083
    0.00666822]
  ...
  [ 0.01911856  0.02970011 -0.00635595 ... -0.01212452  0.03101327
   -0.01940764]
  [-0.04032427 -0.04538781 -0.00670558 ...  0.00985033  0.0156134
   -0.02115321]
  [ 0.0188317  -0.04442241  0.00422117 ...  0.03525497 -0.0081987
   -0.036

In [69]:
one_hot = tf.math.argmax(tf.cast(tf.equal(output_array, tf.math.reduce_max(output_array, 2, keepdims=True)), output_array.dtype), axis=-1)
print(one_hot.shape)
print(one_hot)
# layers.Embedding(512, 512)(one_hot)

(2, 10)
tf.Tensor(
[[228 249 283 206  20 206 435  32 270  30]
 [428 206  20 244 357 289 324 249 498 134]], shape=(2, 10), dtype=int64)


In [70]:
one_hot = tf.math.argmax(output_array, axis=-1)
print(one_hot.shape)
print(one_hot)

(2, 10)
tf.Tensor(
[[228 249 283 206  20 206 435  32 270  30]
 [428 206  20 244 357 289 324 249 498 134]], shape=(2, 10), dtype=int64)


In [50]:
qy = tf.nn.softmax(output_array)
log_qy = tf.math.log(qy + 1e-10)
log_uniform = qy * (log_qy - tf.math.log(1.0 / 512))
kl_loss = tf.reduce_sum(log_uniform, axis=[1, 2])
kl_loss = tf.reduce_mean(kl_loss)

print(kl_loss)

tf.Tensor(0.012074914, shape=(), dtype=float32)


In [30]:
batch_size = tf.shape(output_array)[0]
array1_size = tf.shape(output_array)[1]
array2_size = tf.shape(output_array2)[1]
feature_size = tf.shape(output_array)[-1]

if array1_size < array2_size:
#     append_size = array1_size - array2_size
#     append_zeros = tf.zeros([batch_size, append_size, feature_size])
#     append_zeros = tf.Variable(initial_value=tf.zeros((batch_size, append_size, feature_size)))
#     output_array2 = tf.concat([output_array2, append_zeros], axis=1)
    output_array2 = tf.slice(output_array2, [0, 0, 0], [-1, array1_size, -1])
elif array1_size > array2_size:
#     append_size = array2_size - array1_size
#     append_zeros = tf.zeros([batch_size, append_size, feature_size])
#     append_zeros = tf.Variable(initial_value=tf.zeros((batch_size, append_size, feature_size)))
#     output_array = tf.concat([output_array, append_zeros], axis=1)
    output_array = tf.slice(output_array, [0, 0, 0], [-1, array2_size, -1])

print(output_array.shape)
print(output_array2.shape)
# output_array0 = output_array[1]
# output_array20 = output_array2[1]
# target = np.sum(output_array20 * output_array0) * output_array0 / np.square(np.linalg.norm(output_array0, ord=2))
# noise = output_array20 - target
# npnp = 10 * np.log10(np.square(np.linalg.norm(target, ord=2)) / np.square(np.linalg.norm(noise, ord=2)))
# print(npnp)

target = tf.linalg.matmul(output_array2, output_array, transpose_a=True) * output_array / tf.expand_dims(tf.experimental.numpy.square(tf.norm(output_array, axis=1)), axis=-1)
noise = output_array2 - target
si_sdr = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
si_sdr = tf.reduce_mean(si_sdr)
print(si_sdr)

(2, 9, 1)
(2, 9, 1)
tf.Tensor(2.8309882, shape=(), dtype=float32)


In [20]:
tf.cast(tf.equal(output_array, tf.math.reduce_max(output_array, 2, keepdims=True)), output_array.dtype)

<tf.Tensor: shape=(2, 9, 4), dtype=float32, numpy=
array([[[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.]],

       [[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]], dtype=float32)>

In [21]:
output_array

<tf.Tensor: shape=(2, 9, 4), dtype=float32, numpy=
array([[[-0.03009652, -0.03612775, -0.06680483, -0.03670201],
        [-0.04768711, -0.12344762, -0.03924457, -0.11762322],
        [ 0.01808495, -0.16106637, -0.19467078, -0.15282159],
        [-0.0986427 , -0.08625205, -0.12661007, -0.16366175],
        [-0.09758376, -0.08886974, -0.0433558 , -0.19985165],
        [-0.06933096, -0.03154394, -0.13725929, -0.20143284],
        [ 0.03375649,  0.00182091, -0.01022564, -0.35924646],
        [-0.01645333, -0.10466891, -0.13975918, -0.12066491],
        [-0.13588801, -0.08173112, -0.00253745, -0.28615874]],

       [[ 0.04865369, -0.02880372, -0.06414615, -0.07730438],
        [-0.08225074, -0.03192509, -0.06216412, -0.08035193],
        [-0.09515338,  0.04221668,  0.14230826, -0.23082384],
        [-0.00094383,  0.05597762, -0.09290768, -0.08630683],
        [-0.09894791, -0.04727853, -0.01004983, -0.30325216],
        [ 0.01705559, -0.16948727, -0.08829505, -0.16453639],
        [-0.07230

In [67]:
output_softmax = tf.nn.softmax(output_array)
output_softmax

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0.24628553, 0.282701  , 0.2385324 , 0.23248109],
        [0.2298986 , 0.23856457, 0.25392184, 0.27761498],
        [0.2101039 , 0.2444843 , 0.26824066, 0.27717113],
        [0.22202027, 0.29728216, 0.2501223 , 0.23057525],
        [0.24084595, 0.2724257 , 0.25000373, 0.23672463]],

       [[0.23988546, 0.27957046, 0.23915865, 0.24138539],
        [0.24975686, 0.27135593, 0.24479878, 0.23408844],
        [0.24278015, 0.26340333, 0.24202275, 0.25179377],
        [0.23139507, 0.262904  , 0.25970972, 0.2459912 ],
        [0.2601803 , 0.25627998, 0.24524413, 0.23829558]]], dtype=float32)>

In [73]:
output_reshape = tf.reshape(output_softmax, [-1, 4])
output_reshape.shape

TensorShape([10, 4])

In [83]:
tf.reshape(tf.nn.softmax(output_array), [-1, 5, 4])

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0.24628553, 0.282701  , 0.2385324 , 0.23248109],
        [0.2298986 , 0.23856457, 0.25392184, 0.27761498],
        [0.2101039 , 0.2444843 , 0.26824066, 0.27717113],
        [0.22202027, 0.29728216, 0.2501223 , 0.23057525],
        [0.24084595, 0.2724257 , 0.25000373, 0.23672463]],

       [[0.23988546, 0.27957046, 0.23915865, 0.24138539],
        [0.24975686, 0.27135593, 0.24479878, 0.23408844],
        [0.24278015, 0.26340333, 0.24202275, 0.25179377],
        [0.23139507, 0.262904  , 0.25970972, 0.2459912 ],
        [0.2601803 , 0.25627998, 0.24524413, 0.23829558]]], dtype=float32)>

In [76]:
# tf.cast(tf.equal(y, tf.reduce_max(y,1,keep_dims=True)), y.dtype)
output_hard = tf.cast(tf.equal(output_reshape, tf.math.reduce_max(output_reshape, 1, keepdims=True)), output_softmax.dtype)
output_hard

<tf.Tensor: shape=(10, 4), dtype=float32, numpy=
array([[0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.]], dtype=float32)>

In [81]:
tf.reshape(output_hard, [-1, 5, 4])

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.]],

       [[0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.]]], dtype=float32)>

In [322]:
tf.transpose(output_array, perm=[0, 2, 1]).shape

TensorShape([2, 4, 5])

In [316]:
layers.Softmax(output_array)

<tensorflow.python.keras.layers.advanced_activations.Softmax at 0x26f14e934c8>

In [317]:
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).shape

(3, 3)

In [65]:
np.log(10)

2.302585092994046