In [3]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from tensorflow.keras.layers import Layer, Input

def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)

def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)

def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead)
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights

def print_out(q, k, v):
  temp_out, temp_attn = scaled_dot_product_attention(
      q, k, v, None)
  print('Attention weights are:')
  print(temp_attn)
  print('Output is:')
  print(temp_out)
  
class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention,
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_weights

def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])
  
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model, num_heads)
    self.mha2 = MultiHeadAttention(d_model, num_heads)

    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)

  def call(self, x, enc_output, training,
           look_ahead_mask, padding_mask):
    # enc_output.shape == (batch_size, input_seq_len, d_model)

    attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x)

    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)

    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)

    return out3, attn_weights_block1, attn_weights_block2

class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding,
                                            self.d_model)

    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)
  
class PartialDecoder(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1):
    super().__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate)
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)
    
    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, x, enc_output, training,
           look_ahead_mask, padding_mask):

    seq_len = tf.shape(x)[1]
    attention_weights = {}

    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)

      attention_weights[f'decoder_layer{i+1}_block1'] = block1
      attention_weights[f'decoder_layer{i+1}_block2'] = block2

    x = self.final_layer(x)  # (batch_size, tar_seq_len, target_vocab_size)
    # x.shape == (batch_size, target_seq_len, d_model)
    return x, attention_weights

class Decoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate)
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, enc_output, training,
           look_ahead_mask, padding_mask):

    seq_len = tf.shape(x)[1]
    attention_weights = {}

    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)

      attention_weights[f'decoder_layer{i+1}_block1'] = block1
      attention_weights[f'decoder_layer{i+1}_block2'] = block2

    # x.shape == (batch_size, target_seq_len, d_model)
    return x, attention_weights
  
class AddLayer(Layer):
    def call(self, inputs):
        return tf.math.add(inputs[0], inputs[1])

class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               target_vocab_size, pe_input, pe_target, rate=0.1):
    super().__init__()
    self.encoder = Encoder(num_layers, d_model, num_heads, dff,
                             input_vocab_size, pe_input, rate)

    self.decoder = Decoder(num_layers, d_model, num_heads, dff,
                           target_vocab_size, pe_target, rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)
    
    self.addition_layer = AddLayer()

  def call(self, inputs, training):
    # Keras models prefer if you pass all your inputs in the first argument
    inp, tar = inputs
        
    inp_idx = tf.argmax(inp, axis=2)
    # print("")
    # print("INININININININIIN")
    # print(inp.shape)
    # print(inp)

    enc_padding_mask, look_ahead_mask, dec_padding_mask = self.create_masks(inp_idx, tar)

    enc_output = self.encoder(inp_idx, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)
    
    # print(enc_output.shape)
    # print(inp.shape)
    # print(")))))))))))))))))))))))")
    # print(enc_output)
    # print(inp)
    
    mixed_embedding = self.addition_layer([enc_output, inp])

    # dec_output.shape == (batch_size, tar_seq_len, d_model)
    dec_output, attention_weights = self.decoder(
        tar, inp, training, look_ahead_mask, dec_padding_mask)

    final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)

    return final_output, attention_weights

  def create_masks(self, inp, tar):
    # Encoder padding mask
    enc_padding_mask = create_padding_mask(inp)

    # Used in the 2nd attention block in the decoder.
    # This padding mask is used to mask the encoder outputs.
    dec_padding_mask = create_padding_mask(inp)

    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by
    # the decoder.
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    look_ahead_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return enc_padding_mask, look_ahead_mask, dec_padding_mask

In [4]:
def getDataset(datasetName):
    if datasetName == 'speech':
        return SpeechDataset
    else:
        raise ValueError('Dataset not found')
    
import pathlib
import random
import numpy as np
import tensorflow as tf

PHONE_DEF = [
    'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH'
]

PHONE_DEF_SIL = [
    'AA', 'AE', 'AH', 'AO', 'AW',
    'AY', 'B',  'CH', 'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'JH', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z', 'ZH', 'SIL'
]

CHANG_PHONE_DEF = [
    'AA', 'AE', 'AH', 'AW',
    'AY', 'B',  'D', 'DH',
    'EH', 'ER', 'EY', 'F', 'G',
    'HH', 'IH', 'IY', 'K',
    'L', 'M', 'N', 'NG', 'OW',
    'P', 'R', 'S',
    'T', 'TH', 'UH', 'UW', 'V',
    'W', 'Y', 'Z'
]

CONSONANT_DEF = ['CH', 'SH', 'JH', 'R', 'B',
                 'M',  'W',  'V',  'F', 'P',
                 'D',  'N',  'L',  'S', 'T',
                 'Z',  'TH', 'G',  'Y', 'HH',
                 'K', 'NG', 'ZH', 'DH']
VOWEL_DEF = ['EY', 'AE', 'AY', 'EH', 'AA',
             'AW', 'IY', 'IH', 'OY', 'OW',
             'AO', 'UH', 'AH', 'UW', 'ER']

SIL_DEF = ['SIL']

