Skip to content

Commit

Permalink
Fix the Channels last bug with GradientWithInput. (#179)
Browse files Browse the repository at this point in the history
* Fix the Channels last bug with GradientWithInput.

The bug was mentioned in :
pytorch#77764 (comment)

* Update the placeholder

* Remove the extra print.
  • Loading branch information
kulinseth committed Feb 5, 2023
1 parent 0467f0e commit 472c09f
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions aten/src/ATen/native/mps/operations/Convolution.mm
Original file line number Diff line number Diff line change
Expand Up @@ -198,25 +198,22 @@ Tensor _mps_convolution(
}

Tensor mps_convolution_backward_input(
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef input_size, const Tensor& grad_output_, const Tensor& weight_,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
namespace native_mps = at::native::mps;
using namespace mps;
CheckedFrom c = "mps_convolution_backward_input";
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
weight{ weight_t, "weight", 2 };
TensorArg grad_output{ grad_output_, "grad_output", 1 },
weight{ weight_, "weight", 2 };
checkAllSameType(c, {grad_output, weight});
checkAllSameGPU(c, {grad_output, weight});
auto memory_format = grad_output_t.suggest_memory_format();
auto memory_format = grad_output_.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);

auto grad_input_t = at::empty(
input_size,
grad_output->scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
MPSShape* weightShape = get_mps_conv_shape(weight_, is_channels_last);
MPSShape* gradOutputShape = get_mps_conv_shape(grad_output_, is_channels_last);
Tensor grad_output_t = grad_output_.contiguous(memory_format);
Tensor weight_t = weight_.contiguous(memory_format);
auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt);

// Avoid "grad_input" when this is being used as transposed convolution
TensorArg grad_input{ grad_input_t, "result", 0 };
Expand Down Expand Up @@ -277,7 +274,7 @@ Tensor mps_convolution_backward_input(
at::MemoryFormat::Contiguous, groups);

MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(weight_t.scalar_type()), weightShape);

MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) {
Expand All @@ -300,7 +297,7 @@ Tensor mps_convolution_backward_input(
}

auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t, weightShape);
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
Expand Down

0 comments on commit 472c09f

Please sign in to comment.