Skip to content

Commit

Permalink
Fix convolution crash in backward with weights; remove unnecessary co…
Browse files Browse the repository at this point in the history
…ntiguous calls (#341)

* Fix convolution crash; remove unnecessary contiguous calls

* Fix lintrunner
  • Loading branch information
DenisVieriu97 committed Feb 17, 2023
1 parent 2aa8066 commit ded4299
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 32 deletions.
32 changes: 10 additions & 22 deletions aten/src/ATen/native/mps/operations/Convolution.mm
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,17 @@ Tensor _mps_convolution(
}

Tensor mps_convolution_backward_input(
IntArrayRef input_size, const Tensor& grad_output_, const Tensor& weight_,
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
namespace native_mps = at::native::mps;
using namespace mps;
CheckedFrom c = "mps_convolution_backward_input";
TensorArg grad_output{ grad_output_, "grad_output", 1 },
weight{ weight_, "weight", 2 };
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
weight{ weight_t, "weight", 2 };
checkAllSameType(c, {grad_output, weight});
checkAllSameGPU(c, {grad_output, weight});
auto memory_format = grad_output_.suggest_memory_format();
auto memory_format = grad_output_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
Tensor grad_output_t = grad_output_.contiguous(memory_format);
Tensor weight_t = weight_.contiguous(memory_format);
MPSShape* weightShape = getMPSShape(weight_);
auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt);

// Avoid "grad_input" when this is being used as transposed convolution
Expand Down Expand Up @@ -341,7 +338,7 @@ Tensor mps_convolution_backward_input(
}

MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(weight_t.scalar_type()), weightShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);

MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) {
Expand Down Expand Up @@ -373,7 +370,7 @@ Tensor mps_convolution_backward_input(
}

auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t, weightShape);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
Expand All @@ -391,17 +388,14 @@ Tensor mps_convolution_backward_input(
}

Tensor mps_convolution_backward_weights(
IntArrayRef weight_size, const Tensor& grad_output_, const Tensor& input_,
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
namespace native_mps = at::native::mps;
using namespace mps;
CheckedFrom c = "mps_convolution_backward_weights";
auto memory_format = input_.suggest_memory_format();
auto memory_format = grad_output_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);

auto grad_output_t = grad_output_.to(memory_format);
auto input_t = input_.to(memory_format);

MPSShape* gradOutputShape = mps::getMPSShape(grad_output_t, memory_format);

// For uniformity with everything else, although it seems grad_weight
Expand Down Expand Up @@ -539,12 +533,9 @@ Tensor mps_convolution_backward_weights(
}

std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
std::array<bool,3> output_mask) {

Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());

Tensor grad_input, grad_weight, grad_bias;
if (input.numel() == 0) {
if (output_mask[0]) {
Expand Down Expand Up @@ -609,12 +600,9 @@ Tensor mps_convolution_transpose_backward_weight(


std::tuple<Tensor,Tensor> mps_convolution_transpose_backward(
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
const Tensor& input, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
std::array<bool,2> output_mask) {

Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());

Tensor grad_input, grad_weight;
if (output_mask[0]) {
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
Expand Down
95 changes: 85 additions & 10 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7717,7 +7717,8 @@ def test_conv_transpose_1d_nn_functional(self):
def test_conv_backward_1d_channels_last(self):
def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
# https://github.com/pytorch/pytorch/issues/84511
conv_cpu = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups)
conv_cpu = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
conv_mps = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
Expand Down Expand Up @@ -7757,15 +7758,89 @@ def test_conv1d_contiguous(self):

def test_conv2d_all_strides_paddings(self):
# https://github.com/pytorch/pytorch/issues/83180
y_cpu = torch.randn(2, 2, 3, 6)
y_gpu = y_cpu.to(device='mps')
for strideX in range(1, 4):
for strideY in range(1, 4):
conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=(strideX, strideY))
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
x_cpu = conv_cpu(y_cpu)
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()

if permute_data:
x_cpu.permute(0, 2, 3, 1)
x_mps.permute(0, 2, 3, 1)

for strideX in range(1, 4):
for strideY in range(1, 4):
conv_cpu = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()

conv_mps = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()

res_cpu = conv_cpu(x_cpu)
res_mps = conv_mps(x_mps)
self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)

res_cpu = res_cpu.sum().backward()
res_mps = res_mps.sum().backward()
self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
self.assertEqual(x_cpu.grad, x_mps.grad)

for mem_format_input in [torch.contiguous_format, torch.channels_last]:
for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
for permute_data in [True, False]:
helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)

def test_conv_transpose_2d_strided(self):
def helper(m_cpu, memory_format):
m_mps = copy.deepcopy(m_cpu).requires_grad_()
m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()

input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
input_mps = input_cpu.detach().clone().to("mps")

output_cpu = m_cpu(input_cpu)
output_mps = m_mps(input_mps)
self.assertEqual(output_cpu, output_mps)

for mem_format_input in [torch.contiguous_format, torch.channels_last]:
# With square kernels and equal stride
helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)

# non-square kernels and unequal stride and with padding
helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)

def test_conv_transpose_2d_specified_output(self):
input_cpu = torch.randn(1, 16, 12, 12)
input_mps = input_cpu.detach().clone().to("mps")

downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()

upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()

h_cpu = downsample_cpu(input_cpu)
h_mps = downsample_mps(input_mps)
self.assertEqual(h_cpu, h_mps)

size_cpu = h_cpu.size()
size_mps = h_mps.size()
self.assertEqual(size_cpu, size_mps)

output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
output_mps = upsample_mps(h_mps, output_size=input_mps.size())
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())

def test_conv2d_single_stride(self):
y_cpu = torch.randn(2, 2, 3, 6)
Expand Down

0 comments on commit ded4299

Please sign in to comment.