Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why aggregation methods return Cursor<Document> instead of Cursor<T>? #1098

Closed
clarkmcc opened this issue May 10, 2024 · 5 comments
Closed
Assignees
Labels

Comments

@clarkmcc
Copy link

I'm fighting an issue right now where my find queries are returning Cursor<T> but my aggregate queries are returning Cursor<Document> and when I map the stream to T my trait implementations where S: Stream<Item=Result<T, E>> no longer work for the mapped cursor.

That's a long way of saying the aggregation cursor user experience feels a bit worse than the other methods. This issue was also reported here: https://www.mongodb.com/community/forums/t/get-specific-data-type-from-aggregation-instead-of-document/188241 but I never saw any response.

I assume there's a reason behind aggregation cursors not returning Cursor<T> but I'm not sure what it is, because with the following quick and dirty patch, I was able to run the test suite and get aggregates to return Cursor<T>. Is there some reason I'm missing why this can't be officially supported?

Index: src/action/aggregate.rs
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/action/aggregate.rs b/src/action/aggregate.rs
--- a/src/action/aggregate.rs	(revision 241fe3ddbdcb68409315ffb7dd2db151dbae13f4)
+++ b/src/action/aggregate.rs	(date 1715373649510)
@@ -1,3 +1,4 @@
+use std::marker::PhantomData;
 use std::time::Duration;
 
 use bson::Document;
@@ -27,12 +28,13 @@
     /// `await` will return d[`Result<Cursor<Document>>`] or d[`Result<SessionCursor<Document>>`] if
     /// a `ClientSession` is provided.
     #[deeplink]
-    pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
+    pub fn aggregate<T>(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate<ImplicitSession, T> {
         Aggregate {
             target: AggregateTargetRef::Database(self),
             pipeline: pipeline.into_iter().collect(),
             options: None,
             session: ImplicitSession,
+            _phantom: PhantomData,
         }
     }
 }
@@ -49,12 +51,13 @@
     /// `await` will return d[`Result<Cursor<Document>>`] or d[`Result<SessionCursor<Document>>`] if
     /// a [`ClientSession`] is provided.
     #[deeplink]
-    pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate {
+    pub fn aggregate(&self, pipeline: impl IntoIterator<Item = Document>) -> Aggregate<ImplicitSession, T> {
         Aggregate {
             target: AggregateTargetRef::Collection(CollRef::new(self)),
             pipeline: pipeline.into_iter().collect(),
             options: None,
             session: ImplicitSession,
+            _phantom: PhantomData,
         }
     }
 }
@@ -95,11 +98,12 @@
 /// Run an aggregation operation.  Construct with [`Database::aggregate`] or
 /// [`Collection::aggregate`].
 #[must_use]
-pub struct Aggregate<'a, Session = ImplicitSession> {
+pub struct Aggregate<'a, Session = ImplicitSession, T = Document> {
     target: AggregateTargetRef<'a>,
     pipeline: Vec<Document>,
     options: Option<AggregateOptions>,
     session: Session,
+    _phantom: PhantomData<T>,
 }
 
 impl<'a, Session> Aggregate<'a, Session> {
@@ -119,7 +123,7 @@
     );
 }
 
-impl<'a> Aggregate<'a, ImplicitSession> {
+impl<'a, T> Aggregate<'a, ImplicitSession, T> {
     /// Use the provided session when running the operation.
     pub fn session(
         self,
@@ -130,15 +134,16 @@
             pipeline: self.pipeline,
             options: self.options,
             session: ExplicitSession(value.into()),
+            _phantom: PhantomData,
         }
     }
 }
 
 #[action_impl(sync = crate::sync::Cursor<Document>)]
