In [13]:
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Reshape, Permute
from tensorflow.keras.layers import Convolution2D
from tensorflow.keras.layers import MaxPooling2D, ZeroPadding2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import ELU
from tensorflow.keras.layers import GRU
import tensorflow as tf

In [21]:
channel_axis = 3
freq_axis = 1
time_axis = 2

melgram_input = Input(shape=(96, 1336, 2))

In [27]:
x = ZeroPadding2D(padding=(0, 37))(melgram_input)
x = BatchNormalization(axis=freq_axis, name='bn_0_freq')(x)

# Conv block 1
x = Convolution2D(64, 3, 3, padding='same', name='conv1')(x)
x = BatchNormalization(axis=channel_axis, name='bn1')(x)
x = ELU()(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(x)
x = Dropout(0.1, name='dropout1')(x)

# Conv block 2
x = Convolution2D(128, 3, 3, padding='same', name='conv2')(x)
x = BatchNormalization(axis=channel_axis,name='bn2')(x)
x = ELU()(x)
x = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name='pool2')(x)
x = Dropout(0.1, name='dropout2')(x)

# Conv block 3
x = Convolution2D(128, 3, 3, padding='same', name='conv3')(x)
x = BatchNormalization(axis=channel_axis, name='bn3')(x)
x = ELU()(x)
# x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool3')(x)
x = Dropout(0.1, name='dropout3')(x)

# # Conv block 4
x = Convolution2D(128, 3, 3, padding='same', name='conv4')(x)
x = BatchNormalization(axis=channel_axis, name='bn4')(x)
x = ELU()(x)
# x = MaxPooling2D(pool_size=(4, 4), strides=(4, 4), name='pool4')(x)
x = Dropout(0.1, name='dropout4')(x)

x = Reshape((-1, 128))(x)
x = GRU(32)(x)

model = tf.keras.Model(inputs=melgram_input, outputs=x)
model.summary()

Model: "functional_21"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 96, 1336, 2)]     0         
_________________________________________________________________
zero_padding2d_16 (ZeroPaddi (None, 96, 1410, 2)       0         
_________________________________________________________________
bn_0_freq (BatchNormalizatio (None, 96, 1410, 2)       384       
_________________________________________________________________
conv1 (Conv2D)               (None, 32, 470, 64)       1216      
_________________________________________________________________
bn1 (BatchNormalization)     (None, 32, 470, 64)       256       
_________________________________________________________________
elu_41 (ELU)                 (None, 32, 470, 64)       0         
_________________________________________________________________
pool1 (MaxPooling2D)         (None, 16, 235, 64)     