This repository has been archived by the owner on Feb 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
disable_shape_infer.patch
88 lines (76 loc) · 2.7 KB
/
disable_shape_infer.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
diff --git a/tensor.cpp b/tensor.cpp
index 4531bf5..ab11a2b 100644
--- a/tensor.cpp
+++ b/tensor.cpp
@@ -165,6 +165,7 @@ public:
}
void set_shape(IntArrayRef shape) {
+ return;
sizes_and_strides_.set_sizes(shape);
store_shape();
@@ -173,6 +174,7 @@ public:
}
void set_strides(IntArrayRef strides) {
+ return;
assert(strides.size() == sizes_and_strides_.size());
for (unsigned i = 0, e = strides.size(); i != e; ++i) {
sizes_and_strides_.stride_at_unchecked(i) = strides[i];
@@ -204,8 +206,8 @@ public:
#endif
unsigned getTraceIdx() const { return trace_idx; }
- bool hasShapeData() const { return has_shape_data; }
- bool hasStridesData() const { return has_strides_data; }
+ bool hasShapeData() const { return false && has_shape_data; }
+ bool hasStridesData() const { return false && has_strides_data; }
void set(const Tensor &t) {
assert(dtype() == t.dtype());
@@ -282,13 +284,13 @@ public:
// an extra indirection. Another way is to templatize these.
IntArrayRef sizes() const override {
- if (!has_shape_data)
+ if (1||!has_shape_data)
ensure_materialized(STATS(FlushReason::SIZES));
return TensorImpl::sizes();
}
IntArrayRef strides() const override {
- if (!has_strides_data) {
+ if (1||!has_strides_data) {
if (false && trace_idx != -1u && !trace.is_flushing())
cerr << "BAD STRIDES FOR " << trace.getOps()[trace_idx].id << endl;
ensure_materialized(STATS(FlushReason::STRIDES));
@@ -297,7 +299,7 @@ public:
}
int64_t dim() const override {
- if (!has_shape_data)
+ if (1||!has_shape_data)
ensure_materialized(STATS(FlushReason::DIM));
return TensorImpl::dim();
}
@@ -313,13 +315,13 @@ public:
}
int64_t numel() const override {
- if (!has_shape_data)
+ if (1||!has_shape_data)
ensure_materialized(STATS(FlushReason::NUMEL));
return TensorImpl::numel();
}
bool is_contiguous(at::MemoryFormat memory_format) const override {
- if (!has_strides_data || !has_shape_data) {
+ if (1||!has_strides_data || !has_shape_data) {
if (false && trace_idx != -1u && !trace.is_flushing())
cerr << "BAD ISCONTIGUOUS FOR " << trace.getOps()[trace_idx].id << endl;
ensure_materialized(STATS(FlushReason::IS_CONTIGUOUS));
@@ -355,13 +357,13 @@ public:
}
int64_t size(int64_t d) const override {
- if (!has_shape_data)
+ if (1||!has_shape_data)
ensure_materialized(STATS(FlushReason::SIZE));
return TensorImpl::size(d);
}
int64_t stride(int64_t d) const override {
- if (!has_strides_data)
+ if (1||!has_strides_data)
ensure_materialized(STATS(FlushReason::STRIDE));
return TensorImpl::stride(d);
}