Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FFI] Add SharedLibrary class to track the usage of dynamic library #63

Merged
merged 1 commit into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import functools
import warnings
import os
import ctypes
import shutil
import tempfile
import subprocess
Expand All @@ -25,6 +24,7 @@
from hidet.runtime import CompiledFunction
from hidet.ffi import PackedFunc
from hidet.ffi.ffi import library_paths
from hidet.ffi.shared_lib import SharedLibrary


class CompilationFailed(Exception):
Expand Down Expand Up @@ -181,7 +181,7 @@ def load_task_func(lib_path: str, task) -> CompiledFunction:
The loaded function that can be directly called in python.
"""
try:
lib = ctypes.CDLL(lib_path)
lib = SharedLibrary(lib_path)
except OSError as e:
print("Removed the file '{}'".format(lib_path))
os.remove(lib_path)
Expand All @@ -201,7 +201,7 @@ def load_task_func(lib_path: str, task) -> CompiledFunction:

def load_lib_func(lib_path: str, func_name: str, func_type: FuncType) -> CompiledFunction:
try:
lib = ctypes.CDLL(lib_path)
lib = SharedLibrary(lib_path)
except OSError as e:
print("Removed the file '{}'".format(lib_path))
os.remove(lib_path)
Expand Down
93 changes: 93 additions & 0 deletions python/hidet/ffi/shared_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
from typing import Dict

# retrieve the dlclose function
_dlclose = ctypes.CDLL(None).dlclose # None indicates the main program
_dlclose.argtypes = [ctypes.c_void_p]
_dlclose.rettype = ctypes.c_int


class SharedLibrary:
"""
Manage the loaded dynamic libraries.

Why we need this module?
------------------------

The ctypes.CDLL class does not provide a way to unload the loaded library. When a library is loaded, it will never
be unloaded until the program exits. However, when we tune an operator, we need to generate hundreds of kernels, and
each kernel will be compiled into a shared library. If we do not unload the shared library, we would load tens of
thousands of shared libraries, which will trigger memory error like:
"cannot apply additional memory protection after relocation"
(I also see other error messages).
To solve this problem, we need to unload the shared library after we use it. Thus, whenever we need to load a shared
library, we should use this module instead of the ctypes.CDLL class. The SharedLibrary class will keep track of the
loaded libraries, and unload them when no one references them.

The current implementation only supports *nix systems. Will add support for Windows when we plan to support Windows
in the project-level.

Usage
-----

>>> lib = SharedLibrary('./libhidet.so')
>>> func = lib['func_name']
>>> del func
>>> del lib
>>> # the 'libhidet.so' will be unloaded via dlclose after the last reference to it is deleted.
"""

loaded_cdll_libraries: Dict[str, ctypes.CDLL] = {}
reference_count: Dict[str, int] = {}

def __init__(self, lib_path: str):
self.lib_path: str = lib_path
cls = SharedLibrary
if lib_path in cls.loaded_cdll_libraries:
self.cdll: ctypes.CDLL = cls.loaded_cdll_libraries[lib_path]
cls.reference_count[lib_path] += 1
else:
cdll = ctypes.CDLL(lib_path)
self.cdll: ctypes.CDLL = cdll
cls.loaded_cdll_libraries[lib_path] = cdll
cls.reference_count[lib_path] = 1

def __getitem__(self, item):
"""
Get the function from the loaded library.

Parameters
----------
item: str
The name of the function.

Returns
-------
func: ctypes.CFUNCTYPE
The loaded function.
"""
ret = self.cdll[item]
ret._lib = self
return ret

def __getattr__(self, item):
return self[item]

def __del__(self):
self.reference_count[self.lib_path] -= 1
if self.reference_count[self.lib_path] == 0:
del self.loaded_cdll_libraries[self.lib_path]
del self.reference_count[self.lib_path]
_dlclose(self.cdll._handle)
del self.cdll