# 1. Data Generator
- Raw Data를 읽어옴
- 여기서 만들어진 데이터는 모델의 입력으로 들어감

In [1]:
import os
import numpy as np
# import librosa
from tensorflow.keras.utils import Sequence

import scipy.signal
from scipy.io import wavfile

In [2]:
class RawForVAEGenerator(Sequence):
    def __init__(self, source, wav_dir, files, sourNum='s1', batch_size=10, shuffle=True):
        self.source = source
        self.wav_dir = wav_dir
        self.files = files
        self.sourNum = sourNum
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()
        
        self.sample_rate = 8000
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.source))
        
        if self.shuffle:
            np.random.shuffle(self.indexes)
    
    def __audioread__(self, path, offset=0.0, duration=None, sample_rate=16000):
#         signal = librosa.load(path, sr=self.sample_rate, mono=False, offset=offset, duration=duration)
        
#         return signal[0]
    
        signal_rate, signal = wavfile.read(path)
        number_of_samples = round(len(signal) * float(self.sample_rate) / signal_rate)
        signal = scipy.signal.resample(signal, number_of_samples)
        signal /= np.max(np.abs(signal),axis=0)

        return signal
    
    def __padding__(self, data):
        n_batch = len(data)
        max_len = max([d.shape[0] for d in data])
        extrapadding = int(np.ceil(max_len / self.sample_rate) * self.sample_rate)
        pad = np.zeros((n_batch, extrapadding))
        
        for i in range(n_batch):
            pad[i, :data[i].shape[0]] = data[i]
        
        return np.expand_dims(pad, -1)
        
    def __data_generation__(self, source_list):
        wav_list = []
        for name in source_list:
            name = name.strip('\n')
            
            s_wav_name = self.wav_dir + self.files + '/' + self.sourNum + '/' + name
            
            # ------- AUDIO READ -------
            s_wav = (self.__audioread__(s_wav_name,  offset=0.0, duration=None, sample_rate=self.sample_rate))
            # --------------------------
            
            # ------- PADDING -------
#             pad_len = max(len(samples1),len(samples2))
#             pad_s1 = np.concatenate([s1_wav, np.zeros([pad_len - len(s1_wav)])])
            
#             extrapadding = ceil(len(pad_s1) / sample_rate) * sample_rate - len(pad_s1)
#             pad_s1 = np.concatenate([pad_s1, np.zeros([extrapadding - len(pad_s1)])])
#             pad_s2 = np.concatenate([s2_wav, np.zeros([extrapadding - len(s2_wav)])])
            # -----------------------
            
            wav_list.append(s_wav)
        
        return wav_list, wav_list, source_list
            
    
    def __len__(self):
        return int(np.floor(len(self.source) / self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        source_list = [self.source[k] for k in indexes]
        
        if self.files is not 'tt':
            sour, labels, _ = self.__data_generation__(source_list)
            
            # Get Lengths(K value of each batch)
            lengths = np.array([m.shape[0] for m in sour])
            exp = np.expand_dims(lengths, 1)
            exp = np.expand_dims(exp, -1) # [Batch, 1, 1] (length)
            
            # Padding
            sour_pad = self.__padding__(sour) # [Batch, Time_step, Dimension(=1)]
            label_pad = self.__padding__(labels) # [Batch, Time_step, Dimension(=1)]
            
#             return sour_pad, np.concatenate([label_pad, exp], axis=1)
            return sour_pad, label_pad
        else:
            sour, labels, name = self.__data_generation__(source_list)
            
            # Get Lengths(K value of each batch)
            lengths = np.array([m.shape[0] for m in sour])
            exp = np.expand_dims(lengths, 1)
            exp = np.expand_dims(exp, -1) # [Batch, 1, 1] (length)
            
            # Padding
            sour_pad = self.__padding__(sour) # [Batch, Time_step, Dimension(=1)]
            
            return sour_pad, exp, name

## Data를 어떻게 읽는지에 대한 부분

In [3]:
WAV_DIR = './mycode/wsj0_2mix/use_this/'
LIST_DIR = './mycode/wsj0_2mix/use_this/lists/'

In [4]:
# Directory List file create

wav_dir = WAV_DIR
output_lst = LIST_DIR

for folder in ['tr', 'cv', 'tt']:
    wav_files = os.listdir(wav_dir + folder + '/mix')
    output_lst_files = output_lst + folder + '_wav.lst'
    with open(output_lst_files, 'w') as f:
        for file in wav_files:
            f.write(file + "\n")

print("Generate wav file to .lst done!")

Generate wav file to .lst done!


In [5]:
batch_size_for_generator = 1

train_dataset = 0
valid_dataset = 0
test_dataset = 0

name_list = []
for files in ['tr', 'cv', 'tt']:
    # --- Lead lst file ---""
    output_lst_files = LIST_DIR + files + '_wav.lst'
    fid = open(output_lst_files, 'r')
    lines = fid.readlines()
    fid.close()
    # ---------------------
    
    if files == 'tr':
        train_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', batch_size_for_generator)
    elif files == 'cv':
        valid_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', batch_size_for_generator)
    else:
        test_batch = 1
        test_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', test_batch)

# 2. Building VQ-VAE model with Gumbel Softmax

In [6]:
import threading
from scipy.io.wavfile import write as wav_write
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from tensorflow.keras import backend as Kb
import numpy as np
import pandas as pd
from importlib import reload
import time
from tensorflow.keras.models import Model, Sequential, load_model

In [7]:
def mkdir_p(path):
    """ Creates a path recursively without throwing an error if it already exists
    :param path: path to create
    :return: None
    """
    if not os.path.exists(path):
        os.makedirs(path)

In [8]:
mkdir_p('./CKPT/') # model check point 폴더 만드는 코드
filepath = "./CKPT/CKP_ep_{epoch:d}__loss_{val_loss:.5f}_.h5"

In [9]:
initial_learning_rate = 0.001

# learning rate를 점점 줄이는 부분
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
)

# validation loss에 대해서 좋은 것만 저장됨
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True, mode='min'
)

# early stop 하는 부분인데, validation loss에 대해서 제일 좋은 모델이 저장됨
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', mode='min', verbose=1, patience=50, restore_best_weights=True
)

