-
Notifications
You must be signed in to change notification settings - Fork 262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add support for aggregates and toxicity classification #551
Conversation
agg_func_name = self.visit(child).value | ||
elif isinstance(child, Token): | ||
token = child.value | ||
# Support for COUNT(*) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this logic. Are we hardcoding *
to id
in the parser? If yes, I guess even though it is hacky, it saves us from handling this corner case in the binder. We could change it to IDENTIFIER_COLUMN
, which is supposed to be a unique row id in all the tables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. This query currently works. I was also worried if "id" will be always present. How should I change it to IDENTIFIER_COLUMN
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unit tests do not create tables using IDENTIFIER_COLUMN
-- so the test case fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I verified we don't support projecting IDENTIFIER_COLUMN
, which causes the binder to fail.
id
won't work for images or other tables. Ideally, the binder should take care of it.
An if
condition here should fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eva/models/storage/batch.py
Outdated
@@ -378,7 +380,14 @@ def aggregate(self, method: str) -> None: | |||
Arguments: | |||
method: string with one of the five above options | |||
""" | |||
self._frames = self._frames.agg([method]) | |||
# Aggregate ndarray | |||
if isinstance(self._frames.iat[0, 0], np.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it aggregating each row of the array? If yes, I suspect that will break the execution logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, how will it break the execution logic?
The NDARRAY case is for object detection array etc. The normal case is the one that existed earlier -- self._frames = self._frames.agg([method])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reverted back the NDARRAY
case as it does not make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aggregate on ndarray
and primitive column will result in different row counts.
We have a different set of aggregate operators to apply row-wise aggregates likeArray_Count
self.assertEqual(actual_batch.frames.iat[0, 0], 10) | ||
self.assertEqual(actual_batch.frames.iat[0, 1], 4.5) | ||
|
||
complex_aggregate_query = """SELECT SUM(id), COUNT(label) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a test case with aggregate on ndarray column.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the query operates on an ndarray column, it does not reduce it to a single row. It actually keeps as many rows around as the number of the rows in the input column.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reverted back the NDARRAY
case as it does not make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we raise an error if the query tries to aggregate on an array column? We should add a test case to verify it. Thanks!
I just added a |
@@ -55,9 +55,14 @@ def evaluate(self, *args, **kwargs): | |||
elif self.etype == ExpressionType.AGGREGATION_MAX: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We miss an else
condition to raise an error. Right now, we silently ignore it and return the origin batch.
@@ -0,0 +1,49 @@ | |||
# coding=utf-8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for?
single_result = self.model.predict(text) | ||
toxicity_score = single_result["toxicity"][0] | ||
if toxicity_score >= self.threshold: | ||
outcome = outcome.append({"labels": "toxic"}, ignore_index=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to use list
for the append operation. DataFrame
throws a lot of warnings. You can refer to other udfs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you push your change?
if len(inp.columns) != 1: | ||
raise ValueError("input must only contain one column (seconds)") | ||
|
||
seconds = pd.DataFrame(inp[inp.columns[0]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it no-op
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a timestamp UDF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SELECT id, seconds, Timestamp(seconds) FROM MyVideo WHERE Timestamp(seconds) <= "00:00:01";
@@ -76,6 +76,7 @@ def test_create_multimedia_table_catalog_entry(self, mock): | |||
ColumnDefinition( | |||
"data", ColumnType.NDARRAY, NdArrayType.UINT8, [None, None, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
array_dimension
should be a tuple.
Thanks! This will be a fun example to showcase 💯 |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
stable
version in read-the-docs