-
Notifications
You must be signed in to change notification settings - Fork 427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Detection model #32
Detection model #32
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I added a few comments
@@ -1,5 +1,4 @@ | |||
# Copyright (C) 2021, Mindee. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to keep this line break
|
||
class DetectionModel(keras.Model): | ||
"""Implements abstract DetectionModel class | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing args here
doctr/models/detection/model.py
Outdated
self, | ||
inputs: tf.Tensor, | ||
training: bool = False | ||
) -> Tuple[keras.Model, keras.Model]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Legacy typing here
@@ -161,3 +164,248 @@ def __call__( | |||
|
|||
bounding_boxes.append(boxes_batch) | |||
return bounding_boxes | |||
|
|||
|
|||
class DBModel(DetectionModel, keras.Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inherited DetectionModel already inherits from keras.Model
Also, since we will use ResNet50 for now, let's name it DBResNet50
test/test_models.py
Outdated
assert np.shape(dboutput_notrain.numpy())[0] == 8 | ||
# output dimensions | ||
assert np.shape(dboutput_notrain.numpy())[1] == np.shape(dboutput_notrain.numpy())[2] == 640 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can regroup shape value testing:
assert dboutput_notrain.numpy().shape == (8, 640, 640, 3)
test/test_models.py
Outdated
assert all(np.shape(maps.numpy())[0] == 8 for maps in dboutput_train) | ||
# output dimensions | ||
assert all(np.shape(maps.numpy())[1] == np.shape(maps.numpy())[2] == 640 for maps in dboutput_train) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, this can be grouped (and np.shape
is not required, you can use the .shape
method of the ndarray)
channels: int = 128, | ||
) -> None: | ||
super().__init__(shape) | ||
self.channels = channels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many modules could be instantiated over here
Codecov Report
@@ Coverage Diff @@
## main #32 +/- ##
=======================================
Coverage 98.43% 98.43%
=======================================
Files 12 13 +1
Lines 256 320 +64
=======================================
+ Hits 252 315 +63
- Misses 4 5 +1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about the number of comments 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few adjustments left 😅
It seems we burned out our free CI minutes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the edits! Looks good to me!
DetectionModel class and DBModel (child class) added.
DB is implemented with a ResNet50 backbone (feature extactor) provided by keras