Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ jobs:
with:
python-version: ${{ env.MLC_PYTHON_VERSION }}
- uses: pre-commit/action@v3.0.1
- uses: ytanikin/pr-conventional-commits@1.4.0
with:
task_types: '["feat", "fix", "ci", "chore", "test"]'
add_label: 'false'
windows:
name: Windows
runs-on: windows-latest
Expand Down
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_install_hook_types:
- pre-commit
- commit-msg
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand Down
71 changes: 38 additions & 33 deletions cpp/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,30 @@ inline mlc::Str Serialize(Any any) {
using TObj2Idx = std::unordered_map<Object *, int32_t>;
using TJsonTypeIndex = decltype(get_json_type_index);
struct Emitter {
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
// clang-format off
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast<double>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast<double>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); }
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); }
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
MLC_INLINE void operator()(MLCTypeField *, const char **) { MLC_THROW(TypeError) << "Unserializable type: const char *"; }
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { if (Object *v = obj->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { if (Object *v = opt->get()) EmitObject(v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { if (const int64_t *v = opt->get()) EmitInt(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { if (const double *v = opt->get()) EmitFloat(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { if (const DLDevice *v = opt->get()) EmitDevice(*v); else EmitNil(); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { if (const DLDataType *v = opt->get()) EmitDType(*v); else EmitNil(); }
// clang-format on
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { EmitInt(static_cast<int64_t>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, float *v) { EmitFloat(static_cast<double>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, double *v) { EmitFloat(static_cast<double>(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { EmitDType(*v); }
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { EmitDevice(*v); }
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *) {
MLC_THROW(TypeError) << "Unserializable type: void *";
}
MLC_INLINE void operator()(MLCTypeField *, void **) { MLC_THROW(TypeError) << "Unserializable type: void *"; }
MLC_INLINE void operator()(MLCTypeField *, const char **) {
MLC_THROW(TypeError) << "Unserializable type: const char *";
}
inline void EmitNil() { (*os) << ", null"; }
inline void EmitFloat(double v) { (*os) << ", " << std::fixed << std::setprecision(19) << v; }
inline void EmitInt(int64_t v) {
Expand Down Expand Up @@ -98,10 +102,17 @@ inline mlc::Str Serialize(Any any) {
const TObj2Idx *obj2index;
};

std::unordered_map<Object *, int32_t> topo_indices;
std::ostringstream os;
auto on_visit = [get_json_type_index = &get_json_type_index, os = &os, is_first_object = true](
Object *object, MLCTypeInfo *type_info, const TObj2Idx &obj2index) mutable -> void {
Emitter emitter{os, get_json_type_index, &obj2index};
auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os,
is_first_object = true](Object *object, MLCTypeInfo *type_info) mutable -> void {
int32_t &topo_index = topo_indices[object];
if (topo_index == 0) {
topo_index = static_cast<int32_t>(topo_indices.size()) - 1;
} else {
MLC_THROW(InternalError) << "This should never happen: object already visited";
}
Emitter emitter{os, get_json_type_index, &topo_indices};
if (is_first_object) {
is_first_object = false;
} else {
Expand Down Expand Up @@ -163,29 +174,23 @@ inline mlc::Str Serialize(Any any) {
}

inline Any Deserialize(const char *json_str, int64_t json_str_len) {
MLCVTableHandle init_vtable;
MLCVTableGetGlobal(nullptr, "__init__", &init_vtable);
MLCVTableHandle init_table = ::mlc::base::LibState::init;
// Step 0. Parse JSON string
UDict json_obj = JSONLoads(json_str, json_str_len);
// Step 1. type_key => constructors
UList type_keys = json_obj->at("type_keys");
std::vector<Func> constructors;
std::vector<FuncObj *> constructors;
constructors.reserve(type_keys.size());
for (Str type_key : type_keys) {
Any init_func;
int32_t type_index = ::mlc::base::TypeKey2TypeIndex(type_key->data());
MLCVTableGetFunc(init_vtable, type_index, false, &init_func);
if (!::mlc::base::IsTypeIndexNone(init_func.type_index)) {
constructors.push_back(init_func.operator Func());
} else {
MLC_THROW(InternalError) << "Method `__init__` is not defined for type " << type_key;
}
FuncObj *func = ::mlc::base::LibState::VTableGetFunc(init_table, type_index, "__init__");
constructors.push_back(func);
}
auto invoke_init = [&constructors](UList args) {
int32_t json_type_index = args[0];
Any ret;
::mlc::base::FuncCall(constructors.at(json_type_index).get(), static_cast<int32_t>(args.size()) - 1,
args->data() + 1, &ret);
::mlc::base::FuncCall(constructors.at(json_type_index), static_cast<int32_t>(args.size()) - 1, args->data() + 1,
&ret);
return ret;
};
// Step 2. Translate JSON object to objects
Expand Down
131 changes: 131 additions & 0 deletions cpp/structure.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "mlc/core/error.h"
#include <algorithm>
#include <cmath>
#include <cstdint>
Expand Down Expand Up @@ -532,11 +533,141 @@ inline uint64_t StructuralHash(Object *obj) {
#undef MLC_CORE_HASH_S_POD
#undef MLC_CORE_HASH_S_ANY

inline Any CopyShallow(AnyView source) {
int32_t type_index = source.type_index;
if (::mlc::base::IsTypeIndexPOD(type_index)) {
return source;
} else if (UListObj *list = source.TryCast<UListObj>()) {
return UList(list->begin(), list->end());
} else if (UDictObj *dict = source.TryCast<UDictObj>()) {
return UDict(dict->begin(), dict->end());
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>()) {
return source;
}
struct Copier {
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { fields->push_back(AnyView(*any)); }
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *obj) { fields->push_back(AnyView(*obj)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); }
std::vector<AnyView> *fields;
};
FuncObj *init_func = ::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_index, "__init__");
MLCTypeInfo *type_info = ::mlc::base::TypeIndex2TypeInfo(type_index);
std::vector<AnyView> fields;
VisitFields(source.operator Object *(), type_info, Copier{&fields});
Any ret;
::mlc::base::FuncCall(init_func, static_cast<int32_t>(fields.size()), fields.data(), &ret);
return ret;
}

inline Any CopyDeep(AnyView source) {
if (::mlc::base::IsTypeIndexPOD(source.type_index)) {
return source;
}
struct Copier {
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { HandleAny(any); }
MLC_INLINE void operator()(MLCTypeField *, ObjectRef *ref) {
if (const Object *obj = ref->get()) {
HandleObject(obj);
} else {
fields->push_back(AnyView());
}
}
MLC_INLINE void operator()(MLCTypeField *, Optional<ObjectRef> *opt) {
if (const Object *obj = opt->get()) {
HandleObject(obj);
} else {
fields->push_back(AnyView());
}
}
MLC_INLINE void operator()(MLCTypeField *, Optional<int64_t> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<double> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDevice> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<DLDataType> *opt) { fields->push_back(AnyView(*opt)); }
MLC_INLINE void operator()(MLCTypeField *, int8_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int16_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int32_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, int64_t *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, float *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, double *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDataType *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, DLDevice *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, Optional<void *> *v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, void **v) { fields->push_back(AnyView(*v)); }
MLC_INLINE void operator()(MLCTypeField *, const char **v) { fields->push_back(AnyView(*v)); }

void HandleObject(const Object *obj) {
if (auto it = orig2copy->find(obj); it != orig2copy->end()) {
fields->push_back(AnyView(it->second));
} else {
MLC_THROW(InternalError) << "InternalError: object doesn't exist in the memo: " << AnyView(obj);
}
}

void HandleAny(const Any *any) {
if (const Object *obj = any->TryCast<Object>()) {
HandleObject(obj);
} else {
fields->push_back(AnyView(*any));
}
}

std::unordered_map<const Object *, ObjectRef> *orig2copy;
std::vector<AnyView> *fields;
};
std::unordered_map<const Object *, ObjectRef> orig2copy;
std::vector<AnyView> fields;
TopoVisit(source.operator Object *(), nullptr, [&](Object *object, MLCTypeInfo *type_info) mutable -> void {
Any ret;
if (UListObj *list = object->TryCast<UListObj>()) {
fields.clear();
fields.reserve(list->size());
for (Any &e : *list) {
Copier{&orig2copy, &fields}.HandleAny(&e);
}
UList::FromAnyTuple(static_cast<int32_t>(fields.size()), fields.data(), &ret);
} else if (UDictObj *dict = object->TryCast<UDictObj>()) {
fields.clear();
for (auto [key, value] : *dict) {
Copier{&orig2copy, &fields}.HandleAny(&key);
Copier{&orig2copy, &fields}.HandleAny(&value);
}
UDict::FromAnyTuple(static_cast<int32_t>(fields.size()), fields.data(), &ret);
} else if (object->IsInstance<StrObj>() || object->IsInstance<ErrorObj>() || object->IsInstance<FuncObj>()) {
ret = object;
} else {
fields.clear();
VisitFields(object, type_info, Copier{&orig2copy, &fields});
FuncObj *func =
::mlc::base::LibState::VTableGetFunc(::mlc::base::LibState::init, type_info->type_index, "__init__");
::mlc::base::FuncCall(func, static_cast<int32_t>(fields.size()), fields.data(), &ret);
}
orig2copy[object] = ret.operator ObjectRef();
});
return orig2copy.at(source.operator Object *());
}

MLC_REGISTER_FUNC("mlc.core.StructuralEqual").set_body(::mlc::core::StructuralEqual);
MLC_REGISTER_FUNC("mlc.core.StructuralHash").set_body([](::mlc::Object *obj) -> int64_t {
uint64_t ret = ::mlc::core::StructuralHash(obj);
return static_cast<int64_t>(ret);
});
MLC_REGISTER_FUNC("mlc.core.CopyShallow").set_body(::mlc::core::CopyShallow);
MLC_REGISTER_FUNC("mlc.core.CopyDeep").set_body(::mlc::core::CopyDeep);
} // namespace
} // namespace core
} // namespace mlc
2 changes: 1 addition & 1 deletion include/mlc/base/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ template <typename T> MLC_INLINE AnyView::AnyView(Ref<T> &&src) : AnyView(static
// `src` is not reset here because `AnyView` does not take ownership of the object
}

template <typename T> MLC_INLINE AnyView::AnyView(const Optional<T> &src) {
template <typename T> MLC_INLINE AnyView::AnyView(const Optional<T> &src) : MLCAny() {
if (const auto *value = src.get()) {
if constexpr (::mlc::base::IsPOD<T>) {
using TPOD = T;
Expand Down
7 changes: 6 additions & 1 deletion include/mlc/base/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,11 @@ struct LibState {
DecRef(func.v.v_obj);
}
FuncObj *ret = reinterpret_cast<FuncObj *>(func.v.v_obj);
if (func.type_index != kMLCFunc) {
if (func.type_index == kMLCNone) {
MLC_THROW(TypeError) << "Function `" << vtable_name
<< "` for type: " << ::mlc::base::TypeIndex2TypeKey(type_index)
<< " is not defined in the vtable";
} else if (func.type_index != kMLCFunc) {
MLC_THROW(TypeError) << "Function `" << vtable_name
<< "` for type: " << ::mlc::base::TypeIndex2TypeKey(type_index)
<< " is not callable. Its type is " << ::mlc::base::TypeIndex2TypeKey(func.type_index);
Expand All @@ -401,6 +405,7 @@ struct LibState {
static MLC_SYMBOL_HIDE inline MLCVTableHandle cxx_str = VTableGetGlobal("__cxx_str__");
static MLC_SYMBOL_HIDE inline MLCVTableHandle str = VTableGetGlobal("__str__");
static MLC_SYMBOL_HIDE inline MLCVTableHandle ir_print = VTableGetGlobal("__ir_print__");
static MLC_SYMBOL_HIDE inline MLCVTableHandle init = VTableGetGlobal("__init__");
};

} // namespace base
Expand Down
5 changes: 4 additions & 1 deletion include/mlc/core/dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,14 @@ struct UDict : public ObjectRef {
MLC_INLINE const_iterator end() const { return get()->end(); }
MLC_INLINE const_reverse_iterator rbegin() const { return get()->rbegin(); }
MLC_INLINE const_reverse_iterator rend() const { return get()->rend(); }
MLC_INLINE static void FromAnyTuple(int32_t num_args, const AnyView *args, Any *ret) {
::mlc::core::DictBase::Accessor<UDictObj>::New(num_args, args, ret);
}
MLC_DEF_OBJ_REF(UDict, UDictObj, ObjectRef)
.FieldReadOnly("capacity", &MLCDict::capacity)
.FieldReadOnly("size", &MLCDict::size)
.FieldReadOnly("data", &MLCDict::data)
.StaticFn("__init__", ::mlc::core::DictBase::Accessor<UDictObj>::New)
.StaticFn("__init__", FromAnyTuple)
.MemFn("__str__", &UDictObj::__str__)
.MemFn("__getitem__", ::mlc::core::DictBase::Accessor<UDictObj>::GetItem)
.MemFn("__iter_get_key__", ::mlc::core::DictBase::Accessor<UDictObj>::GetKey)
Expand Down
17 changes: 4 additions & 13 deletions include/mlc/core/field_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ template <typename Visitor> inline void VisitStructure(Object *root, MLCTypeInfo
}

inline void TopoVisit(Object *root, std::function<void(Object *object, MLCTypeInfo *type_info)> pre_visit,
std::function<void(Object *object, MLCTypeInfo *type_info,
const std::unordered_map<Object *, int32_t> &topo_indices)>
on_visit) {
std::function<void(Object *object, MLCTypeInfo *type_info)> on_visit) {
struct TopoInfo {
Object *obj;
MLCTypeInfo *type_info;
Expand Down Expand Up @@ -271,20 +269,13 @@ inline void TopoVisit(Object *root, std::function<void(Object *object, MLCTypeIn
}
}
// Step 3. Traverse the graph by topological order
std::unordered_map<Object *, int32_t> topo_indices;
size_t num_objects = 0;
for (; !stack.empty(); ++num_objects) {
TopoInfo *current = stack.back();
stack.pop_back();
// Step 3.1. Lable object index
int32_t &topo_index = topo_indices[current->obj];
if (topo_index != 0) {
MLC_THROW(InternalError) << "This should never happen: object already visited";
}
topo_index = static_cast<int32_t>(num_objects);
// Step 3.2. Visit object
on_visit(current->obj, current->type_info, topo_indices);
// Step 3.3. Decrease the dependency count of topo_parents
// Step 3.1. Visit object
on_visit(current->obj, current->type_info);
// Step 3.2. Decrease the dependency count of topo_parents
for (TopoInfo *parent : current->topo_parents) {
if (--parent->topo_deps == 0) {
stack.push_back(parent);
Expand Down
Loading
Loading