From d1e98452e9e81d6bfe9502276fe0c8a94174b819 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 18 Jul 2023 10:56:13 +0200 Subject: [PATCH] fully async result builder --- Cargo.lock | 3 + libsqlx-server/Cargo.toml | 3 +- libsqlx-server/src/allocation/mod.rs | 63 +++-- libsqlx-server/src/hrana/batch.rs | 24 +- libsqlx-server/src/hrana/result_builder.rs | 212 +++++++++------- libsqlx-server/src/hrana/stmt.rs | 12 +- libsqlx/Cargo.toml | 5 + libsqlx/src/connection.rs | 90 ++----- libsqlx/src/database/libsql/connection.rs | 41 ++-- libsqlx/src/database/libsql/mod.rs | 24 +- .../database/libsql/replication_log/logger.rs | 8 +- libsqlx/src/database/proxy/connection.rs | 227 +++++++++++------- libsqlx/src/database/proxy/mod.rs | 1 + libsqlx/src/lib.rs | 2 + libsqlx/src/result_builder.rs | 87 +++---- 15 files changed, 442 insertions(+), 360 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2a3cc0f6..d08d5034 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2519,6 +2519,7 @@ dependencies = [ "bytesize", "crc", "crossbeam", + "either", "fallible-iterator 0.3.0", "itertools 0.11.0", "nix", @@ -2533,6 +2534,7 @@ dependencies = [ "sqlite3-parser 0.9.0", "tempfile", "thiserror", + "tokio", "tracing", "uuid", ] @@ -2549,6 +2551,7 @@ dependencies = [ "bytes 1.4.0", "clap", "color-eyre", + "either", "futures", "hmac", "hyper", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 86beceda..a5a11437 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -14,11 +14,12 @@ bincode = "1.3.3" bytes = { version = "1.4.0", features = ["serde"] } clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" +either = "1.8.1" futures = "0.3.28" hmac = "0.12.1" hyper = { version = "0.14.27", features = ["h2", "server"] } itertools = "0.11.0" -libsqlx = { version = "0.1.0", path = "../libsqlx" } +libsqlx = { version = "0.1.0", path = "../libsqlx", features = ["tokio"] } moka = { version = "0.11.2", features = ["future"] } parking_lot = "0.12.1" priority-queue = "1.3.2" diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index fdd08a88..7e7e2b9b 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,12 +1,14 @@ -use std::collections::HashMap; use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; use bytes::Bytes; +use either::Either; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; -use libsqlx::proxy::WriteProxyDatabase; +use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; +use libsqlx::result_builder::ResultBuilder; use libsqlx::{ Database as _, DescribeResponse, Frame, FrameNo, InjectableDatabase, Injector, LogReadError, ReplicationLogger, @@ -27,7 +29,11 @@ use self::config::{AllocConfig, DbConfig}; pub mod config; -type ExecFn = Box; +type LibsqlConnection = Either< + libsqlx::libsql::LibsqlConnection, + WriteProxyConnection, DummyConn>, +>; +type ExecFn = Box; #[derive(Clone)] pub struct ConnectionId { @@ -47,10 +53,10 @@ pub struct DummyDb; pub struct DummyConn; impl libsqlx::Connection for DummyConn { - fn execute_program( + fn execute_program( &mut self, - _pgm: libsqlx::program::Program, - _result_builder: &mut dyn libsqlx::result_builder::ResultBuilder, + _pgm: &libsqlx::program::Program, + _result_builder: B, ) -> libsqlx::Result<()> { todo!() } @@ -207,7 +213,12 @@ impl FrameStreamer { if !self.buffer.is_empty() { self.send_frames().await; } - if self.notifier.wait_for(|fno| dbg!(*fno) >= self.next_frame_no).await.is_err() { + if self + .notifier + .wait_for(|fno| *fno >= self.next_frame_no) + .await + .is_err() + { break; } } @@ -244,7 +255,9 @@ impl Database { path, Compactor, false, - Box::new(move |fno| { let _ = sender.send(fno); } ), + Box::new(move |fno| { + let _ = sender.send(fno); + }), ) .unwrap(); @@ -253,7 +266,7 @@ impl Database { replica_streams: HashMap::new(), frame_notifier: receiver, } - }, + } DbConfig::Replica { primary_node_id } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); let wdb = DummyDb; @@ -285,10 +298,10 @@ impl Database { } } - fn connect(&self) -> Box { + fn connect(&self) -> LibsqlConnection { match self { - Database::Primary { db, .. } => Box::new(db.connect().unwrap()), - Database::Replica { db, .. } => Box::new(db.connect().unwrap()), + Database::Primary { db, .. } => Either::Left(db.connect().unwrap()), + Database::Replica { db, .. } => Either::Right(db.connect().unwrap()), } } } @@ -315,11 +328,11 @@ pub struct ConnectionHandle { impl ConnectionHandle { pub async fn exec(&self, f: F) -> crate::Result where - F: for<'a> FnOnce(&'a mut (dyn libsqlx::Connection + 'a)) -> R + Send + 'static, + F: for<'a> FnOnce(&'a mut LibsqlConnection) -> R + Send + 'static, R: Send + 'static, { let (sender, ret) = oneshot::channel(); - let cb = move |conn: &mut dyn libsqlx::Connection| { + let cb = move |conn: &mut LibsqlConnection| { let res = f(conn); let _ = sender.send(res); }; @@ -371,9 +384,15 @@ impl Allocation { Message::Handshake { .. } => unreachable!("handshake should have been caught earlier"), Message::ReplicationHandshake { .. } => todo!(), Message::ReplicationHandshakeResponse { .. } => todo!(), - Message::Replicate { req_no, next_frame_no } => match &mut self.database { - Database::Primary { db, replica_streams, frame_notifier } => { - dbg!(next_frame_no); + Message::Replicate { + req_no, + next_frame_no, + } => match &mut self.database { + Database::Primary { + db, + replica_streams, + frame_notifier, + } => { let streamer = FrameStreamer { logger: db.logger(), database_id: DatabaseId::from_name(&self.db_name), @@ -396,15 +415,15 @@ impl Allocation { *old_req_no = req_no; old_handle.abort(); } - }, + } Entry::Vacant(e) => { let handle = tokio::spawn(streamer.run()); // For some reason, not yielding causes the task not to be spawned tokio::task::yield_now().await; e.insert((req_no, handle)); - }, + } } - }, + } Database::Replica { .. } => todo!("not a primary!"), }, Message::Frames(frames) => match &mut self.database { @@ -459,7 +478,7 @@ impl Allocation { struct Connection { id: u32, - conn: Box, + conn: LibsqlConnection, exit: oneshot::Receiver<()>, exec: mpsc::Receiver, } @@ -470,7 +489,7 @@ impl Connection { tokio::select! { _ = &mut self.exit => break, Some(exec) = self.exec.recv() => { - tokio::task::block_in_place(|| exec(&mut *self.conn)); + tokio::task::block_in_place(|| exec(&mut self.conn)); } } } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index 1368991e..c4131c45 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -8,10 +8,12 @@ use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; use super::{proto, ProtocolError, Version}; use color_eyre::eyre::anyhow; +use libsqlx::Connection; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; use libsqlx::query::{Params, Query}; use libsqlx::result_builder::{StepResult, StepResultsBuilder}; +use tokio::sync::oneshot; fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre::Result { let try_convert_step = |step: i32| -> Result { @@ -73,15 +75,15 @@ pub async fn execute_batch( db: &ConnectionHandle, pgm: Program, ) -> color_eyre::Result { - let builder = db + let fut = db .exec(move |conn| -> color_eyre::Result<_> { - let mut builder = HranaBatchProtoBuilder::default(); - conn.execute_program(pgm, &mut builder)?; - Ok(builder) + let (builder, ret) = HranaBatchProtoBuilder::new(); + conn.execute_program(&pgm, builder)?; + Ok(ret) }) .await??; - Ok(builder.into_ret()) + Ok(fut.await?) } pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { @@ -110,17 +112,17 @@ pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { } pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { - let builder = conn + let fut = conn .exec(move |conn| -> color_eyre::Result<_> { - let mut builder = StepResultsBuilder::default(); - conn.execute_program(pgm, &mut builder)?; + let (snd, rcv) = oneshot::channel(); + let builder = StepResultsBuilder::new(snd); + conn.execute_program(&pgm, builder)?; - Ok(builder) + Ok(rcv) }) .await??; - builder - .into_ret() + fut.await? .into_iter() .try_for_each(|result| match result { StepResult::Ok => Ok(()), diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index b6b8c635..1047f091 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -3,76 +3,90 @@ use std::io; use bytes::Bytes; use libsqlx::{result_builder::*, FrameNo}; +use tokio::sync::oneshot; use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; use super::proto; -#[derive(Debug, Default)] pub struct SingleStatementBuilder { - has_step: bool, - cols: Vec, - rows: Vec>, - err: Option, - affected_row_count: u64, - last_insert_rowid: Option, - current_size: u64, - max_response_size: u64, + builder: StatementBuilder, + ret: oneshot::Sender>, } impl SingleStatementBuilder { - pub fn into_ret(self) -> Result { - match self.err { - Some(err) => Err(err), - None => Ok(proto::StmtResult { - cols: self.cols, - rows: self.rows, - affected_row_count: self.affected_row_count, - last_insert_rowid: self.last_insert_rowid, - }), - } + pub fn new() -> (Self, oneshot::Receiver>) { + let (ret, rcv) = oneshot::channel(); + (Self { + builder: StatementBuilder::default(), + ret, + }, rcv) } } -struct SizeFormatter(u64); +impl ResultBuilder for SingleStatementBuilder { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.builder.init(config) + } -impl io::Write for SizeFormatter { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0 += buf.len() as u64; - Ok(buf.len()) + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_step() } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.builder.finish_step(affected_row_count, last_insert_rowid) } -} -impl fmt::Write for SizeFormatter { - fn write_str(&mut self, s: &str) -> fmt::Result { - self.0 += s.len() as u64; - Ok(()) + fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { + self.builder.step_error(error) } -} -fn value_json_size(v: &ValueRef) -> u64 { - let mut f = SizeFormatter(0); - match v { - ValueRef::Null => write!(&mut f, r#"{{"type":"null"}}"#).unwrap(), - ValueRef::Integer(i) => write!(&mut f, r#"{{"type":"integer", "value": "{i}"}}"#).unwrap(), - ValueRef::Real(x) => write!(&mut f, r#"{{"type":"integer","value": {x}"}}"#).unwrap(), - ValueRef::Text(s) => { - // error will be caught later. - if let Ok(s) = std::str::from_utf8(s) { - write!(&mut f, r#"{{"type":"text","value":"{s}"}}"#).unwrap() - } - } - ValueRef::Blob(b) => return b.len() as u64, + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.builder.cols_description(cols) } - f.0 + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_row() + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + self.builder.add_row_value(v) + } + + fn finnalize( + self, + _is_txn: bool, + _frame_no: Option, + ) -> Result + where Self: Sized + { + let res = self.builder.into_ret(); + let _ = self.ret.send(res); + Ok(true) + } } -impl ResultBuilder for SingleStatementBuilder { + +#[derive(Debug, Default)] +struct StatementBuilder { + has_step: bool, + cols: Vec, + rows: Vec>, + err: Option, + affected_row_count: u64, + last_insert_rowid: Option, + current_size: u64, + max_response_size: u64, +} + +impl StatementBuilder { fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { *self = Self { max_response_size: config.max_size.unwrap_or(u64::MAX), @@ -138,12 +152,6 @@ impl ResultBuilder for SingleStatementBuilder { Ok(()) } - fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - assert!(self.err.is_none()); - assert!(self.rows.is_empty()); - Ok(()) - } - fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { assert!(self.err.is_none()); self.rows.push(Vec::with_capacity(self.cols.len())); @@ -183,25 +191,57 @@ impl ResultBuilder for SingleStatementBuilder { Ok(()) } - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - assert!(self.err.is_none()); - Ok(()) + pub fn into_ret(self) -> Result { + match self.err { + Some(err) => Err(err), + None => Ok(proto::StmtResult { + cols: self.cols, + rows: self.rows, + affected_row_count: self.affected_row_count, + last_insert_rowid: self.last_insert_rowid, + }), + } } +} - fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - assert!(self.err.is_none()); +struct SizeFormatter(u64); + +impl io::Write for SizeFormatter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0 += buf.len() as u64; + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { Ok(()) } +} - fn finish( - &mut self, - _is_txn: bool, - _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { +impl fmt::Write for SizeFormatter { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.0 += s.len() as u64; Ok(()) } } +fn value_json_size(v: &ValueRef) -> u64 { + let mut f = SizeFormatter(0); + match v { + ValueRef::Null => write!(&mut f, r#"{{"type":"null"}}"#).unwrap(), + ValueRef::Integer(i) => write!(&mut f, r#"{{"type":"integer", "value": "{i}"}}"#).unwrap(), + ValueRef::Real(x) => write!(&mut f, r#"{{"type":"integer","value": {x}"}}"#).unwrap(), + ValueRef::Text(s) => { + // error will be caught later. + if let Ok(s) = std::str::from_utf8(s) { + write!(&mut f, r#"{{"type":"text","value":"{s}"}}"#).unwrap() + } + } + ValueRef::Blob(b) => return b.len() as u64, + } + + f.0 +} + fn estimate_cols_json_size(c: &Column) -> u64 { let mut f = SizeFormatter(0); write!( @@ -214,17 +254,32 @@ fn estimate_cols_json_size(c: &Column) -> u64 { f.0 } -#[derive(Debug, Default)] +#[derive(Debug)] pub struct HranaBatchProtoBuilder { step_results: Vec>, step_errors: Vec>, - stmt_builder: SingleStatementBuilder, + stmt_builder: StatementBuilder, current_size: u64, max_response_size: u64, step_empty: bool, + ret: oneshot::Sender } impl HranaBatchProtoBuilder { + pub fn new() -> (Self, oneshot::Receiver) { + let (ret, rcv) = oneshot::channel(); + (Self { + step_results: Vec::new(), + step_errors: Vec::new(), + stmt_builder: StatementBuilder::default(), + current_size: 0, + max_response_size: u64::MAX, + step_empty: false, + ret, + }, + rcv) + + } pub fn into_ret(self) -> proto::BatchResult { proto::BatchResult { step_results: self.step_results, @@ -235,10 +290,7 @@ impl HranaBatchProtoBuilder { impl ResultBuilder for HranaBatchProtoBuilder { fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - *self = Self { - max_response_size: config.max_size.unwrap_or(u64::MAX), - ..Default::default() - }; + self.max_response_size = config.max_size.unwrap_or(u64::MAX); self.stmt_builder.init(config)?; Ok(()) } @@ -257,7 +309,7 @@ impl ResultBuilder for HranaBatchProtoBuilder { .finish_step(affected_row_count, last_insert_rowid)?; self.current_size += self.stmt_builder.current_size; - let new_builder = SingleStatementBuilder { + let new_builder = StatementBuilder { current_size: 0, max_response_size: self.max_response_size - self.current_size, ..Default::default() @@ -290,10 +342,6 @@ impl ResultBuilder for HranaBatchProtoBuilder { self.stmt_builder.cols_description(cols) } - fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.stmt_builder.begin_rows() - } - fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { self.stmt_builder.begin_row() } @@ -301,20 +349,4 @@ impl ResultBuilder for HranaBatchProtoBuilder { fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { self.stmt_builder.add_row_value(v) } - - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.stmt_builder.finish_row() - } - - fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn finish( - &mut self, - _is_txn: bool, - _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { - Ok(()) - } } diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index 5453ab5c..e6c002a1 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use color_eyre::eyre::{anyhow, bail}; use libsqlx::analysis::Statement; use libsqlx::query::{Params, Query, Value}; +use libsqlx::Connection; use super::result_builder::SingleStatementBuilder; use super::{proto, ProtocolError, Version}; @@ -47,18 +48,17 @@ pub async fn execute_stmt( conn: &ConnectionHandle, query: Query, ) -> color_eyre::Result { - let builder = conn + let fut = conn .exec(move |conn| -> color_eyre::Result<_> { - let mut builder = SingleStatementBuilder::default(); + let (builder, ret) = SingleStatementBuilder::new(); let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute_program(pgm, &mut builder)?; + conn.execute_program(&pgm, builder)?; - Ok(builder) + Ok(ret) }) .await??; - builder - .into_ret() + fut.await? .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { Ok(stmt_error) => anyhow!(stmt_error), Err(sqld_error) => anyhow!(sqld_error), diff --git a/libsqlx/Cargo.toml b/libsqlx/Cargo.toml index abdb39ad..85fd7a9d 100644 --- a/libsqlx/Cargo.toml +++ b/libsqlx/Cargo.toml @@ -27,8 +27,13 @@ crc = "3.0.1" once_cell = "1.18.0" regex = "1.8.4" tempfile = "3.6.0" +either = "1.8.1" +tokio = { version = "1", optional = true, features = ["sync"] } [dev-dependencies] arbitrary = { version = "1.3.0", features = ["derive"] } itertools = "0.11.0" rand = "0.8.5" + +[features] +tokio = ["dep:tokio"] diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index e2fd05f8..fa027997 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -1,8 +1,7 @@ -use rusqlite::types::Value; +use either::Either; -use crate::program::{Program, Step}; -use crate::query::Query; -use crate::result_builder::{QueryBuilderConfig, QueryResultBuilderError, ResultBuilder}; +use crate::program::Program; +use crate::result_builder::ResultBuilder; #[derive(Debug, Clone)] pub struct DescribeResponse { @@ -25,81 +24,36 @@ pub struct DescribeCol { pub trait Connection { /// Executes a query program - fn execute_program( + fn execute_program( &mut self, - pgm: Program, - result_builder: &mut dyn ResultBuilder, + pgm: &Program, + result_builder: B, ) -> crate::Result<()>; /// Parse the SQL statement and return information about it. fn describe(&self, sql: String) -> crate::Result; - - /// execute a single query - fn execute(&mut self, query: Query) -> crate::Result>> { - #[derive(Default)] - struct RowsBuilder { - error: Option, - rows: Vec>, - current_row: Vec, - } - - impl ResultBuilder for RowsBuilder { - fn init( - &mut self, - _config: &QueryBuilderConfig, - ) -> std::result::Result<(), QueryResultBuilderError> { - self.error = None; - self.rows.clear(); - self.current_row.clear(); - - Ok(()) - } - - fn add_row_value( - &mut self, - v: rusqlite::types::ValueRef, - ) -> Result<(), QueryResultBuilderError> { - self.current_row.push(v.into()); - Ok(()) - } - - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - let row = std::mem::take(&mut self.current_row); - self.rows.push(row); - - Ok(()) - } - - fn step_error( - &mut self, - error: crate::error::Error, - ) -> Result<(), QueryResultBuilderError> { - self.error.replace(error); - Ok(()) - } - } - - let pgm = Program::new(vec![Step { cond: None, query }]); - let mut builder = RowsBuilder::default(); - self.execute_program(pgm, &mut builder)?; - if let Some(err) = builder.error.take() { - Err(err) - } else { - Ok(builder.rows) - } - } } -impl Connection for Box { - fn execute_program( +impl Connection for Either +where + T: Connection, + X: Connection, +{ + fn execute_program( &mut self, - pgm: Program, - result_builder: &mut dyn ResultBuilder, + pgm: &Program, + result_builder: B, ) -> crate::Result<()> { - self.as_mut().execute_program(pgm, result_builder) + match self { + Either::Left(c) => c.execute_program(pgm, result_builder), + Either::Right(c) => c.execute_program(pgm, result_builder), + } } fn describe(&self, sql: String) -> crate::Result { - self.as_ref().describe(sql) + match self { + Either::Left(c) => c.describe(sql), + Either::Right(c) => c.describe(sql), + } } } diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index cd5bcff3..27ee59e1 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -14,7 +14,7 @@ use crate::result_builder::{QueryBuilderConfig, ResultBuilder}; use crate::seal::Seal; use crate::Result; -use super::RowStatsHandler; +use super::{LibsqlDbType, RowStatsHandler}; pub struct RowStats { pub rows_read: u64, @@ -49,23 +49,23 @@ where sqld_libsql_bindings::Connection::open(path, flags, wal_methods, hook_ctx) } -pub struct LibsqlConnection { +pub struct LibsqlConnection { timeout_deadline: Option, conn: sqld_libsql_bindings::Connection<'static>, // holds a ref to _context, must be dropped first. row_stats_handler: Option>, builder_config: QueryBuilderConfig, - _context: Seal>, + _context: Seal::Context>>, } -impl LibsqlConnection { - pub(crate) fn new( +impl LibsqlConnection { + pub(crate) fn new( path: &Path, extensions: Option>, - wal_methods: &'static WalMethodsHook, - hook_ctx: W::Context, + wal_methods: &'static WalMethodsHook, + hook_ctx: ::Context, row_stats_callback: Option>, builder_config: QueryBuilderConfig, - ) -> Result> { + ) -> Result> { let mut ctx = Box::new(hook_ctx); let this = LibsqlConnection { conn: open_db( @@ -101,14 +101,14 @@ impl LibsqlConnection { &self.conn } - fn run(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> Result<()> { + fn run(&mut self, pgm: &Program, mut builder: B) -> Result<()> { let mut results = Vec::with_capacity(pgm.steps.len()); builder.init(&self.builder_config)?; let is_autocommit_before = self.conn.is_autocommit(); for step in pgm.steps() { - let res = self.execute_step(step, &results, builder)?; + let res = self.execute_step(step, &results, &mut builder)?; results.push(res); } @@ -117,16 +117,19 @@ impl LibsqlConnection { self.timeout_deadline = Some(Instant::now() + TXN_TIMEOUT) } - builder.finish(!self.conn.is_autocommit(), None)?; + let is_txn = !self.conn.is_autocommit(); + if !builder.finnalize(is_txn, None)? && is_txn { + let _ = self.conn.execute("ROLLBACK", ()); + } Ok(()) } - fn execute_step( + fn execute_step( &mut self, step: &Step, results: &[bool], - builder: &mut dyn ResultBuilder, + builder: &mut B, ) -> Result { builder.begin_step()?; let mut enabled = match step.cond.as_ref() { @@ -160,10 +163,10 @@ impl LibsqlConnection { Ok(enabled) } - fn execute_query( + fn execute_query( &self, query: &Query, - builder: &mut dyn ResultBuilder, + builder: &mut B, ) -> Result<(u64, Option)> { tracing::trace!("executing query: {}", query.stmt.stmt); @@ -236,11 +239,11 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { }) } -impl Connection for LibsqlConnection { - fn execute_program( +impl Connection for LibsqlConnection { + fn execute_program( &mut self, - pgm: Program, - builder: &mut dyn ResultBuilder, + pgm: &Program, + builder: B, ) -> crate::Result<()> { self.run(pgm, builder) } diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 44952df6..c0aaed79 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -163,21 +163,19 @@ impl LibsqlDatabase { } impl Database for LibsqlDatabase { - type Connection = LibsqlConnection<::Context>; + type Connection = LibsqlConnection; fn connect(&self) -> Result { - Ok( - LibsqlConnection::<::Context>::new( - &self.db_path, - self.extensions.clone(), - T::hook(), - self.ty.hook_context(), - self.row_stats_callback.clone(), - QueryBuilderConfig { - max_size: Some(self.response_size_limit), - }, - )?, - ) + Ok(LibsqlConnection::::new( + &self.db_path, + self.extensions.clone(), + T::hook(), + self.ty.hook_context(), + self.row_stats_callback.clone(), + QueryBuilderConfig { + max_size: Some(self.response_size_limit), + }, + )?) } } diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index e17c286c..187f3b25 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -441,7 +441,6 @@ impl LogFile { } pub fn commit(&mut self) -> crate::Result<()> { - dbg!(&self); self.header.frame_count += self.uncommitted_frame_count; self.uncommitted_frame_count = 0; self.commited_checksum = self.uncommitted_checksum; @@ -551,7 +550,6 @@ impl LogFile { /// If the requested frame is before the first frame in the log, or after the last frame, /// Ok(None) is returned. pub fn frame(&self, frame_no: FrameNo) -> std::result::Result { - dbg!(frame_no); if frame_no < self.header.start_frame_no { return Err(LogReadError::SnapshotRequired); } @@ -877,7 +875,6 @@ impl ReplicationLogger { /// Returns the new frame count and checksum to commit fn write_pages(&self, pages: &[WalPage]) -> anyhow::Result<()> { let mut log_file = self.log_file.write(); - dbg!(); for page in pages.iter() { log_file.push_page(page)?; } @@ -906,7 +903,10 @@ impl ReplicationLogger { fn commit(&self) -> anyhow::Result { let mut log_file = self.log_file.write(); log_file.commit()?; - Ok(log_file.header().last_frame_no().expect("there should be at least one frame after commit")) + Ok(log_file + .header() + .last_frame_no() + .expect("there should be at least one frame after commit")) } pub fn get_snapshot_file(&self, from: FrameNo) -> anyhow::Result> { diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index 24c10a47..68d10c00 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -1,7 +1,7 @@ use crate::connection::{Connection, DescribeResponse}; use crate::database::FrameNo; use crate::program::Program; -use crate::result_builder::{QueryBuilderConfig, ResultBuilder}; +use crate::result_builder::{Column, QueryBuilderConfig, QueryResultBuilderError, ResultBuilder}; use crate::Result; use super::WaitFrameNoCb; @@ -18,7 +18,92 @@ pub struct WriteProxyConnection { pub(crate) read_db: ReadDb, pub(crate) write_db: WriteDb, pub(crate) wait_frame_no_cb: WaitFrameNoCb, - pub(crate) state: parking_lot::Mutex, + pub(crate) state: ConnState, +} + +struct MaybeRemoteExecBuilder<'a, 'b, B, W> { + builder: B, + conn: &'a mut W, + pgm: &'b Program, + state: &'a mut ConnState, +} + +impl<'a, 'b, B, W> ResultBuilder for MaybeRemoteExecBuilder<'a, 'b, B, W> +where + W: Connection, + B: ResultBuilder, +{ + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.builder.init(config) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_step() + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.builder + .finish_step(affected_row_count, last_insert_rowid) + } + + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + self.builder.step_error(error) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.builder.cols_description(cols) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_rows() + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_row() + } + + fn add_row_value( + &mut self, + v: rusqlite::types::ValueRef, + ) -> Result<(), QueryResultBuilderError> { + self.builder.add_row_value(v) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.finish_row() + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.finish_rows() + } + + fn finnalize( + self, + is_txn: bool, + frame_no: Option, + ) -> Result { + if is_txn { + // a read only connection is not allowed to leave an open transaction. We mispredicted the + // final state of the connection, so we rollback, and execute again on the write proxy. + let builder = ExtractFrameNoBuilder { + builder: self.builder, + state: self.state, + }; + + self.conn.execute_program(self.pgm, builder).unwrap(); + + Ok(false) + } else { + self.builder.finnalize(is_txn, frame_no) + } + } } impl Connection for WriteProxyConnection @@ -26,137 +111,111 @@ where ReadDb: Connection, WriteDb: Connection, { - fn execute_program(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> Result<()> { - let mut state = self.state.lock(); - let mut builder = ExtractFrameNoBuilder::new(builder); - if !state.is_txn && pgm.is_read_only() { - if let Some(frame_no) = state.last_frame_no { + fn execute_program( + &mut self, + pgm: &Program, + builder: B, + ) -> crate::Result<()> { + if !self.state.is_txn && pgm.is_read_only() { + if let Some(frame_no) = self.state.last_frame_no { (self.wait_frame_no_cb)(frame_no); } + + let builder = MaybeRemoteExecBuilder { + builder, + conn: &mut self.write_db, + state: &mut self.state, + pgm, + }; // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - self.read_db.execute_program(pgm.clone(), &mut builder)?; - - // still in transaction state after running a read-only txn - if builder.is_txn { - // TODO: rollback - // self.read_db.rollback().await?; - self.write_db.execute_program(pgm, &mut builder)?; - state.is_txn = builder.is_txn; - state.last_frame_no = builder.frame_no; - Ok(()) - } else { - Ok(()) - } + self.read_db.execute_program(pgm, builder)?; + // rollback(&mut self.conn.read_db); + Ok(()) } else { - self.write_db.execute_program(pgm, &mut builder)?; - state.is_txn = builder.is_txn; - state.last_frame_no = builder.frame_no; + let builder = ExtractFrameNoBuilder { + builder, + state: &mut self.state, + }; + self.write_db.execute_program(pgm, builder)?; Ok(()) } } - fn describe(&self, sql: String) -> Result { - if let Some(frame_no) = self.state.lock().last_frame_no { + fn describe(&self, sql: String) -> crate::Result { + if let Some(frame_no) = self.state.last_frame_no { (self.wait_frame_no_cb)(frame_no); } self.read_db.describe(sql) } } -struct ExtractFrameNoBuilder<'a> { - inner: &'a mut dyn ResultBuilder, - frame_no: Option, - is_txn: bool, +struct ExtractFrameNoBuilder<'a, B> { + builder: B, + state: &'a mut ConnState, } -impl<'a> ExtractFrameNoBuilder<'a> { - fn new(inner: &'a mut dyn ResultBuilder) -> Self { - Self { - inner, - frame_no: None, - is_txn: false, - } +impl ResultBuilder for ExtractFrameNoBuilder<'_, B> { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.builder.init(config) } -} -impl<'a> ResultBuilder for ExtractFrameNoBuilder<'a> { - fn init( - &mut self, - config: &QueryBuilderConfig, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.init(config) - } - - fn begin_step( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.begin_step() + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_step() } fn finish_step( &mut self, affected_row_count: u64, last_insert_rowid: Option, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner + ) -> Result<(), QueryResultBuilderError> { + self.builder .finish_step(affected_row_count, last_insert_rowid) } - fn step_error( - &mut self, - error: crate::error::Error, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.step_error(error) + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + self.builder.step_error(error) } fn cols_description( &mut self, - cols: &mut dyn Iterator, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.cols_description(cols) + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.builder.cols_description(cols) } - fn begin_rows( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.begin_rows() + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_rows() } - fn begin_row( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.begin_row() + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_row() } fn add_row_value( &mut self, v: rusqlite::types::ValueRef, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.add_row_value(v) + ) -> Result<(), QueryResultBuilderError> { + self.builder.add_row_value(v) } - fn finish_row( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.finish_row() + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.finish_row() } - fn finish_rows( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.finish_rows() + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.finish_rows() } - fn finish( - &mut self, + fn finnalize( + self, is_txn: bool, frame_no: Option, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.frame_no = frame_no; - self.is_txn = is_txn; - self.inner.finish(is_txn, frame_no) + ) -> Result { + self.state.last_frame_no = frame_no; + self.state.is_txn = is_txn; + self.builder.finnalize(is_txn, frame_no) } } @@ -177,7 +236,7 @@ mod test { let write_db = MockDatabase::new().with_execute({ let write_called = write_called.clone(); move |_, b| { - b.finish(false, Some(42)).unwrap(); + b.finnalize(false, Some(42)).unwrap(); write_called.set(true); Ok(()) } diff --git a/libsqlx/src/database/proxy/mod.rs b/libsqlx/src/database/proxy/mod.rs index 62c6925d..0fdf7ceb 100644 --- a/libsqlx/src/database/proxy/mod.rs +++ b/libsqlx/src/database/proxy/mod.rs @@ -5,6 +5,7 @@ use super::FrameNo; mod connection; mod database; +pub use connection::WriteProxyConnection; pub use database::WriteProxyDatabase; // Waits until passed frameno has been replicated back to the database diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index 13223d22..899a7912 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -18,4 +18,6 @@ pub use database::proxy; pub use database::Frame; pub use database::{Database, InjectableDatabase, Injector}; +pub use sqld_libsql_bindings::wal_hook::WalHook; + pub use rusqlite; diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index be5e27a7..98f598c1 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -130,12 +130,16 @@ pub trait ResultBuilder { Ok(()) } /// finish the builder, and pass the transaction state. - fn finish( - &mut self, + /// If false is returned, and is_txn is true, then the transaction is rolledback. + fn finnalize( + self, _is_txn: bool, _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { - Ok(()) + ) -> Result + where + Self: Sized, + { + Ok(true) } } @@ -163,22 +167,40 @@ pub enum StepResult { } /// A `QueryResultBuilder` that ignores rows, but records the outcome of each step in a `StepResult` -#[derive(Debug, Default)] -pub struct StepResultsBuilder { +pub struct StepResultsBuilder { current: Option, step_results: Vec, is_skipped: bool, + ret: R +} + +pub trait RetChannel { + fn send(self, t: T); +} + +#[cfg(feature = "tokio")] +impl RetChannel for tokio::sync::oneshot::Sender { + fn send(self, t: T) { + let _ = self.send(t); + } } -impl StepResultsBuilder { - pub fn into_ret(self) -> Vec { - self.step_results +impl StepResultsBuilder { + pub fn new(ret: R) -> Self { + Self { + current: None, + step_results: Vec::new(), + is_skipped: false, + ret, + } } } -impl ResultBuilder for StepResultsBuilder { +impl>> ResultBuilder for StepResultsBuilder { fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - *self = Default::default(); + self.current = None; + self.step_results.clear(); + self.is_skipped = false; Ok(()) } @@ -218,32 +240,13 @@ impl ResultBuilder for StepResultsBuilder { Ok(()) } - fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn add_row_value(&mut self, _v: ValueRef) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn finish( - &mut self, + fn finnalize( + self, _is_txn: bool, _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { - Ok(()) + ) -> Result { + self.ret.send(self.step_results); + Ok(true) } } @@ -349,12 +352,12 @@ impl ResultBuilder for Take { } } - fn finish( - &mut self, + fn finnalize( + self, is_txn: bool, frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { - self.inner.finish(is_txn, frame_no) + ) -> Result { + self.inner.finnalize(is_txn, frame_no) } } @@ -500,7 +503,7 @@ pub mod test { FinishRow => b.finish_row().unwrap(), FinishRows => b.finish_rows().unwrap(), Finish => { - b.finish(false, None).unwrap(); + b.finnalize(false, None).unwrap(); break; } BuilderError => return b, @@ -643,7 +646,7 @@ pub mod test { self.transition(FinishRows) } - fn finish( + fn finnalize( &mut self, _is_txn: bool, _frame_no: Option, @@ -700,7 +703,7 @@ pub mod test { builder.finish_rows().unwrap(); builder.finish_step(0, None).unwrap(); - builder.finish(false, None).unwrap(); + builder.finnalize(false, None).unwrap(); } #[test]