# 1. Data Generator
- Raw Data를 읽어옴
- 여기서 만들어진 데이터는 모델의 입력으로 들어감

In [1]:
import os
import numpy as np
import librosa
from tensorflow.keras.utils import Sequence

In [2]:
class RawForVAEGenerator(Sequence):
    def __init__(self, source, wav_dir, files, sourNum='s1', batch_size=10, shuffle=True):
        self.source = source
        self.wav_dir = wav_dir
        self.files = files
        self.sourNum = sourNum
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()
        
        self.sample_rate = 8000
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.source))
        
        if self.shuffle:
            np.random.shuffle(self.indexes)
    
    def __audioread__(self, path, offset=0.0, duration=None, sample_rate=16000):
        signal = librosa.load(path, sr=self.sample_rate, mono=False, offset=offset, duration=duration)

        return signal[0]
    
    def __padding__(self, data):
        n_batch = len(data)
        max_len = max([d.shape[0] for d in data])
        extrapadding = int(np.ceil(max_len / self.sample_rate) * self.sample_rate)
        pad = np.zeros((n_batch, extrapadding))
        
        for i in range(n_batch):
            pad[i, :data[i].shape[0]] = data[i]
        
        return np.expand_dims(pad, -1)
        
    def __data_generation__(self, source_list):
        wav_list = []
        for name in source_list:
            name = name.strip('\n')
            
            s_wav_name = self.wav_dir + self.files + '/' + self.sourNum + '/' + name
            
            # ------- AUDIO READ -------
            s_wav = (self.__audioread__(s_wav_name,  offset=0.0, duration=None, sample_rate=self.sample_rate))
            # --------------------------
            
            # ------- PADDING -------
#             pad_len = max(len(samples1),len(samples2))
#             pad_s1 = np.concatenate([s1_wav, np.zeros([pad_len - len(s1_wav)])])
            
#             extrapadding = ceil(len(pad_s1) / sample_rate) * sample_rate - len(pad_s1)
#             pad_s1 = np.concatenate([pad_s1, np.zeros([extrapadding - len(pad_s1)])])
#             pad_s2 = np.concatenate([s2_wav, np.zeros([extrapadding - len(s2_wav)])])
            # -----------------------
            
            wav_list.append(s_wav)
        
        return wav_list, wav_list, source_list
            
    
    def __len__(self):
        return int(np.floor(len(self.source) / self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        source_list = [self.source[k] for k in indexes]
        
        if self.files is not 'tt':
            sour, labels, _ = self.__data_generation__(source_list)
            
            # Get Lengths(K value of each batch)
            lengths = np.array([m.shape[0] for m in sour])
            exp = np.expand_dims(lengths, 1)
            exp = np.expand_dims(exp, -1) # [Batch, 1, 1] (length)
            
            # Padding
            sour_pad = self.__padding__(sour) # [Batch, Time_step, Dimension(=1)]
            label_pad = self.__padding__(labels) # [Batch, Time_step, Dimension(=1)]
            
            return sour_pad , exp#, np.concatenate([label_pad, exp], axis=1)
        else:
            sour, labels, name = self.__data_generation__(source_list)
            
            # Get Lengths(K value of each batch)
            lengths = np.array([m.shape[0] for m in sour])
            exp = np.expand_dims(lengths, 1)
            exp = np.expand_dims(exp, -1) # [Batch, 1, 1] (length)
            
            # Padding
            sour_pad = self.__padding__(sour) # [Batch, Time_step, Dimension(=1)]
            
            return sour_pad, exp, name

## Data를 어떻게 읽는지에 대한 부분

In [3]:
WAV_DIR = './mycode/wsj0_2mix/use_this/'
LIST_DIR = './mycode/wsj0_2mix/use_this/lists/'

In [4]:
# Directory List file create

wav_dir = WAV_DIR
output_lst = LIST_DIR

for folder in ['tr', 'cv', 'tt']:
    wav_files = os.listdir(wav_dir + folder + '/mix')
    output_lst_files = output_lst + folder + '_wav.lst'
    with open(output_lst_files, 'w') as f:
        for file in wav_files:
            f.write(file + "\n")

print("Generate wav file to .lst done!")

Generate wav file to .lst done!


In [5]:
batch_size = 2

train_dataset = 0
valid_dataset = 0
test_dataset = 0

name_list = []
for files in ['tr', 'cv', 'tt']:
    # --- Lead lst file ---""
    output_lst_files = LIST_DIR + files + '_wav.lst'
    fid = open(output_lst_files, 'r')
    lines = fid.readlines()
    fid.close()
    # ---------------------
    
    if files == 'tr':
        train_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', batch_size)
    elif files == 'cv':
        valid_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', batch_size)
    else:
        test_batch = 1
        test_dataset = RawForVAEGenerator(lines, WAV_DIR, files, 's1', test_batch)

# 2. Building VQ-VAE model with Gumbel Softmax

In [6]:
import threading
from scipy.io.wavfile import write as wav_write
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from tensorflow.keras import backend as Kb
import numpy as np
import pandas as pd
from importlib import reload
import time
from tensorflow.keras.models import Model, Sequential, load_model

In [7]:
def mkdir_p(path):
    """ Creates a path recursively without throwing an error if it already exists
    :param path: path to create
    :return: None
    """
    if not os.path.exists(path):
        os.makedirs(path)

In [8]:
mkdir_p('./CKPT/') # model check point 폴더 만드는 코드
filepath = "./CKPT/CKP_ep_{epoch:d}__loss_{val_loss:.5f}_.h5"

In [9]:
initial_learning_rate = 0.001

# learning rate를 점점 줄이는 부분
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
)

# validation loss에 대해서 좋은 것만 저장됨
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True, mode='min'
)

# early stop 하는 부분인데, validation loss에 대해서 제일 좋은 모델이 저장됨
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', mode='min', verbose=1, patience=50, restore_best_weights=True
)