class SpeechDataset():
    def __init__(self,
                 rawFileDir,
                 nInputFeatures,
                 nClasses,
                 maxSeqElements,
                 bufferSize,
                 syntheticFileDir=None,
                 syntheticMixingRate=0.33,
                 subsetSize=-1,
                 labelDir=None,
                 timeWarpSmoothSD=0.0,
                 timeWarpNoiseSD=0.0,
                 chanIndices=None
                 ):

        self.rawFileDir = rawFileDir
        self.nInputFeatures = nInputFeatures
        self.nClasses = nClasses
        self.maxSeqElements = maxSeqElements
        self.bufferSize = bufferSize
        self.syntheticFileDir = syntheticFileDir
        self.syntheticMixingRate = syntheticMixingRate
        self.timeWarpSmoothSD = timeWarpSmoothSD
        self.timeWarpNoiseSD = timeWarpNoiseSD
        self.subsetSize = subsetSize
        self.chanIndices = chanIndices
        
    def build(self, batchSize, isTraining):
        def _loadDataset(fileDir):
            files = sorted([str(x) for x in pathlib.Path(fileDir).glob("*.tfrecord")])
            if isTraining:
                random.shuffle(files)

            dataset = tf.data.TFRecordDataset(files)
            return dataset

        print(f'Load data from {self.rawFileDir}')
        rawDataset = _loadDataset(self.rawFileDir)
        if self.syntheticFileDir and self.syntheticMixingRate > 0:
            print(f'Load data from {self.syntheticFileDir}')
            syntheticDataset = _loadDataset(self.syntheticFileDir)
            dataset = tf.data.experimental.sample_from_datasets(
                [rawDataset.repeat(), syntheticDataset.repeat()],
                weights=[1.0 - self.syntheticMixingRate, self.syntheticMixingRate])
        else:
            dataset = rawDataset

        datasetFeatures = {
            "inputFeatures": tf.io.FixedLenSequenceFeature([self.nInputFeatures], tf.float32, allow_missing=True),
            #"classLabelsOneHot": tf.io.FixedLenSequenceFeature([self.nClasses+1], tf.float32, allow_missing=True),
            "newClassSignal": tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
            "ceMask": tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
            "seqClassIDs": tf.io.FixedLenFeature((self.maxSeqElements), tf.int64),
            "nTimeSteps": tf.io.FixedLenFeature((), tf.int64),
            "nSeqElements": tf.io.FixedLenFeature((), tf.int64),
            "transcription": tf.io.FixedLenFeature((self.maxSeqElements), tf.int64)
        }

        if self.timeWarpNoiseSD>0 and self.timeWarpSmoothSD>0:
            from scipy.ndimage.filters import gaussian_filter1d
            inp = np.zeros([200])
            inp[int(len(inp)/2)] = 1
            gaussKernel = gaussian_filter1d(inp, self.timeWarpSmoothSD)

            validIdx = np.argwhere(gaussKernel>0.001)
            gaussKernel = gaussKernel[validIdx]
            gaussKernel = np.squeeze(gaussKernel/np.sum(gaussKernel))

            timeWarpNoiseSD= self.timeWarpNoiseSD

            def parseDatasetFunctionWarp(exampleProto):
                dat = tf.io.parse_single_example(exampleProto, datasetFeatures)

                warpDat = {}
                warpDat['seqClassIDs'] = dat['seqClassIDs']
                warpDat['nSeqElements'] = dat['nSeqElements']
                warpDat['transcription'] = dat['transcription']

                whiteNoise = tf.random.normal([dat['nTimeSteps']*2], mean=0, stddev=timeWarpNoiseSD)
                rateNoise = tf.nn.conv1d(whiteNoise[tf.newaxis,:,tf.newaxis],
                                         gaussKernel[:,np.newaxis,np.newaxis].astype(np.float32), 1, 'SAME')

                rateNoise = rateNoise[0,:,0]
                toSum = tf.ones([dat['nTimeSteps']*2], dtype=tf.float32) + rateNoise
                toSum = tf.nn.relu(toSum)

                warpFun = tf.cumsum(toSum)
                resampleIdx = tf.cast(warpFun, dtype=tf.int32)
                resampleIdx = resampleIdx[resampleIdx<tf.cast(dat['nTimeSteps'],dtype=tf.int32)]

                warpDat['nTimeSteps'] = tf.cast(tf.reduce_sum(tf.cast(resampleIdx>-1,dtype=tf.int32)), dtype=tf.int32)
                warpDat['inputFeatures'] = tf.gather(dat['inputFeatures'], resampleIdx, axis=0)
                if self.chanIndices is not None:
                    selectChans = tf.gather(warpDat['inputFeatures'], tf.constant(self.chanIndices),axis=-1)
                    paddings = [[0, 0], [0, 256-tf.shape(selectChans)[-1]]]
                    warpDat['inputFeatures'] = tf.pad(selectChans, paddings, 'CONSTANT',constant_values=0)
                warpDat['newClassSignal'] = tf.gather(dat['newClassSignal'], resampleIdx, axis=0)
                warpDat['ceMask'] = tf.gather(dat['ceMask'], resampleIdx, axis=0)

                return warpDat

            dataset = dataset.map(parseDatasetFunctionWarp, num_parallel_calls=tf.data.AUTOTUNE)

        else:
            def parseDatasetFunctionSimple(exampleProto):
                dat = tf.io.parse_single_example(exampleProto, datasetFeatures)
                if self.chanIndices is not None:
                    newDat = {}
                    newDat['seqClassIDs'] = dat['seqClassIDs']
                    newDat['nSeqElements'] = dat['nSeqElements']
                    newDat['transcription'] = dat['transcription']
                    newDat['nTimeSteps'] = dat['nTimeSteps']
                    newDat['newClassSignal'] = dat['newClassSignal']
                    newDat['ceMask'] = dat['ceMask']
                    print(dat['inputFeatures'])
                    selectChans = tf.gather(dat['inputFeatures'], tf.constant(self.chanIndices),axis=-1)
                    paddings = [[0, 0], [0, 256-tf.shape(selectChans)[-1]]]
                    newDat['inputFeatures'] = tf.pad(selectChans, paddings, 'CONSTANT',constant_values=0)
                    print(tf.shape(newDat['inputFeatures']))

                    return newDat
                else:
                    return dat
            dataset = dataset.map(parseDatasetFunctionSimple, num_parallel_calls=tf.data.AUTOTUNE)

        if isTraining:
            # Use all elements to adapt normalization layer
            datasetForAdapt = dataset.map(lambda x: x['inputFeatures'] + 0.001,
                num_parallel_calls=tf.data.AUTOTUNE)
            
            # Take a subset of the data if specified
            if self.subsetSize > 0:
                dataset = dataset.take(self.subsetSize)

            # Shuffle and transform data if training
            dataset = dataset.shuffle(self.bufferSize)
            if self.syntheticMixingRate == 0:
                dataset = dataset.repeat()
            dataset = dataset.padded_batch(batchSize)
            dataset = dataset.prefetch(tf.data.AUTOTUNE)
            
            

            return dataset, datasetForAdapt
        else:
            dataset = dataset.padded_batch(batchSize)
            dataset = dataset.prefetch(tf.data.AUTOTUNE)

            return dataset


In [31]:
import os
import copy
import random
from datetime import datetime
from pathlib import Path
import numpy as np
import scipy.io
import scipy.special
import tensorflow as tf
import pickle
from jiwer import wer
from tqdm import tqdm

# import tensorflow_probability as tfp
from omegaconf import OmegaConf
from omegaconf.listconfig import ListConfig
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.losses import SparseCategoricalCrossentropy
from scipy.ndimage.filters import gaussian_filter1d

record_train_loss = []
record_gradNorm = []
record_cer = []