In [10]:
class GumbelSoftmax(layers.Layer):
    def __init__(self, temperature=0.5, hard=False, name = 'gumbel_softmax',**kwargs):
        super(GumbelSoftmax, self).__init__(name=name, **kwargs)
        
        self.temperature = temperature
        self.hard = hard
    
    def sample_gumbel(self, shape, eps=1e-20): 
        """Sample from Gumbel(0, 1)"""
        U = tf.random.uniform(shape,minval=0,maxval=1)
        
        return -tf.math.log(-tf.math.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature): 
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        
        return tf.nn.softmax(y / temperature)

    def call(self, inputs):
        y = self.gumbel_softmax_sample(inputs, self.temperature)
        
        if self.hard:
            y_hard = tf.cast(tf.equal(y, tf.math.reduce_max(y, 2, keepdims=True)), y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        
        return y


class Encoder(layers.Layer):
    def __init__(self, latent_dim, name = 'encoder',**kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        
        self.conv1d_1 = layers.Conv1D(filters=32, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_2 = layers.Conv1D(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_3 = layers.Conv1D(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_4 = layers.Conv1D(filters=256, kernel_size=4, strides=2, activation='relu', padding='same')
        self.conv1d_5 = layers.Conv1D(filters=512, kernel_size=4, strides=2, activation='relu', padding='same')
        self.logit = layers.Conv1D(filters=latent_dim, kernel_size=1, strides=1, activation=None, padding='valid')
    
    def call(self, inputs):
        x = self.conv1d_1(inputs)
        x = self.conv1d_2(x)
        x = self.conv1d_3(x)
        x = self.conv1d_4(x)
        x = self.conv1d_5(x)
        logit = self.logit(x)
        
        return logit


class Decoder(layers.Layer):
    def __init__(self, latent_dim, name = 'decoder',**kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        
        self.trans_conv1d_1 = layers.Conv1DTranspose(filters=512, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_2 = layers.Conv1DTranspose(filters=256, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_3 = layers.Conv1DTranspose(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_4 = layers.Conv1DTranspose(filters=128, kernel_size=4, strides=2, activation='relu', padding='same')
        self.trans_conv1d_5 = layers.Conv1DTranspose(filters=32, kernel_size=4, strides=2, activation='relu', padding='same')
        self.logit = layers.Conv1DTranspose(filters=1, kernel_size=1, strides=1, activation=None, padding='valid')
    
    def call(self, inputs):
        x = self.trans_conv1d_1(inputs)
        x = self.trans_conv1d_2(x)
        x = self.trans_conv1d_3(x)
        x = self.trans_conv1d_4(x)
        x = self.trans_conv1d_5(x)
        logit = self.logit(x)
        
        return logit

In [11]:
# Custom Metric Si-sdr

class SiSdr(keras.metrics.Metric):
    def __init__(self, name="Si-sdr", **kwargs):
        super(SiSdr, self).__init__(name=name, **kwargs)
        self.sdr = self.add_weight(name="sdr", initializer="zeros")
        self.count = self.add_weight(name="cnt", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        ori_length = tf.shape(y_true)[1]
        
        # Label & Length divide
        labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 1]
        lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]
        
        # Check sequence length
        batch_size = tf.shape(labels)[0]
        label_size = tf.shape(labels)[1]
        pred_size = tf.shape(y_pred)[1]
        feature_size = tf.shape(labels)[-1]
        
        # Change sequence length
        if label_size < pred_size:
            y_pred = tf.slice(y_pred, [0, 0, 0], [-1, label_size, -1])
        elif label_size > pred_size:
            labels = tf.slice(labels, [0, 0, 0], [-1, pred_size, -1])

        # SI-SDR
        target = tf.linalg.matmul(y_pred, labels, transpose_a=True) * labels / tf.expand_dims(tf.experimental.numpy.square(tf.norm(labels, axis=1)), axis=-1)
        noise = y_pred - target
        values = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
        
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.sdr.assign_add(tf.reduce_sum(values))
        self.count.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))

    def result(self):
        return self.sdr / self.count

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.sdr.assign(0.0)
        self.count.assign(0.0)

In [12]:
# Custom loss

# Custom mse
def custom_mse(y_true, y_pred):
    ori_length = tf.shape(y_true)[1]

    # Label & Length divide
    labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 129]
    lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]

    loss = tf.reduce_sum(tf.pow(y_pred - y_true, 2), axis=[1, 2])
    loss = tf.reduce_mean(loss)

    return loss


# Custom si-sdr loss
def custom_sisdr_loss(y_true, y_pred):
    ori_length = tf.shape(y_true)[1]

    # Label & Length divide
    labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 1]
    lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]

    target = tf.linalg.matmul(y_pred, labels, transpose_a=True) * labels / tf.expand_dims(tf.experimental.numpy.square(tf.norm(labels, axis=1)), axis=-1)
    noise = y_pred - target
    si_sdr = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
    si_sdr = tf.reduce_mean(si_sdr) * -1

    return si_sdr

In [15]:
class Vq_vae(keras.Model):
    def __init__(self, latent_dim, gumbel_hard=False, for_predict=False, name='vqvae', **kwargs):
        super(Vq_vae, self).__init__(name=name, **kwargs)
        
        self.for_predict = for_predict
        
        self.latent_dim = latent_dim
        self.softmax = layers.Softmax(-1)
        
        self.encoder = Encoder(latent_dim)
        self.gumbel = GumbelSoftmax(hard=gumbel_hard)
        self.sampled = layers.experimental.EinsumDense('bsc,cd->bsd',
                                                       output_shape=(None, latent_dim),
                                                       bias_axes='d')
        self.decoder = Decoder(latent_dim)
        
    def call(self, inputs, load=False):
        if load:
            inputs = layers.Input(shape=(None, 1))
        
        
        if self.for_predict:
            encode = self.encoder(inputs)
            one_hot_enc = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
            sample = self.sampled(one_hot_enc)
            decode = self.decoder(sample)
        else:
            encode = self.encoder(inputs)
            gumbel = self.gumbel(encode)
            sample = self.sampled(gumbel)
            decode = self.decoder(sample)
        
        # ------------------ KL loss ------------------
        qy = self.softmax(encode)
        log_qy = tf.math.log(qy + 1e-10)
        log_uniform = qy * (log_qy - tf.math.log(1.0 / self.latent_dim))
        kl_loss = tf.reduce_sum(log_uniform, axis=[1, 2])
        kl_loss = tf.reduce_mean(kl_loss) * 0.2
        # ---------------------------------------------
        
        self.add_loss(kl_loss)
        
        return decode

# 이렇게 GradientTape 를 사용해서 프로그램 해도 됨

In [23]:
tf.random.set_seed(42)

latent_size = 1024
epochs = 3

filePath = "./CKPT/CKP_ep_{0}__loss_{1:.5f}_.h5"
model_path = './CKPT/CKP_ep_576__loss_168.75122_.h5'

loss_fun = custom_mse
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

train_loss = tf.keras.metrics.Mean()
train_kl_loss = tf.keras.metrics.Mean()
valid_loss = tf.keras.metrics.Mean()
sisdr_Metric = SiSdr()
val_sisdr_Metric = SiSdr()

# Model 불러오는 부분이다
vq_vae = Vq_vae(latent_size, gumbel_hard=False)
vq_vae(0, True)
# vq_vae.load_weights(model_path)

<KerasTensor: shape=(None, None, 1) dtype=float32 (created by layer 'decoder')>

In [31]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        # Call model
        results = vq_vae(x)
        
        loss_value = loss_fun(y, results)
        loss_value += sum(vq_vae.losses) # Add KL loss
    
    # Update weights
    grads = tape.gradient(loss_value, vq_vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vq_vae.trainable_weights))
    
    # Update loss and si-sdr
    train_loss.update_state(loss_value)
    sisdr_Metric.update_state(y, results)
    
    train_kl_loss.update_state(sum(vq_vae.losses))
    
    return loss_value

@tf.function
def test_step(x, y):
    # Call model
    val_results = vq_vae(x)
    
    val_loss_value = loss_fun(y, val_results)
    val_loss_value += sum(vq_vae.losses) # Add KL loss
    
    # Update loss and si-sdr
    valid_loss.update_state(val_loss_value)
    val_sisdr_Metric.update_state(y, val_results)
    
    return val_loss_value

In [35]:
previous_loss = float('inf')

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch+1,))
    start_time = time.time()

    # Iterate over the batches of the dataset
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        x_batch_train = tf.cast(x_batch_train, dtype=tf.float32)
        y_batch_train = tf.cast(y_batch_train, dtype=tf.float32)
        
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 1 batches
#         if step % 1 == 0:
#             print("Training loss (for one batch) at step %d: %.4f" % (step, train_loss.result()))
#             print("Training Si-sdr (for one batch) at step %d: %.4f" % (step, sisdr_Metric.result()))
#             print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Run a validation loop at the end of each epoch
    for x_batch_val, y_batch_val in valid_dataset:
        x_batch_val = tf.cast(x_batch_val, dtype=tf.float32)
        y_batch_val = tf.cast(y_batch_val, dtype=tf.float32)
        
        val_loss_value = test_step(x_batch_val, y_batch_val)
    
    print()
    print('----------------------------------------------------------------------------------')
    print("Time taken >>> %.2fs <<<" % (time.time() - start_time))
    print('epoch: {}, Train_loss: {}, Train_Si-sdr: {}, Train_KL_loss: {} \n\
    Valid_loss: {}, Valid_Si-sdr: {}'.format(
        epoch+1,
        train_loss.result(),
        sisdr_Metric.result(),
        train_kl_loss.result(),
        valid_loss.result(),
        val_sisdr_Metric.result()))
    print('----------------------------------------------------------------------------------')
    
    # Save Model
    if valid_loss.result() < previous_loss:
        filePath_temp = filePath.format(epoch+1, valid_loss.result())
        
        vq_vae.save_weights(filePath_temp)
        print('Epoch {}: val_loss improved from {} to {}, saving model to {}'.format(
            epoch+1,
            previous_loss,
            valid_loss.result(),
            filePath_temp))
        
        previous_loss = valid_loss.result()
    else:
        print('Epoch {}: val_loss did not improve from {}'.format(
            epoch+1,
            previous_loss))
    print()
    
    # Reset metrics at the end of each epoch
    train_loss.reset_states()
    sisdr_Metric.reset_states()
    valid_loss.reset_states()
    val_sisdr_Metric.reset_states()
    
    train_kl_loss.reset_states()
    
    # Data shuffle at the end of each epoch
    train_dataset.on_epoch_end()
    valid_dataset.on_epoch_end()


