Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mo.ui.dataframe cannot render a Polars dataframe #616

Closed
jsnelgro opened this issue Jan 21, 2024 · 4 comments · Fixed by #1612
Closed

mo.ui.dataframe cannot render a Polars dataframe #616

jsnelgro opened this issue Jan 21, 2024 · 4 comments · Fixed by #1612
Labels
enhancement New feature or request widget

Comments

@jsnelgro
Copy link

jsnelgro commented Jan 21, 2024

Describe the bug

Thanks for creating such a refreshing alternative to Jupyter! However, I'm having trouble using polars with the dataframe ui widget. Here's the simplest example:

import marimo as mo
import polars as pl
 
df = pl.DataFrame({"a": [1,2,3], "b":[True, False, True], "c":["Bob", "Sally", "Jane"]})
ui_df = mo.ui.dataframe(df)
ui_df

Produces the error:

AttributeError
This cell raised an exception: AttributeError(''list' object has no attribute 'to_dict'')

with stacktrace:

Traceback (most recent call last):
  Cell , line 5
    ui_df = mo.ui.dataframe(df)
  File /Users/johndoe/Library/Caches/pypoetry/virtualenvs/python-playground-dPNtgDsu-py3.12/lib/python3.12/site-packages/marimo/_plugins/ui/_impl/dataframes/dataframe.py, line 95, in __init__
    "columns": df.dtypes.to_dict(),
AttributeError: 'list' object has no attribute 'to_dict'

Environment

pyproject.toml

[tool.poetry]
name = "python-playground"
version = "0.1.0"
description = "using poetry and marimo to make python actually nice"
authors = ["Your Name "]
readme = "README.md"
 
[tool.poetry.dependencies]
python = "^3.12"
marimo = "^0.1.76"
altair = "^5.2.0"
numpy = "^1.26.3"
polars = "^0.20.4"
pyarrow = "^14.0.2"
 
 
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

Code to reproduce

import marimo as mo
import polars as pl
 
df = pl.DataFrame({"a": [1,2,3], "b":[True, False, True], "c":["Bob", "Sally", "Jane"]})
ui_df = mo.ui.dataframe(df)
ui_df
@jsnelgro jsnelgro added the bug Something isn't working label Jan 21, 2024
@mscolnick
Copy link
Contributor

Unfortunately we don’t support Polars at this time for this specific plugin. You’ll need to do to_pandas but you’ll be returned a pandas do.

We can add first-class Polars support for this though.

@misolietavec
Copy link

Also with to_pandas there is another error:
TypeError('Object of type DatetimeTZDtype is not JSON serializable')
with polars dataframe containing the column of type Datetime(time_unit='us', time_zone='UTC')
Except of mo.ui.dataframe, polars works for me excellently with marimo.

@mscolnick
Copy link
Contributor

Thanks for finding the serialization bug - it will be fixed in this PR: #631

@robmck1995
Copy link
Contributor

Hi there. I've taken a pass at looking into this issue with our tool, Glide. Please see the High-Level Plan and Implementation below.
Note: The plan does not cover unit tests or frontend rendering.

Plan Steps

Step 1: Modify the dataframe class to accept Polars dataframes

Description: Update the dataframe class in marimo/_plugins/ui/_impl/dataframes/dataframe.py to handle Polars dataframes by checking the type of the input dataframe and setting up the appropriate handlers.

  • Check the type of the df parameter in the __init__ method.
  • If df is a Polars dataframe, initialize a PolarsTransformHandlers instance.
  • Store the type of dataframe (Pandas or Polars) for use in other methods.

Step 2: Create an abstract base class for transform handlers

Description: Define an abstract base class TransformHandlers with abstract methods for each type of transform.

  • Define an abstract base class TransformHandlers.
  • Add abstract methods for each transform type: handle_column_conversion, handle_rename_column, handle_sort_column, handle_filter_rows, handle_group_by, handle_aggregate, handle_select_columns, handle_shuffle_rows, and handle_sample_rows.

Step 3: Implement PandasTransformHandlers and PolarsTransformHandlers subclasses

Description: Create two subclasses that inherit from TransformHandlers and implement the methods for handling transforms for Pandas and Polars dataframes respectively.

  • Create PandasTransformHandlers subclass with methods that use Pandas-specific code.
  • Create PolarsTransformHandlers subclass with methods that use Polars-specific code.

Step 4: Implement Polars transform handler methods

Description: Implement each transform handler method in PolarsTransformHandlers using the provided Polars code context.

  • For handle_shuffle_rows, use Polars' sample_frac function with frac=1 and shuffle=True.
  • Implement other methods by translating the Pandas logic to Polars API calls, using the provided code context as a reference.