In [10]:
class GumbelSoftmax(layers.Layer):
    def __init__(self, temperature=0.5, hard=False, name = 'gumbel_softmax',**kwargs):
        super(GumbelSoftmax, self).__init__(name=name, **kwargs)
        
        self.temperature = temperature
        self.hard = hard
    
    def sample_gumbel(self, shape, eps=1e-8): 
        """Sample from Gumbel(0, 1)"""
        U = tf.random.uniform(shape,minval=0,maxval=1)
        
        return -tf.math.log(-tf.math.log(U + eps) + eps)

    def gumbel_softmax_sample(self, logits, temperature): 
        """ Draw a sample from the Gumbel-Softmax distribution"""
        y = logits + self.sample_gumbel(tf.shape(logits))
        
        return tf.nn.softmax(y / temperature)

    def call(self, inputs):
        y = self.gumbel_softmax_sample(inputs, self.temperature)
        
        if self.hard:
            y_hard = tf.cast(tf.equal(y, tf.math.reduce_max(y, 2, keepdims=True)), y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        
        return y


class Encoder(layers.Layer):
    def __init__(self, latent_dim, name = 'encoder',**kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        
        self.conv1d_1 = layers.Conv1D(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')
        self.conv1d_2 = layers.Conv1D(filters=128, kernel_size=3, strides=2, padding='same', activation='relu')
        self.conv1d_3 = layers.Conv1D(filters=256, kernel_size=3, strides=2, padding='same', activation='relu')
        self.logit = layers.Conv1D(filters=latent_dim, kernel_size=3, strides=2, activation='relu', padding='same')
    
    def call(self, inputs):
        x = self.conv1d_1(inputs)
        x = self.conv1d_2(x)
        x = self.conv1d_3(x)
        logit = self.logit(x)
        
        return logit


class Decoder(layers.Layer):
    def __init__(self, latent_dim, name = 'decoder',**kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        
        self.trans_conv1d_1 = layers.Conv1DTranspose(filters=256, kernel_size=3, strides=2, activation='relu', padding='same')
        self.trans_conv1d_2 = layers.Conv1DTranspose(filters=128, kernel_size=3, strides=2, activation='relu', padding='same')
        self.trans_conv1d_3 = layers.Conv1DTranspose(filters=64, kernel_size=3, strides=2, activation='relu', padding='same')
        self.logit = layers.Conv1DTranspose(filters=1, kernel_size=3, strides=2, padding='same', activation=None)
    
    def call(self, inputs):
        x = self.trans_conv1d_1(inputs)
        x = self.trans_conv1d_2(x)
        x = self.trans_conv1d_3(x)
        logit = self.logit(x)
        
        return logit

In [11]:
# Custom Metric Si-sdr

class SiSdr(keras.metrics.Metric):
    def __init__(self, name="Si-sdr", **kwargs):
        super(SiSdr, self).__init__(name=name, **kwargs)
        self.sdr = self.add_weight(name="sdr", initializer="zeros")
        self.count = self.add_weight(name="cnt", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        ori_length = tf.shape(y_true)[1]
        
        # Label & Length divide
        labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 1]
        lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]
        
        # Check sequence length
        batch_size = tf.shape(labels)[0]
        label_size = tf.shape(labels)[1]
        pred_size = tf.shape(y_pred)[1]
        feature_size = tf.shape(labels)[-1]
        
        # Change sequence length
        if label_size < pred_size:
            y_pred = tf.slice(y_pred, [0, 0, 0], [-1, label_size, -1])
        elif label_size > pred_size:
            labels = tf.slice(labels, [0, 0, 0], [-1, pred_size, -1])

        # SI-SDR
        target = tf.linalg.matmul(y_pred, labels, transpose_a=True) * labels / tf.expand_dims(tf.experimental.numpy.square(tf.norm(labels, axis=1)), axis=-1)
        noise = y_pred - target
        values = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
        
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.sdr.assign_add(tf.reduce_sum(values))
        self.count.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))

    def result(self):
        return self.sdr / self.count

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.sdr.assign(0.0)
        self.count.assign(0.0)

In [12]:
# Custom loss

# Custom mse
def custom_mse(y_true, y_pred):
    ori_length = tf.shape(y_true)[1]

    # Label & Length divide
    labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 129]
    lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]

    loss = tf.reduce_sum(tf.pow(y_pred - labels, 2))

    return loss


# Custom si-sdr loss
def custom_sisdr_loss(y_true, y_pred):
    ori_length = tf.shape(y_true)[1]

    # Label & Length divide
    labels = tf.slice(y_true, [0, 0, 0], [-1, ori_length-1, -1]) # [batch_size, length_size, 1]
    lengths = tf.slice(y_true, [0, ori_length-1, 0], [-1, -1, 1]) # [batch_size, 1, 1]

    target = tf.linalg.matmul(y_pred, labels, transpose_a=True) * labels / tf.expand_dims(tf.experimental.numpy.square(tf.norm(labels, axis=1)), axis=-1)
    noise = y_pred - target
    si_sdr = 10 * tf.experimental.numpy.log10(tf.experimental.numpy.square(tf.norm(target, axis=1)) / tf.experimental.numpy.square(tf.norm(noise, axis=1)))
    si_sdr = tf.reduce_mean(si_sdr) * -1

    return si_sdr

In [13]:
class Vq_vae(keras.Model):
    def __init__(self, latent_dim, gumbel_hard=False, name='vqvae', **kwargs):
        super(Vq_vae, self).__init__(name=name, **kwargs)
        
        self.latent_dim = latent_dim
        self.softmax = layers.Softmax(-1)
        
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        self.gumbel = GumbelSoftmax(hard=gumbel_hard)
        
    def call(self, inputs, load=False):
        if load:
            inputs = layers.Input(shape=(None, 1))
        
        
        encode = self.encoder(inputs)
        gumbel = self.gumbel(encode)
        decode = self.decoder(gumbel)
        
        # ------------------ KL loss ------------------
        qy = self.softmax(encode)
        log_qy = tf.math.log(qy + 1e-8)
        log_uniform = qy * (log_qy - tf.math.log(1.0 / self.latent_dim))
        kl_loss = tf.reduce_mean(log_uniform)
        # ---------------------------------------------
        
        self.add_loss(kl_loss)
        
        return decode

In [108]:
latent_size = 512
epoch = 300
BATCH_SIZE = 2

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    model_path = './CKPT/CKP_ep_30__loss_158.95885_.h5'
    
    loss_fun = custom_mse
#     loss_fun = custom_sisdr_loss
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)

    optimizer = keras.optimizers.Adam(learning_rate=1e-3)
    vq_vae.compile(optimizer, loss=loss_fun, metrics=[SiSdr()])
    
    vq_vae(0, True)
    vq_vae.summary()
    
    # 사용 안할 때는 load_model 주석 처리 하자
#     vq_vae.load_weights(model_path)
    # ----------------------------------------
    
    tf.executing_eagerly()

history = vq_vae.fit(
    train_dataset,
    epochs=epoch,
    validation_data=valid_dataset,
    shuffle=True,
    callbacks=[checkpoint_cb, early_stopping_cb],
)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_10 (Softmax)         (None, None, 512)         0         
_________________________________________________________________
encoder (Encoder)            (None, None, 512)         517248    
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           516737    
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 512)         0         
Total params: 1,033,985
Trainable params: 1,033,985
Non-trainable params: 0
_________________________________________________________________
Epoch 1/300

Epoch 00001: val_loss did not improve from 619.74805
Epoch 2/300

Epoch 00002: val_loss did not improve from 619.74805
Epoch 


Epoch 00032: val_loss improved from 115.00878 to 110.13300, saving model to ./CKPT\CKP_ep_32__loss_110.13300_.h5
Epoch 33/300

Epoch 00033: val_loss improved from 110.13300 to 105.24696, saving model to ./CKPT\CKP_ep_33__loss_105.24696_.h5
Epoch 34/300

Epoch 00034: val_loss improved from 105.24696 to 98.27646, saving model to ./CKPT\CKP_ep_34__loss_98.27646_.h5
Epoch 35/300

Epoch 00035: val_loss did not improve from 98.27646
Epoch 36/300

Epoch 00036: val_loss improved from 98.27646 to 90.68739, saving model to ./CKPT\CKP_ep_36__loss_90.68739_.h5
Epoch 37/300

Epoch 00037: val_loss improved from 90.68739 to 88.85553, saving model to ./CKPT\CKP_ep_37__loss_88.85553_.h5
Epoch 38/300

Epoch 00038: val_loss improved from 88.85553 to 86.96925, saving model to ./CKPT\CKP_ep_38__loss_86.96925_.h5
Epoch 39/300

Epoch 00039: val_loss improved from 86.96925 to 83.52572, saving model to ./CKPT\CKP_ep_39__loss_83.52572_.h5
Epoch 40/300

Epoch 00040: val_loss improved from 83.52572 to 81.04535, 


Epoch 00066: val_loss improved from 52.34985 to 51.62273, saving model to ./CKPT\CKP_ep_66__loss_51.62273_.h5
Epoch 67/300

Epoch 00067: val_loss improved from 51.62273 to 49.83034, saving model to ./CKPT\CKP_ep_67__loss_49.83034_.h5
Epoch 68/300

Epoch 00068: val_loss improved from 49.83034 to 49.51345, saving model to ./CKPT\CKP_ep_68__loss_49.51345_.h5
Epoch 69/300

Epoch 00069: val_loss improved from 49.51345 to 48.84103, saving model to ./CKPT\CKP_ep_69__loss_48.84103_.h5
Epoch 70/300

Epoch 00070: val_loss improved from 48.84103 to 46.77575, saving model to ./CKPT\CKP_ep_70__loss_46.77575_.h5
Epoch 71/300

Epoch 00071: val_loss improved from 46.77575 to 46.60735, saving model to ./CKPT\CKP_ep_71__loss_46.60735_.h5
Epoch 72/300

Epoch 00072: val_loss did not improve from 46.60735
Epoch 73/300

Epoch 00073: val_loss improved from 46.60735 to 44.95594, saving model to ./CKPT\CKP_ep_73__loss_44.95594_.h5
Epoch 74/300