Start of epoch 1
(2, 80000, 1)
(2, 64000, 1)

----------------------------------------------------------------------------------
Time taken >>> 1.70s <<<
epoch: 1, Train_loss: 468.0998840332031, Train_Si-sdr: -38.02525329589844, Train_KL_loss: 0.0 
    Valid_loss: 468.36773681640625, Valid_Si-sdr: -40.73042678833008
----------------------------------------------------------------------------------
Epoch 1: val_loss improved from inf to 468.36773681640625, saving model to ./CKPT/CKP_ep_1__loss_468.36774_.h5


Start of epoch 2
(2, 56000, 1)
(2, 80000, 1)

----------------------------------------------------------------------------------
Time taken >>> 1.63s <<<
epoch: 2, Train_loss: 468.0527648925781, Train_Si-sdr: -34.70179748535156, Train_KL_loss: 0.0 
    Valid_loss: 467.18829345703125, Valid_Si-sdr: -24.99163246154785
----------------------------------------------------------------------------------
Epoch 2: val_loss improved from 468.36773681640625 to 467.18829345703125, saving mod

In [14]:
def gen_train_data_generator():
    for i in range(train_dataset.__len__()):
        data = np.squeeze(train_dataset.__getitem__(i)[0], axis=0)
        label = np.squeeze(train_dataset.__getitem__(i)[1], axis=0)
        
        yield (data, label)

def gen_valid_data_generator():
    for i in range(valid_dataset.__len__()):
        data = np.squeeze(valid_dataset.__getitem__(i)[0], axis=0)
        label = np.squeeze(valid_dataset.__getitem__(i)[1], axis=0)
        
        yield (data, label)

# 여기는 기존의 .fit() 함수를 사용해서 학습하는 부분임

In [20]:
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()

tf.random.set_seed(42)

batch_size = 10
latent_size = 1024
epoch = 600

