Skip to content

Commit

Permalink
[jit] Add RRef to IValue and JIT type system (pytorch#32992)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#32992

This PR add RRef to IValue and the JIT type system.

- The RRefInterface abstract class inherit from intrusive_ptr_target,
  this made the RRef class can be hold in ivalue as intrusive_ptr

- Add RRefType as a JIT type, it's a container type similar to
future type.

Test Plan: Imported from OSS

Differential Revision: D19871242

Pulled By: wanchaol

fbshipit-source-id: cb80ca32605096f9a42ef147109fb368a7c1d4d3
  • Loading branch information
wanchaol authored and facebook-github-bot committed Feb 14, 2020
1 parent 9ae4d38 commit b2c5896
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 1 deletion.
4 changes: 4 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Expand Up @@ -47,6 +47,8 @@ TypePtr IValue::type() const {
return ListType::create(toList().elementType());
case Tag::Future:
return toFuture()->type();
case Tag::RRef:
return toRRef()->type();
case Tag::Device:
return DeviceObjType::get();
case Tag::Object:
Expand Down Expand Up @@ -221,6 +223,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return out << "Capsule";
case IValue::Tag::GenericList:
return printList(out, v.toList(), "[", "]", formatter);
case IValue::Tag::RRef:
return out << "RRef";
case IValue::Tag::Future:
return out << "Future";
case IValue::Tag::Uninitialized:
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/core/ivalue.h
Expand Up @@ -23,6 +23,7 @@ template<class T> class List;
struct IValue;
struct ClassType;
struct Type;
class RRefInterface;
using TypePtr = std::shared_ptr<Type>;
namespace ivalue {
struct Tuple;
Expand Down Expand Up @@ -56,7 +57,8 @@ struct PyObjectHolder;
_(Object) \
_(PyObject) \
_(Uninitialized) \
_(Capsule)
_(Capsule) \
_(RRef) \

// [doxygen private]
// These methods are not actually private but we don't want to document them, so
Expand Down Expand Up @@ -244,6 +246,12 @@ struct CAFFE2_API IValue final {
c10::intrusive_ptr<ivalue::Future> toFuture() &&;
c10::intrusive_ptr<ivalue::Future> toFuture() const &;

// RRef
IValue(c10::intrusive_ptr<c10::RRefInterface> v);
bool isRRef() const { return Tag::RRef == tag; }
c10::intrusive_ptr<c10::RRefInterface> toRRef() &&;
c10::intrusive_ptr<c10::RRefInterface> toRRef() const &;

// Int
IValue(int64_t i)
: tag(Tag::Int), is_intrusive_ptr(false) {
Expand Down
19 changes: 19 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -10,6 +10,7 @@
#include <c10/core/UndefinedTensorImpl.h>
#include <ATen/core/Dict.h>
#include <ATen/core/List.h>
#include <ATen/core/rref_interface.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -58,6 +59,11 @@ intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) {
return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release()));
}

template<class T, class U>
intrusive_ptr<T> dynamic_intrusive_pointer_cast(intrusive_ptr<U> r) {
return intrusive_ptr<T>::reclaim(dynamic_cast<T*>(r.release()));
}

inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && {
AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
return moveToIntrusivePtr<ivalue::Future>();
Expand All @@ -66,6 +72,14 @@ inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() const & {
AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
return toIntrusivePtr<ivalue::Future>();
}
inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() && {
AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
return moveToIntrusivePtr<c10::RRefInterface>();
}
inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() const & {
AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
return toIntrusivePtr<c10::RRefInterface>();
}
inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() && {
AT_ASSERT(isString(), "Expected String but got ", tagKind());
return moveToIntrusivePtr<ivalue::ConstantString>();
Expand Down Expand Up @@ -472,6 +486,7 @@ DEFINE_TO(c10::impl::GenericDict, toGenericDict)
DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple)
DEFINE_TO(std::string, toStringRef)
DEFINE_TO(c10::intrusive_ptr<ivalue::Future>, toFuture)
DEFINE_TO(c10::intrusive_ptr<c10::RRefInterface>, toRRef)
DEFINE_TO(IValue, toIValue)
DEFINE_TO(c10::Device, toDevice)
DEFINE_TO(at::ScalarType, toScalarType)
Expand Down Expand Up @@ -770,6 +785,10 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
payload.as_intrusive_ptr = v.release();
}

inline IValue::IValue(c10::intrusive_ptr<c10::RRefInterface> v)
: tag(Tag::RRef), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
}
inline const std::string& IValue::toStringRef() const {
return toString()->string();
}
Expand Down
33 changes: 33 additions & 0 deletions aten/src/ATen/core/jit_type.h
Expand Up @@ -38,6 +38,7 @@ using OptNameList = c10::optional<std::vector<std::string>>;
_(NumberType) \
_(FloatType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
_(NoneType) \
_(StringType) \
Expand Down Expand Up @@ -790,6 +791,38 @@ struct CAFFE2_API FutureType
FutureType(TypePtr elem) : SingleElementType(elem) {}
};

struct RRefType;
using RRefTypePtr = std::shared_ptr<RRefType>;

struct CAFFE2_API RRefType
: public SingleElementType<TypeKind::RRefType, RRefType> {
friend struct Type;
template <typename... T>
static RRefTypePtr create(TypePtr elem) {
return RRefTypePtr(
new RRefType(std::move(elem))); // NOLINT(modernize-make-shared)
}

std::string str() const override {
std::stringstream ss;
ss << "RRef(" << getElementType()->str() << ")";
return ss.str();
}
std::string python_str() const override {
std::stringstream ss;
ss << "RRef[" << getElementType()->python_str() << "]";
return ss.str();
}
TypePtr createWithContained(
std::vector<TypePtr> contained_types) const override {
return create(contained_types.at(0));
}

private:
RRefType(TypePtr elem) : SingleElementType(elem) {}
};


using ::torch::jit::Function;
struct NamedType;
using NamedTypePtr = std::shared_ptr<NamedType>;
Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/core/type.cpp
Expand Up @@ -44,6 +44,9 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
} else if(t.kind() == TypeKind::FutureType) {
auto elem = t.cast<FutureType>()->getElementType();
out << "Future[" << *elem << "]";
} else if(t.kind() == TypeKind::RRefType) {
auto elem = t.cast<RRefType>()->getElementType();
out << "RRef[" << *elem << "]";
} else if(auto tup = t.cast<TupleType>()) {
if (tup->schema()) {
out << "NamedTuple";
Expand Down Expand Up @@ -377,6 +380,19 @@ MatchTypeReturn matchTypeVariables(
ss << "Cannot match a future to " << actual->python_str();
return ss.str();
}
} else if (auto lt_formal = formal->cast<RRefType>()) {
if (auto lt_actual = actual->cast<RRefType>()) {
const auto innerMatch = matchTypeVariables(
lt_formal->getElementType(), lt_actual->getElementType(), type_env);
if (!innerMatch.success()) {
return innerMatch;
}
return MatchTypeReturn::Success();
} else {
std::stringstream ss;
ss << "Cannot match a rref to " << actual->python_str();
return ss.str();
}
} else if (auto opt_formal = formal->cast<OptionalType>()) {
if (auto opt_actual = actual->cast<OptionalType>()) {
const auto optionedMatch = matchTypeVariables(
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/pybind_utils.h
Expand Up @@ -560,8 +560,12 @@ inline IValue toIValue(
return py::cast<int64_t>(obj);
} else if (py::isinstance<py::float_>(obj)) {
return py::cast<double>(obj);
} else {
throw py::cast_error(
c10::str("Cannot cast ", py::str(obj), " to ", type->python_str()));
}
}
case TypeKind::RRefType:
case TypeKind::GeneratorType:
case TypeKind::VarType:
case TypeKind::FutureType:
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/unpickler.cpp
Expand Up @@ -65,6 +65,7 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
case QSchemeType::Kind:
case LayoutType::Kind:
case ScalarTypeType::Kind:
case RRefType::Kind:
// no op, there is nothing to tag
break;
case AnyType::Kind:
Expand Down

0 comments on commit b2c5896

Please sign in to comment.