diff --git a/firebase/firestore.rules b/firebase/firestore.rules index 9a8a22b..badc9ce 100644 --- a/firebase/firestore.rules +++ b/firebase/firestore.rules @@ -7,5 +7,8 @@ service cloud.firestore { match /users/{user_id} { allow read, write: if true; } + match /transactions/{transaction_id} { + allow read, write: if true; + } } } diff --git a/rust/crates/web/src/infra/firestore.rs b/rust/crates/web/src/infra/firestore.rs index 7839765..6621b48 100644 --- a/rust/crates/web/src/infra/firestore.rs +++ b/rust/crates/web/src/infra/firestore.rs @@ -86,10 +86,48 @@ mod tests { endpoint, ) .await?; - // let collection_path = client.collection("repositories".to_string()); + let collection_path = client.collection("transactions".to_string()); - let transaction = client.begin_transaction().await?; - client.commit(transaction, vec![]).await?; + // reset + let (documents, _) = client.list::(&collection_path).await?; + for doc in documents { + client.delete(doc.name(), doc.update_time()).await?; + } + + let document_path = collection_path.doc("1".to_string()); + + #[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)] + struct V { + k1: String, + } + + let input = V { + k1: "v1".to_string(), + }; + let mut transaction = client.begin_transaction().await?; + transaction.create(&document_path, input)?; + client.commit(transaction).await?; + + let got = client.get::(&document_path).await?; + let current_update_time = got.update_time(); + + let mut transaction = client.begin_transaction().await?; + transaction.delete(&document_path, current_update_time)?; + client.rollback(transaction).await?; + + let got = client.get::(&document_path).await?; + let current_update_time = got.update_time(); + + let mut transaction = client.begin_transaction().await?; + transaction.delete(&document_path, current_update_time)?; + client.commit(transaction).await?; + + let err = client.get::(&document_path).await.unwrap_err(); + if let crate::infra::firestore::client::Error::Status(status) = err { + assert_eq!(status.code(), tonic::Code::NotFound); + } else { + panic!("unexpected error: {:?}", err); + } Ok(()) } diff --git a/rust/crates/web/src/infra/firestore/client.rs b/rust/crates/web/src/infra/firestore/client.rs index 68ded31..15f188b 100644 --- a/rust/crates/web/src/infra/firestore/client.rs +++ b/rust/crates/web/src/infra/firestore/client.rs @@ -1,9 +1,12 @@ use google_api_proto::google::firestore::v1::{ - firestore_client::FirestoreClient, precondition::ConditionType, value::ValueType, + firestore_client::FirestoreClient, + precondition::ConditionType, + value::ValueType, + write::{self, Operation}, BeginTransactionRequest, BeginTransactionResponse, CommitRequest, CommitResponse, CreateDocumentRequest, DeleteDocumentRequest, Document as FirestoreDocument, GetDocumentRequest, ListDocumentsRequest, ListDocumentsResponse, MapValue, Precondition, - UpdateDocumentRequest, Write, + RollbackRequest, UpdateDocumentRequest, Write, }; use google_authz::{Credentials, GoogleAuthz}; use serde::{de::DeserializeOwned, Serialize}; @@ -18,7 +21,57 @@ use super::{ timestamp::Timestamp, }; -pub struct Transaction(prost::bytes::Bytes); +pub struct Transaction { + transaction: prost::bytes::Bytes, + writes: Vec, +} + +impl Transaction { + pub fn create(&mut self, document_path: &DocumentPath, fields: T) -> Result<(), Error> + where + T: Serialize, + { + self.writes.push(Write { + operation: Some(Operation::Update(FirestoreDocument { + name: document_path.path(), + fields: { + let ser = to_value(&fields)?; + if let Some(ValueType::MapValue(MapValue { fields })) = ser.value_type { + fields + } else { + return Err(Error::ValueType); + } + }, + create_time: None, + update_time: None, + })), + update_mask: None, + update_transforms: vec![], + current_document: Some(Precondition { + condition_type: Some(ConditionType::Exists(false)), + }), + }); + Ok(()) + } + + pub fn delete( + &mut self, + document_path: &DocumentPath, + current_update_time: Timestamp, + ) -> Result<(), Error> { + self.writes.push(Write { + operation: Some(Operation::Delete(document_path.path())), + update_mask: None, + update_transforms: vec![], + current_document: Some(Precondition { + condition_type: Some(ConditionType::UpdateTime(prost_types::Timestamp::from( + current_update_time, + ))), + }), + }); + Ok(()) + } +} #[derive(Debug, thiserror::Error)] pub enum Error { @@ -72,7 +125,10 @@ impl Client { }) .await?; let BeginTransactionResponse { transaction } = response.into_inner(); - Ok(Transaction(transaction)) + Ok(Transaction { + transaction, + writes: vec![], + }) } pub fn collection(&self, collection_id: String) -> CollectionPath { @@ -82,14 +138,13 @@ impl Client { pub async fn commit( &mut self, transaction: Transaction, - writes: Vec, ) -> Result<((), Option), Error> { let response = self .client .commit(CommitRequest { database: self.root_path.database_name(), - writes, - transaction: transaction.0, + writes: transaction.writes, + transaction: transaction.transaction, }) .await?; // TODO: write_results @@ -197,6 +252,16 @@ impl Client { .map(|documents| (documents, next_page_token)) } + pub async fn rollback(&mut self, transaction: Transaction) -> Result<(), Error> { + self.client + .rollback(RollbackRequest { + database: self.root_path.database_name(), + transaction: transaction.transaction, + }) + .await?; + Ok(()) + } + pub async fn update( &mut self, document_path: &DocumentPath,