Step 5: Implement get_dataframe for Polars

Description: Use Polars' write_csv method to produce a CSV output that matches the format expected by Marimo.

  • Modify the get_dataframe method to check the type of dataframe.
  • If it's a Polars dataframe, use write_csv with the correct arguments to produce a CSV string.
  • Convert the CSV string to a VirtualFile and return the appropriate GetDataFrameResponse.

Step 6: Update get_column_values for Polars

Description: Ensure that the get_column_values function can retrieve unique values from a specified column in a Polars dataframe.

  • Modify the get_column_values method to handle Polars dataframes.
  • Use Polars' API to get unique values from the specified column.
  • Return a GetColumnValuesResponse with the values or indicate if there are too many values.

Additional context to gather

  • Verify the compatibility of CSV formatting options between Pandas' to_csv and Polars' write_csv methods.
  • Ensure that the VirtualFile creation process is consistent with Marimo's handling of CSV data.

Watch out for:

  • Ensure that the behavior of transformations is consistent between Pandas and Polars dataframes.
  • Make sure that the CSV output from Polars is formatted correctly for Marimo's expectations.
  • Handle any edge cases or differences in API behavior between Pandas and Polars.

Implementation Plan

Edit 1: Update the dataframe class to accept Polars dataframes

Description: Modify the dataframe class in marimo/_plugins/ui/_impl/dataframes/dataframe.py to handle both Pandas and Polars dataframes by checking the type of the input dataframe and setting up the appropriate handlers. Code:

# Import necessary libraries at the top of the file
import polars as pl

# Modify the __init__ method of the dataframe class
def __init__(
    self,
    df: Union[pd.DataFrame, pl.DataFrame],
    on_change: Optional[Callable[[Union[pd.DataFrame, pl.DataFrame]], None]] = None,
) -> None:
    dataframe_name = "df"
    try:
        frame = inspect.currentframe()
        if frame is not None and frame.f_back is not None:
            for (
                var_name,
                var_value,
            ) in frame.f_back.f_locals.items():
                if var_value is df:
                    dataframe_name = var_name
                    break
    except Exception:
        pass

    self._data = df
    self._transform_container = TransformsContainer(df)
    self._error: Optional[str] = None

    # Determine if the dataframe is a Pandas or Polars dataframe
    if isinstance(df, pd.DataFrame):
        self._df_type = 'pandas'
    elif isinstance(df, pl.DataFrame):
        self._df_type = 'polars'
    else:
        raise ValueError("Unsupported dataframe type. Only Pandas and Polars dataframes are supported.")

    super().__init__(
        component_name=dataframe._name,
        initial_value={
            "transforms": [],
        },
        on_change=on_change,
        label="",
        args={
            "columns": self._get_columns_dict(df),
            "dataframe-name": dataframe_name,
            "total": len(df),
        },
        functions=(
            Function(
                name=self.get_dataframe.__name__,
                arg_cls=EmptyArgs,
                function=self.get_dataframe,
            ),
            Function(
                name=self.get_column_values.__name__,
                arg_cls=GetColumnValuesArgs,
                function=self.get_column_values,
            ),
        ),
    )

# Add a new method to get the columns dictionary
def _get_columns_dict(self, df: Union[pd.DataFrame, pl.DataFrame]) -> Dict[str, str]:
    if self._df_type == 'pandas':
        return df.dtypes.to_dict()
    elif self._df_type == 'polars':
        return {name: str(dtype) for name, dtype in zip(df.columns, df.dtypes)}
    else:
        raise ValueError("Unsupported dataframe type.")

Edit 2: Create an abstract base class for transform handlers and implement subclasses

Description: Define an abstract base class TransformHandlers with abstract methods for each type of transform and create two subclasses PandasTransformHandlers and PolarsTransformHandlers. Code:

from abc import ABC, abstractmethod
from typing import Union