Epoch 00074: val_loss improved from 44.95594 to 43.57182, saving 


Epoch 00101: val_loss did not improve from 29.51239
Epoch 102/300

Epoch 00102: val_loss improved from 29.51239 to 29.32246, saving model to ./CKPT\CKP_ep_102__loss_29.32246_.h5
Epoch 103/300

Epoch 00103: val_loss improved from 29.32246 to 29.00892, saving model to ./CKPT\CKP_ep_103__loss_29.00892_.h5
Epoch 104/300

Epoch 00104: val_loss improved from 29.00892 to 28.42862, saving model to ./CKPT\CKP_ep_104__loss_28.42862_.h5
Epoch 105/300

Epoch 00105: val_loss improved from 28.42862 to 27.99593, saving model to ./CKPT\CKP_ep_105__loss_27.99593_.h5
Epoch 106/300

Epoch 00106: val_loss did not improve from 27.99593
Epoch 107/300

Epoch 00107: val_loss did not improve from 27.99593
Epoch 108/300

Epoch 00108: val_loss did not improve from 27.99593
Epoch 109/300

Epoch 00109: val_loss did not improve from 27.99593
Epoch 110/300

Epoch 00110: val_loss improved from 27.99593 to 27.10705, saving model to ./CKPT\CKP_ep_110__loss_27.10705_.h5
Epoch 111/300

Epoch 00111: val_loss improved fro


Epoch 00138: val_loss did not improve from 21.45797
Epoch 139/300

Epoch 00139: val_loss did not improve from 21.45797
Epoch 140/300

Epoch 00140: val_loss improved from 21.45797 to 21.18142, saving model to ./CKPT\CKP_ep_140__loss_21.18142_.h5
Epoch 141/300

Epoch 00141: val_loss improved from 21.18142 to 21.13377, saving model to ./CKPT\CKP_ep_141__loss_21.13377_.h5
Epoch 142/300

Epoch 00142: val_loss improved from 21.13377 to 20.85234, saving model to ./CKPT\CKP_ep_142__loss_20.85234_.h5
Epoch 143/300

Epoch 00143: val_loss did not improve from 20.85234
Epoch 144/300

Epoch 00144: val_loss improved from 20.85234 to 20.49308, saving model to ./CKPT\CKP_ep_144__loss_20.49308_.h5
Epoch 145/300

Epoch 00145: val_loss did not improve from 20.49308
Epoch 146/300

Epoch 00146: val_loss improved from 20.49308 to 20.44168, saving model to ./CKPT\CKP_ep_146__loss_20.44168_.h5
Epoch 147/300

Epoch 00147: val_loss did not improve from 20.44168
Epoch 148/300

Epoch 00148: val_loss improved fro


Epoch 00212: val_loss improved from 15.28763 to 15.16775, saving model to ./CKPT\CKP_ep_212__loss_15.16775_.h5
Epoch 213/300

Epoch 00213: val_loss improved from 15.16775 to 14.95414, saving model to ./CKPT\CKP_ep_213__loss_14.95414_.h5
Epoch 214/300

Epoch 00214: val_loss did not improve from 14.95414
Epoch 215/300

Epoch 00215: val_loss did not improve from 14.95414
Epoch 216/300

Epoch 00216: val_loss improved from 14.95414 to 14.93959, saving model to ./CKPT\CKP_ep_216__loss_14.93959_.h5
Epoch 217/300

Epoch 00217: val_loss did not improve from 14.93959
Epoch 218/300

Epoch 00218: val_loss improved from 14.93959 to 14.79410, saving model to ./CKPT\CKP_ep_218__loss_14.79410_.h5
Epoch 219/300

Epoch 00219: val_loss did not improve from 14.79410
Epoch 220/300

Epoch 00220: val_loss did not improve from 14.79410
Epoch 221/300

Epoch 00221: val_loss did not improve from 14.79410
Epoch 222/300

Epoch 00222: val_loss did not improve from 14.79410
Epoch 223/300

Epoch 00223: val_loss did 


Epoch 00251: val_loss did not improve from 13.53243
Epoch 252/300

Epoch 00252: val_loss did not improve from 13.53243
Epoch 253/300

Epoch 00253: val_loss did not improve from 13.53243
Epoch 254/300

Epoch 00254: val_loss did not improve from 13.53243
Epoch 255/300

Epoch 00255: val_loss improved from 13.53243 to 13.38348, saving model to ./CKPT\CKP_ep_255__loss_13.38348_.h5
Epoch 256/300

Epoch 00256: val_loss did not improve from 13.38348
Epoch 257/300

Epoch 00257: val_loss did not improve from 13.38348
Epoch 258/300

Epoch 00258: val_loss did not improve from 13.38348
Epoch 259/300

Epoch 00259: val_loss did not improve from 13.38348
Epoch 260/300

Epoch 00260: val_loss did not improve from 13.38348
Epoch 261/300

Epoch 00261: val_loss did not improve from 13.38348
Epoch 262/300

Epoch 00262: val_loss improved from 13.38348 to 13.36317, saving model to ./CKPT\CKP_ep_262__loss_13.36317_.h5
Epoch 263/300

Epoch 00263: val_loss did not improve from 13.36317
Epoch 264/300

Epoch 0026


Epoch 00290: val_loss did not improve from 12.43894
Epoch 291/300

Epoch 00291: val_loss did not improve from 12.43894
Epoch 292/300

Epoch 00292: val_loss did not improve from 12.43894
Epoch 293/300

Epoch 00293: val_loss did not improve from 12.43894
Epoch 294/300

Epoch 00294: val_loss did not improve from 12.43894
Epoch 295/300

Epoch 00295: val_loss improved from 12.43894 to 12.43871, saving model to ./CKPT\CKP_ep_295__loss_12.43871_.h5
Epoch 296/300

Epoch 00296: val_loss did not improve from 12.43871
Epoch 297/300

Epoch 00297: val_loss did not improve from 12.43871
Epoch 298/300

Epoch 00298: val_loss did not improve from 12.43871
Epoch 299/300

Epoch 00299: val_loss did not improve from 12.43871
Epoch 300/300

Epoch 00300: val_loss did not improve from 12.43871


## 2.2. Encoder 부르는 방법, Decoder에 값 넣는 방법

In [None]:
latent_size = 512
epoch = 200
BATCH_SIZE = 2

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    model_path = './CKPT/CKP_ep_283__loss_141.77045_.h5'
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)
    vq_vae(0, True)
    vq_vae.summary()
    
    vq_vae.load_weights(model_path)
    
    # 이렇게 하면, transforer의 input으로 들어가는 one-hot 형식의 값을 얻을 수 있음
    for inputs, label in train_dataset:
        encode = vq_vae.encoder(inputs).numpy()
        encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
    
    # 이렇게 하면, transformer의 output을 vq-vae의 decoder 입력으로 넣을 수 있음
    for inputs, label in train_dataset:
        encode = vq_vae.encoder(inputs).numpy()
        encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)
        
        # 이렇게 이전 layer의 출렫을 넣으면 됨
        decode = vq_vae.decoder(encode_onehot).numpy()

# 3. Test Model

In [14]:
def mkdir_p(path):
    """ Creates a path recursively without throwing an error if it already exists
    :param path: path to create
    :return: None
    """
    if not os.path.exists(path):
        os.makedirs(path)

In [15]:
mkdir_p('./test_wav/') # Result wav 폴더 만드는 코드

