From 8badb1a0bc1a50c3c2344a7c336f79ec6059812c Mon Sep 17 00:00:00 2001 From: Ryan Soley Date: Thu, 7 Mar 2024 10:03:26 -0500 Subject: [PATCH] guard dask import (#417) --- rubicon_ml/client/project.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index cf60a859..74aa11ea 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -4,7 +4,6 @@ import warnings from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union -import dask.dataframe as dd import pandas as pd from rubicon_ml import domain @@ -15,6 +14,8 @@ from rubicon_ml.schema.logger import SchemaMixin if TYPE_CHECKING: + import dask.dataframe as dd + from rubicon_ml import Rubicon from rubicon_ml.client import Config, Dataframe from rubicon_ml.domain import Project as ProjectDomain @@ -204,6 +205,15 @@ def to_df( df = df.sort_values(by=["created_at"], ascending=False).reset_index(drop=True) if df_type == "dask": + try: + from dask import dataframe as dd + except ImportError: + raise RubiconException( + "`rubicon_ml` requires `dask` to be installed in the current " + "environment to create dataframes with `df_type`='dask'. `pip install " + "dask[dataframe]` or `conda install dask` to continue." + ) + df = dd.from_pandas(df, npartitions=1) experiment_dfs[group] = df