@tf.function(experimental_relax_shapes=True)
def gaussSmooth(inputs, kernelSD=2, padding="SAME"):
    """
    Applies a 1D gaussian smoothing operation with tensorflow to smooth the data along the time axis.

    Args:
        inputs (tensor : B x T x N): A 3d tensor with batch size B, time steps T, and number of features N
        kernelSD (float): standard deviation of the Gaussian smoothing kernel

    Returns:
        smoothedData (tensor : B x T x N): A smoothed 3d tensor with batch size B, time steps T, and number of features N
    """

    # get gaussian smoothing kernel
    inp = np.zeros([100], dtype=np.float32)
    inp[50] = 1
    gaussKernel = gaussian_filter1d(inp, kernelSD)
    validIdx = np.argwhere(gaussKernel > 0.01)
    gaussKernel = gaussKernel[validIdx]
    gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))

    # Apply depth_wise convolution
    B, T, C = inputs.shape.as_list()
    filters = tf.tile(gaussKernel[None, :, None, None], [1, 1, C, 1])  # [1, W, C, 1]
    inputs = inputs[:, None, :, :]  # [B, 1, T, C]
    smoothedInputs = tf.nn.depthwise_conv2d(
        inputs, filters, strides=[1, 1, 1, 1], padding=padding
    )
    smoothedInputs = tf.squeeze(smoothedInputs, 1)

    return smoothedInputs


