Skip to content

Commit

Permalink
Add Client::{commit,rollback}, Transaction::{create,delete}
Browse files Browse the repository at this point in the history
  • Loading branch information
bouzuya committed Nov 4, 2023
1 parent 5f2c5f9 commit 9d1738a
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 10 deletions.
3 changes: 3 additions & 0 deletions firebase/firestore.rules
Expand Up @@ -7,5 +7,8 @@ service cloud.firestore {
match /users/{user_id} {
allow read, write: if true;
}
match /transactions/{transaction_id} {
allow read, write: if true;
}
}
}
44 changes: 41 additions & 3 deletions rust/crates/web/src/infra/firestore.rs
Expand Up @@ -86,10 +86,48 @@ mod tests {
endpoint,
)
.await?;
// let collection_path = client.collection("repositories".to_string());
let collection_path = client.collection("transactions".to_string());

let transaction = client.begin_transaction().await?;
client.commit(transaction, vec![]).await?;
// reset
let (documents, _) = client.list::<V>(&collection_path).await?;
for doc in documents {
client.delete(doc.name(), doc.update_time()).await?;
}

let document_path = collection_path.doc("1".to_string());

#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
struct V {
k1: String,
}

let input = V {
k1: "v1".to_string(),
};
let mut transaction = client.begin_transaction().await?;
transaction.create(&document_path, input)?;
client.commit(transaction).await?;

let got = client.get::<V>(&document_path).await?;
let current_update_time = got.update_time();

let mut transaction = client.begin_transaction().await?;
transaction.delete(&document_path, current_update_time)?;
client.rollback(transaction).await?;

let got = client.get::<V>(&document_path).await?;
let current_update_time = got.update_time();

let mut transaction = client.begin_transaction().await?;
transaction.delete(&document_path, current_update_time)?;
client.commit(transaction).await?;

let err = client.get::<V>(&document_path).await.unwrap_err();
if let crate::infra::firestore::client::Error::Status(status) = err {
assert_eq!(status.code(), tonic::Code::NotFound);
} else {
panic!("unexpected error: {:?}", err);
}

Ok(())
}
Expand Down
79 changes: 72 additions & 7 deletions rust/crates/web/src/infra/firestore/client.rs
@@ -1,9 +1,12 @@
use google_api_proto::google::firestore::v1::{
firestore_client::FirestoreClient, precondition::ConditionType, value::ValueType,
firestore_client::FirestoreClient,
precondition::ConditionType,
value::ValueType,
write::{self, Operation},
BeginTransactionRequest, BeginTransactionResponse, CommitRequest, CommitResponse,
CreateDocumentRequest, DeleteDocumentRequest, Document as FirestoreDocument,
GetDocumentRequest, ListDocumentsRequest, ListDocumentsResponse, MapValue, Precondition,
UpdateDocumentRequest, Write,
RollbackRequest, UpdateDocumentRequest, Write,
};
use google_authz::{Credentials, GoogleAuthz};
use serde::{de::DeserializeOwned, Serialize};
Expand All @@ -18,7 +21,57 @@ use super::{
timestamp::Timestamp,
};

pub struct Transaction(prost::bytes::Bytes);
pub struct Transaction {
transaction: prost::bytes::Bytes,
writes: Vec<Write>,
}

impl Transaction {
pub fn create<T>(&mut self, document_path: &DocumentPath, fields: T) -> Result<(), Error>
where
T: Serialize,
{
self.writes.push(Write {
operation: Some(Operation::Update(FirestoreDocument {
name: document_path.path(),
fields: {
let ser = to_value(&fields)?;
if let Some(ValueType::MapValue(MapValue { fields })) = ser.value_type {
fields
} else {
return Err(Error::ValueType);
}
},
create_time: None,
update_time: None,
})),
update_mask: None,
update_transforms: vec![],
current_document: Some(Precondition {
condition_type: Some(ConditionType::Exists(false)),
}),
});
Ok(())
}

pub fn delete(
&mut self,
document_path: &DocumentPath,
current_update_time: Timestamp,
) -> Result<(), Error> {
self.writes.push(Write {
operation: Some(Operation::Delete(document_path.path())),
update_mask: None,
update_transforms: vec![],
current_document: Some(Precondition {
condition_type: Some(ConditionType::UpdateTime(prost_types::Timestamp::from(
current_update_time,
))),
}),
});
Ok(())
}
}

#[derive(Debug, thiserror::Error)]
pub enum Error {
Expand Down Expand Up @@ -72,7 +125,10 @@ impl Client {
})
.await?;
let BeginTransactionResponse { transaction } = response.into_inner();
Ok(Transaction(transaction))
Ok(Transaction {
transaction,
writes: vec![],
})
}

pub fn collection(&self, collection_id: String) -> CollectionPath {
Expand All @@ -82,14 +138,13 @@ impl Client {
pub async fn commit(
&mut self,
transaction: Transaction,
writes: Vec<Write>,
) -> Result<((), Option<Timestamp>), Error> {
let response = self
.client
.commit(CommitRequest {
database: self.root_path.database_name(),
writes,
transaction: transaction.0,
writes: transaction.writes,
transaction: transaction.transaction,
})
.await?;
// TODO: write_results
Expand Down Expand Up @@ -197,6 +252,16 @@ impl Client {
.map(|documents| (documents, next_page_token))
}

pub async fn rollback(&mut self, transaction: Transaction) -> Result<(), Error> {
self.client
.rollback(RollbackRequest {
database: self.root_path.database_name(),
transaction: transaction.transaction,
})
.await?;
Ok(())
}

pub async fn update<T, U>(
&mut self,
document_path: &DocumentPath,
Expand Down

0 comments on commit 9d1738a

Please sign in to comment.