Skip to content

Commit

Permalink
Implement basic support for the field reflection API
Browse files Browse the repository at this point in the history
  • Loading branch information
cyang1 committed Nov 29, 2021
1 parent 415acbf commit dd5336e
Show file tree
Hide file tree
Showing 15 changed files with 4,898 additions and 258 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
* `copy_from_reader` replaces `from_reader` and allows a `PbBuffer` to be constructed, by copying, from any `Buf`. Implementations can still opt out by returning `Err`.
* `copy_to_writer` replaces `into_reader`. Callers that were using `into_reader` to call `Message::deserialize` should instead construct their desired `PbBufferReader` directly (e.g. `Cursor<Bytes>`).
* `PbBufferReader::as_buffer` is renamed to `read_buffer`, and provided implementations fall back to copying `Lazy` fields by default (instead of returning an error).
* Implement basic support for the field reflection API. (#121)
* `MessageDescriptor` has been reworked to be a `struct` rather than a `trait`, and is returned from `Message::descriptor` rather than implemented on a `Message`.
* `MessageDescriptor` has been augmented to return information about the fields and oneofs in the `Message`.
* A `Reflection` trait has been added to provide more dynamic access to proto fields.
* This is automatically generated by `pb-jelly-gen` for all generated protos, but any manually implemented `Message` that is used in a generated `Message` will also need to manually implement `Reflection`.
* Example usage can be seen in `pb-test/src/pbtest.rs` in `all_fields_reflection3`.

# 0.0.9
### November 2, 2021
Expand Down
188 changes: 160 additions & 28 deletions pb-jelly-gen/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,7 @@ def get_method(self) -> Tuple[Text, Text]:
elif self.field.type == FieldDescriptorProto.TYPE_BOOL:
return "bool", "self.%s.unwrap_or(false)" % name
elif self.field.type == FieldDescriptorProto.TYPE_STRING:
return (
"&str",
'self.%s.as_deref().unwrap_or("")' % name,
)
return ("&str", 'self.%s.as_deref().unwrap_or("")' % name)
elif self.field.type == FieldDescriptorProto.TYPE_BYTES:
assert not (
self.is_blob() or self.is_grpc_slices() or self.is_lazy_bytes()
Expand Down Expand Up @@ -529,12 +526,14 @@ def enum_closed(enum: EnumDescriptorProto) -> bool:


@contextmanager
def block(ctx: "CodeWriter", header: Text) -> Iterator[None]:
ctx.write("%s {" % header)
def block(
ctx: "CodeWriter", header: Text, start: Text = " {", end: Text = "}"
) -> Iterator[None]:
ctx.write("%s%s" % (header, start))
ctx.indentation += 1
yield
ctx.indentation -= 1
ctx.write("}")
ctx.write(end)


@contextmanager
Expand Down Expand Up @@ -1013,7 +1012,70 @@ def gen_msg(
)

with block(self, "impl ::pb_jelly::Message for " + name):
with block(self, "fn compute_size(&self) -> usize "):
with block(
self,
"fn descriptor(&self) -> ::std::option::Option<::pb_jelly::MessageDescriptor>",
):
name = "_".join(path + [msg_type.name])
full_name = (
".".join([self.proto_file.package, name])
if self.proto_file.package
else name
)

with block(
self, "Some(::pb_jelly::MessageDescriptor", start=" {", end="})"
):
self.write('name: "%s",' % name)
self.write('full_name: "%s",' % full_name)
with block(self, "fields:", start=" &[", end="],"):
for i, field in enumerate(msg_type.field):
with block(
self,
"::pb_jelly::FieldDescriptor",
start=" {",
end="},",
):
full_name = ".".join(
[self.proto_file.package, name, field.name]
if self.proto_file.package
else [name, field.name]
)

typ = self.rust_type(msg_type, field)
self.write('name: "%s",' % field.name)
self.write('full_name: "%s",' % full_name)
self.write("index: %d," % i)
self.write("number: %d," % field.number)
self.write(
"typ: ::pb_jelly::wire_format::Type::%s,"
% typ.wire_format()
)
if field.label == FieldDescriptorProto.LABEL_OPTIONAL:
self.write("label: ::pb_jelly::Label::Optional,")
elif field.label == FieldDescriptorProto.LABEL_REQUIRED:
self.write("label: ::pb_jelly::Label::Required,")
elif field.label == FieldDescriptorProto.LABEL_REPEATED:
self.write("label: ::pb_jelly::Label::Repeated,")

if field.HasField("oneof_index"):
self.write(
"oneof_index: Some(%d)," % field.oneof_index
)
else:
self.write("oneof_index: None,")

with block(self, "oneofs:", start=" &[", end="],"):
for oneof in msg_type.oneof_decl:
with block(
self,
"::pb_jelly::OneofDescriptor",
start=" {",
end="},",
):
self.write('name: "%s",' % oneof.name)

with block(self, "fn compute_size(&self) -> usize"):
if (
len(msg_type.field) > 0
or msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]
Expand Down Expand Up @@ -1057,7 +1119,7 @@ def gen_msg(
else:
self.write("0")

with block(self, "fn compute_grpc_slices_size(&self) -> usize "):
with block(self, "fn compute_grpc_slices_size(&self) -> usize"):
if len(msg_type.field) > 0:
self.write("let mut size = 0;")
for field in msg_type.field:
Expand Down Expand Up @@ -1291,21 +1353,96 @@ def gen_msg(
self.write("unrecognized.serialize(&mut self._unrecognized)?;")
self.write("Ok(())")

def gen_msg_descriptor(
self,
path: List[Text],
desc_proto: DescriptorProto,
package: Optional[Text],
scl: SourceCodeLocation,
) -> None:
assert self.indentation == 0
with block(self, "impl ::pb_jelly::Reflection for " + name):
with block(
self,
"fn which_one_of(&self, oneof_name: &str) -> ::std::option::Option<&'static str>",
):
with block(self, "match oneof_name"):
for oneof in msg_type.oneof_decl:
oneof_field = oneof_fields[oneof.name][0]

name = "_".join(path + [desc_proto.name])
full_name = ".".join([package, name]) if package else name
for oneof in msg_type.oneof_decl:
with block(self, '"%s" =>' % oneof.name):
for oneof_field in oneof_fields[oneof.name]:
with field_iter(
self, "val", name, msg_type, oneof_field
):
self.write('return Some("%s");' % oneof_field.name)
self.write("return None;")
with block(self, "_ =>"):
self.write('panic!("unknown oneof name given");')

with block(self, "impl ::pb_jelly::MessageDescriptor for " + name):
self.write('const NAME: &\'static str = "%s";' % name)
self.write('const FULL_NAME: &\'static str = "%s";' % full_name)
with block(
self,
"fn get_field_mut(&mut self, field_name: &str) -> ::pb_jelly::reflection::FieldMut<'_>",
):
with block(self, "match field_name"):
for field in msg_type.field:
typ = self.rust_type(msg_type, field)
with block(self, '"%s" =>' % field.name):
if typ.oneof:
with block(self, "match self.%s" % typ.oneof.name):
self.write("%s => ()," % typ.oneof_val(name, "_"))
with block(self, "_ =>", start=" {", end="},"):
# If this oneof is not currently set to this variant, we explicitly
# set it to this variant.
self.write(
"self.%s = %s;"
% (
typ.oneof.name,
typ.oneof_val(
name,
"::std::default::Default::default()",
),
)
)
if typ.is_empty_oneof_field():
self.write(
"return ::pb_jelly::reflection::FieldMut::Empty;"
)
else:
with block(
self,
"if let %s = self.%s"
% (
typ.oneof_val(name, "ref mut val"),
typ.oneof.name,
),
):
self.write(
"return ::pb_jelly::reflection::FieldMut::Value(val);"
)
self.write("unreachable!()")
elif typ.is_repeated():
# TODO: Would be nice to support this, but some more thought would
# need to be put into what the API for it looks like.
# self.write("return ::pb_jelly::reflection::FieldMut::Repeated(&mut self.%s);" % field.name)
self.write(
'unimplemented!("Repeated fields are not currently supported.");'
)
elif typ.is_nullable() and typ.is_boxed():
self.write(
"return ::pb_jelly::reflection::FieldMut::Value(self.%s.get_or_insert_with(::std::default::Default::default).as_mut());"
% field.name
)
elif typ.is_boxed():
self.write(
"return ::pb_jelly::reflection::FieldMut::Value(self.%s.as_mut());"
% field.name
)
elif typ.is_nullable():
self.write(
"return ::pb_jelly::reflection::FieldMut::Value(self.%s.get_or_insert_with(::std::default::Default::default));"
% field.name
)
else:
self.write(
"return ::pb_jelly::reflection::FieldMut::Value(&mut self.%s);"
% field.name
)
with block(self, "_ =>"):
self.write('panic!("unknown field name given");')


def walk(proto: FileDescriptorProto) -> WalkRet:
Expand Down Expand Up @@ -1344,11 +1481,7 @@ def _walk(

class ProtoType(Generic[M]):
def __init__(
self,
ctx: "Context",
proto_file: FileDescriptorProto,
path: List[Text],
typ: M,
self, ctx: "Context", proto_file: FileDescriptorProto, path: List[Text], typ: M
) -> None:
self.ctx = ctx
self.proto_file = proto_file
Expand Down Expand Up @@ -1823,7 +1956,6 @@ def add_mod(writer: CodeWriter) -> None:
for path, msg_typ, scl in messages:
ctx.set_boxed_if_recursive(ProtoType(ctx, proto_file, path, msg_typ))
writer.gen_msg(path, msg_typ, scl)
writer.gen_msg_descriptor(path, msg_typ, proto_file.package, scl)
writer.write("")

add_mod(writer=writer)
Expand Down

0 comments on commit dd5336e

Please sign in to comment.