This repository has been archived by the owner on Feb 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
autograd.cpp
64 lines (47 loc) · 1.78 KB
/
autograd.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// Copyright (c) 2021-present The Torchy Authors.
// Distributed under the MIT license that can be found in the LICENSE file.
#include "tensor.h"
#include <torch/csrc/autograd/variable.h>
#include <torch/library.h>
using namespace at;
using namespace std;
namespace {
using AG = torch::autograd::AutogradMeta;
struct TorchyAutograd final : public AG {
void set_requires_grad(bool requires_grad, TensorImpl *self_impl) override {
ensure_materialized(self_impl STATS_ARG(FlushReason::AUTOGRAD));
AG::set_requires_grad(requires_grad, self_impl);
}
// No need to override this one as all the other methods already trigger
// materialization.
//bool requires_grad() const override;
Tensor& mutable_grad() override {
ensure_materialized(nullptr STATS_ARG(FlushReason::AUTOGRAD));
return AG::mutable_grad();
}
const Tensor& grad() const override {
ensure_materialized(nullptr STATS_ARG(FlushReason::AUTOGRAD));
return AG::grad();
}
const Tensor& fw_grad(uint64_t level, const TensorBase &self) const override {
ensure_materialized(nullptr STATS_ARG(FlushReason::AUTOGRAD));
return AG::fw_grad(level, self);
}
void set_fw_grad(const TensorBase &new_grad, const TensorBase &self,
uint64_t level, bool is_inplace_op) override {
ensure_materialized(nullptr STATS_ARG(FlushReason::AUTOGRAD));
AG::set_fw_grad(new_grad, self, level, is_inplace_op);
}
};
Tensor singleton_undefined_tensor;
struct TorchyFactory : public c10::impl::AutogradMetaFactory {
unique_ptr<AutogradMetaInterface> make() const override {
return make_unique<TorchyAutograd>();
}
const Tensor& undefined_tensor() const override {
return singleton_undefined_tensor;
}
};
TorchyFactory factory;
c10::impl::AutogradMetaFactoryRegisterer reg(&factory);
}