# Abstract base class for transform handlers
class TransformHandlers(ABC):
    def handle(self, df, transform):
        transform_type = transform.type

        if transform_type is TransformType.COLUMN_CONVERSION:
            return self.handle_column_conversion(df, transform)
        elif transform_type is TransformType.RENAME_COLUMN:
            return self.handle_rename_column(df, transform)
        elif transform_type is TransformType.SORT_COLUMN:
            return self.handle_sort_column(df, transform)
        elif transform_type is TransformType.FILTER_ROWS:
            return self.handle_filter_rows(df, transform)
        elif transform_type is TransformType.GROUP_BY:
            return self.handle_group_by(df, transform)
        elif transform_type is TransformType.AGGREGATE:
            return self.handle_aggregate(df, transform)
        elif transform_type is TransformType.SELECT_COLUMNS:
            return self.handle_select_columns(df, transform)
        elif transform_type is TransformType.SHUFFLE_ROWS:
            return self.handle_shuffle_rows(df, transform)
        elif transform_type is TransformType.SAMPLE_ROWS:
            return self.handle_sample_rows(df, transform)
        else:
            raise NotImplementedError(f"Transform type {transform_type} is not implemented")

    @abstractmethod
    def handle_column_conversion(self, df, transform): pass

    @abstractmethod
    def handle_rename_column(self, df, transform): pass

    @abstractmethod
    def handle_sort_column(self, df, transform): pass

    @abstractmethod
    def handle_filter_rows(self, df, transform): pass

    @abstractmethod
    def handle_group_by(self, df, transform): pass

    @abstractmethod
    def handle_aggregate(self, df, transform): pass

    @abstractmethod
    def handle_select_columns(self, df, transform): pass

    @abstractmethod
    def handle_shuffle_rows(self, df, transform): pass

    @abstractmethod
    def handle_sample_rows(self, df, transform): pass
# Subclass for handling Pandas dataframe transforms
class PandasTransformHandlers(TransformHandlers):
    def handle_column_conversion(self, df, transform):
        # Existing implementation from TransformHandlers for column conversion should be moved here

    def handle_rename_column(self, df, transform):
        # Existing implementation from TransformHandlers for column renaming should be moved here

    def handle_sort_column(self, df, transform):
        # Existing implementation from TransformHandlers for sorting columns should be moved here

    def handle_filter_rows(self, df, transform):
        # Existing implementation from TransformHandlers for filtering rows should be moved here

    def handle_group_by(self, df, transform):
        # Existing implementation from TransformHandlers for grouping by columns should be moved here

    def handle_aggregate(self, df, transform):
        # Existing implementation from TransformHandlers for aggregating data should be moved here

    def handle_select_columns(self, df, transform):
        # Existing implementation from TransformHandlers for selecting columns should be moved here

    def handle_shuffle_rows(self, df, transform):
        # Existing implementation from TransformHandlers for shuffling rows should be moved here

    def handle_sample_rows(self, df, transform):
        # Existing implementation from TransformHandlers for sampling rows should be moved here
# Subclass for handling Polars dataframe transforms
from polars import col