strategy = tf.distribute.MirroredStrategy(['/gpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    model_path = './CKPT/CKP_ep_594__loss_229.89435_.h5'
    
    loss_fun = custom_mse
#     loss_fun = custom_sisdr_loss
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)

    optimizer = keras.optimizers.Adam(learning_rate=1e-4)
    vq_vae.compile(optimizer, loss=loss_fun, metrics=[SiSdr()])
    
    vq_vae(0, True)
    vq_vae.summary()
    
    # 사용 안할 때는 load_model 주석 처리 하자
    vq_vae.load_weights(model_path)
    # ----------------------------------------
    
    tf.executing_eagerly()

    
train_data = tf.data.Dataset.from_generator(gen_train_data_generator, output_signature=(
                                            tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
                                            tf.TensorSpec(shape=(None, 1), dtype=tf.float32)))
train_data = train_data.shuffle(train_dataset.__len__()).padded_batch(batch_size)

val_data = tf.data.Dataset.from_generator(gen_valid_data_generator, output_signature=(
                                            tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
                                            tf.TensorSpec(shape=(None, 1), dtype=tf.float32)))
val_data = val_data.padded_batch(batch_size)

# Disable AutoShard.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
train_data = train_data.with_options(options)
val_data = val_data.with_options(options)
    
history = vq_vae.fit(
    train_data,
    epochs=epoch,
    validation_data=val_data,
    shuffle=True,
    callbacks=[checkpoint_cb],
)

# history = vq_vae.fit_generator(
#     generator=train_dataset,
#     validation_data=valid_dataset,
#     epochs=epoch,
#     use_multiprocessing=False,
#     shuffle=True,
#     callbacks=[checkpoint_cb],
# )

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 softmax_4 (Softmax)         (None, None, 1024)        0         
                                                                 
 encoder (Encoder)           (None, None, 1024)        1263776   
                                                                 
 gumbel_softmax (GumbelSoftm  (None, None, 1024)       0         
 ax)                                                             
                                                                 
 einsum_dense_4 (EinsumDense  (None, None, 1024)       1049600   
 )                                                               
                                                                 
 decoder (Decoder)           (None, None, 1)           2835521   
                

Epoch 25/600
Epoch 00025: val_loss did not improve from 229.88667
Epoch 26/600
Epoch 00026: val_loss did not improve from 229.88667
Epoch 27/600
Epoch 00027: val_loss did not improve from 229.88667
Epoch 28/600
Epoch 00028: val_loss did not improve from 229.88667
Epoch 29/600
Epoch 00029: val_loss did not improve from 229.88667
Epoch 30/600
Epoch 00030: val_loss did not improve from 229.88667
Epoch 31/600
Epoch 00031: val_loss did not improve from 229.88667
Epoch 32/600
Epoch 00032: val_loss did not improve from 229.88667
Epoch 33/600
Epoch 00033: val_loss did not improve from 229.88667
Epoch 34/600
Epoch 00034: val_loss did not improve from 229.88667
Epoch 35/600
Epoch 00035: val_loss did not improve from 229.88667
Epoch 36/600
Epoch 00036: val_loss did not improve from 229.88667
Epoch 37/600
Epoch 00037: val_loss did not improve from 229.88667
Epoch 38/600
Epoch 00038: val_loss did not improve from 229.88667
Epoch 39/600
Epoch 00039: val_loss did not improve from 229.88667
Epoch 40/6

Epoch 54/600
Epoch 00054: val_loss did not improve from 229.72289
Epoch 55/600
Epoch 00055: val_loss did not improve from 229.72289
Epoch 56/600
Epoch 00056: val_loss did not improve from 229.72289
Epoch 57/600
Epoch 00057: val_loss did not improve from 229.72289
Epoch 58/600
Epoch 00058: val_loss did not improve from 229.72289
Epoch 59/600
Epoch 00059: val_loss did not improve from 229.72289
Epoch 60/600
Epoch 00060: val_loss did not improve from 229.72289
Epoch 61/600
Epoch 00061: val_loss did not improve from 229.72289
Epoch 62/600
Epoch 00062: val_loss improved from 229.72289 to 228.70078, saving model to ./CKPT\CKP_ep_62__loss_228.70078_.h5
Epoch 63/600
Epoch 00063: val_loss did not improve from 228.70078
Epoch 64/600
Epoch 00064: val_loss did not improve from 228.70078
Epoch 65/600
Epoch 00065: val_loss did not improve from 228.70078
Epoch 66/600
Epoch 00066: val_loss did not improve from 228.70078
Epoch 67/600
Epoch 00067: val_loss did not improve from 228.70078
Epoch 68/600
Epo

Epoch 83/600
Epoch 00083: val_loss did not improve from 228.70078
Epoch 84/600
Epoch 00084: val_loss did not improve from 228.70078
Epoch 85/600
Epoch 00085: val_loss did not improve from 228.70078
Epoch 86/600
Epoch 00086: val_loss did not improve from 228.70078
Epoch 87/600
Epoch 00087: val_loss did not improve from 228.70078
Epoch 88/600
Epoch 00088: val_loss did not improve from 228.70078
Epoch 89/600
Epoch 00089: val_loss did not improve from 228.70078
Epoch 90/600
Epoch 00090: val_loss did not improve from 228.70078
Epoch 91/600
Epoch 00091: val_loss did not improve from 228.70078
Epoch 92/600
Epoch 00092: val_loss did not improve from 228.70078
Epoch 93/600
Epoch 00093: val_loss did not improve from 228.70078
Epoch 94/600
Epoch 00094: val_loss did not improve from 228.70078
Epoch 95/600
Epoch 00095: val_loss did not improve from 228.70078
Epoch 96/600
Epoch 00096: val_loss did not improve from 228.70078
Epoch 97/600
Epoch 00097: val_loss did not improve from 228.70078
Epoch 98/6

Epoch 112/600
Epoch 00112: val_loss did not improve from 228.70078
Epoch 113/600
Epoch 00113: val_loss did not improve from 228.70078
Epoch 114/600
Epoch 00114: val_loss did not improve from 228.70078
Epoch 115/600
Epoch 00115: val_loss did not improve from 228.70078
Epoch 116/600
Epoch 00116: val_loss did not improve from 228.70078
Epoch 117/600
Epoch 00117: val_loss did not improve from 228.70078
Epoch 118/600
Epoch 00118: val_loss did not improve from 228.70078
Epoch 119/600
Epoch 00119: val_loss did not improve from 228.70078
Epoch 120/600
Epoch 00120: val_loss did not improve from 228.70078
Epoch 121/600
Epoch 00121: val_loss did not improve from 228.70078
Epoch 122/600
Epoch 00122: val_loss did not improve from 228.70078
Epoch 123/600
Epoch 00123: val_loss did not improve from 228.70078
Epoch 124/600
Epoch 00124: val_loss did not improve from 228.70078
Epoch 125/600
Epoch 00125: val_loss did not improve from 228.70078
Epoch 126/600
Epoch 00126: val_loss did not improve from 228.7

Epoch 141/600
Epoch 00141: val_loss did not improve from 228.70078
Epoch 142/600
Epoch 00142: val_loss did not improve from 228.70078
Epoch 143/600
Epoch 00143: val_loss did not improve from 228.70078
Epoch 144/600
Epoch 00144: val_loss did not improve from 228.70078
Epoch 145/600
Epoch 00145: val_loss did not improve from 228.70078
Epoch 146/600
Epoch 00146: val_loss did not improve from 228.70078
Epoch 147/600
Epoch 00147: val_loss did not improve from 228.70078
Epoch 148/600
Epoch 00148: val_loss did not improve from 228.70078
Epoch 149/600
Epoch 00149: val_loss did not improve from 228.70078
Epoch 150/600
Epoch 00150: val_loss did not improve from 228.70078
Epoch 151/600
Epoch 00151: val_loss did not improve from 228.70078
Epoch 152/600
Epoch 00152: val_loss did not improve from 228.70078
Epoch 153/600
Epoch 00153: val_loss did not improve from 228.70078
Epoch 154/600
Epoch 00154: val_loss did not improve from 228.70078
Epoch 155/600
Epoch 00155: val_loss did not improve from 228.7

Epoch 170/600
Epoch 00170: val_loss did not improve from 228.70078
Epoch 171/600
Epoch 00171: val_loss did not improve from 228.70078
Epoch 172/600
Epoch 00172: val_loss did not improve from 228.70078
Epoch 173/600
Epoch 00173: val_loss improved from 228.70078 to 228.02232, saving model to ./CKPT\CKP_ep_173__loss_228.02232_.h5
Epoch 174/600
Epoch 00174: val_loss did not improve from 228.02232
Epoch 175/600
Epoch 00175: val_loss did not improve from 228.02232
Epoch 176/600
Epoch 00176: val_loss did not improve from 228.02232
Epoch 177/600
Epoch 00177: val_loss did not improve from 228.02232
Epoch 178/600
Epoch 00178: val_loss did not improve from 228.02232
Epoch 179/600
Epoch 00179: val_loss did not improve from 228.02232
Epoch 180/600
Epoch 00180: val_loss did not improve from 228.02232
Epoch 181/600
Epoch 00181: val_loss did not improve from 228.02232
Epoch 182/600
Epoch 00182: val_loss did not improve from 228.02232
Epoch 183/600
Epoch 00183: val_loss did not improve from 228.02232
E

Epoch 199/600
Epoch 00199: val_loss did not improve from 228.02232
Epoch 200/600
Epoch 00200: val_loss did not improve from 228.02232
Epoch 201/600
Epoch 00201: val_loss did not improve from 228.02232
Epoch 202/600
Epoch 00202: val_loss did not improve from 228.02232
Epoch 203/600
Epoch 00203: val_loss did not improve from 228.02232
Epoch 204/600
Epoch 00204: val_loss did not improve from 228.02232
Epoch 205/600
Epoch 00205: val_loss did not improve from 228.02232
Epoch 206/600
Epoch 00206: val_loss did not improve from 228.02232
Epoch 207/600
Epoch 00207: val_loss did not improve from 228.02232
Epoch 208/600
Epoch 00208: val_loss did not improve from 228.02232
Epoch 209/600
Epoch 00209: val_loss did not improve from 228.02232
Epoch 210/600
Epoch 00210: val_loss did not improve from 228.02232
Epoch 211/600
Epoch 00211: val_loss did not improve from 228.02232
Epoch 212/600
Epoch 00212: val_loss did not improve from 228.02232
Epoch 213/600
Epoch 00213: val_loss did not improve from 228.0

Epoch 228/600
Epoch 00228: val_loss did not improve from 227.94547
Epoch 229/600
Epoch 00229: val_loss did not improve from 227.94547
Epoch 230/600
Epoch 00230: val_loss did not improve from 227.94547
Epoch 231/600
Epoch 00231: val_loss did not improve from 227.94547
Epoch 232/600
Epoch 00232: val_loss did not improve from 227.94547
Epoch 233/600
Epoch 00233: val_loss did not improve from 227.94547
Epoch 234/600
Epoch 00234: val_loss did not improve from 227.94547
Epoch 235/600
Epoch 00235: val_loss did not improve from 227.94547
Epoch 236/600
Epoch 00236: val_loss did not improve from 227.94547
Epoch 237/600
Epoch 00237: val_loss did not improve from 227.94547
Epoch 238/600
Epoch 00238: val_loss did not improve from 227.94547
Epoch 239/600
Epoch 00239: val_loss did not improve from 227.94547
Epoch 240/600
Epoch 00240: val_loss did not improve from 227.94547
Epoch 241/600
Epoch 00241: val_loss did not improve from 227.94547
Epoch 242/600
Epoch 00242: val_loss did not improve from 227.9

Epoch 257/600
Epoch 00257: val_loss did not improve from 227.94547
Epoch 258/600
Epoch 00258: val_loss did not improve from 227.94547
Epoch 259/600
Epoch 00259: val_loss did not improve from 227.94547
Epoch 260/600
Epoch 00260: val_loss did not improve from 227.94547
Epoch 261/600
Epoch 00261: val_loss did not improve from 227.94547
Epoch 262/600
Epoch 00262: val_loss did not improve from 227.94547
Epoch 263/600
Epoch 00263: val_loss did not improve from 227.94547
Epoch 264/600
Epoch 00264: val_loss did not improve from 227.94547
Epoch 265/600
Epoch 00265: val_loss did not improve from 227.94547
Epoch 266/600
Epoch 00266: val_loss did not improve from 227.94547
Epoch 267/600
Epoch 00267: val_loss did not improve from 227.94547
Epoch 268/600
Epoch 00268: val_loss did not improve from 227.94547
Epoch 269/600
Epoch 00269: val_loss did not improve from 227.94547
Epoch 270/600
Epoch 00270: val_loss did not improve from 227.94547
Epoch 271/600
Epoch 00271: val_loss did not improve from 227.9

Epoch 314/600
Epoch 00314: val_loss did not improve from 227.24799
Epoch 315/600
Epoch 00315: val_loss did not improve from 227.24799
Epoch 316/600
Epoch 00316: val_loss did not improve from 227.24799
Epoch 317/600
Epoch 00317: val_loss did not improve from 227.24799
Epoch 318/600
Epoch 00318: val_loss did not improve from 227.24799
Epoch 319/600
Epoch 00319: val_loss did not improve from 227.24799
Epoch 320/600
Epoch 00320: val_loss did not improve from 227.24799
Epoch 321/600
Epoch 00321: val_loss did not improve from 227.24799
Epoch 322/600
Epoch 00322: val_loss did not improve from 227.24799
Epoch 323/600
Epoch 00323: val_loss did not improve from 227.24799
Epoch 324/600
Epoch 00324: val_loss did not improve from 227.24799
Epoch 325/600
Epoch 00325: val_loss did not improve from 227.24799
Epoch 326/600
Epoch 00326: val_loss did not improve from 227.24799
Epoch 327/600
Epoch 00327: val_loss did not improve from 227.24799
Epoch 328/600
Epoch 00328: val_loss did not improve from 227.2

Epoch 343/600
Epoch 00343: val_loss did not improve from 227.24799
Epoch 344/600
Epoch 00344: val_loss did not improve from 227.24799
Epoch 345/600
Epoch 00345: val_loss did not improve from 227.24799
Epoch 346/600
Epoch 00346: val_loss did not improve from 227.24799
Epoch 347/600
Epoch 00347: val_loss did not improve from 227.24799
Epoch 348/600
Epoch 00348: val_loss did not improve from 227.24799
Epoch 349/600
Epoch 00349: val_loss did not improve from 227.24799
Epoch 350/600
Epoch 00350: val_loss did not improve from 227.24799
Epoch 351/600
Epoch 00351: val_loss did not improve from 227.24799
Epoch 352/600
Epoch 00352: val_loss did not improve from 227.24799
Epoch 353/600
Epoch 00353: val_loss did not improve from 227.24799
Epoch 354/600
Epoch 00354: val_loss did not improve from 227.24799
Epoch 355/600
Epoch 00355: val_loss did not improve from 227.24799
Epoch 356/600
Epoch 00356: val_loss did not improve from 227.24799
Epoch 357/600
Epoch 00357: val_loss did not improve from 227.2

Epoch 372/600
Epoch 00372: val_loss did not improve from 227.24799
Epoch 373/600
Epoch 00373: val_loss did not improve from 227.24799
Epoch 374/600
Epoch 00374: val_loss did not improve from 227.24799
Epoch 375/600
Epoch 00375: val_loss did not improve from 227.24799
Epoch 376/600
Epoch 00376: val_loss did not improve from 227.24799
Epoch 377/600
Epoch 00377: val_loss did not improve from 227.24799
Epoch 378/600
Epoch 00378: val_loss did not improve from 227.24799
Epoch 379/600
Epoch 00379: val_loss did not improve from 227.24799
Epoch 380/600
Epoch 00380: val_loss did not improve from 227.24799
Epoch 381/600
Epoch 00381: val_loss did not improve from 227.24799
Epoch 382/600
Epoch 00382: val_loss did not improve from 227.24799
Epoch 383/600
Epoch 00383: val_loss did not improve from 227.24799
Epoch 384/600
Epoch 00384: val_loss did not improve from 227.24799
Epoch 385/600
Epoch 00385: val_loss did not improve from 227.24799
Epoch 386/600
Epoch 00386: val_loss did not improve from 227.2

Epoch 401/600
Epoch 00401: val_loss did not improve from 227.24799
Epoch 402/600
Epoch 00402: val_loss did not improve from 227.24799
Epoch 403/600
Epoch 00403: val_loss did not improve from 227.24799
Epoch 404/600
Epoch 00404: val_loss did not improve from 227.24799
Epoch 405/600
Epoch 00405: val_loss did not improve from 227.24799
Epoch 406/600
Epoch 00406: val_loss did not improve from 227.24799
Epoch 407/600
Epoch 00407: val_loss did not improve from 227.24799
Epoch 408/600
Epoch 00408: val_loss did not improve from 227.24799
Epoch 409/600
Epoch 00409: val_loss did not improve from 227.24799
Epoch 410/600
Epoch 00410: val_loss did not improve from 227.24799
Epoch 411/600
Epoch 00411: val_loss did not improve from 227.24799
Epoch 412/600
Epoch 00412: val_loss did not improve from 227.24799
Epoch 413/600
Epoch 00413: val_loss did not improve from 227.24799
Epoch 414/600
Epoch 00414: val_loss did not improve from 227.24799
Epoch 415/600
Epoch 00415: val_loss did not improve from 227.2

Epoch 430/600
Epoch 00430: val_loss did not improve from 226.30409
Epoch 431/600
Epoch 00431: val_loss did not improve from 226.30409
Epoch 432/600
Epoch 00432: val_loss did not improve from 226.30409
Epoch 433/600
Epoch 00433: val_loss did not improve from 226.30409
Epoch 434/600
Epoch 00434: val_loss did not improve from 226.30409
Epoch 435/600
Epoch 00435: val_loss did not improve from 226.30409
Epoch 436/600
Epoch 00436: val_loss did not improve from 226.30409
Epoch 437/600
Epoch 00437: val_loss did not improve from 226.30409
Epoch 438/600
Epoch 00438: val_loss did not improve from 226.30409
Epoch 439/600
Epoch 00439: val_loss did not improve from 226.30409
Epoch 440/600
Epoch 00440: val_loss did not improve from 226.30409
Epoch 441/600
Epoch 00441: val_loss did not improve from 226.30409
Epoch 442/600
Epoch 00442: val_loss did not improve from 226.30409
Epoch 443/600
Epoch 00443: val_loss did not improve from 226.30409
Epoch 444/600
Epoch 00444: val_loss did not improve from 226.3

Epoch 459/600
Epoch 00459: val_loss did not improve from 226.29158
Epoch 460/600
Epoch 00460: val_loss did not improve from 226.29158
Epoch 461/600
Epoch 00461: val_loss did not improve from 226.29158
Epoch 462/600
Epoch 00462: val_loss did not improve from 226.29158
Epoch 463/600
Epoch 00463: val_loss did not improve from 226.29158
Epoch 464/600
Epoch 00464: val_loss did not improve from 226.29158
Epoch 465/600
Epoch 00465: val_loss did not improve from 226.29158
Epoch 466/600
Epoch 00466: val_loss did not improve from 226.29158
Epoch 467/600
Epoch 00467: val_loss did not improve from 226.29158
Epoch 468/600
Epoch 00468: val_loss improved from 226.29158 to 226.11737, saving model to ./CKPT\CKP_ep_468__loss_226.11737_.h5
Epoch 469/600
Epoch 00469: val_loss did not improve from 226.11737
Epoch 470/600
Epoch 00470: val_loss did not improve from 226.11737
Epoch 471/600
Epoch 00471: val_loss did not improve from 226.11737
Epoch 472/600
Epoch 00472: val_loss did not improve from 226.11737
E

Epoch 516/600
Epoch 00516: val_loss did not improve from 225.47810
Epoch 517/600
Epoch 00517: val_loss did not improve from 225.47810
Epoch 518/600
Epoch 00518: val_loss did not improve from 225.47810
Epoch 519/600
Epoch 00519: val_loss did not improve from 225.47810
Epoch 520/600
Epoch 00520: val_loss did not improve from 225.47810
Epoch 521/600
Epoch 00521: val_loss did not improve from 225.47810
Epoch 522/600
Epoch 00522: val_loss did not improve from 225.47810
Epoch 523/600
Epoch 00523: val_loss did not improve from 225.47810
Epoch 524/600
Epoch 00524: val_loss did not improve from 225.47810
Epoch 525/600
Epoch 00525: val_loss did not improve from 225.47810
Epoch 526/600
Epoch 00526: val_loss did not improve from 225.47810
Epoch 527/600
Epoch 00527: val_loss did not improve from 225.47810
Epoch 528/600
Epoch 00528: val_loss did not improve from 225.47810
Epoch 529/600
Epoch 00529: val_loss did not improve from 225.47810
Epoch 530/600
Epoch 00530: val_loss did not improve from 225.4

Epoch 545/600
Epoch 00545: val_loss did not improve from 225.47810
Epoch 546/600
Epoch 00546: val_loss did not improve from 225.47810
Epoch 547/600
Epoch 00547: val_loss did not improve from 225.47810
Epoch 548/600
Epoch 00548: val_loss did not improve from 225.47810
Epoch 549/600
Epoch 00549: val_loss did not improve from 225.47810
Epoch 550/600
Epoch 00550: val_loss did not improve from 225.47810
Epoch 551/600
Epoch 00551: val_loss did not improve from 225.47810
Epoch 552/600
Epoch 00552: val_loss did not improve from 225.47810
Epoch 553/600
Epoch 00553: val_loss did not improve from 225.47810
Epoch 554/600
Epoch 00554: val_loss did not improve from 225.47810
Epoch 555/600
Epoch 00555: val_loss did not improve from 225.47810
Epoch 556/600
Epoch 00556: val_loss did not improve from 225.47810
Epoch 557/600
Epoch 00557: val_loss did not improve from 225.47810
Epoch 558/600
Epoch 00558: val_loss did not improve from 225.47810
Epoch 559/600
Epoch 00559: val_loss did not improve from 225.4

Epoch 574/600
Epoch 00574: val_loss did not improve from 225.47810
Epoch 575/600
Epoch 00575: val_loss did not improve from 225.47810
Epoch 576/600
Epoch 00576: val_loss did not improve from 225.47810
Epoch 577/600
Epoch 00577: val_loss did not improve from 225.47810
Epoch 578/600
Epoch 00578: val_loss did not improve from 225.47810
Epoch 579/600
Epoch 00579: val_loss did not improve from 225.47810
Epoch 580/600
Epoch 00580: val_loss did not improve from 225.47810
Epoch 581/600
Epoch 00581: val_loss did not improve from 225.47810
Epoch 582/600
Epoch 00582: val_loss did not improve from 225.47810
Epoch 583/600
Epoch 00583: val_loss did not improve from 225.47810
Epoch 584/600
Epoch 00584: val_loss did not improve from 225.47810
Epoch 585/600
Epoch 00585: val_loss did not improve from 225.47810
Epoch 586/600
Epoch 00586: val_loss did not improve from 225.47810
Epoch 587/600
Epoch 00587: val_loss did not improve from 225.47810
Epoch 588/600
Epoch 00588: val_loss did not improve from 225.4

## 2.2. Encoder 부르는 방법, Decoder에 값 넣는 방법

In [50]:
latent_size = 512
epoch = 200
BATCH_SIZE = 2

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    model_path = './CKPT/CKP_ep_283__loss_141.77045_.h5'
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)
    vq_vae(0, True)
    vq_vae.summary()
    
    vq_vae.load_weights(model_path)
    
    # 이렇게 하면, transforer의 input으로 들어가는 one-hot 형식의 값을 얻을 수 있음
    for inputs, label in train_dataset:
        encode = vq_vae.encoder(inputs).numpy()
        encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
    
    # 이렇게 하면, transformer의 output을 vq-vae의 decoder 입력으로 넣을 수 있음
    for inputs, label in train_dataset:
        encode = vq_vae.encoder(inputs).numpy()
        encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
        
        # 이렇게 이전 layer의 출렫을 넣으면 됨
        decode = vq_vae.decoder(encode_onehot).numpy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_25 (Softmax)         (None, None, 512)         0         