class NeuralSequenceDecoder(object):
    """
    This class encapsulates all the functionality needed for training, loading and running the neural sequence decoder RNN.
    To use it, initialize this class and then call .train() or .inference(). It can also be run from the command line (see bottom
    of the script). The args dictionary passed during initialization is used to configure all aspects of its behavior.
    """

    def __init__(self, args):
        self.args = args

        if not os.path.isdir(self.args["outputDir"]):
            os.mkdir(self.args["outputDir"])

        # record these parameters
        if self.args["mode"] == "train":
            with open(os.path.join(args["outputDir"], "args.yaml"), "w") as f:
                OmegaConf.save(config=self.args, f=f)

        # random variable seeding
        if self.args["seed"] == -1:
            self.args["seed"] = datetime.now().microsecond
        np.random.seed(self.args["seed"])
        tf.random.set_seed(self.args["seed"])
        random.seed(self.args["seed"])
        
        # Hyperparameters
        d_model = 256

        self.model = PartialDecoder(
                                                num_layers=4, d_model=512, num_heads=4, dff=512,
                                                target_vocab_size=45,
                                                maximum_position_encoding=5000
                                            )

        # Compile the model with an appropriate optimizer and loss function
        self.model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        self.learning_rate = CustomSchedule(d_model)
        self.optimizer = tf.keras.optimizers.Adam(self.learning_rate, beta_1=0.9, beta_2=0.98,
                                            epsilon=1e-9)
        
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
        
        self.count = 0
        self.max_wer = 10000

        self._prepareForTraining()
        
    def _save_model(self, path):
        # Save the entire model to a file
        path = path + "_model_weights.h5"
        # print(path)
        # print(self.args["outputDir"])
        # weights_path = os.path.join(self.args["outputDir"], path)
        weights_path = self.args["outputDir"] + path
        # print(weights_path)
        self.model.save_weights(weights_path)
        print(f"Model weights saved to {weights_path}")
        
    def _load_model(self, weights_path):
        # Load the model weights
        dummy_input = tf.constant(tf.random.normal([64, 50, 256]))
        
        outputLabels = tf.fill((64, 1), 44)
        outputLabels = tf.cast(outputLabels, dtype=tf.int64)
        
        enc_padding_mask, look_ahead_mask, dec_mask = create_masks(tf.argmax(dummy_input, axis=2), outputLabels)

        predictions, _ = self.model(outputLabels, dummy_input, False, look_ahead_mask, dec_mask)
        self.model.load_weights(weights_path)
        print(f"Model weights loaded from {weights_path}")

    def _buildInputNetworks(self, isTraining):
        # Build day transformation and normalization layers (FCNs)
        self.nInputLayers = np.max(self.args["dataset"]["datasetToLayerMap"]) + 1
        self.inputLayers = []
        self.normLayers = []
        for layerIdx in range(self.nInputLayers):
            datasetIdx = np.argwhere(
                np.array(self.args["dataset"]["datasetToLayerMap"]) == layerIdx
            )
            datasetIdx = datasetIdx[0, 0]
            nInputFeatures = self.args["dataset"]["nInputFeatures"]

            normLayer = tf.keras.layers.experimental.preprocessing.Normalization(
                input_shape=[nInputFeatures]
            )

            if isTraining and self.args["normLayer"]:
                normLayer.adapt(self.tfAdaptDatasets[datasetIdx].take(-1))

            inputModel = tf.keras.Sequential()
            inputModel.add(tf.keras.Input(shape=(None, nInputFeatures)))

            for i in range(self.args["model"]["inputNetwork"]["nInputLayers"]):
                if i == 0:
                    if (
                        self.args["model"]["inputNetwork"]["inputLayerSizes"][0]
                        == nInputFeatures
                    ):
                        kernelInit = tf.keras.initializers.identity()
                    else:
                        kernelInit = "glorot_uniform"
                else:
                    if (
                        self.args["model"]["inputNetwork"]["inputLayerSizes"][i]
                        == self.args["model"]["inputNetwork"]["inputLayerSizes"][i - 1]
                    ):
                        kernelInit = tf.keras.initializers.identity()
                    else:
                        kernelInit = "glorot_uniform"

                inputModel.add(
                    tf.keras.layers.Dense(
                        self.args["model"]["inputNetwork"]["inputLayerSizes"][i],
                        activation=self.args["model"]["inputNetwork"]["activation"],
                        kernel_initializer=kernelInit,
                        kernel_regularizer=tf.keras.regularizers.L2(
                            self.args["model"]["weightReg"]
                        ),
                    )
                )
                inputModel.add(
                    tf.keras.layers.Dropout(
                        rate=self.args["model"]["inputNetwork"]["dropout"]
                    )
                )

            inputModel.trainable = self.args["model"]["inputNetwork"].get(
                "trainable", True
            )
            inputModel.summary()

            self.inputLayers.append(inputModel)
            self.normLayers.append(normLayer)

    def _buildInputLayers(self, isTraining):
        # Build day transformation and normalization layers
        self.nInputLayers = np.max(self.args["dataset"]["datasetToLayerMap"]) + 1
        self.inputLayers = []
        self.normLayers = []
        for layerIdx in range(self.nInputLayers):
            datasetIdx = np.argwhere(
                np.array(self.args["dataset"]["datasetToLayerMap"]) == layerIdx
            )
            datasetIdx = datasetIdx[0, 0]

            nInputFeatures = self.args["dataset"]["nInputFeatures"]

            # Adapt normalization layer with all data.
            normLayer = tf.keras.layers.experimental.preprocessing.Normalization(
                input_shape=[nInputFeatures]
            )
            if isTraining and self.args["normLayer"]:
                normLayer.adapt(self.tfAdaptDatasets[datasetIdx].take(-1))

            inputLayerSize = self.args["model"].get("inputLayerSize", nInputFeatures)
            if inputLayerSize == nInputFeatures:
                kernelInit = tf.keras.initializers.identity()
            else:
                kernelInit = "glorot_uniform"
            linearLayer = tf.keras.layers.Dense(
                inputLayerSize,
                kernel_initializer=kernelInit,
                kernel_regularizer=tf.keras.regularizers.L2(
                    self.args["model"]["weightReg"]
                ),
            )
            linearLayer.build(input_shape=[nInputFeatures])

            self.inputLayers.append(linearLayer)
            self.normLayers.append(normLayer)

    def _prepareForTraining(self):
        # build the dataset pipelines
        self.tfAdaptDatasets = []
        self.tfTrainDatasets = []
        self.tfValDatasets = []
        subsetChans = self.args["dataset"].get("subsetChans", -1)
        lastDaySubsetChans = self.args["dataset"].get("lastDaySubsetChans", -1)
        TXThreshold = self.args["dataset"].get("TXThreshold", True)
        spkPower = self.args["dataset"].get("spkPower", True)
        nInputFeatures = self.args["dataset"]["nInputFeatures"]
        if subsetChans > 0:
            if TXThreshold and spkPower:
                # nInputFeatures = 2*subsetChans
                chanIndices = np.random.permutation(128)[:subsetChans]
                chanIndices = np.concatenate((chanIndices, chanIndices + 128))
            else:
                # nInputFeatures = subsetChans
                if TXThreshold:
                    chanIndices = np.random.permutation(128)[:subsetChans]
                else:
                    chanIndices = np.random.permutation(128)[:subsetChans] + 128
        else:
            chanIndices = None
            if "chanIndices" in self.args["dataset"]:
                chanIndices = np.array(
                    list(
                        range(
                            self.args["dataset"]["chanIndices"][0],
                            self.args["dataset"]["chanIndices"][1],
                        )
                    )
                )
            nInputFeatures = self.args["dataset"]["nInputFeatures"]

        for i, (thisDataset, thisDataDir) in enumerate(
            zip(self.args["dataset"]["sessions"], self.args["dataset"]["dataDir"])
        ):
            trainDir = os.path.join(thisDataDir, thisDataset, "train")
            syntheticDataDir = None
            if (
                self.args["dataset"]["syntheticMixingRate"] > 0
                and self.args["dataset"]["syntheticDataDir"] is not None
            ):
                if isinstance(self.args["dataset"]["syntheticDataDir"], ListConfig):
                    if self.args["dataset"]["syntheticDataDir"][i] is not None:
                        syntheticDataDir = os.path.join(
                            self.args["dataset"]["syntheticDataDir"][i],
                            f"{thisDataset}_syntheticSentences",
                        )
                else:
                    syntheticDataDir = os.path.join(
                        self.args["dataset"]["syntheticDataDir"],
                        f"{thisDataset}_syntheticSentences",
                    )

            datasetName = self.args["dataset"]["name"]
            labelDir = None
            labelDirs = self.args["dataset"].get("labelDir", None)
            if labelDirs is not None and labelDirs[i] is not None:
                labelDir = os.path.join(labelDirs[i], thisDataset)

            lastDaySubsetSize = self.args["dataset"].get("lastDaySubsetSize", -1)
            if (
                i == (len(self.args["dataset"]["sessions"]) - 1)
                and lastDaySubsetSize != -1
            ):
                subsetSize = lastDaySubsetSize
            else:
                subsetSize = self.args["dataset"].get("subsetSize", -1)

            newTrainDataset = getDataset(datasetName)(
                trainDir,
                nInputFeatures,
                self.args["dataset"]["nClasses"],
                self.args["dataset"]["maxSeqElements"],
                self.args["dataset"]["bufferSize"],
                syntheticDataDir,
                0
                if syntheticDataDir is None
                else self.args["dataset"]["syntheticMixingRate"],
                subsetSize,
                labelDir,
                self.args["dataset"].get("timeWarpSmoothSD", 0),
                self.args["dataset"].get("timeWarpNoiseSD", 0),
                chanIndices=chanIndices,
            )

            newTrainDataset, newDatasetForAdapt = newTrainDataset.build(
                self.args["batchSize"], isTraining=True
            )

            testOnTrain = self.args["dataset"].get("testOnTrain", False)
            if "testDir" in self.args.keys():
                testDir = self.args["testDir"]
            else:
                testDir = "test"
            valDir = os.path.join(
                thisDataDir, thisDataset, testDir if not testOnTrain else "train"
            )

            newValDataset = getDataset(datasetName)(
                valDir,
                nInputFeatures,
                self.args["dataset"]["nClasses"],
                self.args["dataset"]["maxSeqElements"],
                self.args["dataset"]["bufferSize"],
                chanIndices=chanIndices,
            )
            newValDataset = newValDataset.build(
                self.args["batchSize"], isTraining=False
            )

            self.tfAdaptDatasets.append(newDatasetForAdapt)
            self.tfTrainDatasets.append(newTrainDataset)
            self.tfValDatasets.append(newValDataset)

        # Define input layers, including feature normalization which is adapted on the training data
        if "inputNetwork" in self.args["model"]:
            self._buildInputNetworks(isTraining=True)
        else:
            self._buildInputLayers(isTraining=True)

        # Train dataset selector. Used for switch between different day's data during training.
        self.trainDatasetSelector = {}
        self.trainDatasetIterators = [iter(d) for d in self.tfTrainDatasets]
        for x in range(len(self.args["dataset"]["sessions"])):
            self.trainDatasetSelector[x] = lambda x=x: self._datasetLayerTransform(
                self.trainDatasetIterators[x].get_next(),
                self.normLayers[self.args["dataset"]["datasetToLayerMap"][x]],
                self.args["dataset"]["whiteNoiseSD"],
                self.args["dataset"]["constantOffsetSD"],
                self.args["dataset"]["randomWalkSD"],
                self.args["dataset"]["staticGainSD"],
                self.args["dataset"].get("randomCut", 0),
            )

        # clear old checkpoints
        ckptFiles = [str(x) for x in Path(self.args["outputDir"]).glob("ckpt-*")]
        for file in ckptFiles:
            os.remove(file)

        if os.path.isfile(self.args["outputDir"] + "/checkpoint"):
            os.remove(self.args["outputDir"] + "/checkpoint")

        # saving/loading
        ckptVars = {}
        ckptVars["net"] = self.model
        for x in range(len(self.normLayers)):
            ckptVars["normLayer_" + str(x)] = self.normLayers[x]
            ckptVars["inputLayer_" + str(x)] = self.inputLayers[x]

        # Resume if checkpoint exists in outputDir
        resume = os.path.exists(os.path.join(self.args["outputDir"], "checkpoint"))
        if resume:
            # Resume training, so we need to load optimizer and step from checkpoint.
            ckptVars["step"] = tf.Variable(0)
            ckptVars["bestValCer"] = tf.Variable(1.0)
            ckptVars["optimizer"] = self.optimizer
            self.checkpoint = tf.train.Checkpoint(**ckptVars)
            ckptPath = tf.train.latest_checkpoint(self.args["outputDir"])
            # If in infer mode, we may want to load a particular checkpoint idx
            if self.args["mode"] == "infer":
                if self.args["loadCheckpointIdx"] is not None:
                    ckptPath = os.path.join(
                        self.args["outputDir"], f'ckpt-{self.args["loadCheckpointIdx"]}'
                    )
            print("Loading from : " + ckptPath)
            self.checkpoint.restore(ckptPath).expect_partial()
        else:
            if self.args["loadDir"] != None and os.path.exists(
                os.path.join(self.args["loadDir"], "checkpoint")
            ):
                if self.args["loadCheckpointIdx"] is not None:
                    ckptPath = os.path.join(
                        self.args["loadDir"], f'ckpt-{self.args["loadCheckpointIdx"]}'
                    )
                else:
                    ckptPath = tf.train.latest_checkpoint(self.args["loadDir"])

                print("Loading from : " + ckptPath)
                self.checkpoint = tf.train.Checkpoint(**ckptVars)
                self.checkpoint.restore(ckptPath)

                if (
                    "copyInputLayer" in self.args["dataset"]
                    and self.args["dataset"]["copyInputLayer"] is not None
                ):
                    print(self.args["dataset"]["copyInputLayer"].items())
                    for t, f in self.args["dataset"]["copyInputLayer"].items():
                        for vf, vt in zip(
                            self.inputLayers[int(f)].variables,
                            self.inputLayers[int(t)].variables,
                        ):
                            vt.assign(vf)

                # After loading, we need to put optimizer and step back to checkpoint in order to save them.
                ckptVars["step"] = tf.Variable(0)
                ckptVars["bestValCer"] = tf.Variable(1.0)
                ckptVars["optimizer"] = self.optimizer
                self.checkpoint = tf.train.Checkpoint(**ckptVars)
            else:
                # Nothing to load.
                ckptVars["step"] = tf.Variable(0)
                ckptVars["bestValCer"] = tf.Variable(1.0)
                ckptVars["optimizer"] = self.optimizer
                self.checkpoint = tf.train.Checkpoint(**ckptVars)

        self.ckptManager = tf.train.CheckpointManager(
            self.checkpoint,
            self.args["outputDir"],
            max_to_keep=None if self.args["batchesPerSave"] > 0 else 10,
        )

        # Tensorboard summary
        if self.args["mode"] == "train":
            self.summary_writer = tf.summary.create_file_writer(self.args["outputDir"])

    # train에서 그 data 들 dictionary 원본
    def _datasetLayerTransform(
        self,
        dat,
        normLayer,
        whiteNoiseSD,
        constantOffsetSD,
        randomWalkSD,
        staticGainSD,
        randomCut,
    ):
        features = dat["inputFeatures"]
        features = normLayer(dat["inputFeatures"])

        featShape = tf.shape(features)
        batchSize = featShape[0]
        featDim = featShape[2]
        if staticGainSD > 0:
            warpMat = tf.tile(
                tf.eye(features.shape[2])[tf.newaxis, :, :], [batchSize, 1, 1]
            )
            warpMat += tf.random.normal(tf.shape(warpMat), mean=0, stddev=staticGainSD)
            features = tf.linalg.matmul(features, warpMat)

        if whiteNoiseSD > 0:
            features += tf.random.normal(featShape, mean=0, stddev=whiteNoiseSD)

        if constantOffsetSD > 0:
            features += tf.random.normal(
                [batchSize, 1, featDim], mean=0, stddev=constantOffsetSD
            )

        if randomWalkSD > 0:
            features += tf.math.cumsum(
                tf.random.normal(featShape, mean=0, stddev=randomWalkSD),
                axis=self.args["randomWalkAxis"],
            )

        if randomCut > 0:
            cut = np.random.randint(0, randomCut)
            features = features[:, cut:, :]
            dat["nTimeSteps"] = dat["nTimeSteps"] - cut

        if self.args["smoothInputs"]:
            features = gaussSmooth(features, kernelSD=self.args["smoothKernelSD"])

        if self.args["lossType"] == "ctc":
            outDict = {
                "inputFeatures": features,
                #'classLabelsOneHot': dat['classLabelsOneHot'],
                "newClassSignal": dat["newClassSignal"],
                "seqClassIDs": dat["seqClassIDs"],
                "nTimeSteps": dat["nTimeSteps"],
                "nSeqElements": dat["nSeqElements"],
                "ceMask": dat["ceMask"],
                "transcription": dat["transcription"],
            }
        elif self.args["lossType"] == "ce":
            outDict = {
                "inputFeatures": features,
                "classLabelsOneHot": dat["classLabelsOneHot"],
                "newClassSignal": dat["newClassSignal"],
                "seqClassIDs": dat["seqClassIDs"],
                "nTimeSteps": dat["nTimeSteps"],
                "nSeqElements": dat["nSeqElements"],
                "ceMask": dat["ceMask"],
                "transcription": dat["transcription"],
            }

        return outDict

    def train(self):
        perBatchData_train = np.zeros([self.args["nBatchesToTrain"] + 1, 6])
        perBatchData_val = np.zeros([self.args["nBatchesToTrain"] + 1, 6])

        # Restore snapshot
        restoredStep = int(self.checkpoint.step)
        if restoredStep > 0:
            outputSnapshot = scipy.io.loadmat(
                self.args["outputDir"] + "/outputSnapshot"
            )
            perBatchData_train = outputSnapshot["perBatchData_train"]
            perBatchData_val = outputSnapshot["perBatchData_val"]

        saveBestCheckpoint = self.args["batchesPerSave"] == 0
        bestValCer = self.checkpoint.bestValCer
        print("bestVal-WER: " + str(bestValCer))
        for batchIdx in range(restoredStep, self.args["nBatchesToTrain"] + 1):
            # --training--
            if self.args["dataset"]["datasetProbability"] is None:
                nSessions = len(self.args["dataset"]["sessions"])
                self.args["dataset"]["datasetProbability"] = [
                    1.0 / nSessions
                ] * nSessions
            datasetIdx = int(
                np.argwhere(
                    np.random.multinomial(1, self.args["dataset"]["datasetProbability"])
                )[0][0]
            )
            
            layerIdx = self.args["dataset"]["datasetToLayerMap"][datasetIdx]
            
            dtStart = datetime.now()
            try:
                sample_input, sample_output, seqLength = self._trainStep(
                    tf.constant(datasetIdx, dtype=tf.int32),
                    tf.constant(layerIdx, dtype=tf.int32),
                )

                self.checkpoint.step.assign_add(1)
                totalSeconds = (datetime.now() - dtStart).total_seconds()
                # self._addRowToStatsTable(
                #     perBatchData_train, batchIdx, totalSeconds, trainOut, True
                # )
                print(
                    f"Train batch {batchIdx}: "
                    + f'loss: {self.train_loss.result():.8f} '
                    + f'Accuracy: {self.train_accuracy.result():.8f} '
                    + f"time {totalSeconds:.2f}"
                )
                
                record_train_loss.append(self.train_loss.result())       
                
            except tf.errors.InvalidArgumentError as e:
                print(e)
            
            # --validation--
            if batchIdx % self.args["batchesPerVal"] == 0 and batchIdx != 0:
                
                # print("------------------- Train Sample Result ----------------")
                # print(sample_input[0][:seqLength[0] + 2].numpy())
                # print(sample_output[0][:seqLength[0] + 2].numpy())
                
                dtStart = datetime.now()
                valOutputs = self.inference()
                
                avg_wer = np.average(valOutputs["wer"])

                totalSeconds = (datetime.now() - dtStart).total_seconds()
                                
                print(
                    f"Val batch {batchIdx}: "
                    + f'WER: {avg_wer} '
                    + f"time {totalSeconds:.2f}"
                )
                # print(valOutputs["targetSentences"])
                # print(valOutputs["decodedSentences"])
                # print("-------------------- EXAMPLE -------------------")
                # # Target : Phoneme
                # print("Target : " + str(valOutputs["targetSentences"][0][0][:valOutputs["targetLength"][0][0]]))
                # print("Output : " + str(valOutputs["decodedSentences"][0][0][:valOutputs["targetLength"][0][0]]))
                # print(f"WER {self.max_wer} -> {avg_wer}")
                
                if saveBestCheckpoint and avg_wer < bestValCer:
                    bestValCer = avg_wer
                    self.checkpoint.bestValCer.assign(bestValCer)
                    savedCkpt = self.ckptManager.save(checkpoint_number=batchIdx)
                    print(f"Checkpoint saved {savedCkpt}")
                
                record_cer.append(avg_wer)
                
                if self.max_wer < avg_wer and avg_wer < 0.0001:
                    file_name = "/Early_Stop_" + str(batchIdx)
                    break
                else:
                    file_name = "/" + str(batchIdx)
                    self.max_wer = avg_wer
                    
                self._save_model(file_name)

            if (
                self.args["batchesPerSave"] > 0
                and batchIdx % self.args["batchesPerSave"] == 0
            ):
                savedCkpt = self.ckptManager.save(checkpoint_number=batchIdx)
                print(f"Checkpoint saved {savedCkpt}")
                
        with open('../record_train_loss.pkl', 'wb') as file:
            pickle.dump(record_train_loss, file)
        with open('../record_wer.pkl', 'wb') as file:
            pickle.dump(record_cer, file)
            
        return float(bestValCer)

    def inference(self, returnData=False, load=False, weights_path=None):
        # run through the specified dataset a single time and return the outputs
        if(load):
            self.model.load_weights(weights_path)
        infOut = {}
        infOut["logits"] = []
        infOut["inferSeqs"] = []
        infOut["transcription"] = []
        infOut["targetSentences"] = []
        infOut["targetLength"] = []
        infOut["decodedSentences"] = []
        infOut["wer"] = []
        allData = []

        print("--------------- Start Validation Step --------------")

        for datasetIdx, valProb in enumerate(
            self.args["dataset"]["datasetProbabilityVal"]
        ):
            # print(str(datasetIdx) + "/" + str(len(self.args["dataset"]["datasetProbabilityVal"])))
            
            if valProb <= 0:
                continue

            layerIdx = self.args["dataset"]["datasetToLayerMap"][datasetIdx]
            
            
            for data in self.tfValDatasets[datasetIdx]:
                
                out = self._valStep(data, layerIdx)

                infOut["logits"].append(out["logits"].numpy())
                infOut["transcription"].append(out["transcription"].numpy())
                infOut["inferSeqs"].append(out["inferSeqs"].numpy())
                infOut["targetSentences"].append(out["targetSentences"])
                infOut["targetLength"].append(out["targetLength"])
                infOut["decodedSentences"].append(out["decodedSentences"])
                infOut["wer"].append(out["wer"])

        if returnData:
            return infOut, allData
        else:
            return infOut

    def _addRowToStatsTable(
        self, currentTable, batchIdx, computationTime, minibatchOutput, isTrainBatch
    ):
        currentTable[batchIdx, :] = np.array(
            [
                batchIdx,
                computationTime,
                minibatchOutput["predictionLoss"] if isTrainBatch else 0.0,
                minibatchOutput["regularizationLoss"] if isTrainBatch else 0.0,
                tf.reduce_mean(minibatchOutput["seqErrorRate"]),
                minibatchOutput["gradNorm"] if isTrainBatch else 0.0,
            ],
            dtype=object,
        )

        prefix = "train" if isTrainBatch else "val"

        with self.summary_writer.as_default():
            if isTrainBatch:
                tf.summary.scalar(
                    f"{prefix}/predictionLoss",
                    minibatchOutput["predictionLoss"],
                    step=batchIdx,
                )
                tf.summary.scalar(
                    f"{prefix}/regLoss",
                    minibatchOutput["regularizationLoss"],
                    step=batchIdx,
                )
                tf.summary.scalar(
                    f"{prefix}/gradNorm", minibatchOutput["gradNorm"], step=batchIdx
                )
            tf.summary.scalar(
                f"{prefix}/seqErrorRate",
                tf.reduce_mean(minibatchOutput["seqErrorRate"]),
                step=batchIdx,
            )
            tf.summary.scalar(
                f"{prefix}/computationTime", computationTime, step=batchIdx
            )
            # if isTrainBatch:
            #    tf.summary.scalar(
            #        f'{prefix}/lr', self.optimizer._decayed_lr(tf.float32), step=batchIdx)

    @tf.function()
    def _trainStep(self, datasetIdx, layerIdx):
        
        data = tf.switch_case(datasetIdx, self.trainDatasetSelector)
        
        inputTransformSelector = {}
        for x in range(self.nInputLayers):
            inputTransformSelector[x] = lambda x=x: self.inputLayers[x](
                data["inputFeatures"], training=True
            )

        regLossSelector = {}
        for x in range(self.nInputLayers):
            regLossSelector[x] = lambda x=x: self.inputLayers[x].losses
                    
        with tf.GradientTape() as tape:
            
            inputTransformedFeatures = tf.switch_case(layerIdx, inputTransformSelector)
            
            # Target : Charactor
            # padded_tensor = tf.pad(data["transcription"], paddings=[[0, 0], [1, 0]], constant_values=149)
            
            # Target : Phoneme
            padded_tensor = tf.pad(data["seqClassIDs"], paddings=[[0, 0], [1, 0]], constant_values=44)
            
            # Transformer
            # predictions, _ = self.model([inputTransformedFeatures, padded_tensor[:,:-1]], training = True)
            
            # Only Decoder
            enc_padding_mask, look_ahead_mask, dec_mask = create_masks(tf.argmax(inputTransformedFeatures, axis=2),padded_tensor[:,:-1])
            
            
            predictions, _ = self.model(padded_tensor[:,:-1], inputTransformedFeatures, True, look_ahead_mask, dec_mask)
                
            #Target : Charactor
            # loss = loss_function(data["transcription"], predictions)
            
            #Target : Phoneme
            loss = loss_function(data["seqClassIDs"], predictions)
    
        gradients = tape.gradient(loss,self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        self.train_loss(loss)
        # self.train_accuracy(accuracy_function(data["transcription"], predictions))
        self.train_accuracy(accuracy_function(data["seqClassIDs"], predictions))
        
        return padded_tensor, tf.argmax(predictions, axis=-1), data["nSeqElements"]

    def _valStep(self, data, layerIdx):
        data = self._datasetLayerTransform(
            data, self.normLayers[layerIdx], 0, 0, 0, 0, 0
        )

        # channel zeroing
        if "channelMask" in self.args.keys():
            maskedFeatures = data["inputFeatures"] * tf.constant(
                np.array(self.args["channelMask"])[np.newaxis, np.newaxis, :],
                dtype=tf.float32,
            )
            print("masking")
        else:
            maskedFeatures = data["inputFeatures"]

        inputTransformedFeatures = self.inputLayers[layerIdx](
            maskedFeatures, training=False
        )
        
        target_sent = []
        target_length = []
        # Target : Charactor
        # for seq in data["transcription"]:
        #     endIdx = tf.argmax(tf.cast(tf.equal(seq, 0), tf.int32)).numpy()
        #     target_length.append(endIdx)
        #     characters = [chr(value) for value in seq]
        #     result_string = ''.join(characters)
        #     removed = result_string[:endIdx]
        #     target_sent.append(removed)
        
        print("-------------------- FIND PROPER END STEP ------------------------")
        print()
        print("SAMPLES")
        print(data['nSeqElements'])
        print()
        print()
        print()
        # print(data['classLabelsOneHot'])
        print()
        
        # Target : Phoneme
        for seq in data["seqClassIDs"]:
            endIdx = tf.argmax(tf.cast(tf.equal(seq, 0), tf.int32)).numpy()
            target_length.append(endIdx)
            phoneme_idx = [value.numpy() for value in seq]
            target_sent.append(phoneme_idx)
            
        outputLabels = tf.fill((inputTransformedFeatures.shape[0], 1), 44)
        outputLabels = tf.cast(outputLabels, dtype=tf.int64)
        
        # output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
        # output_array = output_array.write(0, 149)
    
        # In Real Inference
        for i in range(100):

            enc_padding_mask, look_ahead_mask, dec_mask = create_masks(tf.argmax(inputTransformedFeatures, axis=2),outputLabels)
            
            # Only Decoder
            predictions, _ = self.model(outputLabels, inputTransformedFeatures, True, look_ahead_mask, dec_mask)
            
            # output = tf.transpose(output_array.stack())
            # predictions, _ = self.model([inputTransformedFeatures, output], training = False)
            # print("Predict")
            # print(outputLabels)
            # print(predictions)
            # print(predictions.shape)
            # 20 1 150
            full_pred = predictions
            predictions = predictions[:,-1:,:]
            predicted_id = tf.argmax(predictions, axis=-1)
            # nextLabels = tf.argmax(predictions, axis=2)
            
            # print("Output Lables")
            # print(predicted_id)

            # print("Concated Next Labels")
            # output_array = output_array.write(i+1, predicted_id[0])

            outputLabels = tf.concat([outputLabels, predicted_id], axis=1)
            # print(outputLabels)
            
        # Target : Charactor
        # output_sent = []
        # for idx in range(len(outputLabels)):
        #     characters = [chr(value) for value in outputLabels[idx]]
        #     result_string = ''.join(characters)
        #     removed = result_string[1:target_length[idx]+1]
        #     output_sent.append(removed)
            
        # Target : Phoneme
        # print(outputLabels) # 20, 71
        # print(full_pred)  # 20, 70, 45
        # print(tf.argmax(full_pred, axis=2)) # 20 70
        # print(data["seqClassIDs"]) # 20, 500
        # print(data["seqClassIDs"][:,:outputLabels.shape[1]]) # 20, 71
        
        output_sent = []
        for idx in range(len(outputLabels)):
            output_idx = [value.numpy() for value in outputLabels[idx]]
            output_sent.append(output_idx[1:])
        # Target : Phoneme
        print("Target : " + str(target_sent[0][:100]))
        print("Output : " + str(output_sent[0][:100]))
        
        s_wer = 0
        for idx in range(len(target_sent)):
            # s_wer += wer(target_sent[idx], output_sent[idx])
            s_wer += calculate_cer(target_sent[idx], output_sent[idx])
            # print(s_wer)
        s_wer /= len(target_sent)

        output = {}
        output["logits"] = predictions
        output["inferSeqs"] = outputLabels
        
        output["transcription"] = data["transcription"]
        output["targetSentences"] = target_sent
        output["targetSentences"] = target_sent
        output["targetLength"] = target_length
        
        output["decodedSentences"] = output_sent
        
        output["wer"] = s_wer

        return output


def timeWarpDataElement(dat, timeScalingRange):
    warpDat = {}
    warpDat["seqClassIDs"] = dat["seqClassIDs"]
    warpDat["nSeqElements"] = dat["nSeqElements"]
    warpDat["transcription"] = dat["transcription"]

    # nTimeSteps, inputFeatures need to be modified
    globalTimeFactor = (
        1 + (tf.random.uniform(shape=[], dtype=tf.float32) - 0.5) * timeScalingRange
    )
    warpDat["nTimeSteps"] = tf.cast(
        tf.cast(dat["nTimeSteps"], dtype=tf.float32) * globalTimeFactor, dtype=tf.int64
    )

    b = tf.shape(dat["inputFeatures"])[0]
    t = tf.cast(tf.shape(dat["inputFeatures"])[1], dtype=tf.int32)
    warppedT = tf.cast(tf.cast(t, dtype=tf.float32) * globalTimeFactor, dtype=tf.int32)
    newIdx = tf.linspace(
        tf.zeros_like(dat["nTimeSteps"], dtype=tf.int32),
        tf.ones_like(dat["nTimeSteps"], dtype=tf.int32) * (t - 1),
        warppedT,
        axis=1,
    )
    newIdx = tf.cast(newIdx, dtype=tf.int32)
    batchIdx = tf.tile(tf.range(b)[:, None, None], [1, warppedT, 1])
    newIdx = tf.concat([batchIdx, newIdx[..., None]], axis=-1)
    warpDat["inputFeatures"] = tf.gather_nd(dat["inputFeatures"], newIdx)
    # warpDat['classLabelsOneHot'] = tf.gather(
    #    dat['classLabelsOneHot'], newIdx, axis=0)
    warpDat["newClassSignal"] = tf.gather_nd(dat["newClassSignal"], newIdx)
    warpDat["ceMask"] = tf.gather_nd(dat["ceMask"], newIdx)

    return warpDat

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

# def loss_function(real, pred):
#     loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
#     from_logits=True, reduction='none')
#     mask = tf.math.logical_not(tf.math.equal(real, 0))
#     loss_ = loss_object(real, pred)

#     mask = tf.cast(mask, dtype=loss_.dtype)
#     loss_ *= mask

#     return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

# Only decoder
def loss_function(real, pred):
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_sum(loss_) / tf.reduce_sum(mask)


def accuracy_function(real, pred):
    accuracies = tf.equal(real, tf.argmax(pred, axis=2))

    mask = tf.math.logical_not(tf.math.equal(real, 0))
    accuracies = tf.math.logical_and(mask, accuracies)

    accuracies = tf.cast(accuracies, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)

def _wer(reference, hypothesis):
    # Create a 2D matrix to store the distances
    distance = [[0] * (len(hypothesis) + 1) for _ in range(len(reference) + 1)]

    # Initialize the matrix with the distances
    for i in range(len(reference) + 1):
        for j in range(len(hypothesis) + 1):
            if i == 0:
                distance[i][j] = j
            elif j == 0:
                distance[i][j] = i
            else:
                cost = 0 if reference[i - 1] == hypothesis[j - 1] else 1
                distance[i][j] = min(
                    distance[i - 1][j] + 1,      # Deletion
                    distance[i][j - 1] + 1,      # Insertion
                    distance[i - 1][j - 1] + cost  # Substitution
                )

    # Return the WER
    return distance[len(reference)][len(hypothesis)] / len(reference)

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (size, size)

def create_decoder_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    # Add extra dimensions to add the padding to the attention logits.
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_masks(inp, tar):
    # Encoder padding mask
    enc_padding_mask = create_padding_mask(inp)

    # Used in the 2nd attention block in the decoder.
    # This padding mask is used to mask the encoder outputs.
    dec_padding_mask = create_padding_mask(inp)

    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by
    # the decoder.
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    look_ahead_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return enc_padding_mask, look_ahead_mask, dec_padding_mask

def calculate_cer(predicted, target):
    m = len(predicted)
    n = len(target)

    # Create a matrix to store edit distances
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    # Initialize the matrix
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j

    # Calculate edit distances
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            cost = 0 if predicted[i - 1] == target[j - 1] else 1
            dp[i][j] = min(
                dp[i - 1][j] + 1,  # Deletion
                dp[i][j - 1] + 1,  # Insertion
                dp[i - 1][j - 1] + cost,
            )  # Substitution

    # CER calculation
    cer = dp[m][n] / n  # Divide by the total number of characters in the reference
    return cer


  from scipy.ndimage.filters import gaussian_filter1d


In [26]:
def getSubsampledTimeSteps(timeSteps):
    timeSteps = tf.cast(timeSteps / 1, dtype=tf.int32)
    timeSteps = tf.cast(
        (timeSteps - 14)
        / 4
        + 1,
        dtype=tf.int32,
    )
    return timeSteps

In [5]:
baseDir = '/home/s2/nlp002/pj_data'

import os
from glob import glob
from pathlib import Path
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]=""

