Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- can now specify dimension in ResizeAndConcatenate
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Jul 21, 2018
1 parent a6ce6f9 commit 4a88ebc
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions inferno/extensions/layers/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,24 @@ def forward(self, *inputs):
class ResizeAndConcatenate(nn.Module):
"""
Resize input tensors spatially (to a specified target size) before concatenating
them along the channel dimension. The downsampling mode can be specified
('average' or 'max'), but the upsampling is always 'nearest'.
them along the a given `dim`ension (channel, i.e. 1 by default). The downsampling mode can
be specified ('average' or 'max'), but the upsampling is always 'nearest'.
"""

POOL_MODE_MAPPING = {'avg': 'avg',
'average': 'avg',
'mean': 'avg',
'max': 'max'}

def __init__(self, target_size, pool_mode='average'):
def __init__(self, target_size, pool_mode='average', dim=1):
super(ResizeAndConcatenate, self).__init__()
self.target_size = target_size
assert_(pool_mode in self.POOL_MODE_MAPPING.keys(),
"`pool_mode` must be one of {}, got {} instead."
.format(self.POOL_MODE_MAPPING.keys(), pool_mode),
ValueError)
self.pool_mode = pool_mode
self.dim = dim

def forward(self, *inputs):
dim = inputs[0].dim()
Expand All @@ -151,7 +152,7 @@ def forward(self, *inputs):
ShapeError)
resized_inputs.append(resize_function(input, target_size))
# Concatenate along the channel axis
concatenated = torch.cat(tuple(resized_inputs), 1)
concatenated = torch.cat(tuple(resized_inputs), self.dim)
# Done
return concatenated

Expand Down

0 comments on commit 4a88ebc

Please sign in to comment.