Skip to content

Commit

Permalink
Fix channel last 3d support for batch_norm (#642)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayisunx committed Mar 25, 2022
1 parent a61732e commit ae268ac
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 37 deletions.
7 changes: 5 additions & 2 deletions intel_extension_for_pytorch/csrc/aten/cpu/Normalization.cpp
Expand Up @@ -68,7 +68,8 @@ struct Var {
};

static inline bool is_contiguous(const at::Tensor& t) {
return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast);
return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast) ||
t.is_contiguous(at::MemoryFormat::ChannelsLast3d);
}

// For some ambiguous cases, it is possible a channels last contiguous
Expand All @@ -78,7 +79,9 @@ static inline bool is_contiguous(const at::Tensor& t) {
static inline at::MemoryFormat suggest_memory_format_contig(
const at::Tensor& t) {
return t.is_contiguous() ? at::MemoryFormat::Contiguous
: at::MemoryFormat::ChannelsLast;
: (t.is_contiguous(at::MemoryFormat::ChannelsLast3d)
? at::MemoryFormat::ChannelsLast3d
: at::MemoryFormat::ChannelsLast);
}

template <typename scalar_t, typename param_t>
Expand Down
Expand Up @@ -1267,7 +1267,9 @@ void batch_norm_cpu_kernel_impl(
eps);
}
});
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
} else if (
input.is_contiguous(at::MemoryFormat::ChannelsLast) ||
input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
input.scalar_type(),
Expand Down Expand Up @@ -1338,7 +1340,9 @@ void batch_norm_cpu_collect_stats_kernel_impl(
}
}
});
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
} else if (
input.is_contiguous(at::MemoryFormat::ChannelsLast) ||
input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
input.scalar_type(),
Expand Down Expand Up @@ -1445,7 +1449,9 @@ void batch_norm_cpu_backward_kernel_impl(
}
}
});
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
} else if (
input.is_contiguous(at::MemoryFormat::ChannelsLast) ||
input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::BFloat16,
input.scalar_type(),
Expand Down
80 changes: 48 additions & 32 deletions tests/cpu/test_cpu_ops.py
Expand Up @@ -13,6 +13,8 @@
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

bn_m = {1 : nn.BatchNorm1d, 2 : nn.BatchNorm2d, 3 : nn.BatchNorm3d}

class CPUOPsTester(TestCase):

def test_channelshuffle(self):
Expand Down Expand Up @@ -142,38 +144,52 @@ def test_pixel_shuffle_nhwc_cpu(self):
self.assertEqual(input.grad, ref_input.grad)

def test_batch_norm(self):
m = nn.BatchNorm2d(100)
x = torch.randn(20, 100, 35, 45)
x1 = x.clone().detach().requires_grad_()
y1 = m(x1)
y1.mean().backward()

# test channels last
x2 = x.clone().detach().to(memory_format=torch.channels_last).requires_grad_()
y2 = m(x2)
y2.mean().backward()
self.assertTrue(y2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(y1, y2)
self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(x1.grad, x2.grad)

# test bfloat16
x3 = x.clone().detach().bfloat16().requires_grad_()
y3 = m(x3)
y3.mean().backward()
self.assertTrue(y3.dtype == torch.bfloat16)
self.assertEqual(y1, y3, prec=0.1)
self.assertTrue(x3.grad.dtype == torch.bfloat16)
self.assertEqual(x1.grad, x3.grad)

# test autocast
with torch.cpu.amp.autocast():
for datatype in (torch.bfloat16, torch.float32):
x4 = x.clone().detach().to(datatype).requires_grad_()
y4 = m(x4)
y4.mean().backward()
self.assertTrue(y4.dtype == datatype)
self.assertTrue(x4.grad.dtype == datatype)
for dim in [2, 3]:
m = bn_m[dim](10)
input_size = [3, 10, 25, 25]
if dim == 3:
input_size.append(25)
x = torch.randn(input_size)
x1 = x.clone().detach().requires_grad_()
y1 = m(x1)
y1.mean().backward()

# test channels last
suggest_memory_format = torch.channels_last if dim == 2 else torch.channels_last_3d
x2 = x.clone().detach().to(memory_format=suggest_memory_format).requires_grad_()

y2 = m(x2)
y2.mean().backward()
self.assertTrue(y2.is_contiguous(memory_format=suggest_memory_format))
self.assertEqual(y1, y2)
self.assertTrue(x2.grad.is_contiguous(memory_format=suggest_memory_format))
self.assertEqual(x1.grad, x2.grad)

# test bfloat16
x3 = x.clone().detach().bfloat16().requires_grad_()
y3 = m(x3)
y3.mean().backward()
self.assertTrue(y3.dtype == torch.bfloat16)
self.assertEqual(y1, y3, prec=0.1)
self.assertTrue(x3.grad.dtype == torch.bfloat16)
self.assertEqual(x1.grad, x3.grad)

# test autocast
with torch.cpu.amp.autocast():
for datatype in (torch.bfloat16, torch.float32):
x4 = x.clone().detach().to(datatype).requires_grad_()
y4 = m(x4)
y4.mean().backward()
self.assertTrue(y4.dtype == datatype)
self.assertTrue(x4.grad.dtype == datatype)

x5 = x.clone().detach().to(datatype).to(memory_format=suggest_memory_format).requires_grad_()
y5 = m(x5)
y5.mean().backward()
self.assertTrue(y5.dtype == datatype)
self.assertTrue(x5.grad.dtype == datatype)
self.assertTrue(y5.is_contiguous(memory_format=suggest_memory_format))
self.assertTrue(x5.grad.is_contiguous(memory_format=suggest_memory_format))

def test_adaptive_avg_pool2d(self):
m = nn.AdaptiveAvgPool2d((5,7))
Expand Down

0 comments on commit ae268ac

Please sign in to comment.