Skip to content

Commit

Permalink
fix(common): disallow plain string inputs for SequenceOf patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Sep 2, 2023
1 parent 62d1dc4 commit 578980d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
12 changes: 4 additions & 8 deletions ibis/common/patterns.py
Expand Up @@ -1224,13 +1224,11 @@ def __init__(self, item, type=tuple):
super().__init__(item=pattern(item), type=type)

def match(self, values, context):
try:
iterable = iter(values)
except TypeError:
if not is_iterable(values):
return NoMatch

result = []
for item in iterable:
for item in values:
item = self.item.match(item, context)
if item is NoMatch:
return NoMatch
Expand Down Expand Up @@ -1293,13 +1291,11 @@ def __init__(
super().__init__(item=item, type=type, length=length)

def match(self, values, context):
try:
iterable = iter(values)
except TypeError:
if not is_iterable(values):
return NoMatch

result = []
for value in iterable:
for value in values:
value = self.item.match(value, context)
if value is NoMatch:
return NoMatch
Expand Down
2 changes: 2 additions & 0 deletions ibis/common/tests/test_patterns.py
Expand Up @@ -414,6 +414,7 @@ def test_sequence_of():
assert p.match(["foo", "bar"], context={}) == ["foo", "bar"]
assert p.match([1, 2], context={}) is NoMatch
assert p.match(1, context={}) is NoMatch
assert p.match("string", context={}) is NoMatch


def test_generic_sequence_of():
Expand All @@ -426,6 +427,7 @@ def __coerce__(cls, value, T=...):
assert isinstance(p, GenericSequenceOf)
assert p == GenericSequenceOf(InstanceOf(str), MyList)
assert p.match(["foo", "bar"], context={}) == MyList(["foo", "bar"])
assert p.match("string", context={}) is NoMatch

p = SequenceOf(InstanceOf(str), tuple, at_least=1)
assert isinstance(p, GenericSequenceOf)
Expand Down
10 changes: 9 additions & 1 deletion ibis/expr/operations/structs.py
Expand Up @@ -4,7 +4,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.annotations import ValidationError, attribute
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Value

Expand Down Expand Up @@ -34,6 +34,14 @@ class StructColumn(Value):

shape = rlz.shape_like("values")

def __init__(self, names, values):
if len(names) != len(values):
raise ValidationError(
f"Length of names ({len(names)}) does not match length of "
f"values ({len(values)})"
)
super().__init__(names=names, values=values)

@attribute
def dtype(self) -> dt.DataType:
dtypes = (value.dtype for value in self.values)
Expand Down
14 changes: 14 additions & 0 deletions ibis/expr/operations/tests/test_structs.py
@@ -1,9 +1,12 @@
from __future__ import annotations

import pytest

import ibis
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import ValidationError


def test_struct_column_shape():
Expand All @@ -17,3 +20,14 @@ def test_struct_column_shape():
)
op = ops.StructColumn(names=("a",), values=(col,))
assert op.shape == ds.columnar


def test_struct_column_validates_input_lengths():
one = ops.Literal(1, dtype=dt.int64)
two = ops.Literal(2, dtype=dt.int64)

with pytest.raises(ValidationError):
ops.StructColumn(names=("a",), values=(one, two))

with pytest.raises(ValidationError):
ops.StructColumn(names=("a", "b"), values=(one,))

0 comments on commit 578980d

Please sign in to comment.