In [16]:
def audiowrite(data, path, samplerate=16000, normalize=False, threaded=True):
    """ Write the audio data ``data`` to the wav file ``path``
    The file can be written in a threaded mode. In this case, the writing
    process will be started at a separate thread. Consequently, the file will
    not be written when this function exits.
    :param data: A numpy array with the audio data
    :param path: The wav file the data should be written to
    :param samplerate: Samplerate of the audio data
    :param normalize: Normalize the audio first so that the values are within
        the range of [INTMIN, INTMAX]. E.g. no clipping occurs
    :param threaded: If true, the write process will be started as a separate
        thread
    :return: The number of clipped samples
    """
    data = data.copy()
    int16_max = np.iinfo(np.int16).max
    int16_min = np.iinfo(np.int16).min

    if normalize:
        if not data.dtype.kind == 'f':
            data = data.astype(np.float)
        data /= np.max(np.abs(data))

    if data.dtype.kind == 'f':
        data *= int16_max

    sample_to_clip = np.sum(data > int16_max)
    if sample_to_clip > 0:
        print('Warning, clipping {} samples'.format(sample_to_clip))
    data = np.clip(data, int16_min, int16_max)
    data = data.astype(np.int16)

    if threaded:
        threading.Thread(target=wav_write, args=(path, samplerate, data)).start()
    else:
        wav_write(path, samplerate, data)

    return sample_to_clip

In [127]:
with tf.device('/cpu:0'):
    latent_size = 512
    sample_rate = 8000
    model_path = './CKPT/CKP_ep_295__loss_12.43871_.h5'
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=True)
    vq_vae(0, True)
    vq_vae.summary()
    vq_vae.load_weights(model_path)

    for batch in test_dataset:
        input_batch, length_batch, name = batch

        result = vq_vae.predict(input_batch)
        
        wav_name = './test_wav/' + name[0][:-5] + '.wav'
        audiowrite(result[0], wav_name, sample_rate, True, True)

Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_12 (Softmax)         (None, None, 512)         0         
_________________________________________________________________
encoder (Encoder)            (None, None, 512)         517248    
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           516737    
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 512)         0         
Total params: 1,033,985
Trainable params: 1,033,985
Non-trainable params: 0
_________________________________________________________________


# 여기 밑에는 연습장임

In [17]:
def create_masks(inp, tar, length=None):
    enc_padding_mask = tf.cast(tf.sequence_mask(tf.repeat(tf.shape(inp)[1],tf.shape(inp)[0]), tf.shape(inp)[1]), tf.float32)
    dec_padding_mask = tf.linalg.band_part(tf.ones((tf.shape(tar)[1], tf.shape(tar)[1])), -1, 0)
    
    return enc_padding_mask, dec_padding_mask        

In [16]:
encode = vq_vae.encoder(next(iter(train_dataset))[0]).numpy()
encode_onehot = tf.cast(tf.equal(encode, tf.math.reduce_max(encode, 2, keepdims=True)), encode.dtype)

In [24]:
encode.shape

(2, 5000, 512)

In [25]:
encode_onehot.shape

TensorShape([2, 5000, 512])

In [28]:
target = vq_vae.encoder(next(iter(train_dataset))[0]).numpy()
target_onehot = tf.cast(tf.equal(target, tf.math.reduce_max(target, 2, keepdims=True)), target.dtype)

In [173]:
target_onehot[0][100][324]

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

In [38]:
a = target_onehot[0][0]
a = tf.cast(a, tf.int32)

In [135]:
[tf.shape(encode_onehot)[0].numpy(), tf.shape(encode_onehot)[1].numpy(), 2]

[2, 5000, 2]

In [53]:
zeros = tf.zeros([2,5000,2], a.dtype)

In [55]:
decoder_input = tf.concat([tf.cast(target_onehot,tf.int32), zeros], -1)

In [60]:
decoder_input[:,:,:-2]

<tf.Tensor: shape=(2, 5000, 512), dtype=int32, numpy=
array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]])>

In [69]:
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encode_onehot, decoder_input, next(iter(train_dataset))[2])

In [147]:

start = tf.one_hot(tf.repeat([512], tf.shape(encode_onehot)[0]),tf.shape(encode_onehot)[-1]+2)
start = tf.expand_dims(start, 1)
zeros = tf.zeros([tf.shape(encode_onehot)[0].numpy(), tf.shape(encode_onehot)[1].numpy(), 2], encode_onehot.dtype)
encode_added = tf.concat([encode_onehot, zeros], -1)

tf.concat([start, encode_added],1)

