Skip to content

Commit

Permalink
Merge branch 'feat/statically-typed-pipeline' into rust_sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
laysakura committed Apr 4, 2023
2 parents d1b468a + d8071cf commit 87b9939
Show file tree
Hide file tree
Showing 20 changed files with 383 additions and 319 deletions.
125 changes: 125 additions & 0 deletions sdks/rust/src/coders/coder_resolver/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::{fmt, marker::PhantomData};

use integer_encoding::VarInt;

use crate::{
coders::{
required_coders::{BytesCoder, Iterable, IterableCoder, KVCoder, KV},
standard_coders::{StrUtf8Coder, VarIntCoder},
CoderI,
},
elem_types::ElemType,
};

/// Resolve a coder (implementing `CoderI) from a coder URN and an `ElemType`.
///
/// You may use original coders by implementing the `CoderResolver` trait.
pub trait CoderResolver {
type E: ElemType;
type C: CoderI<E = Self::E>;

/// Resolve a coder from a coder URN.
///
/// # Returns
///
/// `Some(C)` if the coder was resolved, `None` otherwise.
fn resolve(coder_urn: &str) -> Option<Self::C> {
(coder_urn == Self::C::get_coder_urn()).then_some(Self::C::default())
}
}

/// `Vec<u8>` -> `BytesCoder`.
#[derive(Debug)]
pub struct BytesCoderResolverDefault;

impl CoderResolver for BytesCoderResolverDefault {
type E = Vec<u8>;
type C = BytesCoder;
}

/// `KV` -> `KVCoder`.
#[derive(Debug)]
pub struct KVCoderResolverDefault<K, V> {
phantom: PhantomData<KV<K, V>>,
}

impl<K, V> CoderResolver for KVCoderResolverDefault<K, V>
where
K: Clone + fmt::Debug + Send + Sync + 'static,
V: Clone + fmt::Debug + Send + Sync + 'static,
{
type E = KV<K, V>;
type C = KVCoder<Self::E>;
}

/// `Iterable` -> `IterableCoder`.
#[derive(Debug)]
pub struct IterableCoderResolverDefault<E>
where
E: ElemType,
{
phantom: PhantomData<E>,
}

impl<E> CoderResolver for IterableCoderResolverDefault<E>
where
E: ElemType + fmt::Debug,
{
type E = Iterable<E>;
type C = IterableCoder<Self::E>;
}

/// `String` -> `StrUtf8Coder`.
#[derive(Debug)]
pub struct StrUtf8CoderResolverDefault;

impl CoderResolver for StrUtf8CoderResolverDefault {
type E = String;
type C = StrUtf8Coder;
}

#[derive(Debug)]
pub struct VarIntCoderResolverDefault<N: fmt::Debug + VarInt> {
phantom: PhantomData<N>,
}

impl CoderResolver for VarIntCoderResolverDefault<i8> {
type E = i8;
type C = VarIntCoder<i8>;
}
impl CoderResolver for VarIntCoderResolverDefault<i16> {
type E = i16;
type C = VarIntCoder<i16>;
}
impl CoderResolver for VarIntCoderResolverDefault<i32> {
type E = i32;
type C = VarIntCoder<i32>;
}
impl CoderResolver for VarIntCoderResolverDefault<i64> {
type E = i64;
type C = VarIntCoder<i64>;
}
impl CoderResolver for VarIntCoderResolverDefault<isize> {
type E = isize;
type C = VarIntCoder<isize>;
}
impl CoderResolver for VarIntCoderResolverDefault<u8> {
type E = u8;
type C = VarIntCoder<u8>;
}
impl CoderResolver for VarIntCoderResolverDefault<u16> {
type E = u16;
type C = VarIntCoder<u16>;
}
impl CoderResolver for VarIntCoderResolverDefault<u32> {
type E = u32;
type C = VarIntCoder<u32>;
}
impl CoderResolver for VarIntCoderResolverDefault<u64> {
type E = u64;
type C = VarIntCoder<u64>;
}
impl CoderResolver for VarIntCoderResolverDefault<usize> {
type E = usize;
type C = VarIntCoder<usize>;
}
77 changes: 11 additions & 66 deletions sdks/rust/src/coders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,90 +16,35 @@
* limitations under the License.
*/

pub mod coder_resolver;
pub mod required_coders;
pub mod rust_coders;
pub mod standard_coders;
pub mod urns;

use std::collections::HashMap;
use std::fmt;
use std::io::{self, Read, Write};

use crate::coders::urns::*;

pub struct CoderRegistry {
internal_registry: HashMap<&'static str, CoderTypeDiscriminants>,
}

impl CoderRegistry {
pub fn new() -> Self {
let internal_registry: HashMap<&'static str, CoderTypeDiscriminants> = HashMap::from([
(BYTES_CODER_URN, CoderTypeDiscriminants::Bytes),
(
GENERAL_OBJECT_CODER_URN,
CoderTypeDiscriminants::GeneralObject,
),
(KV_CODER_URN, CoderTypeDiscriminants::KV),
(ITERABLE_CODER_URN, CoderTypeDiscriminants::Iterable),
(STR_UTF8_CODER_URN, CoderTypeDiscriminants::StrUtf8),
(VARINT_CODER_URN, CoderTypeDiscriminants::VarIntCoder),
]);

Self { internal_registry }
}

pub fn get_coder_type(&self, urn: &str) -> &CoderTypeDiscriminants {
let coder_type = self
.internal_registry
.get(urn)
.unwrap_or_else(|| panic!("No coder type registered for URN {urn}"));

coder_type
}

pub fn register(&mut self, urn: &'static str, coder_type: CoderTypeDiscriminants) {
self.internal_registry.insert(urn, coder_type);
}
}

impl Default for CoderRegistry {
fn default() -> Self {
Self::new()
}
}

#[derive(Clone, EnumDiscriminants)]
pub enum CoderType {
// ******* Required coders *******
Bytes,
Iterable,
KV,

// ******* Rust coders *******
GeneralObject,

// ******* Standard coders *******
StrUtf8,
VarIntCoder,
}

// TODO: create and use separate AnyCoder trait instead of Any
// ...

/// This is the base interface for coders, which are responsible in Apache Beam to encode and decode
/// elements of a PCollection.
pub trait CoderI<T> {
fn get_coder_type(&self) -> &CoderTypeDiscriminants;
pub trait CoderI: fmt::Debug + Default {
/// The type of the elements to be encoded/decoded.
type E;

fn get_coder_urn() -> &'static str
where
Self: Sized;

/// Encode an element into a stream of bytes
fn encode(
&self,
element: T,
element: Self::E,
writer: &mut dyn Write,
context: &Context,
) -> Result<usize, io::Error>;

/// Decode an element from an incoming stream of bytes
fn decode(&self, reader: &mut dyn Read, context: &Context) -> Result<T, io::Error>;
fn decode(&self, reader: &mut dyn Read, context: &Context) -> Result<Self::E, io::Error>;
}

/// The context for encoding a PCollection element.
Expand Down

0 comments on commit 87b9939

Please sign in to comment.