Skip to content

Commit

Permalink
[PASS] Remove Unused Functions in IRModule (apache#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg authored and junrushao committed Feb 5, 2023
1 parent ee59912 commit 77cd4ab
Show file tree
Hide file tree
Showing 6 changed files with 525 additions and 46 deletions.
134 changes: 88 additions & 46 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,68 @@ enum class CallingConv : int {
kDeviceKernelLaunch = 2,
};

/*!
* \brief Supported linkage types.
*/
enum class LinkageType : int {
/*!
* \brief Internal linkage.
*/
kInternal = 0,
/*!
* \brief External linkage.
- Function with external linkage should have a global symbol attached to it.
*/
kExternal = 1
};

/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
*/
namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
*
* \sa tvm::CallingConv
*/
constexpr const char* kCallingConv = "calling_conv";

/*!
* \brief Compilation target of the function.
*
* Type: Target
*
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";

/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

} // namespace attr

/*!
* \brief Base node of all functions.
*
Expand Down Expand Up @@ -131,6 +193,32 @@ class BaseFuncNode : public RelayExprNode {
*/
bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }

/*!
* \brief Get the type of the linkage.
*
* Currently, we only consider external/internal linkage.
* This can be extended in the future when necessary.
*
* \return Linkage type.
*
* \code
*
* void Example(const BaseFunc& f) {
* if (f->GetLinkageType() == tvm::LinkageType::kExternal) {
* // Do not remove a function with external linkage
* }
* }
*
* \endcode
*/

LinkageType GetLinkageType() const {
if (GetAttr<String>(attr::kGlobalSymbol))
return LinkageType::kExternal;
else
return LinkageType::kInternal;
}

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
Expand All @@ -145,51 +233,5 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
*/
namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
*
* \sa tvm::CallingConv
*/
constexpr const char* kCallingConv = "calling_conv";

/*!
* \brief Compilation target of the function.
*
* Type: Target
*
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";

/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
16 changes: 16 additions & 0 deletions python/tvm/ir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,19 @@ def with_attr(self, attr_key_or_dict, attr_value=None):
return _ffi_api.BaseFuncWithAttr(
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)
)

def without_attr(self, attr_key: str):
"""Create a new copy of the function with an attribute without provided key.
Parameters
----------
attr_key : str
The attribute key to delete from the attrubte pairs.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.BaseFuncWithoutAttr(self, attr_key)
18 changes: 18 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,24 @@ def BindParams(func_name: str, params: Dict[str, tvm.runtime.NDArray]) -> tvm.ir
return _ffi_api.BindParams(func_name, params)


def RemoveUnusedFunctions(entry_functions=None) -> tvm.ir.transform.Pass:
"""Remove unused relax/prim functions without external linkage in a IRModule.
Parameters
----------
entry_functions: list[string]
The set of entry functions to start from.
Returns
-------
ret : tvm.transform.Pass
The registered pass to remove unused functions.
"""
if entry_functions is None:
entry_functions = ["main"]
return _ffi_api.RemoveUnusedFunctions(entry_functions)


def FoldConstant() -> tvm.ir.transform.Pass:
"""Fold constant expressions.
Expand Down
14 changes: 14 additions & 0 deletions src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,18 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
});

TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr")
.set_body_typed([](BaseFunc func, String key) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithoutAttr(Downcast<tir::PrimFunc>(std::move(func)), key);
} else if (func->IsInstance<relay::FunctionNode>()) {
return WithoutAttr(Downcast<relay::Function>(std::move(func)), key);
} else if (func->IsInstance<relax::FunctionNode>()) {
return WithoutAttr(Downcast<relax::Function>(std::move(func)), key);
} else {
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
return func;
}
});

} // namespace tvm
118 changes: 118 additions & 0 deletions src/relax/transform/removed_unused_funcs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
*
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/backend/remove_unused_funcs.cc
* \brief Remove unused global relax functions in a IRModule.
*/

#include <tvm/relax/expr_functor.h>

#include <iostream>
#include <unordered_set>
#include <vector>

namespace tvm {
namespace relax {

/**
* \brief Detects all the functions that can be possibly called by entry function.
*/
class CallTracer : ExprVisitor {
public:
explicit CallTracer(IRModule mod_) : mod_{mod_}, called_funcs_{}, visiting_{} {}

void VisitExpr_(const GlobalVarNode* op) final {
called_funcs_.insert(GetRef<GlobalVar>(op));
auto func = mod_->Lookup(op->name_hint);
if (const auto* function_node = func.as<FunctionNode>()) {
VisitExpr(GetRef<Function>(function_node));
}
// else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein.
}

void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); }

void VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
if (visiting_.find(func) == visiting_.end()) {
visiting_.insert(func);
for (auto param : func_node->params) {
ExprVisitor::VisitExpr(param);
}
ExprVisitor::VisitExpr(func_node->body);
}
}

void Trace(std::string entry) {
called_funcs_.insert(mod_->GetGlobalVar(entry));
auto main_func = mod_->Lookup(entry);
VisitExpr(main_func);
}

bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; }

private:
IRModule mod_;

// Record the names of all encountered functions.
std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> called_funcs_;

// Record the expressions that are being visited.
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
};

/*!
* \brief Remove functions that are not used.
*
* \param mod_ IRModule.
* \param entry_funcs The set of functions that can be entry function.
*
* \return The module with dead functions removed.
*/
IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String> entry_funcs) {
auto tracer = CallTracer(mod_);
for (auto entry : entry_funcs) {
tracer.Trace(entry);
}
auto existing_functions = mod_->functions;
for (auto f : existing_functions) {
// If a function has an external linkage type, we do not remove it.
// Otherwise, we check the function and remove it if it is not used anywhere.
if (f.second->GetLinkageType() == LinkageType::kInternal && !tracer.check_if_called(f.first)) {
mod_->Remove(f.first);
}
}
return mod_;
}

} // namespace relax

namespace transform {
Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return relax::RemoveUnusedFunctions(m, entry_functions); };
return CreateModulePass(pass_func, 0, "RemoveUnusedFunctions", {});
}

TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions);

} // namespace transform
} // namespace tvm
Loading

0 comments on commit 77cd4ab

Please sign in to comment.