Skip to content

Commit

Permalink
add channels last support for thnn_conv2d (non-dilated)
Browse files Browse the repository at this point in the history
ghstack-source-id: 0584ef4b7b499004aa55c1d62db34c584ee60aab
Pull Request resolved: pytorch#49582
  • Loading branch information
mingfeima committed May 15, 2021
1 parent b7af4f5 commit da844e9
Show file tree
Hide file tree
Showing 7 changed files with 507 additions and 47 deletions.
15 changes: 12 additions & 3 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,9 @@ static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
if (!tensor.defined()) {
return at::Tensor();
}
auto memory_format = tensor.suggest_memory_format();
int64_t n = tensor.sizes()[dim] / groups;
return tensor.narrow(dim, n * g, n).contiguous();
return tensor.narrow(dim, n * g, n).contiguous(memory_format);
}


Expand Down Expand Up @@ -979,12 +980,20 @@ at::Tensor _convolution(
params.stride,
params.padding);
} else if (input.device().is_cpu() || input.is_cuda()) {
bool is_channels_last_supported = !params.transposed && (input.ndimension() == 4) &&
!params.use_nnpack(input, weight) && input.device().is_cpu() &&
!params.is_dilated();
if (is_channels_last_supported) {
auto memory_format = input.suggest_memory_format();
input = input.contiguous(memory_format);
} else {
input = input.contiguous();
}
if (params.groups == 1) {
output = at::_convolution_nogroup(
input.contiguous(), weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
input, weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
} else {
std::vector<Tensor> outputs(params.groups);
input = input.contiguous();
for (int g = 0; g < params.groups; ++g) {
auto input_g = subtensor(input, 1, params.groups, g);
auto weight_g = subtensor(weight, 0, params.groups, g);
Expand Down

0 comments on commit da844e9

Please sign in to comment.