_________________________________________________________________
encoder (Encoder)            (None, None, 512)         517248    
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           516737    
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 512)         0         
Total params: 1,033,985
Trainable params: 1,033,985
Non-trainable params: 0
_________________________________________________________________


# 3. Test Model

In [16]:
def mkdir_p(path):
    """ Creates a path recursively without throwing an error if it already exists
    :param path: path to create .//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    :return: None
    """
    if not os.path.exists(path):
        os.makedirs(path)

In [17]:
mkdir_p('./test_wav/') # Result wav 폴더 만드는 코드

In [18]:
def audiowrite(data, path, samplerate=16000, normalize=False, threaded=True):
    """ Write the audio data ``data`` to the wav file ``path``
    The file can be written in a threaded mode. In this case, the writing
    process will be started at a separate thread. Consequently, the file will
    not be written when this function exits.
    :param data: A numpy array with the audio data
    :param path: The wav file the data should be written to
    :param samplerate: Samplerate of the audio data
    :param normalize: Normalize the audio first so that the values are within
        the range of [INTMIN, INTMAX]. E.g. no clipping occurs
    :param threaded: If true, the write process will be started as a separate
        thread
    :return: The number of clipped samples
    """
    data = data.copy()
    int16_max = np.iinfo(np.int16).max
    int16_min = np.iinfo(np.int16).min

    if normalize:
        if not data.dtype.kind == 'f':
            data = data.astype(np.float)
        data /= np.max(np.abs(data))

    if data.dtype.kind == 'f':
        data *= int16_max

    sample_to_clip = np.sum(data > int16_max)
    if sample_to_clip > 0:
        print('Warning, clipping {} samples'.format(sample_to_clip))
    data = np.clip(data, int16_min, int16_max)
    data = data.astype(np.int16)

    if threaded:
        threading.Thread(target=wav_write, args=(path, samplerate, data)).start()
    else:
        wav_write(path, samplerate, data)

    return sample_to_clip

