Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Jul 8, 2023
1 parent 36db991 commit 2a8e399
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ludwig/data/dataframe/daft.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class LudwigDaftSeries:
```
"""

def __init__(self, expr: daft.Expression):
def __init__(self, expr: daft.expressions.Expression):
self._expr = expr

@property
Expand Down
32 changes: 32 additions & 0 deletions tests/ludwig/data/dataframe/test_daft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import daft
import numpy as np
import pytest

from ludwig.data.dataframe.daft import DaftEngine, LudwigDaftDataframe, LudwigDaftSeries


@pytest.fixture(scope="function")
def df() -> LudwigDaftDataframe:
data = {
"a": [i for i in range(10)],
"b": ["a" * i for i in range(10)],
"c": [np.zeros((i, i)) for i in range(1, 11)],
}
return LudwigDaftDataframe(daft.from_pydict(data))


@pytest.fixture(scope="function", params=[1, 2])
def engine(request) -> DaftEngine:
parallelism = request.param
return DaftEngine(parallelism=parallelism)


def test_df_like(df: LudwigDaftDataframe, engine: DaftEngine):
s1 = LudwigDaftSeries(df["a"].expr * 2)
s2 = LudwigDaftSeries(df["b"].expr + "_suffix")
df = engine.df_like(df, {"foo": s1, "bar": s2})
pd_df = engine.compute(df)

assert list(pd_df.columns) == ["a", "b", "c", "foo", "bar"]
np.testing.assert_equal(np.array(pd_df["foo"]), np.array(pd_df["a"] * 2))
np.testing.assert_equal(np.array(pd_df["bar"]), np.array([item + "_suffix" for item in pd_df["b"]]))

0 comments on commit 2a8e399

Please sign in to comment.