Skip to content

Commit

Permalink
Add XBlobGetMutableTensor that returns Tensor (pytorch#14424)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#14424

Pull Request resolved: pytorch#14136

Since now Tensor is a shared_ptr, it doesn't make sense to have Tensor* around anymore,
so we want to change Tensor* to Tensor in the interface.
We added functions that work with `Tensor` instead of `Tensor*` in this diff.

To remove Tensor*, we'll do following
```
auto* Y = Ouptut(0);
Y->mutable_data...
```
-->
```
auto Y = Output(0);
Y.mutable_data...
```

But to run clangr codemod, we'll keep both APIs in different names, e.g. `Output` and `XOutput`, and do the refactor and then delete the old method and rename the new method into the old one.
For example for `Output`, we'll first codemod the callsites from `Output` to `XOutput`, then delete the old `Output` and rename `XOutput` to `Output` in the end.

Reviewed By: smessmer

Differential Revision: D12934074

fbshipit-source-id: 120778830835fc4d90286cf2ed00b4994cf32737
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Nov 28, 2018
1 parent 0f62af4 commit 8bb14df
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
7 changes: 7 additions & 0 deletions caffe2/core/blob.h
Expand Up @@ -28,6 +28,8 @@ inline Tensor* BlobSetTensor(Blob* blob, const Tensor& tensor) {
return blob->Reset<Tensor>(new Tensor(tensor));
}

// need to keep both functions that returns Tensor* and the one
// returns Tensor for clangr codemod
inline Tensor*
BlobGetMutableTensor(Blob* blob, at::IntList dims, at::TensorOptions options) {
if (blob->IsType<Tensor>()) {
Expand Down Expand Up @@ -58,6 +60,11 @@ BlobGetMutableTensor(Blob* blob, at::IntList dims, at::TensorOptions options) {
return BlobSetTensor(blob, caffe2::empty(dims, options));
}

inline Tensor
XBlobGetMutableTensor(Blob* blob, at::IntList dims, at::TensorOptions options) {
return *BlobGetMutableTensor(blob, dims, options);
}

inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
Expand Down
18 changes: 17 additions & 1 deletion caffe2/core/operator.h
Expand Up @@ -127,6 +127,14 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
return BlobGetMutableTensor(outputs_.at(idx), type);
}

inline Tensor
XOutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
CAFFE_ENFORCE_WITH_CALLER(
options.device_opt() != c10::nullopt,
"device must be provided in option.");
return XBlobGetMutableTensor(outputs_.at(idx), dims, options);
}

inline Tensor*
OutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
CAFFE_ENFORCE_WITH_CALLER(
Expand Down Expand Up @@ -495,7 +503,15 @@ class Operator : public OperatorBase {
return OperatorBase::template Input<Tensor>(idx, type);
}

inline Tensor* Output(int idx, at::IntList dims, at::TensorOptions options) {
Tensor XOutput(int idx, at::IntList dims, at::TensorOptions options) {
if (options.device_opt() == c10::nullopt) {
return OperatorBase::XOutputTensor(
idx, dims, options.device(context_.device()));
}
return OperatorBase::XOutputTensor(idx, dims, options);
}

Tensor* Output(int idx, at::IntList dims, at::TensorOptions options) {
if (options.device_opt() == c10::nullopt) {
return OperatorBase::OutputTensor(
idx, dims, options.device(context_.device()));
Expand Down

0 comments on commit 8bb14df

Please sign in to comment.