In [20]:
with tf.device('/cpu:0'):
    latent_size = 1024
    sample_rate = 8000
    model_path = './CKPT/CKP_ep_293__loss_49.28763_.h5'
    
    vq_vae = Vq_vae(latent_size, for_predict=True)
    vq_vae(0, True)
    vq_vae.summary()
    vq_vae.load_weights(model_path)

    for batch in test_dataset:
        input_batch, length_batch, name = batch

        result = vq_vae.predict(input_batch)
        
        wav_name = './test_wav/' + name[0][:-5] + '_s1.wav'
        audiowrite(result[0], wav_name, sample_rate, True, True)

Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_3 (Softmax)          (None, None, 1024)        0         
_________________________________________________________________
encoder (Encoder)            (None, None, 1024)        1263776   
_________________________________________________________________
gumbel_softmax (GumbelSoftma multiple                  0 (unused)
_________________________________________________________________
einsum_dense_3 (EinsumDense) (None, None, 1024)        1049600   
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           2835521   
Total params: 5,148,897
Trainable params: 5,148,897
Non-trainable params: 0
_________________________________________________________________


# 여기 밑에는 연습장임

In [73]:
import numpy as np
from keras.models import Sequential
from keras.layers import Embedding

model = Sequential()
model.add(layers.Conv1D(filters=64, kernel_size=3, activation='relu', padding='same', input_shape=(None, 1)))
model.add(layers.Conv1D(filters=512, kernel_size=3, padding='same'))

input_array = np.random.randn(2, 10, 1)
input_array2 = np.random.randn(2, 9, 1)
with tf.device('/cpu:0'):
    model.compile('rmsprop', 'mse')

    output_array = model.predict(input_array)
    output_array2 = model.predict(input_array2)

In [74]:
tf.reduce_sum(output_array, axis=[1, 2])

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([3.5056162, 3.214589 ], dtype=float32)>

In [75]:
tf.unstack(output_array, axis=1)

[<tf.Tensor: shape=(2, 512), dtype=float32, numpy=
 array([[ 0.05127785,  0.04667316,  0.02635628, ...,  0.04573939,
          0.01375185, -0.01533533],
        [ 0.01121127,  0.06675138,  0.0021627 , ...,  0.07490488,
          0.07613068,  0.02180689]], dtype=float32)>,
 <tf.Tensor: shape=(2, 512), dtype=float32, numpy=
 array([[ 0.08263874,  0.03258345,  0.01568549, ...,  0.03589562,
         -0.01625197, -0.03177397],
        [ 0.13691738,  0.09397386, -0.01036285, ..., -0.00614451,
          0.05317175, -0.01284541]], dtype=float32)>,
 <tf.Tensor: shape=(2, 512), dtype=float32, numpy=
 array([[ 0.07028106,  0.00527185,  0.01043324, ...,  0.0617263 ,
         -0.03244079, -0.01786789],
        [ 0.10520705,  0.042405  ,  0.03726472, ...,  0.07816655,
          0.03663039, -0.03122897]], dtype=float32)>,
 <tf.Tensor: shape=(2, 512), dtype=float32, numpy=
 array([[ 0.07235897,  0.00636261, -0.03309729, ...,  0.05743181,
         -0.01674554, -0.01104611],
        [ 0.11989237,  0.047

In [76]:
print(output_array)
print(output_array.shape)
print(output_array2.shape)

[[[ 0.05127785  0.04667316  0.02635628 ...  0.04573939  0.01375185
   -0.01533533]
  [ 0.08263874  0.03258345  0.01568549 ...  0.03589562 -0.01625197
   -0.03177397]
  [ 0.07028106  0.00527185  0.01043324 ...  0.0617263  -0.03244079
   -0.01786789]
  ...
  [-0.01148595  0.03309207 -0.02289026 ...  0.02078854 -0.06813245
   -0.09970599]
  [ 0.07415138  0.0991153  -0.03032039 ...  0.01240078 -0.04231954
   -0.04340101]
  [ 0.02075784  0.00618587 -0.07708566 ...  0.10809869  0.01363952
    0.02574195]]

 [[ 0.01121127  0.06675138  0.0021627  ...  0.07490488  0.07613068
    0.02180689]
  [ 0.13691738  0.09397386 -0.01036285 ... -0.00614451  0.05317175
   -0.01284541]
  [ 0.10520705  0.042405    0.03726472 ...  0.07816655  0.03663039
   -0.03122897]
  ...
  [ 0.04621019  0.06871783  0.00798892 ...  0.01603857  0.04825678
    0.01381727]
  [ 0.0583101   0.0205835  -0.01229584 ... -0.01952541  0.00030562
   -0.05487346]
  [-0.00304264 -0.00650905 -0.00512165 ...  0.01797911 -0.00719647
   -0.

In [79]:
one_hot = tf.math.argmax(tf.cast(tf.equal(output_array, tf.math.reduce_max(output_array, 2, keepdims=True)), output_array.dtype), axis=-1)
print(one_hot.shape)
print(one_hot)
# layers.Embedding(512, 512)(one_hot)

(2, 10)
tf.Tensor(
[[259 259  34  50   0 110   0 305 200 287]
 [257   0 200 259  34 200 509 110 200 287]], shape=(2, 10), dtype=int64)


In [70]:
one_hot = tf.math.argmax(output_array, axis=-1)
print(one_hot.shape)
print(one_hot)

(2, 10)
tf.Tensor(
[[228 249 283 206  20 206 435  32 270  30]
 [428 206  20 244 357 289 324 249 498 134]], shape=(2, 10), dtype=int64)


In [50]:
qy = tf.nn.softmax(output_array)
log_qy = tf.math.log(qy + 1e-10)
log_uniform = qy * (log_qy - tf.math.log(1.0 / 512))
kl_loss = tf.reduce_sum(log_uniform, axis=[1, 2])
kl_loss = tf.reduce_mean(kl_loss)

print(kl_loss)

tf.Tensor(0.012074914, shape=(), dtype=float32)


In [30]:
batch_size = tf.shape(output_array)[0]
array1_size = tf.shape(output_array)[1]
array2_size = tf.shape(output_array2)[1]
feature_size = tf.shape(output_array)[-1]

if array1_size < array2_size:
#     append_size = array1_size - array2_size
#     append_zeros = tf.zeros([batch_size, append_size, feature_size])
#     append_zeros = tf.Variable(initial_value=tf.zeros((batch_size, append_size, feature_size)))
#     output_array2 = tf.concat([output_array2, append_zeros], axis=1)
    output_array2 = tf.slice(output_array2, [0, 0, 0], [-1, array1_size, -1])
elif array1_size > array2_size:
#     append_size = array2_size - array1_size
#     append_zeros = tf.zeros([batch_size, append_size, feature_size])
#     append_zeros = tf.Variable(initial_value=tf.zeros((batch_size, append_size, feature_size)))
#     output_array = tf.concat([output_array, append_zeros], axis=1)
    output_array = tf.slice(output_array, [0, 0, 0], [-1, array2_size, -1])

print(output_array.shape)
print(output_array2.shape)
# output_array0 = output_array[1]
# output_array20 = output_array2[1]
# target = np.sum(output_array20 * output_array0) * output_array0 / np.square(np.linalg.norm(output_array0, ord=2))
# noise = output_array20 - target
# npnp = 10 * np.log10(np.square(np.linalg.norm(target, ord=2)) / np.square(np.linalg.norm(noise, ord=2)))
# print(npnp)

target = tf.linalg.matmul(output_array2, output_array, transpose_a=True) * output_array / tf.expand_dims(tf.experimental.numpy.square(tf.norm(output_array, axis=1)), axis=-1)
noise = output_array2 - target
si_sdr = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
si_sdr = tf.reduce_mean(si_sdr)
print(si_sdr)

(2, 9, 1)
(2, 9, 1)
tf.Tensor(2.8309882, shape=(), dtype=float32)


In [20]:
tf.cast(tf.equal(output_array, tf.math.reduce_max(output_array, 2, keepdims=True)), output_array.dtype)

<tf.Tensor: shape=(2, 9, 4), dtype=float32, numpy=
array([[[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.]],

       [[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]]], dtype=float32)>

In [21]:
output_array

<tf.Tensor: shape=(2, 9, 4), dtype=float32, numpy=
array([[[-0.03009652, -0.03612775, -0.06680483, -0.03670201],
        [-0.04768711, -0.12344762, -0.03924457, -0.11762322],
        [ 0.01808495, -0.16106637, -0.19467078, -0.15282159],
        [-0.0986427 , -0.08625205, -0.12661007, -0.16366175],
        [-0.09758376, -0.08886974, -0.0433558 , -0.19985165],
        [-0.06933096, -0.03154394, -0.13725929, -0.20143284],
        [ 0.03375649,  0.00182091, -0.01022564, -0.35924646],
        [-0.01645333, -0.10466891, -0.13975918, -0.12066491],
        [-0.13588801, -0.08173112, -0.00253745, -0.28615874]],

       [[ 0.04865369, -0.02880372, -0.06414615, -0.07730438],
        [-0.08225074, -0.03192509, -0.06216412, -0.08035193],
        [-0.09515338,  0.04221668,  0.14230826, -0.23082384],
        [-0.00094383,  0.05597762, -0.09290768, -0.08630683],
        [-0.09894791, -0.04727853, -0.01004983, -0.30325216],
        [ 0.01705559, -0.16948727, -0.08829505, -0.16453639],
        [-0.07230

In [67]:
output_softmax = tf.nn.softmax(output_array)
output_softmax

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0.24628553, 0.282701  , 0.2385324 , 0.23248109],
        [0.2298986 , 0.23856457, 0.25392184, 0.27761498],
        [0.2101039 , 0.2444843 , 0.26824066, 0.27717113],
        [0.22202027, 0.29728216, 0.2501223 , 0.23057525],
        [0.24084595, 0.2724257 , 0.25000373, 0.23672463]],

       [[0.23988546, 0.27957046, 0.23915865, 0.24138539],
        [0.24975686, 0.27135593, 0.24479878, 0.23408844],
        [0.24278015, 0.26340333, 0.24202275, 0.25179377],
        [0.23139507, 0.262904  , 0.25970972, 0.2459912 ],
        [0.2601803 , 0.25627998, 0.24524413, 0.23829558]]], dtype=float32)>

In [73]:
output_reshape = tf.reshape(output_softmax, [-1, 4])
output_reshape.shape

TensorShape([10, 4])

In [83]:
tf.reshape(tf.nn.softmax(output_array), [-1, 5, 4])

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0.24628553, 0.282701  , 0.2385324 , 0.23248109],
        [0.2298986 , 0.23856457, 0.25392184, 0.27761498],
        [0.2101039 , 0.2444843 , 0.26824066, 0.27717113],
        [0.22202027, 0.29728216, 0.2501223 , 0.23057525],
        [0.24084595, 0.2724257 , 0.25000373, 0.23672463]],

       [[0.23988546, 0.27957046, 0.23915865, 0.24138539],
        [0.24975686, 0.27135593, 0.24479878, 0.23408844],
        [0.24278015, 0.26340333, 0.24202275, 0.25179377],
        [0.23139507, 0.262904  , 0.25970972, 0.2459912 ],
        [0.2601803 , 0.25627998, 0.24524413, 0.23829558]]], dtype=float32)>

In [76]:
# tf.cast(tf.equal(y, tf.reduce_max(y,1,keep_dims=True)), y.dtype)
output_hard = tf.cast(tf.equal(output_reshape, tf.math.reduce_max(output_reshape, 1, keepdims=True)), output_softmax.dtype)
output_hard

<tf.Tensor: shape=(10, 4), dtype=float32, numpy=
array([[0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.]], dtype=float32)>

In [81]:
tf.reshape(output_hard, [-1, 5, 4])

<tf.Tensor: shape=(2, 5, 4), dtype=float32, numpy=
array([[[0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.]],

       [[0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.]]], dtype=float32)>

In [322]:
tf.transpose(output_array, perm=[0, 2, 1]).shape

TensorShape([2, 4, 5])

In [316]:
layers.Softmax(output_array)

<tensorflow.python.keras.layers.advanced_activations.Softmax at 0x26f14e934c8>

In [317]:
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).shape

(3, 3)

In [65]:
np.log(10)

2.302585092994046