-impl<'a> Action for Aggregate<'a, ImplicitSession> {
+impl<'a, T> Action for Aggregate<'a, ImplicitSession, T> {
     type Future = AggregateFuture;
 
-    async fn execute(mut self) -> Result<Cursor<Document>> {
+    async fn execute(mut self) -> Result<Cursor<T>> {
         resolve_options!(
             self.target,
             self.options,
@@ -156,10 +161,10 @@
 }
 
 #[action_impl(sync = crate::sync::SessionCursor<Document>)]
-impl<'a> Action for Aggregate<'a, ExplicitSession<'a>> {
+impl<'a, T> Action for Aggregate<'a, ExplicitSession<'a>, T> {
     type Future = AggregateSessionFuture;
 
-    async fn execute(mut self) -> Result<SessionCursor<Document>> {
+    async fn execute(mut self) -> Result<SessionCursor<T>> {
         resolve_read_concern_with_session!(self.target, self.options, Some(&mut *self.session.0))?;
         resolve_write_concern_with_session!(self.target, self.options, Some(&mut *self.session.0))?;
         resolve_selection_criteria_with_session!(
Index: src/test/util.rs
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/test/util.rs b/src/test/util.rs
--- a/src/test/util.rs	(revision 241fe3ddbdcb68409315ffb7dd2db151dbae13f4)
+++ b/src/test/util.rs	(date 1715375273252)
@@ -282,6 +282,22 @@
         self.get_coll(db_name, coll_name)
     }
 
+    pub(crate) async fn create_fresh_typed<T: Send + Sync>(
+        &self,
+        db_name: &str,
+        coll_name: &str,
+        options: impl Into<Option<CreateCollectionOptions>>,
+    ) -> Collection<T> {
+        self.drop_collection(db_name, coll_name).await;
+        self.database(db_name)
+            .create_collection(coll_name)
+            .with_options(options)
+            .await
+            .unwrap();
+
+        self.database(db_name).collection(coll_name)
+    }
+
     pub(crate) fn supports_fail_command(&self) -> bool {
         let version = if self.is_sharded() {
             ">= 4.1.5"
Index: src/test/db.rs
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/test/db.rs b/src/test/db.rs
--- a/src/test/db.rs	(revision 241fe3ddbdcb68409315ffb7dd2db151dbae13f4)
+++ b/src/test/db.rs	(date 1715375343608)
@@ -1,6 +1,9 @@
+use std::borrow::Borrow;
 use std::cmp::Ord;
 
 use futures::stream::TryStreamExt;
+use futures_util::StreamExt;
+use serde::{Deserialize, Serialize};
 
 use crate::{
     action::Action,
@@ -217,6 +220,23 @@
     assert!(coll3.id_index.is_none());
 }
 
+#[tokio::test]
+async fn db_aggregate_2() {
+    let client = TestClient::new().await;
+
+    #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
+    struct Record {
+        name: String
+    }
+
+
+    let col = client.create_fresh_typed::<Record>("default", "default", None).await;
+    col.insert_one(Record { name: "a".to_string() }).await.unwrap();
+    let cursor = col.aggregate(vec![doc!{"$match": {}}]).await.unwrap();
+    let docs = cursor.try_collect::<Vec<_>>().await.unwrap();
+    println!("{:?}", docs)
+}
+
 #[tokio::test]
 async fn db_aggregate() {
     let client = TestClient::new().await;
@@ -254,7 +274,7 @@
         },
     ];
 
-    db.aggregate(pipeline)
+    db.aggregate::<Document>(pipeline)
         .await
         .expect("aggregate should succeed");
 }
@clarkmcc clarkmcc changed the title Why Aggregate<Document> instead of Aggregate<T>? Why aggregation cursors return Cursor<Document> instead of Cursor<T>? May 10, 2024
@clarkmcc clarkmcc changed the title Why aggregation cursors return Cursor<Document> instead of Cursor<T>? Why aggregation methods return Cursor<Document> instead of Cursor<T>? May 10, 2024
@isabelatkinson
Copy link
Contributor

isabelatkinson commented May 10, 2024

Hey, thanks for opening this issue! This is a great idea.

For some historical context, when we originally added generics to the return types from collection methods, all of the Ts were the same as the collection's T. However, an aggregation pipeline can change the shape of the data being returned, which could cause deserialization errors when attempting to iterate a cursor over the collection's T. Rust does not support default generics for method signatures, so something like this wouldn't work:

impl<T> Collection<T> {
    pub async fn aggregate<U>(...) -> Result<Cursor<U = Document>> { ... }
}

So we opted to stick with Cursor<Document> to avoid requiring type labels for every call to aggregate.

However, that problem goes away with the new fluent-style API that will be released soon as shown by your diff. Would you like to submit a PR adding these changes? I can also take this on as it looks like a very simple change, but happy to let you take credit in the git history if desired :)

Edit: misread your code example a bit, I'm going to play around with this to figure out the exact API we want!

@clarkmcc
Copy link
Author

Hey @isabelatkinson, this is great news! My example definitely doesn't have the right API, more than anything it was just to illustrate that there weren't glaring type system issues with something like this. That being said I hear what you're saying and it makes sense that tying the Aggregate<T> to Collection<T> does not make sense in all cases. In my case it would be fine, but there are certainly times when that is not true.

What would be nice is the serde semantics, where the aggregate method accepts it's own type parameter that implements the required traits for deserialization. Sure, users would always have to specify the type they want to deserialize to, but from my more limited perspective, that developer experience is far more tolerable than dealing with Document.

let cursor: Cursor<Foobar> = collection.aggregate(...)
// or
let cursor = collection.aggregate::<Foobar>()

@isabelatkinson
Copy link
Contributor

Hey, I just put up a draft PR to add a with_type method to the Aggregate action. The basic usage would be:

#[derive(Deserialize)]
struct PipelineOutput {
    len: usize,
}

// returns Cursor<PipelineOutput>
let aggregate_cursor = collection
    .aggregate(pipeline)
    .with_type::<PipelineOutput>()
    .await?;

This change adds the functionality you're requesting without breaking any existing code that depends on a Cursor<Document> being returned. Let me know if that works for you!

@clarkmcc
Copy link
Author

@isabelatkinson this is a great usability improvement. Works for me!

@isabelatkinson
Copy link
Contributor

Just merged #1100, going to close this out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants