Skip to content

Commit

Permalink
Merge remote-tracking branch 'robertwb/rust-dofn' into rust_sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
laysakura committed Feb 8, 2023
2 parents a85aad7 + eb3894d commit 9564b4e
Show file tree
Hide file tree
Showing 17 changed files with 938 additions and 116 deletions.
1 change: 1 addition & 0 deletions sdks/rust/src/internals/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

pub mod pipeline;
pub mod pvalue;
pub mod serialize;
pub mod urns;

pub mod utils {
Expand Down
53 changes: 43 additions & 10 deletions sdks/rust/src/internals/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ impl<'a> Pipeline {

let flattened = flatten_pvalue(input.clone(), None);
let mut inputs: HashMap<String, String> = HashMap::new();
for (name, pvalue) in flattened {
inputs.insert(name.clone(), pvalue.get_id());
for (name, id) in flattened {
inputs.insert(name.clone(), id);
}

let transform_proto = proto_pipeline::PTransform {
Expand Down Expand Up @@ -249,18 +249,52 @@ impl<'a> Pipeline {
Out: Clone + Send,
F: PTransform<In, Out> + Send,
{
let (transform_id, transform_proto) = self.pre_apply_transform(&transform, &input);
// TODO: Inline pre_apply and post_apply.
// (They exist in typescript only to share code between the sync and
// async variants).
let (transform_id, mut transform_proto) = self.pre_apply_transform(&transform, &input);

let mut transform_stack = self.transform_stack.lock().unwrap();
{
let mut transform_stack = self.transform_stack.lock().unwrap();
transform_stack.push(transform_id.clone());
drop(transform_stack);
}

transform_stack.push(transform_id);
let result = transform.expand_internal(input, pipeline, &mut transform_proto);

let result = transform.expand_internal(input, pipeline, transform_proto.clone());
for (name, id) in flatten_pvalue(result.clone(), None) {
// Causes test to hang...
transform_proto.outputs.insert(name.clone(), id);
}

// TODO: ensure this happens even if an error takes place above
transform_stack.pop();
// Re-insert the transform with its outputs and any mutation that
// expand_internal performed.
let mut pipeline_proto = self.proto.lock().unwrap();
// This may have been mutated.
// TODO: Perhaps only insert at the end?
transform_proto.subtransforms = pipeline_proto
.components
.as_mut()
.unwrap()
.transforms
.get(&transform_id)
.unwrap()
.subtransforms
.clone();
pipeline_proto
.components
.as_mut()
.unwrap()
.transforms
.insert(transform_id.clone(), transform_proto.clone());
drop(pipeline_proto);

drop(transform_stack);
// TODO: ensure this happens even if an error takes place above
{
let mut transform_stack = self.transform_stack.lock().unwrap();
transform_stack.pop();
drop(transform_stack);
}

self.post_apply_transform(transform, transform_proto, result)
}
Expand Down Expand Up @@ -291,7 +325,6 @@ impl<'a> Pipeline {
// TODO: remove pcoll_proto arg
PValue::new(
crate::internals::pvalue::PType::PCollection,
proto_pipeline::PCollection::default(),
pipeline,
self.create_pcollection_id_internal(coder_id),
)
Expand Down
79 changes: 24 additions & 55 deletions sdks/rust/src/internals/pvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ use crate::proto::beam_api::pipeline as proto_pipeline;

use crate::internals::pipeline::Pipeline;

// TODO: remove field pcoll_proto.
// T should be never(!) for Root
// https://github.com/rust-lang/rust/issues/35121
#[derive(Clone)]
Expand All @@ -37,7 +36,6 @@ where
{
id: String,
ptype: PType,
pcoll_proto: proto_pipeline::PCollection,
pipeline: Arc<Pipeline>,

phantom: PhantomData<T>,
Expand All @@ -47,61 +45,29 @@ impl<T> PValue<T>
where
T: Clone + Send,
{
pub fn new(
ptype: PType,
pcoll_proto: proto_pipeline::PCollection,
pipeline: Arc<Pipeline>,
id: String,
) -> Self {
pub fn new(ptype: PType, pipeline: Arc<Pipeline>, id: String) -> Self {
Self {
id,
ptype,
pcoll_proto,
pipeline,

phantom: PhantomData::default(),
}
}

pub fn new_root(pipeline: Arc<Pipeline>) -> Self {
let pcoll_name = "root".to_string();

let proto_coder_id = pipeline.register_coder_proto(proto_pipeline::Coder {
spec: Some(proto_pipeline::FunctionSpec {
urn: String::from(crate::coders::urns::BYTES_CODER_URN),
payload: Vec::with_capacity(0),
}),
component_coder_ids: Vec::with_capacity(0),
});

pipeline.register_coder::<BytesCoder, Vec<u8>>(Box::new(BytesCoder::new()));

let output_proto = proto_pipeline::PCollection {
unique_name: pcoll_name.clone(),
coder_id: proto_coder_id,
is_bounded: proto_pipeline::is_bounded::Enum::Bounded as i32,
windowing_strategy_id: "placeholder".to_string(),
display_data: Vec::with_capacity(0),
};

let impulse_proto = proto_pipeline::PTransform {
unique_name: "root".to_string(),
spec: None,
subtransforms: Vec::with_capacity(0),
inputs: HashMap::with_capacity(0),
outputs: HashMap::from([("out".to_string(), pcoll_name)]),
display_data: Vec::with_capacity(0),
environment_id: "".to_string(),
annotations: HashMap::with_capacity(0),
};

pipeline.register_proto_transform(impulse_proto);
PValue::new(PType::Root, pipeline, crate::internals::utils::get_bad_id())
}

pub fn new_array(pcolls: &[PValue<T>]) -> Self {
PValue::new(
PType::Root,
output_proto,
pipeline,
crate::internals::utils::get_bad_id(),
PType::PValueArr,
pcolls[0].clone().pipeline,
pcolls
.iter()
.map(|pcoll| -> String { pcoll.id.clone() })
.collect::<Vec<String>>()
.join(","),
)
}

Expand Down Expand Up @@ -145,29 +111,32 @@ where
// }
}

/// Returns a PValue as a flat object with string keys and PCollection values.
/// Returns a PValue as a flat object with string keys and PCollection id values.
///
/// The full set of PCollections reachable by this PValue will be returned,
/// with keys corresponding roughly to the path taken to get there
pub fn flatten_pvalue<T>(pvalue: PValue<T>, prefix: Option<String>) -> HashMap<String, PValue<T>>
pub fn flatten_pvalue<T>(pvalue: PValue<T>, prefix: Option<String>) -> HashMap<String, String>
where
T: Clone + Send,
{
let mut result: HashMap<String, PValue<T>> = HashMap::new();
let mut result: HashMap<String, String> = HashMap::new();
match pvalue.ptype {
PType::PCollection => match prefix {
Some(pr) => {
result.insert(pr, pvalue);
result.insert(pr, pvalue.get_id());
}
None => {
result.insert("main".to_string(), pvalue);
result.insert("main".to_string(), pvalue.get_id());
}
},
PType::PValueArr => todo!(),
PType::PValueMap => todo!(),
PType::Root => {
result.insert("main".to_string(), pvalue);
PType::PValueArr => {
// TODO: Remove this hack, PValues can have multiple ids.
for (i, id) in pvalue.get_id().split(",").enumerate() {
result.insert(i.to_string(), id.to_string());
}
}
PType::PValueMap => todo!(),
PType::Root => {}
}

result
Expand Down Expand Up @@ -201,7 +170,7 @@ where
&self,
input: &PValue<In>,
pipeline: Arc<Pipeline>,
transform_proto: proto_pipeline::PTransform,
transform_proto: &mut proto_pipeline::PTransform,
) -> PValue<Out>
where
Self: Sized,
Expand Down
124 changes: 124 additions & 0 deletions sdks/rust/src/internals/serialize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use std::collections::HashMap;
use std::iter::Iterator;
use std::marker::PhantomData;

use std::any::Any;
use std::boxed::Box;
use std::sync::Mutex;

use once_cell::sync::Lazy;

static SERIALIZED_FNS: Lazy<Mutex<HashMap<String, Box<dyn Any + Sync + Send>>>> =
Lazy::new(|| Mutex::new(HashMap::new()));

pub fn serialize_fn<T: Any + Sync + Send>(obj: Box<T>) -> String {
let name = format!("object{}", SERIALIZED_FNS.lock().unwrap().len());
SERIALIZED_FNS.lock().unwrap().insert(name.to_string(), obj);
name
}

pub fn deserialize_fn<T: Any + Sync + Send>(name: &String) -> Option<&'static T> {
let binding = SERIALIZED_FNS.lock().unwrap();
let untyped = binding.get(name);
let typed = match untyped {
Some(x) => x.downcast_ref::<T>(),
None => None,
};

unsafe {
return std::mem::transmute::<Option<&T>, Option<&'static T>>(typed);
}
}

// ******* DoFn Wrappers, perhaps move elsewhere? *******

// TODO: Give these start/finish_bundles, etc.
pub type GenericDoFn =
Box<dyn Fn(&dyn Any) -> Box<dyn Iterator<Item = Box<dyn Any>>> + Sync + Send>;

struct GenericDoFnWrapper {
func: GenericDoFn,
}

unsafe impl std::marker::Send for GenericDoFnWrapper {}

struct BoxedIter<O: Any, I: IntoIterator<Item = O>> {
typed_iter: I::IntoIter,
}

impl<O: Any, I: IntoIterator<Item = O>> Iterator for BoxedIter<O, I> {
type Item = Box<dyn Any>;

fn next(&mut self) -> Option<Box<dyn Any>> {
if let Some(x) = self.typed_iter.next() {
return Some(Box::new(x));
} else {
return None;
}
}
}

pub fn to_generic_dofn<T: Any, O: Any, I: IntoIterator<Item = O> + 'static>(
func: fn(&T) -> I,
) -> GenericDoFn {
Box::new(
move |untyped_input: &dyn Any| -> Box<dyn Iterator<Item = Box<dyn Any>>> {
let typed_input: &T = untyped_input.downcast_ref::<T>().unwrap();
Box::new(BoxedIter::<O, I> {
typed_iter: func(typed_input).into_iter(),
})
},
)
}

pub fn to_generic_dofn_dyn<T: Any, O: Any, I: IntoIterator<Item = O> + 'static>(
func: Box<dyn Fn(&T) -> I + Sync + Send>,
) -> GenericDoFn {
Box::new(
move |untyped_input: &dyn Any| -> Box<dyn Iterator<Item = Box<dyn Any>>> {
let typed_input: &T = untyped_input.downcast_ref::<T>().unwrap();
Box::new(BoxedIter::<O, I> {
typed_iter: func(typed_input).into_iter(),
})
},
)
}

pub trait KeyExtractor: Sync + Send {
fn extract(&self, kv: &dyn Any) -> (String, Box<dyn Any + Sync + Send>);
fn recombine(
&self,
key: &String,
values: &Box<Vec<Box<dyn Any + Sync + Send>>>,
) -> Box<dyn Any + Sync + Send>;
}

pub struct TypedKeyExtractor<V: Clone + Sync + Send + 'static> {
phantom_data: PhantomData<V>,
}

impl<V: Clone + Sync + Send + 'static> TypedKeyExtractor<V> {
pub fn default() -> Self {
Self {
phantom_data: PhantomData,
}
}
}

impl<V: Clone + Sync + Send + 'static> KeyExtractor for TypedKeyExtractor<V> {
fn extract(&self, kv: &dyn Any) -> (String, Box<dyn Any + Sync + Send>) {
let typed_kv = kv.downcast_ref::<(String, V)>().unwrap();
return (typed_kv.0.clone(), Box::new(typed_kv.1.clone()));
}
fn recombine(
&self,
key: &String,
values: &Box<Vec<Box<dyn Any + Sync + Send>>>,
) -> Box<dyn Any + Sync + Send> {
let mut typed_values: Vec<V> = Vec::new();
for untyped_value in values.iter() {
typed_values.push(untyped_value.downcast_ref::<V>().unwrap().clone());
}
return Box::new((key.clone(), typed_values));
}
}
4 changes: 4 additions & 0 deletions sdks/rust/src/internals/urns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ pub const IMPULSE_BUFFER: &[u8] = "impulse".as_bytes();

pub const DATA_INPUT_URN: &str = "beam:runner:source:v1";
pub const DATA_OUTPUT_URN: &str = "beam:runner:sink:v1";
pub const IMPULSE_URN: &str = "beam:transform:impulse:v1";
pub const PAR_DO_URN: &str = "beam:beam:pardo:v1";
pub const GROUP_BY_KEY_URN: &str = "beam:beam:group_by_key:v1";
pub const FLATTEN_URN: &str = "beam:beam:flatten:v1";
pub const IDENTITY_DOFN_URN: &str = "beam:dofn:identity:0.1";

// TODO: move test urns elsewhere
Expand Down
2 changes: 1 addition & 1 deletion sdks/rust/src/runners/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub trait RunnerI {
where
In: Clone + Send,
Out: Clone + Send,
F: FnOnce(PValue<In>) -> PValue<Out> + Send,
F: FnOnce(PValue<In>) -> PValue<Out> + Send, // TODO: Don't require a return value.
{
self.run_async(pipeline).await;
}
Expand Down

0 comments on commit 9564b4e

Please sign in to comment.