In [1]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout, Add
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K

  from ._conv import register_converters as _register_converters


In [5]:
def EEGNet(nb_classes, Chans = 64, Samples = 128, 
             dropoutRate = 0.5, kernLength = 64, F1 = 8, 
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    
    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')
    
    input1   = Input(shape = (1, Chans, Samples))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same', data_format='channels_first',
                                   input_shape = (1, Chans, Samples),
                                   use_bias = False)(input1)
#     Creating an identity mapping
    idMap        = Conv2D(F1*D, (1, kernLength), padding = 'same', data_format='channels_first',
                                   input_shape = (1, Chans, Samples),
                                   use_bias = False)(input1) 
    
    block1       = BatchNormalization(axis = 1)(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False, data_format='channels_first', 
                                   depth_multiplier = D, padding = 'same',
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = dropoutType(dropoutRate)(block1)
    
    block1       = BatchNormalization(axis = 1)(block1)
    block1       = Activation('elu')(block1)
    block1       = Add()([block1, idMap])

    
    idMap2       = Conv2D(F1*D, (1, 8), padding = 'valid', data_format='channels_first', strides=(1,8))(block1) 
    
    block2       = SeparableConv2D(F2, (1, 16),
                                   use_bias = False, padding = 'same')(block1)    
    block2       = dropoutType(dropoutRate)(block2)
    
    block2       = BatchNormalization(axis = 1)(block2)
    block2       = Activation('elu')(block2)
    block2       = Add()([block2, idMap2])
    
    block2       = AveragePooling2D((8, 8), data_format='channels_first')(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)


In [6]:
model = EEGNet(nb_classes = 4, Chans = 64, Samples = 128, 
               dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16, 
               dropoutType = 'Dropout')

In [7]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 1, 64, 128)] 0                                            
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 8, 64, 128)   256         input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 8, 64, 128)   32          conv2d_3[0][0]                   
__________________________________________________________________________________________________
depthwise_conv2d_1 (DepthwiseCo (None, 16, 64, 128)  1024        batch_normalization_3[0][0]      
____________________________________________________________________________________________