class PolarsTransformHandlers(TransformHandlers):
    def handle_column_conversion(self, df, transform):
        # Use the `cast` method from the Polars API
        dtypes = {transform.column_id: transform.data_type}
        return df.cast(dtypes)

    def handle_rename_column(self, df, transform):
        # Use the `rename` method from the Polars API
        return df.rename({transform.column_id: transform.new_column_id})

    def handle_sort_column(self, df, transform):
        # Use the `sort` method from the Polars API
        return df.sort(by_column=transform.column_id, descending=not transform.ascending, nulls_last=transform.na_position == 'last')

    def handle_filter_rows(self, df, transform):
        # Start with no filter (all rows included)
        filter_expr = None

        # Iterate over all conditions and build the filter expression
        for condition in transform.where:
            column = col(condition.column_id)
            value = condition.value

            # Build the expression based on the operator
            if condition.operator == "==":
                condition_expr = column == value
            elif condition.operator == "!=":
                condition_expr = column != value
            elif condition.operator == ">":
                condition_expr = column > value
            elif condition.operator == "<":
                condition_expr = column < value
            elif condition.operator == ">=":
                condition_expr = column >= value
            elif condition.operator == "<=":
                condition_expr = column <= value
            elif condition.operator == "is_true":
                condition_expr = column.is_true()
            elif condition.operator == "is_false":
                condition_expr = column.is_false()
            elif condition.operator == "is_nan":
                condition_expr = column.is_null()
            elif condition.operator == "is_not_nan":
                condition_expr = column.is_not_null()
            elif condition.operator == "equals":
                condition_expr = column == value
            elif condition.operator == "does_not_equal":
                condition_expr = column != value
            elif condition.operator == "contains":
                condition_expr = column.str_contains(value)
            elif condition.operator == "regex":
                condition_expr = column.str_contains(value, regex=True)
            elif condition.operator == "starts_with":
                condition_expr = column.str_starts_with(value)
            elif condition.operator == "ends_with":
                condition_expr = column.str_ends_with(value)
            elif condition.operator == "in":
                condition_expr = column.is_in(value)
            else:
                raise ValueError(f"Unsupported operator: {condition.operator}")

            # Combine the condition expression with the filter expression
            if filter_expr is None:
                filter_expr = condition_expr
            else:
                filter_expr = filter_expr & condition_expr

        # Apply the filter expression to the dataframe
        if filter_expr is not None:
            df = df.filter(filter_expr)

        # Handle the operation (keep_rows or remove_rows)
        if transform.operation == "keep_rows":
            return df
        elif transform.operation == "remove_rows":
            return df.filter(~filter_expr)
        else:
            raise ValueError(f"Unsupported operation: {transform.operation}")

    def handle_group_by(self, df, transform):
        # Use the `group_by` and `agg` methods from the Polars API
        return df.groupby(transform.column_ids).agg(transform.aggregations)

    def handle_aggregate(self, df, transform):
        agg_exprs = []

        for column_id, aggregations in transform.aggregations.items():
            for agg_func in aggregations:
                if agg_func == "count":
                    agg_exprs.append(col(column_id).count().alias(f"{column_id}_count"))
                elif agg_func == "sum":
                    agg_exprs.append(col(column_id).sum().alias(f"{column_id}_sum"))
                elif agg_func == "mean":
                    agg_exprs.append(col(column_id).mean().alias(f"{column_id}_mean"))
                elif agg_func == "median":
                    agg_exprs.append(col(column_id).median().alias(f"{column_id}_median"))
                elif agg_func == "min":
                    agg_exprs.append(col(column_id).min().alias(f"{column_id}_min"))
                elif agg_func == "max":
                    agg_exprs.append(col(column_id).max().alias(f"{column_id}_max"))
                else:
                    raise ValueError(f"Unsupported aggregation function: {agg_func}")

        return df.groupby(transform.column_ids).agg(agg_exprs)

    def handle_select_columns(self, df, transform):
        # Use the `select` method from the Polars API
        return df.select(transform.column_ids)

    def handle_shuffle_rows(self, df, transform):
        # Use the `sample_frac` method from the Polars API with frac=1 and shuffle=True
        return df.sample_frac(frac=1, shuffle=True, seed=transform.seed)

    def handle_sample_rows(self, df, transform):
        # Use the `sample_n` method from the Polars API
        return df.sample_n(n=transform.n, shuffle=True, seed=transform.seed, with_replacement=transform.replace)

Edit 3: Implement get_dataframe for Polars

Description: Modify the get_dataframe method in the dataframe class to handle Polars dataframes using the write_csv method to produce a compatible CSV output.

import io  # Make sure to import io at the top of the file

# Modify the get_dataframe method of the dataframe class
def get_dataframe(self, _args: EmptyArgs) -> GetDataFrameResponse:
    LIMIT = 100

    if self._error is not None:
        raise Exception(self._error)

    # Check if the dataframe is a Polars dataframe and handle accordingly
    if self._df_type == 'polars':
        # Create a buffer to write the CSV data
        buffer = io.BytesIO()
        # Write the CSV data to the buffer
        self._value.head(LIMIT).write_csv(buffer)
        # Seek to the start of the buffer to read its content
        buffer.seek(0)
        # Read the buffer content
        csv_data = buffer.read()
        # Create a VirtualFile from the CSV data
        url = mo_data.any_data(csv_data, ext="csv").url
    else:
        # Existing handling for Pandas dataframe
        url = mo_data.csv(self._value.head(LIMIT)).url

    total_rows = len(self._value)
    return GetDataFrameResponse(
        url=url,
        total_rows=total_rows,
        has_more=total_rows > LIMIT,
        row_headers=get_row_headers(self._value),
    )

Edit 4: Update get_column_values for Polars

Description: Update the get_column_values method in the dataframe class to work with Polars dataframes, retrieving unique values from a specified column.

# Modify the get_column_values method of the dataframe class
def get_column_values(self, args: GetColumnValuesArgs) -> GetColumnValuesResponse:
    LIMIT = 500

    # Check if the dataframe is a Polars dataframe and handle accordingly
    if self._df_type == 'polars':
        # Use Polars' API to get unique values from the specified column
        unique_values = self._data.select(args.column).unique().to_list()
    else:
        # Existing handling for Pandas dataframe
        unique_values = self._data[args.column].unique().tolist()

    if len(unique_values) <= LIMIT:
        return GetColumnValuesResponse(
            values=list(sorted(unique_values, key=str)),
            too_many_values=False,
        )
    else:
        return GetColumnValuesResponse(
            values=[],
            too_many_values=True,
        )

Generated with Glide by Agentic Labs

@mscolnick mscolnick added enhancement New feature or request and removed bug Something isn't working labels Jun 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request widget
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants