Skip to content

Commit

Permalink
Remove unnecesary shape from BN layers
Browse files Browse the repository at this point in the history
  • Loading branch information
edufonseca committed Dec 6, 2018
1 parent d720c2b commit e1a4622
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
12 changes: 6 additions & 6 deletions architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,38 @@ def get_model_baseline(params_learn=None, params_extract=None):
spec_x = spec_start

# l1
spec_x = BatchNormalization(axis=1, input_shape=input_shape)(spec_x)
spec_x = BatchNormalization(axis=1)(spec_x)
spec_x = Activation('relu')(spec_x)

spec_x = Conv2D(24, (5, 5),
padding='same', # fmap has same size as input
kernel_initializer='he_normal',
data_format='channels_first')(spec_x)
spec_x = BatchNormalization(axis=1, input_shape=input_shape)(spec_x)
spec_x = BatchNormalization(axis=1)(spec_x)
spec_x = Activation('relu')(spec_x)
spec_x = MaxPooling2D(pool_size=(4, 2), data_format="channels_first")(spec_x)

# l2
spec_x = BatchNormalization(axis=1, input_shape=input_shape)(spec_x)
spec_x = BatchNormalization(axis=1)(spec_x)
spec_x = Activation('relu')(spec_x)

spec_x = Conv2D(48, (5, 5),
padding='same', # fmap has same size as input
kernel_initializer='he_normal',
data_format='channels_first')(spec_x)
spec_x = BatchNormalization(axis=1, input_shape=input_shape)(spec_x)
spec_x = BatchNormalization(axis=1)(spec_x)
spec_x = Activation('relu')(spec_x)
spec_x = MaxPooling2D(pool_size=(4, 2), data_format="channels_first")(spec_x)

# l3
spec_x = BatchNormalization(axis=1, input_shape=input_shape)(spec_x)
spec_x = BatchNormalization(axis=1)(spec_x)
spec_x = Activation('relu')(spec_x)

spec_x = Conv2D(48, (5, 5),
padding='same', # fmap has same size as input
kernel_initializer='he_normal',
data_format='channels_first')(spec_x)
spec_x = BatchNormalization(axis=1, input_shape=input_shape)(spec_x)
spec_x = BatchNormalization(axis=1)(spec_x)
spec_x = Activation('relu')(spec_x)

spec_x = Flatten()(spec_x)
Expand Down
3 changes: 0 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@
filelist_audio_tr_nV_small_dur = train_csv.nV_small_dur.values.tolist()
idx_nV_small_dur = [i for i, x in enumerate(filelist_audio_tr_nV_small_dur) if x == 1]

filelist_audio_tr_nV_small_clips = train_csv.nV_small_clips.values.tolist()
idx_nV_small_clips= [i for i, x in enumerate(filelist_audio_tr_nV_small_clips) if x == 1]


# create dict with ground truth mapping with labels:
# -key: path to wav
Expand Down

0 comments on commit e1a4622

Please sign in to comment.