-
Notifications
You must be signed in to change notification settings - Fork 216
/
Conv.cpp
591 lines (548 loc) · 18.6 KB
/
Conv.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
#include "Conv.h"
#include <torch/all.h>
#include "WeightPack.h"
#include "autocast/autocast_mode.h"
#include "ideep/IDeepConversions.h"
#include "utils/utils.h"
namespace torch_ipex {
namespace cpu {
std::vector<int64_t> calc_conv_output_size(
at::IntArrayRef input_size,
at::IntArrayRef kernel_size,
at::IntArrayRef padding,
at::IntArrayRef stride,
at::IntArrayRef dilation) {
auto dim = input_size.size();
std::vector<int64_t> output_size(dim);
output_size[0] = input_size[0];
output_size[1] = kernel_size[0];
for (size_t d = 2; d < dim; ++d) {
auto kernel = dilation[d - 2] * (kernel_size[d] - 1) + 1;
output_size[d] =
(input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
}
return output_size;
}
c10::SymDimVector calc_conv_output_size(
c10::SymIntArrayRef input_size,
at::IntArrayRef kernel_size,
at::IntArrayRef padding,
at::IntArrayRef stride,
at::IntArrayRef dilation) {
auto dim = input_size.size();
c10::SymDimVector output_size(dim);
output_size[0] = input_size[0];
output_size[1] = kernel_size[0];
for (size_t d = 2; d < dim; ++d) {
auto kernel = dilation[d - 2] * (kernel_size[d] - 1) + 1;
output_size[d] =
(input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
}
return output_size;
}
void convolution_kernel_output(
const at::Tensor& input,
const ideep::tensor& mkldnn_weight,
const ideep::tensor& mkldnn_bias,
at::Tensor& output,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
int64_t groups,
const ideep::attr_t& attr) {
TORCH_CHECK(
(IS_CONTIGUOUS_ANY(input)) && (IS_CONTIGUOUS_ANY(output)),
"input and output are need contiguous tensor for "
"convolution_kernel_output");
const ideep::tensor mkldnn_input_ = itensor_view_from_dense(input);
ideep::tensor mkldnn_input = mkldnn_input_;
// The following code forces the 3D input to channels last, which is a
// temporary workaround before channels last 1D is formally supported in
// PyTorch.
if (mkldnn_input_.ndims() == 3 &&
!mkldnn_input_.get_desc().is_channels_last()) {
ideep::tensor mkldnn_input_conv1d{
mkldnn_input_.get_desc().to_format(ideep::format_tag::nwc)};
mkldnn_input_conv1d.feed_from(mkldnn_input_);
mkldnn_input = mkldnn_input_conv1d;
}
auto output_sizes = output.sizes();
ideep::tensor mkldnn_output = itensor_view_from_dense(output);
if (mkldnn_bias.is_empty()) {
ideep::convolution_forward::compute(
mkldnn_input,
mkldnn_weight,
{output_sizes.begin(), output_sizes.end()},
mkldnn_output,
{stride.begin(), stride.end()},
{dilation.begin(), dilation.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
groups,
ideep::scale_t(),
ideep::scale_t(),
ideep::scale_t(),
attr);
} else {
ideep::convolution_forward::compute(
mkldnn_input,
mkldnn_weight,
mkldnn_bias,
{output_sizes.begin(), output_sizes.end()},
mkldnn_output,
{stride.begin(), stride.end()},
{dilation.begin(), dilation.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
groups,
ideep::scale_t(),
ideep::scale_t(),
ideep::scale_t(),
attr);
}
}
at::Tensor convolution_kernel(
const at::Tensor& input,
const ideep::tensor& mkldnn_weight,
const ideep::tensor& mkldnn_bias,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
int64_t groups,
const ideep::attr_t& attr) {
// Base convolution kernel, this base kernel will not change input's format,
// so make sure you has make process the input's format before call this
// function, the output wil has same format with input.
// TODO: the input will be actively converted to channels last format
// after the 5-D tensor supports channels last format.
TORCH_CHECK(
IS_CONTIGUOUS_ANY(input),
"input is need to a contiguous tensor for convolution_kernel");
auto kernel_size = mkldnn_weight.get_dims();
auto input_size = input.sizes();
std::vector<int64_t> output_sizes =
calc_conv_output_size(input_size, kernel_size, padding, stride, dilation);
at::Tensor output;
if (input.dim() != 3) {
output = at::empty(
output_sizes,
input.options().memory_format(input.suggest_memory_format()));
} else {
// This a temporary workaround before channels last 1D is formally supported
// in PyTorch. We will force to return nwc output.
std::vector<int64_t> output_strides = {
(output_sizes[1] * output_sizes[2]), 1, output_sizes[1]};
output = at::empty_strided(output_sizes, output_strides, input.options());
}
convolution_kernel_output(
input,
mkldnn_weight,
mkldnn_bias,
output,
stride,
padding,
dilation,
groups,
attr);
return output;
}
at::Tensor convolution_forward_impl(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& op_context,
c10::optional<at::IntArrayRef> kernel_size,
c10::optional<at::IntArrayRef> padding,
c10::optional<at::IntArrayRef> stride,
c10::optional<at::IntArrayRef> dilation) {
#if defined(IPEX_DISP_OP)
printf("torch_ipex::convolution_forward_impl\n");
#endif
RECORD_FUNCTION(
"torch_ipex::convolution_forward_impl", c10::ArrayRef<c10::IValue>({}));
return reinterpret_cast<IpexConvolutionOpContext*>(
op_context.data_ptr<int64_t>()[0])
->run(input, ideep::attr_t(torch_ipex::fpmath_mode));
}
at::Tensor convolution_backward_input(
at::IntArrayRef input_size,
const at::Tensor& grad_output,
const ideep::tensor& mkldnn_weight,
at::IntArrayRef padding,
at::IntArrayRef stride,
at::IntArrayRef dilation,
int64_t groups,
bool bias_defined,
bool weight_use_channels_last) {
TORCH_CHECK(
input_size.size() == 4 || input_size.size() == 5,
"Only support 2d or 3d convolution for convolution_backward_input");
const ideep::tensor mkldnn_grad_output = itensor_view_from_dense(grad_output);
bool is_channels_last_contiguous =
grad_output.is_contiguous(at::MemoryFormat::ChannelsLast) ||
grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d);
auto memory_format = at::MemoryFormat::Contiguous;
if (is_channels_last_contiguous) {
if (input_size.size() == 4) {
memory_format = at::MemoryFormat::ChannelsLast;
} else {
memory_format = at::MemoryFormat::ChannelsLast3d;
}
}
auto grad_input =
at::empty(input_size, grad_output.options().memory_format(memory_format));
ideep::tensor mkldnn_grad_input;
if (is_channels_last_contiguous) {
mkldnn_grad_input = itensor_view_from_dense(grad_input);
}
ideep::convolution_backward_data::compute(
mkldnn_grad_output,
mkldnn_weight,
input_size.vec(),
mkldnn_grad_input,
stride.vec(),
dilation.vec(),
padding.vec(),
padding.vec(),
groups,
ideep::attr_t(torch_ipex::fpmath_mode));
if (is_channels_last_contiguous) {
return grad_input;
} else {
return mkldnn_to_dense(new_with_itensor_mkldnn(
std::move(mkldnn_grad_input),
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
grad_output.options().device_opt()));
}
}
std::tuple<at::Tensor, at::Tensor> convolution_backward_weights(
const at::Tensor& grad_output,
const at::Tensor& input,
const at::Tensor& weight,
const ideep::tensor::desc& packed_weight_desc,
at::IntArrayRef padding,
at::IntArrayRef stride,
at::IntArrayRef dilation,
int64_t groups,
bool bias_defined) {
TORCH_CHECK(
input.dim() == 4 || input.dim() == 5,
"Only support 2d or 3d convolution for convolution_backward_weights");
const ideep::tensor mkldnn_grad_output = itensor_view_from_dense(grad_output);
const ideep::tensor mkldnn_input = itensor_view_from_dense(input);
bool is_channels_last_contiguous =
grad_output.is_contiguous(at::MemoryFormat::ChannelsLast) ||
grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d);
auto grad_weight = at::empty_like(weight, grad_output.options());
at::Tensor grad_bias;
ideep::tensor mkldnn_grad_weight, mkldnn_grad_bias;
if (grad_output.scalar_type() == at::ScalarType::Float) {
mkldnn_grad_weight.init(
packed_weight_desc, grad_weight.template data_ptr<float>());
} else if (grad_output.scalar_type() == at::ScalarType::BFloat16) {
mkldnn_grad_weight.init(
packed_weight_desc, grad_weight.template data_ptr<c10::BFloat16>());
} else {
TORCH_CHECK(
grad_output.scalar_type() == at::ScalarType::Half,
"Only support bfloat16, float16 and float for convolution_backward_weights");
mkldnn_grad_weight.init(
packed_weight_desc, grad_weight.template data_ptr<c10::Half>());
}
if (bias_defined) {
grad_bias = at::empty({grad_output.size(1)}, grad_output.options());
mkldnn_grad_bias = itensor_view_from_dense(grad_bias);
ideep::convolution_backward_weights::compute(
mkldnn_input,
mkldnn_grad_output,
packed_weight_desc.get_dims(),
mkldnn_grad_weight,
mkldnn_grad_bias,
stride.vec(),
dilation.vec(),
padding.vec(),
padding.vec(),
groups,
ideep::attr_t(torch_ipex::fpmath_mode));
} else {
ideep::convolution_backward_weights::compute(
mkldnn_input,
mkldnn_grad_output,
packed_weight_desc.get_dims(),
mkldnn_grad_weight,
stride.vec(),
dilation.vec(),
padding.vec(),
padding.vec(),
groups,
ideep::attr_t(torch_ipex::fpmath_mode));
}
return std::make_tuple(grad_weight, grad_bias);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> convolution_backward_kernel(
const at::Tensor& input,
const at::Tensor& grad_output,
const at::Tensor& at_weight,
const ideep::tensor& mkldnn_weight,
const ideep::tensor& mkldnn_bias,
at::IntArrayRef stride,
at::IntArrayRef padding,
at::IntArrayRef dilation,
int64_t groups,
const bool weight_channels_last,
std::array<bool, 3> output_mask) {
#if defined(IPEX_DISP_OP)
printf("torch_ipex::convolution_backward\n");
#endif
RECORD_FUNCTION(
"torch_ipex::convolution_backward", c10::ArrayRef<c10::IValue>({}));
TORCH_CHECK(
input.dim() == 4 || input.dim() == 5,
"Only support 2d or 3d convolution for convolution_backward");
bool use_channels_last =
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d ||
weight_channels_last;
auto memory_format = at::MemoryFormat::Contiguous;
if (use_channels_last) {
if (input.dim() == 4) {
memory_format = at::MemoryFormat::ChannelsLast;
} else {
memory_format = at::MemoryFormat::ChannelsLast3d;
}
}
auto grad_output_ = grad_output.contiguous(memory_format);
at::Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
grad_input = convolution_backward_input(
input.sizes(),
grad_output_,
mkldnn_weight,
padding,
stride,
dilation,
groups,
output_mask[2],
weight_channels_last);
}
if (output_mask[1] || output_mask[2]) {
auto input_ = input.contiguous(memory_format);
std::tie(grad_weight, grad_bias) = convolution_backward_weights(
grad_output_,
input_,
at_weight,
mkldnn_weight.get_desc(),
padding,
stride,
dilation,
groups,
output_mask[2]);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> convolution_backward(
const at::Tensor& input,
const at::Tensor& grad_output,
std::array<bool, 3> output_mask,
const at::Tensor& op_context) {
return reinterpret_cast<IpexConvolutionOpContext*>(
op_context.data_ptr<int64_t>()[0])
->run_backward(input, grad_output, output_mask);
}
at::Tensor IPEXConvolutionOp::_forward(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& op_context,
c10::optional<at::IntArrayRef> kernel_size,
c10::optional<at::IntArrayRef> padding,
c10::optional<at::IntArrayRef> stride,
c10::optional<at::IntArrayRef> dilation) {
at::AutoDispatchBelowADInplaceOrView g;
RECORD_FUNCTION(
"IPEXConvolutionOp::_forward", c10::ArrayRef<c10::IValue>({}));
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torch_ipex::convolution_forward", "")
.typed<decltype(convolution_forward)>();
return op.call(
input,
weight,
bias_opt,
op_context,
kernel_size,
padding,
stride,
dilation);
}
at::Tensor IPEXConvolutionOp::forward(
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& op_context,
c10::optional<at::IntArrayRef> kernel_size,
c10::optional<at::IntArrayRef> padding,
c10::optional<at::IntArrayRef> stride,
c10::optional<at::IntArrayRef> dilation) {
RECORD_FUNCTION("IPEXConvolutionOp::forward", c10::ArrayRef<c10::IValue>({}));
at::AutoDispatchBelowADInplaceOrView g;
ctx->saved_data["op_context"] = op_context;
ctx->saved_data["input_requires_grad"] = input.requires_grad();
ctx->saved_data["weight_requires_grad"] = weight.requires_grad();
ctx->saved_data["bias_requires_grad"] =
bias_opt.has_value() && bias_opt.value().requires_grad() ? true : false;
ctx->save_for_backward({input});
return _forward(
input,
weight,
bias_opt,
op_context,
kernel_size,
padding,
stride,
dilation);
}
torch::autograd::variable_list IPEXConvolutionOp::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
RECORD_FUNCTION(
"IPEXConvolutionOp::backward", c10::ArrayRef<c10::IValue>({}));
auto op_context = ctx->saved_data["op_context"].toTensor();
std::array<bool, 3> output_mask;
output_mask[0] = ctx->saved_data["input_requires_grad"].toBool();
output_mask[1] = ctx->saved_data["weight_requires_grad"].toBool();
output_mask[2] = ctx->saved_data["bias_requires_grad"].toBool();
auto saved = ctx->get_saved_variables();
at::Tensor input = saved[0];
at::Tensor grad_input, grad_weight, grad_bias;
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("torch_ipex::convolution_backward", "")
.typed<decltype(convolution_backward)>();
std::tie(grad_input, grad_weight, grad_bias) =
op.call(input, grad_outputs[0], output_mask, op_context);
return {
grad_input,
grad_weight,
grad_bias,
at::Tensor(),
at::Tensor(),
at::Tensor(),
at::Tensor(),
at::Tensor()};
}
at::Tensor convolution_forward(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& op_context,
c10::optional<at::IntArrayRef> kernel_size,
c10::optional<at::IntArrayRef> padding,
c10::optional<at::IntArrayRef> stride,
c10::optional<at::IntArrayRef> dilation) {
if (at::GradMode::is_enabled()) {
return IPEXConvolutionOp::apply(
input,
weight,
bias_opt,
op_context,
kernel_size,
padding,
stride,
dilation);
}
return IPEXConvolutionOp::_forward(
input,
weight,
bias_opt,
op_context,
kernel_size,
padding,
stride,
dilation);
}
at::Tensor convolution_forward_meta(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& op_context,
c10::optional<at::IntArrayRef> kernel_size,
c10::optional<at::IntArrayRef> padding,
c10::optional<at::IntArrayRef> stride,
c10::optional<at::IntArrayRef> dilation) {
TORCH_CHECK(
kernel_size.has_value() && padding.has_value() && stride.has_value() &&
dilation.has_value(),
"kernel_size, padding, stride and dilation must have value for convolution_forward_meta");
auto input_size = input.sym_sizes();
c10::SymDimVector output_sizes = calc_conv_output_size(
input_size,
kernel_size.value(),
padding.value(),
stride.value(),
dilation.value());
auto output = at::empty_symint(output_sizes, input.options());
return output;
}
} // namespace cpu
} // namespace torch_ipex
namespace torch_ipex {
namespace autocast {
at::Tensor convolution_forward(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& op_context,
c10::optional<at::IntArrayRef> kernel_size,
c10::optional<at::IntArrayRef> padding,
c10::optional<at::IntArrayRef> stride,
c10::optional<at::IntArrayRef> dilation) {
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torch_ipex::convolution_forward", "")
.typed<decltype(convolution_forward)>();
auto target_type = get_autocast_dtype();
// TODO: make check weight dtype should be float for training case.
return op.call(
cpu_cached_cast(target_type, input),
weight,
bias_opt,
op_context,
kernel_size,
padding,
stride,
dilation);
}
} // namespace autocast
} // namespace torch_ipex
namespace {
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def(
"convolution_forward(Tensor input, Tensor weight, Tensor? bias, "
"Tensor W_prepack, int[]? kernel_size, int[]? padding, int[]? stride, int[]? dilation) -> Tensor");
m.impl(
"convolution_forward",
c10::DispatchKey::Autograd,
torch_ipex::cpu::convolution_forward);
m.impl(
"convolution_forward",
c10::DispatchKey::AutocastCPU,
torch_ipex::autocast::convolution_forward);
m.impl(
"convolution_forward",
c10::DispatchKey::CPU,
torch_ipex::cpu::convolution_forward_impl);
m.impl(
"convolution_forward",
c10::DispatchKey::Meta,
torch_ipex::cpu::convolution_forward_meta);
// bw
m.def(
"convolution_backward(Tensor input, Tensor grad_output, bool[3] out_mask, "
"Tensor W_prepack) -> (Tensor, Tensor, Tensor)");
m.impl(
"convolution_backward",
c10::DispatchKey::CPU,
torch_ipex::cpu::convolution_backward);
}
} // namespace