Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Dec 27, 2023
1 parent 889d44a commit e78660f
Showing 1 changed file with 21 additions and 30 deletions.
51 changes: 21 additions & 30 deletions mlflow/pytorch/__init__.py
Expand Up @@ -531,7 +531,6 @@ def save_model(
code=code_dir_subpath,
conda_env=_CONDA_ENV_FILE_NAME,
python_env=_PYTHON_ENV_FILE_NAME,
model_config={"device": None},
)
if size := get_total_file_size(path):
mlflow_model.model_size_bytes = size
Expand Down Expand Up @@ -572,7 +571,7 @@ def save_model(
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))


def _load_model(path, device=None, **kwargs):
def _load_model(path, **kwargs):
"""
:param path: The path to a serialized PyTorch model.
:param kwargs: Additional kwargs to pass to the PyTorch ``torch.load`` function.
Expand Down Expand Up @@ -609,21 +608,16 @@ def _load_model(path, device=None, **kwargs):
model_path = path

if Version(torch.__version__) >= Version("1.5.0"):
pytorch_model = torch.load(model_path, **kwargs)
return torch.load(model_path, **kwargs)
else:
try:
# load the model as an eager model.
pytorch_model = torch.load(model_path, **kwargs)
return torch.load(model_path, **kwargs)
except Exception:
# If fails, assume the model as a scripted model
# `torch.jit.load` does not accept `pickle_module`.
kwargs.pop("pickle_module", None)
pytorch_model = torch.jit.load(model_path, **kwargs)

pytorch_model.eval()
if device:
pytorch_model.to(device=device)
return pytorch_model
return torch.jit.load(model_path, **kwargs)


def load_model(model_uri, dst_path=None, **kwargs):
Expand Down Expand Up @@ -691,28 +685,13 @@ def load_model(model_uri, dst_path=None, **kwargs):
return _load_model(path=torch_model_artifacts_path, **kwargs)


def _load_pyfunc(path, model_config):
def _load_pyfunc(path, **kwargs):
"""
Load PyFunc implementation. Called by ``pyfunc.load_model``.
:param path: Local filesystem path to the MLflow Model with the ``pytorch`` flavor.
"""
import torch

device = model_config.get("device", None)
# if CUDA is available, we use the default CUDA device.
# To force inference to the CPU when the GPU is available, please set
# MLFLOW_DEFAULT_PREDICTION_DEVICE to "cpu"
# If a specific non-default device is passed in, we continue to respect that.
if device is None:
if MLFLOW_DEFAULT_PREDICTION_DEVICE.get():
device = MLFLOW_DEFAULT_PREDICTION_DEVICE.get()
elif torch.cuda.is_available():
device = _TORCH_DEFAULT_GPU_DEVICE_NAME
else:
device = _TORCH_CPU_DEVICE_NAME

return _PyTorchWrapper(_load_model(path, device=device), device=device)
return _PyTorchWrapper(_load_model(path, **kwargs))


class _PyTorchWrapper:
Expand All @@ -721,9 +700,8 @@ class _PyTorchWrapper:
predict(data: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame)
"""

def __init__(self, pytorch_model, device):
def __init__(self, pytorch_model):
self.pytorch_model = pytorch_model
self.device = device

def predict(self, data, params: Optional[Dict[str, Any]] = None):
"""
Expand All @@ -737,6 +715,18 @@ def predict(self, data, params: Optional[Dict[str, Any]] = None):
"""
import torch

device = params.get("device", None) if params else None
# if CUDA is available, we use the default CUDA device.
# To force inference to the CPU when the GPU is available, please set
# MLFLOW_DEFAULT_PREDICTION_DEVICE to "cpu"
# If a specific non-default device is passed in, we continue to respect that.
if device is None:
if MLFLOW_DEFAULT_PREDICTION_DEVICE.get():
device = MLFLOW_DEFAULT_PREDICTION_DEVICE.get()
elif torch.cuda.is_available():
device = _TORCH_DEFAULT_GPU_DEVICE_NAME
else:
device = _TORCH_CPU_DEVICE_NAME
if isinstance(data, pd.DataFrame):
inp_data = data.values.astype(np.float32)
elif isinstance(data, np.ndarray):
Expand All @@ -749,7 +739,8 @@ def predict(self, data, params: Optional[Dict[str, Any]] = None):
else:
raise TypeError("Input data should be pandas.DataFrame or numpy.ndarray")

device = self.device
self.pytorch_model.to(device)
self.pytorch_model.eval()
with torch.no_grad():
input_tensor = torch.from_numpy(inp_data).to(device)
preds = self.pytorch_model(input_tensor)
Expand Down

0 comments on commit e78660f

Please sign in to comment.