From 472c09fc810b6bf9fa77a999e2d832cf54c9ad6a Mon Sep 17 00:00:00 2001 From: Kulin Seth Date: Thu, 17 Nov 2022 10:06:46 -0800 Subject: [PATCH] Fix the Channels last bug with GradientWithInput. (#179) * Fix the Channels last bug with GradientWithInput. The bug was mentioned in : https://github.com/pytorch/pytorch/issues/77764#issuecomment-1312241902 * Update the placeholder * Remove the extra print. --- .../ATen/native/mps/operations/Convolution.mm | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index eb1ee36eca028..1d2ffd0d662a9 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -198,25 +198,22 @@ Tensor _mps_convolution( } Tensor mps_convolution_backward_input( - IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t, + IntArrayRef input_size, const Tensor& grad_output_, const Tensor& weight_, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { namespace native_mps = at::native::mps; using namespace mps; CheckedFrom c = "mps_convolution_backward_input"; - TensorArg grad_output{ grad_output_t, "grad_output", 1 }, - weight{ weight_t, "weight", 2 }; + TensorArg grad_output{ grad_output_, "grad_output", 1 }, + weight{ weight_, "weight", 2 }; checkAllSameType(c, {grad_output, weight}); checkAllSameGPU(c, {grad_output, weight}); - auto memory_format = grad_output_t.suggest_memory_format(); + auto memory_format = grad_output_.suggest_memory_format(); bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast); - - auto grad_input_t = at::empty( - input_size, - grad_output->scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); + MPSShape* weightShape = get_mps_conv_shape(weight_, is_channels_last); + MPSShape* gradOutputShape = get_mps_conv_shape(grad_output_, is_channels_last); + Tensor grad_output_t = grad_output_.contiguous(memory_format); + Tensor weight_t = weight_.contiguous(memory_format); + auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt); // Avoid "grad_input" when this is being used as transposed convolution TensorArg grad_input{ grad_input_t, "result", 0 }; @@ -277,7 +274,7 @@ Tensor mps_convolution_backward_input( at::MemoryFormat::Contiguous, groups); MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape); - MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t); + MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(weight_t.scalar_type()), weightShape); MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor; if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) { @@ -300,7 +297,7 @@ Tensor mps_convolution_backward_input( } auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape); - auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t); + auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t, weightShape); auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input); NSDictionary *feeds = @{