import numpy as np
from omegaconf import OmegaConf
import tensorflow as tf

In [33]:
#evaluate the RNN on the test partition and competitionHoldOut partition
from tqdm import tqdm

model_file_name = "/7000_model_weights.h5"

testDirs = ['test','competitionHoldOut']
trueTranscriptions = [[],[]]
decodedTranscriptions = [[],[]]
for dirIdx in range(2):
    ckptDir = baseDir + '/derived/tr_to_ph/baselineRelease'
    dirIdx += 1
    args = OmegaConf.load(os.path.join(ckptDir, 'args.yaml'))
    args['loadDir'] = ckptDir
    args['mode'] = 'infer'
    args['loadCheckpointIdx'] = None

    for x in range(len(args['dataset']['datasetProbabilityVal'])):
        args['dataset']['datasetProbabilityVal'][x] = 0.0

    for sessIdx in range(4,19):
        args['dataset']['datasetProbabilityVal'][sessIdx] = 1.0
        args['dataset']['dataDir'][sessIdx] = baseDir+'/derived/tfRecords'
    args['testDir'] = testDirs[dirIdx]

    # Initialize model
    tf.compat.v1.reset_default_graph()
    nsd = NeuralSequenceDecoder(args)
    nsd._load_model(ckptDir + model_file_name)
    # Inference
    out = nsd.inference()
    
    print(out)
    break

    # trueTranscriptions[dirIdx] = out['targetSentences']
    # decodedTranscriptions[dirIdx] = out["decodedSentences"]


Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.04.28/train
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.04.28/competitionHoldOut
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.05/train
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.05/competitionHoldOut
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.17/train
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.17/competitionHoldOut
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.19/train
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.19/competitionHoldOut
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.24/train
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.24/competitionHoldOut
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.26/train
Load data from /home/s2/nlp002/pj_data/derived/tfRecords/t12.2022.05.26/compe

KeyboardInterrupt: 