Skip to content

Commit 7e8b0d3

Browse files
authored
Merge pull request #250 from rvandeghen/patch-1
Add new checkpoint
2 parents 45834ee + 75ca1bf commit 7e8b0d3

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

efficientnet_pytorch/model.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,18 @@ def extract_endpoints(self, inputs):
238238
Returns:
239239
Dictionary of last intermediate features
240240
with reduction levels i in [1, 2, 3, 4, 5].
241-
242-
Example:
243-
>>> import torch
244-
>>> from efficientnet.model import EfficientNet
245-
>>> inputs = torch.rand(1, 3, 224, 224)
246-
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
247-
>>> endpoints = model.extract_endpoints(inputs)
248-
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
249-
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
250-
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
251-
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
252-
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
241+
Example:
242+
>>> import torch
243+
>>> from efficientnet.model import EfficientNet
244+
>>> inputs = torch.rand(1, 3, 224, 224)
245+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
246+
>>> endpoints = model.extract_endpoints(inputs)
247+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
248+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
249+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
250+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
251+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
252+
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
253253
"""
254254
endpoints = dict()
255255

@@ -265,6 +265,8 @@ def extract_endpoints(self, inputs):
265265
x = block(x, drop_connect_rate=drop_connect_rate)
266266
if prev_x.size(2) > x.size(2):
267267
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
268+
elif idx == len(self._blocks) - 1:
269+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
268270
prev_x = x
269271

270272
# Head

0 commit comments

Comments
 (0)