Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_StructModuleWhich: Use enum #262

Merged
merged 3 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions capnp/lib/capnp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ from libc.string cimport memcpy
import array
import asyncio
import collections as _collections
import enum as _enum
import inspect as _inspect
import os as _os
import random as _random
Expand Down Expand Up @@ -1392,6 +1393,8 @@ cdef class _DynamicStructBuilder:

:Raises: :exc:`KjException` if the field isn't in this struct
"""
if isinstance(field, _StructModuleWhich):
field = field.name[0].lower() + field.name[1:]
if size is None:
return to_python_builder(self.thisptr.init(field), self._parent)
else:
Expand Down Expand Up @@ -3152,8 +3155,12 @@ cdef _new_message(self, kwargs, num_first_segment_words):
return msg


class _StructModuleWhich(object):
pass
class _StructModuleWhich(_enum.Enum):
def __eq__(self, other):
if isinstance(other, int):
return self.value == other
else:
return self.name == other


class _StructModule(object):
Expand All @@ -3170,11 +3177,20 @@ class _StructModule(object):
if field_schema.discriminantCount == 0:
sub_module = _StructModule(raw_schema, name)
else:
sub_module = _StructModuleWhich()
setattr(sub_module, 'schema', raw_schema)
mapping = []
for union_field in field_schema.fields:
setattr(sub_module, union_field.name, union_field.discriminantValue)
mapping.append((union_field.name, union_field.discriminantValue))
sub_module = _StructModuleWhich("StructModuleWhich", mapping)
setattr(sub_module, 'schema', raw_schema)
setattr(self, name, sub_module)
if schema.union_fields and not schema.non_union_fields:
mapping = []
for union_field in schema.node.struct.fields:
name = union_field.name
name = name[0].upper() + name[1:]
mapping.append((name, union_field.discriminantValue))
sub_module = _StructModuleWhich("StructModuleWhich", mapping)
setattr(self, 'Union', sub_module)

def read(self, file, traversal_limit_in_words=None, nesting_limit=None):
"""Returns a Reader for the unpacked object read from file.
Expand Down
18 changes: 18 additions & 0 deletions test/all_types.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,21 @@ struct TestAllTypes {
enumList @32 : List(TestEnum);
interfaceList @33 : List(Void); # TODO
}

struct UnionAllTypes {
union {
unionStructField1 @0 : TestAllTypes;
unionStructField2 @1 : TestAllTypes;
}
}

struct GroupedUnionAllTypes {
union {
g1 :group {
unionStructField1 @0 : TestAllTypes;
}
g2 :group {
unionStructField2 @1 : TestAllTypes;
}
}
}
29 changes: 29 additions & 0 deletions test/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,38 @@ def test_set_dict(all_types):
def test_set_dict_union(addressbook):
person = addressbook.Person.new_message(**{'employment': {'employer': {'name': 'foo'}}})

assert person.employment.which == addressbook.Person.Employment.employer

assert person.employment.employer.name == 'foo'


def test_union_enum(all_types):
assert all_types.UnionAllTypes.Union.UnionStructField1 == 0
assert all_types.UnionAllTypes.Union.UnionStructField2 == 1

msg = all_types.UnionAllTypes.new_message(**{'unionStructField1': {'textField': "foo"}})
assert msg.which == all_types.UnionAllTypes.Union.UnionStructField1
assert msg.which == 'unionStructField1'
assert msg.which == 0

msg = all_types.UnionAllTypes.new_message(**{'unionStructField2': {'textField': "foo"}})
assert msg.which == all_types.UnionAllTypes.Union.UnionStructField2
assert msg.which == 'unionStructField2'
assert msg.which == 1

assert all_types.GroupedUnionAllTypes.Union.G1 == 0
assert all_types.GroupedUnionAllTypes.Union.G2 == 1

msg = all_types.GroupedUnionAllTypes.new_message(**{'g1': {'unionStructField1': {'textField': "foo"}}})
assert msg.which == all_types.GroupedUnionAllTypes.Union.G1

msg = all_types.GroupedUnionAllTypes.new_message(**{'g2': {'unionStructField2': {'textField': "foo"}}})
assert msg.which == all_types.GroupedUnionAllTypes.Union.G2

msg = all_types.UnionAllTypes.new_message()
msg.unionStructField2 = msg.init(all_types.UnionAllTypes.Union.UnionStructField2)


def isstr(s):
return isinstance(s, str)

Expand Down