<tf.Tensor: shape=(2, 5001, 514), dtype=float32, numpy=
array([[[0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)>

In [68]:
next(iter(train_dataset))[0].shape

(2, 56000, 1)

In [77]:
dec_padding_mask[0][0][0]

<tf.Tensor: shape=(5000,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>

In [108]:
create_masks(encode_onehot, decoder_input, next(iter(train_dataset))[2])

(<tf.Tensor: shape=(2, 5000), dtype=float32, numpy=
 array([[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)>,
 <tf.Tensor: shape=(5000, 5000), dtype=float32, numpy=
 array([[1., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 0., 0.],
        [1., 1., 1., ..., 0., 0., 0.],
        ...,
        [1., 1., 1., ..., 1., 0., 0.],
        [1., 1., 1., ..., 1., 1., 0.],
        [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)>)

In [104]:
tf.repeat(tf.shape(encode_onehot)[1],tf.shape(encode_onehot)[0])

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([5000, 5000])>

# 여기서부터 Transformer 임

In [88]:
from src.util.math_function import create_padding_mask, create_look_ahead_mask
from src.losses.custom_loss import mse_with_proper_loss, MSE_Custom_Loss_No_Length, pit_with_outputsize, pit_with_stft_trace
from src.models.Layers import TransformerSpeechSep
from src.models.Schedulers import CustomSchedule
from src.models.Real_Layers import T5Model, T5ModelNoMaskCreationModel, T5ModelYesMaskCreationModel
from src.pre_processing.data_pre_processing import load_data
import tensorflow_addons as tfa
from collections import namedtuple
import time

In [117]:
Config = namedtuple('Config',  
    field_names="d_ff,     d_kv,     d_model,              dropout, feed_forward_proj, num_layers, init_factor," 
                "layer_norm_epsilon, model_type, num_heads, positional_embedding, n_epochs, vocab_size, relative_attention_num_buckets,"
                    "model_path, wav_type, size_type, train_type, loss_type, learning_rate_type,"
                    "input_size, output_size, batch_size, case, ckpt_path, tr_path, val_path, tt_path,"
                    "test_wav_dir, is_load_model")
args = Config( 1024      , 64      , 512              , 0.1 , "gated-gelu", 1       , 1.,
                1e-06    , "t5"             , 8 , "absolute" , 5     , 512   , 32,
                "CKPT", "wav8k", "min", "train-360", "mse", "inverse_root",
                512, 514, 25, 'trace', 'C:/J_and_J_Research/CKPT/gen_code2', 
                'C:/J_and_J_Research/mycode/wsj0_2mix/use_this/tr/', 
                'C:/J_and_J_Research/mycode/wsj0_2mix/use_this/cv/',
                'C:/J_and_J_Research/mycode/wsj0_2mix/use_this/tt/', 
                'C:/J_and_J_Research/test_wav/gen2',
                True) 

In [19]:
!pip install tensorflow_addons



You should consider upgrading via the 'c:\users\qkrwo\anaconda3\envs\nlp_task\python.exe -m pip install --upgrade pip' command.


In [20]:
def create_masks(inp, tar, length=None):
    enc_padding_mask = tf.cast(tf.sequence_mask(tf.repeat(tf.shape(inp)[1],tf.shape(inp)[0]), tf.shape(inp)[1]), tf.float32)
    #dec_padding_mask = tf.linalg.band_part(tf.ones((tf.shape(tar)[1], tf.shape(tar)[1])), -1, 0)
    
    return enc_padding_mask#, dec_padding_mask        

In [56]:
class T5VQ_VAE(tf.keras.Model):
    def train_step(self, data):
        """print('inp',inp.shape) 
        startMask = tf.cast(tf.fill([1,258],-1),dtype=tf.float32)
        endMask = tf.cast(tf.fill([1,258],-2),dtype=tf.float32)
        tar_inp = tf.concat([startMask, tar],0)
        tar_real = tf.concat([tar, endMask],0)
        """
        inp, tar, length = data
        """start = tf.repeat([tf.shape(inp)[-1]], tf.shape(inp)[0])
        start = tf.cast(tf.one_hot(start,tf.shape(inp)[-1]+2),dtype=tf.float32)
        start = tf.expand_dims(start, 1)
        end = tf.repeat([tf.shape(inp)[-1]+1], tf.shape(inp)[0])
        end = tf.cast(tf.one_hot(end,tf.shape(inp)[-1]+2),dtype=tf.float32)
        end = tf.expand_dims(end, 1)
        
        tar = tf.concat([start, tar],1)
        tar = tf.concat([tar, end],1)"""

        tar_inp = tar[:, :-1]
        tar_real = tar[:, 1:]
        

        with tf.GradientTape() as tape:
            prediction = self((inp, tar_inp, length), training=True)
            
            loss = self.compiled_loss(tar_real, prediction, regularization_losses=self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.compiled_metrics.update_state(tar_real, prediction)

        return {m.name: m.result() for m in self.metrics}
        #train_accuracy(accuracy_function(tar_real, predictions))


    def test_step(self, data):
        inp, tar, length = data
        """start = tf.repeat([tf.shape(inp)[-1]], tf.shape(inp)[0])
        start = tf.cast(tf.one_hot(start,tf.shape(inp)[-1]+2),dtype=tf.float32)
        start = tf.expand_dims(start, 1)
        end = tf.repeat([tf.shape(inp)[-1]+1], tf.shape(inp)[0])
        end = tf.cast(tf.one_hot(end,tf.shape(inp)[-1]+2),dtype=tf.float32)
        end = tf.expand_dims(end, 1)
        
        tar = tf.concat([start, tar],1)
        tar = tf.concat([tar, end],1)"""

        tar_inp = tar[:, :-1]
        tar_real = tar[:, 1:]

        predictions = self((inp, tar_inp, length), training=False)

        # Updates stateful loss metrics.
        self.compiled_loss(tar_real, predictions, regularization_losses=self.losses)

        self.compiled_metrics.update_state(tar_real, predictions)
        # Collect metrics to return
        return {m.name: m.result() for m in self.metrics}

        return_metrics = {}
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
        return return_metrics
    
    def predict_step(self, data):
        inp, tar, length = data
        startMask = tf.cast(tf.fill([tf.shape(tar)[0], 1, tf.shape(tar)[-1]],-1),dtype=tf.float32)
        tar = tf.concat([startMask, tar],1)

        tar_inp = tar[:, :-1, :]

        return self((inp, tar_inp, length), training=False)

In [57]:
def build_real_T5(input_size, output_size, args):
    inputs = (tf.keras.layers.Input(shape=(None, 1)),
    tf.keras.layers.Input(shape=(None, 1)),
    tf.keras.layers.Input(shape=(1)) )
    # targets, length
    transformer = T5ModelNoMaskCreationModel(vocab_size = args.vocab_size, num_layers=args.num_layers, d_model=args.d_model, num_heads=args.num_heads, d_ff=args.d_ff, d_kv = args.d_kv, feed_forward_proj = args.feed_forward_proj, 
            relative_attention_num_buckets=args.relative_attention_num_buckets, eps=args.layer_norm_epsilon, dropout=args.dropout, factor=args.init_factor,
            embed_or_dense="embed", target_size= args.output_size)

    inp, tar, length = inputs
    enc_padding_mask = create_masks(inp, tar, length)
    #dec_padding_mask = tf.squeeze(dec_padding_mask)
    outputs = transformer(input_ids=inp, attention_mask=enc_padding_mask, 
            decoder_input_ids=tar, 
             training=False) # (batch_size, tar_seq_len, target_vocab_size)
    
    model = T5VQ_VAE(inputs=inputs, outputs=outputs)
    model.summary()
    learning_rate = CustomSchedule(args.d_model)
    #optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,epsilon=1e-8)
    optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999,epsilon=1e-8, weight_decay = 0.01)
    #model.add_metric(tf.keras.metrics.Mean(name='train_loss')(outputs))
    #model.compile(loss=mse_with_proper_loss(output_size), optimizer=optimizer)
    model.compile(loss=pit_with_stft_trace(output_size), optimizer=optimizer)
#     model.compile(loss=keras.losses.mean_squared_error, optimizer=adam)

    return model

In [23]:
ckpt_path = args.ckpt_path
mkdir_p(ckpt_path) # model check point 폴더 만드는 코드

filepath = ckpt_path + "/CKP_ep_{epoch:d}__loss_{val_loss:.5f}_.h5"

# validation loss에 대해서 좋은 것만 저장됨
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min', save_weights_only=True
)

# early stop 하는 부분인데, validation loss에 대해서 제일 좋은 모델이 저장됨
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', mode='min', verbose=1, patience=50, restore_best_weights=True
)

# Training part

epoch = args.n_epochs
strategy = tf.distribute.MirroredStrategy() # '/gpu:0','/gpu:1','/gpu:2','/gpu:4','/gpu:5','/gpu:6','/gpu:7'
#physical_devices = tf.config.list_physical_devices('GPU')
#tf.config.set_visible_devices(physical_devices[0:7], 'GPU')
#strategy =  tf.distribute.MultiWorkerMirroredStrategy()
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
    # 사용 안할 때는 load_model 주석 처리 하자
#     model = load_model('./CKPT/CKP_ep_29__loss_102.63367_.h5', custom_objects={'pit_loss': pit_with_outputsize(OUTPUT_SIZE)})

    model = build_real_T5(args.input_size, args.output_size, args)
    #if args.is_load_model is True:

    tf.executing_eagerly()

history = model.fit(
    train_dataset,
    epochs=epoch,
    validation_data=valid_dataset,
    shuffle=True,
    callbacks=[checkpoint_cb, early_stopping_cb],
)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "t5vq_vae"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None, 512)]  0                                            
__________________________________________________________________________________________________
tf.compat.v1.shape (TFOpLambda) (3,)                 0           input_1[0][0]                    
__________________________________________________________________________________________________
tf.compat.v1.shape_1 (TFOpLambd (3,)                 0           input_1[0][0]                    
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici ()                   0           tf.compat.v1.s

ValueError: in user code:

    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\keras\engine\training.py:853 train_function  *
        return step_function(self, iterator)
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\keras\engine\training.py:842 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1286 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2849 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\tensorflow\python\distribute\mirrored_strategy.py:671 _call_for_each_replica
        self._container_strategy(), fn, args, kwargs)
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\tensorflow\python\distribute\mirrored_run.py:104 call_for_each_replica
        return _call_for_each_replica(strategy, fn, args, kwargs)
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\tensorflow\python\distribute\mirrored_run.py:246 _call_for_each_replica
        coord.join(threads)
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\tensorflow\python\training\coordinator.py:389 join
        six.reraise(*self._exc_info_to_raise)
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\six.py:703 reraise
        raise value
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\tensorflow\python\training\coordinator.py:297 stop_on_exception
        yield
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\tensorflow\python\distribute\mirrored_run.py:346 run
        self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
    c:\users\qkrwo\anaconda3\envs\nlp_task\lib\site-packages\keras\engine\training.py:835 run_step  **
        outputs = model.train_step(data)
    <ipython-input-21-bb0f33be6f31>:9 train_step
        inp, tar, length = data

    ValueError: not enough values to unpack (expected 3, got 2)


In [58]:
latent_size = 512
epochs = 2

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with tf.device('/cpu:0'):
    model_path = './CKPT/vqvae_same/CKP_ep_291__loss_89.49190_.h5'
    
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    sisdr_Metric = SiSdr()
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)
    vq_vae(0, True)
    vq_vae.summary()
    vq_vae.load_weights(model_path)
    

    model = build_real_T5(args.input_size, args.output_size, args)
    
    # Iterate over epochs
    for epoch in range(epochs):
        print("Start of epoch %d" % (epoch,))

        # Iterate over the batches of the dataset
        for step, x_batch_train in enumerate(train_dataset):
            train_inputs = tf.cast(x_batch_train[0], dtype=tf.float32)
            train_labels = tf.cast(x_batch_train[0], dtype=tf.float32)

            
            
            with tf.GradientTape() as tape:
                vqvae_encode = vq_vae.encoder(train_inputs).numpy()
                #encode_onehot = tf.cast(tf.equal(vqvae_encode, tf.math.reduce_max(vqvae_encode, 2, keepdims=True)), vqvae_encode.dtype)
                encode_inp = tf.math.argmax(vqvae_encode, -1)
                
                #zeros = tf.zeros([tf.shape(encode_onehot)[0].numpy(), tf.shape(encode_onehot)[1].numpy(), 2], encode_onehot.dtype)
                #decode_onehot = tf.concat([encode_onehot, zeros],-1)
                start = tf.cast(tf.repeat(tf.constant([[args.vocab_size]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
                end = tf.cast(tf.repeat(tf.constant([[args.vocab_size+1]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
                decode_inp = tf.concat([start, encode_inp],1)
                decode_inp = tf.concat([decode_inp, end],1)
                
                reconstructed = model((encode_inp, decode_inp, x_batch_train[1]))
                reconstructed = reconstructed[:,:-1,:-1]
                reconstructed = tf.one_hot(reconstructed, 512)
                vqvae_decode = vq_vae.decoder(reconstructed).numpy()
                
                # Compute reconstruction loss
                loss = custom_mse(vqvae_encode, reconstructed)
                loss += sum(ae.losses)  # Add KL loss

            #grads = tape.gradient(loss, ae.trainable_weights)
            #optimizer.apply_gradients(zip(grads, ae.trainable_weights))

            loss_metric(loss)
#             sisdr_Metric.update_state(x_batch_train[0], x_batch_train[0])

            if step % 100 == 0:
                print("step %d: mean loss = %.4f, Si-sdr = %.4f" % (step, loss_metric.result(), sisdr_Metric()))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_33 (Softmax)         (None, None, 512)         0         
_________________________________________________________________
encoder (Encoder)            (None, None, 512)         517248    
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           516737    
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 512)         0         
Total params: 1,033,985
Trainable params: 1,033,985
Non-trainable params: 0
_________________________________________________________________
Model: "t5vq_vae_1"
__________________________________________________________________________________________________
Layer (type)       

InvalidArgumentError: Incompatible shapes: [7000,8,1,1] vs. [2,8,1,3500] [Op:AddV2]

In [45]:
tf.zeros(encode_onehot.shape, encode_onehot.dtype)

<tf.Tensor: shape=(2, 55992, 512), dtype=float32, numpy=
array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)>

In [36]:
vq_vae.encoder(train_inputs)
decode_inp

<tf.Tensor: shape=(2, 3502), dtype=int64, numpy=
array([[512, 324, 324, ..., 324, 176, 513],
       [512, 324, 324, ..., 324, 176, 513]], dtype=int64)>

In [37]:
encode_inp = tf.math.argmax(encode_onehot,-1)

In [55]:
tf.repeat(tf.constant([[args.vocab_size,args.vocab_size+1]]), tf.shape(encode_inp)[0], 0)

<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[512, 513],
       [512, 513]])>

In [54]:
tf.constant([[args.vocab_size,args.vocab_size+1]])

<tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[512, 513]])>

In [63]:
start = tf.cast(tf.repeat(tf.constant([[args.vocab_size]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
end = tf.cast(tf.repeat(tf.constant([[args.vocab_size+1]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
deocde_inp = tf.concat([start, encode_inp],1)
deocde_inp = tf.concat([deocde_inp, end],1)

In [74]:
create_masks(deocde_inp,_)

<tf.Tensor: shape=(2, 5002), dtype=float32, numpy=
array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)>

In [66]:
tf.one_hot(deocde_inp, 512)

<tf.Tensor: shape=(2, 5002, 512), dtype=float32, numpy=
array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)>

In [76]:
tf.math.argmax(vqvae_encode, -1)

<tf.Tensor: shape=(2, 5000), dtype=int64, numpy=
array([[324, 324, 324, ..., 324, 324, 176],
       [324, 324, 324, ..., 324, 324, 176]], dtype=int64)>

In [34]:

emb_layer = tf.keras.layers.Embedding(args.vocab_size, args.d_model)
dec_emb_layer = tf.keras.layers.Embedding(args.vocab_size+2, args.d_model)
enc_mask = create_masks(encode_inp,_)

In [35]:
emb_layer(encode_inp)
dec_emb_layer(decode_inp)


<tf.Tensor: shape=(2, 3502, 512), dtype=float32, numpy=
array([[[-0.04950542, -0.00063734, -0.00788502, ..., -0.03748046,
         -0.03154826, -0.03086646],
        [ 0.04070134, -0.02675874,  0.01688543, ..., -0.00723369,
          0.04475541, -0.01825378],
        [ 0.04070134, -0.02675874,  0.01688543, ..., -0.00723369,
          0.04475541, -0.01825378],
        ...,
        [ 0.04070134, -0.02675874,  0.01688543, ..., -0.00723369,
          0.04475541, -0.01825378],
        [ 0.00895997, -0.03796548,  0.03192111, ...,  0.03491438,
         -0.04272968, -0.02189927],
        [ 0.03629345,  0.00849929, -0.03875323, ..., -0.02384899,
         -0.02035632, -0.01134278]],

       [[-0.04950542, -0.00063734, -0.00788502, ..., -0.03748046,
         -0.03154826, -0.03086646],
        [ 0.04070134, -0.02675874,  0.01688543, ..., -0.00723369,
          0.04475541, -0.01825378],
        [ 0.04070134, -0.02675874,  0.01688543, ..., -0.00723369,
          0.04475541, -0.01825378],
        ...

In [26]:
from src.models.Real_Layers import T5Stack


In [42]:
encoder = T5Stack(args.num_layers, args.d_model, args.d_ff, args.d_kv, args.feed_forward_proj, args.num_heads, is_decoder = False, relative_attention_num_buckets=args.relative_attention_num_buckets, eps = args.layer_norm_epsilon, dropout=args.dropout, embed_tokens = emb_layer, factor=args.init_factor)
decoder = T5Stack(args.num_layers, args.d_model, args.d_ff, args.d_kv, args.feed_forward_proj, args.num_heads, is_decoder = True, relative_attention_num_buckets=args.relative_attention_num_buckets, eps = args.layer_norm_epsilon, dropout=args.dropout, embed_tokens = dec_emb_layer, factor=args.init_factor)


In [43]:
output = encoder(
    input_ids=encode_inp,
    attention_mask=enc_mask,
)

dec_output = decoder(
    input_ids=decode_inp,
    encoder_hidden_states=output[0],
    encoder_attention_mask=enc_mask
)

In [46]:
dec_output


(<tf.Tensor: shape=(2, 3502, 512), dtype=float32, numpy=
 array([[[ 1.2077979 , -0.30607352,  0.5464545 , ..., -0.31013697,
          -0.9685639 ,  2.2401168 ],
         [ 1.2952253 , -0.8128181 , -1.0662067 , ..., -1.1302086 ,
          -0.7823541 ,  0.8705663 ],
         [ 1.0742295 , -0.8347188 , -1.2404468 , ..., -1.1410517 ,
          -0.75787365,  0.6026806 ],
         ...,
         [ 0.56603163, -0.46112776, -1.6801546 , ..., -0.97513324,
          -0.43376756, -0.35038418],
         [ 0.18591672, -0.58123124, -1.4556829 , ..., -0.9629582 ,
          -0.5377447 , -0.27069557],
         [ 0.22506867, -0.30955017, -1.7474878 , ..., -0.96178037,
          -0.28729254, -0.05892743]],
 
        [[ 1.0260235 , -0.2839064 ,  0.5034074 , ..., -0.3286717 ,
          -1.0567868 ,  1.9603367 ],
         [ 0.9995614 , -0.7665592 , -1.071817  , ..., -1.1798645 ,
          -0.96889806,  0.87005365],
         [ 0.8442797 , -0.7402053 , -1.2602735 , ..., -1.1824063 ,
          -0.9689222 ,  0.6

In [49]:
test_model = T5ModelNoMaskCreationModel(vocab_size = args.vocab_size, num_layers=args.num_layers, d_model=args.d_model, num_heads=args.num_heads, d_ff=args.d_ff, d_kv = args.d_kv, feed_forward_proj = args.feed_forward_proj, 
            relative_attention_num_buckets=args.relative_attention_num_buckets, eps=args.layer_norm_epsilon, dropout=args.dropout, factor=args.init_factor,
            embed_or_dense="embed", target_size= args.output_size)

In [68]:
output_fin = test_model(
        input_ids=encode_inp, #
        attention_mask=enc_mask, #
        decoder_input_ids=decode_inp
)

In [69]:
tf.math.argmax(output_fin, -1)

<tf.Tensor: shape=(2, 3501), dtype=int64, numpy=
array([[ 98, 359,  61, ...,  61,  61, 275],
       [420, 222,  61, ...,  61,  61, 275]], dtype=int64)>

In [70]:
output_fin

<tf.Tensor: shape=(2, 3501, 514), dtype=float32, numpy=
array([[[ 0.869631  ,  1.1690212 ,  0.35820168, ...,  0.6019451 ,
          1.0112873 ,  2.1435962 ],
        [ 0.80149245,  2.3579922 , -0.4399348 , ...,  0.60069275,
         -0.3548818 ,  1.9100547 ],
        [ 0.68666923,  2.3211918 , -0.5342808 , ...,  0.41276547,
         -0.68009174,  1.6850523 ],
        ...,
        [ 0.911541  ,  2.128042  , -0.632104  , ..., -0.6814339 ,
         -0.76991796,  1.3301015 ],
        [ 0.9115375 ,  2.1280565 , -0.63212836, ..., -0.6814034 ,
         -0.76994693,  1.3300073 ],
        [ 1.1326382 ,  0.8186412 ,  0.89920235, ..., -0.2691553 ,
          0.5539378 ,  1.2252066 ]],

       [[ 0.7201735 ,  1.2553155 ,  0.17787439, ...,  0.42479405,
          1.0152178 ,  1.9904431 ],
        [ 0.8037424 ,  2.2758029 , -0.5985501 , ...,  0.302431  ,
         -0.13018467,  1.7118871 ],
        [ 0.7113944 ,  2.2694674 , -0.67817116, ...,  0.13510275,
         -0.41518915,  1.5366485 ],
        ...

In [86]:
latent_size = 512
epochs = 2

strategy = tf.distribute.MirroredStrategy(['cpu:0'])
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))

with tf.device('/cpu:0'):
    model_path = './CKPT/vqvae_same/CKP_ep_291__loss_89.49190_.h5'
    
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    sisdr_Metric = SiSdr()
    
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)
    vq_vae(0, True)
    vq_vae.summary()
    vq_vae.load_weights(model_path)
    

    transformer = T5ModelNoMaskCreationModel(vocab_size = args.vocab_size, num_layers=args.num_layers, d_model=args.d_model, num_heads=args.num_heads, d_ff=args.d_ff, d_kv = args.d_kv, feed_forward_proj = args.feed_forward_proj, 
            relative_attention_num_buckets=args.relative_attention_num_buckets, eps=args.layer_norm_epsilon, dropout=args.dropout, factor=args.init_factor,
            embed_or_dense="embed", target_size= args.output_size)
    ce_loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    inputs = tf.keras.layers.Input(shape=(None, 1))
    dec_inputs = tf.keras.layers.Input(shape=(None, 1))
    transformer(input_ids=inputs, decoder_input_ids = dec_inputs)
    transformer.summary()
    
    
    # Iterate over epochs
    for epoch in range(epochs):
        print("Start of epoch %d" % (epoch,))

        # Iterate over the batches of the dataset
        for step, x_batch_train in enumerate(train_dataset):
            train_inputs = tf.cast(x_batch_train[0], dtype=tf.float32)
            train_labels = tf.cast(x_batch_train[0], dtype=tf.float32)

            
            
            with tf.GradientTape() as tape:
                vqvae_encode = vq_vae.encoder(train_inputs).numpy()
                #encode_onehot = tf.cast(tf.equal(vqvae_encode, tf.math.reduce_max(vqvae_encode, 2, keepdims=True)), vqvae_encode.dtype)
                encode_inp = tf.math.argmax(vqvae_encode, -1)
                
                #zeros = tf.zeros([tf.shape(encode_onehot)[0].numpy(), tf.shape(encode_onehot)[1].numpy(), 2], encode_onehot.dtype)
                #decode_onehot = tf.concat([encode_onehot, zeros],-1)
                start = tf.cast(tf.repeat(tf.constant([[args.vocab_size]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
                end = tf.cast(tf.repeat(tf.constant([[args.vocab_size+1]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
                decode_inp = tf.concat([start, encode_inp],1)
                decode_label = tf.concat([encode_inp, end],1)
                
                enc_mask = create_masks(encode_inp, _)
                
                prediction = transformer(
                                input_ids=encode_inp, #
                                attention_mask=enc_mask, #
                                decoder_input_ids=decode_inp
                            )
                reconstructed = prediction[:,:-1,:-2]
                reconstructed = tf.math.argmax(reconstructed,-1)
                reconstructed = tf.one_hot(reconstructed, args.vocab_size)
                vqvae_decode = vq_vae.decoder(reconstructed).numpy()
                
                # Compute reconstruction loss
                loss = ce_loss_object(decode_label, prediction)
                loss += sum(transformer.losses)  # Add KL loss

            grads = tape.gradient(loss, transformer.trainable_weights)
            optimizer.apply_gradients(zip(grads, transformer.trainable_weights))

            #loss_metric(loss)
#             sisdr_Metric.update_state(x_batch_train[0], x_batch_train[0])

            if step % 100 == 1:
                print("step %d: mean loss = %.4f, Si-sdr = %.4f" % (step, loss_metric.result(), sisdr_Metric()))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
장치의 수: 1
Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_96 (Softmax)         (None, None, 512)         0         
_________________________________________________________________
encoder (Encoder)            (None, None, 512)         517248    
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           516737    
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 512)         0         
Total params: 1,033,985
Trainable params: 1,033,985
Non-trainable params: 0
_________________________________________________________________
Model: "t5_model_no_mask_creation_model_12"
_________________________________________________________________
Layer (type)                

KeyboardInterrupt: 

In [None]:
def build_real_T5(input_size, output_size, args):
    inputs = (tf.keras.layers.Input(shape=(None, 1)),
    tf.keras.layers.Input(shape=(None, 1)),
    tf.keras.layers.Input(shape=(1)) )
    # targets, length
    transformer = T5ModelNoMaskCreationModel(vocab_size = args.vocab_size, num_layers=args.num_layers, d_model=args.d_model, num_heads=args.num_heads, d_ff=args.d_ff, d_kv = args.d_kv, feed_forward_proj = args.feed_forward_proj, 
            relative_attention_num_buckets=args.relative_attention_num_buckets, eps=args.layer_norm_epsilon, dropout=args.dropout, factor=args.init_factor,
            embed_or_dense="embed", target_size= args.output_size)

    inp, tar, length = inputs
    enc_padding_mask = create_masks(inp, tar, length)
    #dec_padding_mask = tf.squeeze(dec_padding_mask)
    outputs = transformer(input_ids=inp, attention_mask=enc_padding_mask, 
            decoder_input_ids=tar, 
             training=False) # (batch_size, tar_seq_len, target_vocab_size)
    
    model = T5VQ_VAE(inputs=inputs, outputs=outputs)
    model.summary()
    learning_rate = CustomSchedule(args.d_model)
    #optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,epsilon=1e-8)
    optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999,epsilon=1e-8, weight_decay = 0.01)
    #model.add_metric(tf.keras.metrics.Mean(name='train_loss')(outputs))
    #model.compile(loss=mse_with_proper_loss(output_size), optimizer=optimizer)
    model.compile(loss=pit_with_stft_trace(output_size), optimizer=optimizer)
#     model.compile(loss=keras.losses.mean_squared_error, optimizer=adam)

    return model

In [83]:
# Custom mse
ce_loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
ce_loss_object(decode_label, prediction)

ValueError: Shape mismatch: The shape of labels (received (2, 3502)) should equal the shape of logits except for the last dimension (received (2, 3501, 514)).

In [131]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        vqvae_encode = vq_vae.encoder(x)
        #encode_onehot = tf.cast(tf.equal(vqvae_encode, tf.math.reduce_max(vqvae_encode, 2, keepdims=True)), vqvae_encode.dtype)
        encode_inp = tf.math.argmax(vqvae_encode, -1)

        #zeros = tf.zeros([tf.shape(encode_onehot)[0].numpy(), tf.shape(encode_onehot)[1].numpy(), 2], encode_onehot.dtype)
        #decode_onehot = tf.concat([encode_onehot, zeros],-1)
        start = tf.cast(tf.repeat(tf.constant([[args.vocab_size]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
        end = tf.cast(tf.repeat(tf.constant([[args.vocab_size+1]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
        decode_inp = tf.concat([start, encode_inp],1)
        decode_label = tf.concat([encode_inp, end],1)

        enc_mask = create_masks(encode_inp, _)

        prediction = transformer(
                        input_ids=encode_inp, #
                        attention_mask=enc_mask, #
                        decoder_input_ids=decode_inp
                    )
        reconstructed = prediction[:,:-1,:-2]
        reconstructed = tf.math.argmax(reconstructed,-1)
        reconstructed = tf.one_hot(reconstructed, args.vocab_size)
        vqvae_decode = vq_vae.decoder(reconstructed)

        # Compute reconstruction loss
        loss = ce_loss_object(decode_label, prediction)
        loss += sum(transformer.losses)  # Add KL loss
    
    # Update weights
    grads = tape.gradient(loss, transformer.trainable_weights)
    optimizer.apply_gradients(zip(grads, transformer.trainable_weights))
    
    # Update loss and si-sdr
    train_loss.update_state(loss)
    sisdr_Metric.update_state(y, vqvae_decode)
    
    return loss

@tf.function
def test_step(x, y):
    # Call model
    vqvae_encode = vq_vae.encoder(x)
    # 원핫따위 쓰지 않고 바로 (batch, sequence) 로 맞춰줌
    encode_inp = tf.math.argmax(vqvae_encode, -1)

    # decoder input과 label을 만들어주자.
    start = tf.cast(tf.repeat(tf.constant([[args.vocab_size]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
    end = tf.cast(tf.repeat(tf.constant([[args.vocab_size+1]]), tf.shape(encode_inp)[0], 0), encode_inp.dtype)
    decode_inp = tf.concat([start, encode_inp],1)
    decode_label = tf.concat([encode_inp, end],1)

    # attention mask 만들어주자.
    enc_mask = create_masks(encode_inp, _)

    prediction = transformer(
                    input_ids=encode_inp, 
                    attention_mask=enc_mask, 
                    decoder_input_ids=decode_inp
                )
    # decoder = (batch, seq_len + 1, 514) 를 (batch, seq_len, 512)로 되돌림
    reconstructed = prediction[:,:-1,:-2]
    reconstructed = tf.math.argmax(reconstructed,-1)
    reconstructed = tf.one_hot(reconstructed, args.vocab_size)
    # decoder로 복원하자.
    vqvae_decode = vq_vae.decoder(reconstructed)

    # Calculate losses
    val_loss_value = ce_loss_object(decode_label, prediction)
    val_loss_value += sum(transformer.losses) # Add KL loss
    
    # Update loss and si-sdr
    valid_loss.update_state(val_loss_value)
    val_sisdr_Metric.update_state(y, vqvae_decode)
    
    return val_loss_value

In [None]:
from tqdm.auto import tqdm
latent_size = 512
epochs = 2

with tf.device('/cpu:0'):
    model_path = './CKPT/vqvae_same/CKP_ep_291__loss_89.49190_.h5'
    
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    ce_loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    train_loss = tf.keras.metrics.Mean()
    valid_loss = tf.keras.metrics.Mean()
    sisdr_Metric = SiSdr()
    val_sisdr_Metric = SiSdr()
    
    
    # load vq_vae
    vq_vae = Vq_vae(latent_size, gumbel_hard=False)
    vq_vae(0, True)
    vq_vae.summary()
    vq_vae.load_weights(model_path)
    
    # load_transformer
    transformer = T5ModelNoMaskCreationModel(vocab_size = args.vocab_size, num_layers=args.num_layers, d_model=args.d_model, num_heads=args.num_heads, d_ff=args.d_ff, d_kv = args.d_kv, feed_forward_proj = args.feed_forward_proj, 
            relative_attention_num_buckets=args.relative_attention_num_buckets, eps=args.layer_norm_epsilon, dropout=args.dropout, factor=args.init_factor,
            embed_or_dense="embed", target_size= args.output_size)
    inputs = tf.keras.layers.Input(shape=(None, 1))
    dec_inputs = tf.keras.layers.Input(shape=(None, 1))
    transformer(input_ids=inputs, decoder_input_ids = dec_inputs)
    transformer.summary()
    
    
    for epoch in range(epochs):
        print("\nStart of epoch %d" % (epoch+1,))
        start_time = time.time()

        # Iterate over the batches of the dataset
        progress_bar = tqdm(range(len(train_dataset)))
        for step, (x_batch_train, length) in enumerate(train_dataset):
            x_batch_train = tf.cast(x_batch_train, dtype=tf.float32)

            loss_value = train_step(x_batch_train, x_batch_train)

            # Log every 1 batches
            progress_bar.update(1)
            progress_bar.set_description("step : %d loss : %.4f Si-sdr : %.4f" % (step, train_loss.result(), sisdr_Metric.result()))

        # Run a validation loop at the end of each epoch
        valid_progress_bar = tqdm(range(len(train_dataset)))
        for x_batch_val, length in valid_dataset:
            x_batch_val = tf.cast(x_batch_val, dtype=tf.float32)

            val_loss_value = test_step(x_batch_val, x_batch_val)
            valid_progress_bar.update(1)
            valid_progress_bar.set_description("valid loss : %.4f valid Si-sdr : %.4f" % (valid_loss.result(), val_sisdr_Metric.result()))


        print()
        print('----------------------------------------------------------------------------------')
        print("Time taken >>> %.2fs <<<" % (time.time() - start_time))
        print('epoch: {}, Train_loss: {}, Train_Si-sdr: {} \n\
        Valid_loss: {}, Valid_Si-sdr: {}'.format(
            epoch + 1,
            train_loss.result(),
            sisdr_Metric.result(),
            valid_loss.result(),
            val_sisdr_Metric.result()))
        print('----------------------------------------------------------------------------------')
        print()

        # Reset metrics at the end of each epoch
        train_loss.reset_states()
        sisdr_Metric.reset_states()
        valid_loss.reset_states()
        val_sisdr_Metric.reset_states()

        # Data shuffle at the end of each epoch
        train_dataset.on_epoch_end()
        valid_dataset.on_epoch_end()

Model: "vqvae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
softmax_208 (Softmax)        (None, None, 512)         0         
_________________________________________________________________
encoder (Encoder)            (None, None, 512)         517248    
_________________________________________________________________
decoder (Decoder)            (None, None, 1)           516737    
_________________________________________________________________
gumbel_softmax (GumbelSoftma (None, None, 512)         0         
Total params: 1,033,985
Trainable params: 1,033,985
Non-trainable params: 0
_________________________________________________________________
Model: "t5_model_no_mask_creation_model_31"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
t5_model_31 (T5Model)        multiple                  6824448   
_______

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))





In [98]:
for step, (x_batch_train, length) in enumerate(train_dataset):
    x_batch_train = tf.cast(x_batch_train, dtype=tf.float32)
    vq_vae.encoder(x_batch_train)

In [100]:
vq_vae.encoder(x_batch_train).numpy()

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)