-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/add spec augment layer #135
Feature/add spec augment layer #135
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thank you for the PR! It looks pretty good and SpecAug would be useful for sure. There are some requests for changes :)
kapre/augmentation.py
Outdated
Generate a mask for the axis provided | ||
Args: | ||
inputs (`tuple`): A 3-tuple with the following structure: | ||
inputs[0]: A spectrogram |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape of spectrogram? (i guessed it's 4D)
kapre/augmentation.py
Outdated
Args: | ||
inputs (`tuple`): A 3-tuple with the following structure: | ||
inputs[0]: A spectrogram | ||
inputs[1]: The axis where the mask will be applied |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(int
)
kapre/augmentation.py
Outdated
inputs (`tuple`): A 3-tuple with the following structure: | ||
inputs[0]: A spectrogram | ||
inputs[1]: The axis where the mask will be applied | ||
inputs[2]: The mask param as defined in the original paper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
..which is the max width of the mask, right? let's clarify.
kapre/augmentation.py
Outdated
inputs[1]: The axis where the mask will be applied | ||
inputs[2]: The mask param as defined in the original paper | ||
Returns: | ||
A mask represented as a boolean tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we say something about the shape here? (to be extremely clear about everything)
axis_indices = tf.reshape(axis_indices, (1, 1, -1)) | ||
|
||
mask_width = tf.random.uniform(shape=(), maxval=mask_param, dtype=tf.int32) | ||
mask_start = tf.random.uniform(shape=(), maxval=axis_limit - mask_width, dtype=tf.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like axis_limit - mask_width
can be <0
. should we take care of this case?
kapre/augmentation.py
Outdated
by the axis provided. | ||
Args: | ||
x: The input spectrogram | ||
axis: The axis where the masks will be applied |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(int
)
kapre/augmentation.py
Outdated
Args: | ||
x: The input spectrogram | ||
axis: The axis where the masks will be applied | ||
mask_param: The mask param as defined in the original paper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here - type and description.
kapre/augmentation.py
Outdated
axis_indices = tf.reshape(axis_indices, (-1, 1, 1)) | ||
elif axis == 1: | ||
axis_indices = tf.reshape(axis_indices, (1, -1, 1)) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would else
be okay? or shall we add elif axis ==2:
followed by else: raise NotImplementedError
? i usually prefer the latter unless it's taken care of somewhere else very clearly.
tests/test_augmentation.py
Outdated
model.add(spec_augment) | ||
|
||
# We must force training to True to test properly if SpecAugment works as expected | ||
spec_augmented = model(batch_src, training=True)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a test to ensure it doesn't change anything when training != False
.
tests/test_augmentation.py
Outdated
|
||
# We must force training to True to test properly if SpecAugment works as expected | ||
spec_augmented = model(batch_src, training=True)[0] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i get that it's difficult to test the behavior of the whole layer. but can't there be any way at all..?
also, at least we can test the static methods we defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, you mean the method for generating masks, applyng masks and similar, right? I
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, that's right.
I had to refactor a bit the original code because I was having problems saving the layer because of the |
Hi @keunwoochoi! Have you had a chance to look at the changes? Any suggestions please go ahead! 👍 |
Thank you again for the solid work and sorry for my delay. It has been overwhelmingly busy these days. I have two final requests.
|
@MichaelisTrofficus looks great, thank you! i made some small changes. will do another review soon just in case; and merge. |
* Feature/add spec augment layer (#135) * Add SpecAugment Layer * Extend docstrings for methods in the main class * Add __all__ to import augmentation techniques * Improve docstrings and add exceptions handling * Improve tests for SpecAugment augmentation layer * Uncomment tests. Fix setup.py * Solve saving issue with tf format / Refactor code * Add Jupyter Notebook and apply black for code reformatting * add version printing; add 5 figures * bump version; 0.3.6 -> 0.3.7 Co-authored-by: Keunwoo Choi <keunwoo.choi@bytedance.com> * add release note Co-authored-by: Miguel Otero Pedrido <32078719+MichaelisTrofficus@users.noreply.github.com>
The PR incorporates SpecAugment's augmentation technique. Currently, it cannot be used for spectrograms deeper than 1 (although I am working on it, but the original paper does not discuss this issue). Also, image_warping has not been introduced because it gave problems when encapsulating it in a Keras Layer.I have tried to test those features that seemed important to me, but any suggestions are welcome! :)