Skip to content

Commit

Permalink
Support multiple statements simultaniously and running concurrently
Browse files Browse the repository at this point in the history
read write lock will be added later to prevent btree from being modified
while querying.
  • Loading branch information
kawasin73 committed Sep 29, 2023
1 parent 7f147b5 commit ee35cbb
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 66 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ use prsqlite::Connection;
use prsqlite::NextRow;
use prsqlite::Value;

let mut conn = Connection::open(Path::new("path/to/sqlite.db")).unwrap();
let conn = Connection::open(Path::new("path/to/sqlite.db")).unwrap();

let mut stmt = conn.prepare("INSERT INTO example (col) VALUES (1), (2);").unwrap();
let stmt = conn.prepare("INSERT INTO example (col) VALUES (1), (2);").unwrap();
assert_eq!(stmt.execute().unwrap(), 2);

let stmt = conn.prepare("SELECT * FROM example WHERE col = 1;").unwrap();
Expand Down
9 changes: 6 additions & 3 deletions src/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ impl CursorPage {

/// The cursor of btree.
///
/// This does not support creating multiple cursors for the same btree.
/// Otherwise, [BtreeCursor::insert()] fails to get a writable buffer from the
/// pager.
/// [BtreeCursor::insert()] may fail to get a writable buffer from the pager if
/// there are another [BtreeCursor] pointing the same btree simultaniously.
pub struct BtreeCursor<'a> {
pager: &'a Pager,
btree_ctx: &'a BtreeContext,
Expand Down Expand Up @@ -387,6 +386,10 @@ impl<'a> BtreeCursor<'a> {
Ok(())
}

/// Insert or update a new item to the table.
///
/// There should not be other [BtreeCursor]s pointing the same btree.
/// Otherwise, this fails.
pub fn insert(&mut self, key: i64, payload: &[u8]) -> anyhow::Result<()> {
let current_cell_key = self.table_move_to(key)?;

Expand Down
34 changes: 18 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod token;
mod utils;
mod value;

use std::cell::RefCell;
use std::cmp::Ordering;
use std::fmt::Display;
use std::fs::OpenOptions;
Expand Down Expand Up @@ -143,7 +144,7 @@ impl<'a> DatabaseHeader<'a> {
pub struct Connection {
pager: Pager,
btree_ctx: BtreeContext,
schema: Option<Schema>,
schema: RefCell<Option<Schema>>,
}

impl Connection {
Expand All @@ -168,11 +169,11 @@ impl Connection {
Ok(Self {
pager,
btree_ctx: BtreeContext::new(header.usable_size()),
schema: None,
schema: RefCell::new(None),
})
}

pub fn prepare<'a>(&mut self, sql: &'a str) -> Result<'a, Statement> {
pub fn prepare<'a>(&self, sql: &'a str) -> Result<'a, Statement> {
let input = sql.as_bytes();
let mut parser = Parser::new(input);
let statement = parse_sql(&mut parser)?;
Expand All @@ -185,24 +186,25 @@ impl Connection {
}
}

fn load_schema(&mut self) -> anyhow::Result<()> {
fn load_schema(&self) -> anyhow::Result<()> {
let schema_table = Schema::schema_table();
let columns = schema_table
.get_all_columns()
.map(Expression::Column)
.collect::<Vec<_>>();
self.schema = Some(Schema::generate(
*self.schema.borrow_mut() = Some(Schema::generate(
SelectStatement::new(self, schema_table.root_page_id, columns, None),
schema_table,
)?);
Ok(())
}

fn prepare_select<'a>(&mut self, select: Select<'a>) -> Result<'a, SelectStatement> {
if self.schema.is_none() {
fn prepare_select<'a>(&self, select: Select<'a>) -> Result<'a, SelectStatement> {
if self.schema.borrow().is_none() {
self.load_schema()?;
}
let schema = self.schema.as_ref().unwrap();
let schema_cell = self.schema.borrow();
let schema = schema_cell.as_ref().unwrap();
let table_name = select.table_name.dequote();
let table = schema.get_table(&table_name).ok_or(anyhow::anyhow!(
"table not found: {:?}",
Expand Down Expand Up @@ -290,12 +292,12 @@ impl Connection {
}
}

fn prepare_insert<'a>(&mut self, insert: Insert<'a>) -> Result<'a, InsertStatement> {
if self.schema.is_none() {
fn prepare_insert<'a>(&self, insert: Insert<'a>) -> Result<'a, InsertStatement> {
if self.schema.borrow().is_none() {
self.load_schema()?;
}

let schema = self.schema.as_ref().unwrap();
let schema_cell = self.schema.borrow();
let schema = schema_cell.as_ref().unwrap();
let table_name = insert.table_name.dequote();
let table = schema.get_table(&table_name).ok_or(anyhow::anyhow!(
"table not found: {:?}",
Expand Down Expand Up @@ -657,7 +659,7 @@ impl<'conn> Statement<'conn> {

// TODO: make Connection non mut and support multiple statements.
pub struct SelectStatement<'conn> {
conn: &'conn mut Connection,
conn: &'conn Connection,
table_page_id: PageId,
columns: Vec<Expression>,
filter: Option<Expression>,
Expand All @@ -667,7 +669,7 @@ pub struct SelectStatement<'conn> {

impl<'conn> SelectStatement<'conn> {
pub(crate) fn new(
conn: &'conn mut Connection,
conn: &'conn Connection,
table_page_id: PageId,
columns: Vec<Expression>,
filter: Option<Expression>,
Expand Down Expand Up @@ -701,7 +703,7 @@ impl<'conn> SelectStatement<'conn> {
}

fn with_index(
conn: &'conn mut Connection,
conn: &'conn Connection,
table_page_id: PageId,
columns: Vec<Expression>,
filter: Option<Expression>,
Expand Down Expand Up @@ -970,7 +972,7 @@ struct InsertRecord {
}

pub struct InsertStatement<'conn> {
conn: &'conn mut Connection,
conn: &'conn Connection,
table_page_id: PageId,
records: Vec<InsertRecord>,
}
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {
std::process::exit(1);
}
let file_path = args.nth(1).unwrap();
let mut conn = Connection::open(Path::new(&file_path)).expect("open database");
let conn = Connection::open(Path::new(&file_path)).expect("open database");

let mut stdout = io::stdout();
let stdin = io::stdin();
Expand Down
4 changes: 2 additions & 2 deletions src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,14 @@ mod tests {
use crate::Expression;

fn generate_schema(filepath: &Path) -> Schema {
let mut conn = Connection::open(filepath).unwrap();
let conn = Connection::open(filepath).unwrap();
let schema_table = Schema::schema_table();
let columns = schema_table
.get_all_columns()
.map(Expression::Column)
.collect::<Vec<_>>();
Schema::generate(
SelectStatement::new(&mut conn, schema_table.root_page_id, columns, None),
SelectStatement::new(&conn, schema_table.root_page_id, columns, None),
schema_table,
)
.unwrap()
Expand Down
8 changes: 4 additions & 4 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,29 @@ pub fn buffer_to_hex(buf: &[u8]) -> String {
}

pub fn find_table_page_id(table: &str, filepath: &Path) -> PageId {
let mut conn = Connection::open(filepath).unwrap();
let conn = Connection::open(filepath).unwrap();
let schema_table = Schema::schema_table();
let columns = schema_table
.get_all_columns()
.map(Expression::Column)
.collect::<Vec<_>>();
let schema = Schema::generate(
SelectStatement::new(&mut conn, schema_table.root_page_id, columns, None),
SelectStatement::new(&conn, schema_table.root_page_id, columns, None),
schema_table,
)
.unwrap();
schema.get_table(table.as_bytes()).unwrap().root_page_id
}

pub fn find_index_page_id(index: &str, filepath: &Path) -> PageId {
let mut conn = Connection::open(filepath).unwrap();
let conn = Connection::open(filepath).unwrap();
let schema_table = Schema::schema_table();
let columns = schema_table
.get_all_columns()
.map(Expression::Column)
.collect::<Vec<_>>();
let schema = Schema::generate(
SelectStatement::new(&mut conn, schema_table.root_page_id, columns, None),
SelectStatement::new(&conn, schema_table.root_page_id, columns, None),
schema_table,
)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub fn create_sqlite_database(queries: &[&str]) -> NamedTempFile {
}

#[allow(dead_code)]
pub fn load_rowids(conn: &mut Connection, query: &str) -> Vec<i64> {
pub fn load_rowids(conn: &Connection, query: &str) -> Vec<i64> {
let stmt = conn.prepare(query).unwrap();
let mut rows = stmt.query().unwrap();
let mut results = Vec::new();
Expand Down Expand Up @@ -72,7 +72,7 @@ pub fn assert_same_results(
expected: &[&[Value]],
query: &str,
test_conn: &rusqlite::Connection,
conn: &mut Connection,
conn: &Connection,
) {
let mut stmt = test_conn.prepare(query).unwrap();
let mut rows = stmt.query([]).unwrap();
Expand Down
51 changes: 43 additions & 8 deletions tests/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use prsqlite::Value;
#[test]
fn test_insert() {
let file = create_sqlite_database(&["CREATE TABLE example(col1, col2, col3);"]);
let mut conn = Connection::open(file.path()).unwrap();
let conn = Connection::open(file.path()).unwrap();

let mut stmt = conn
.prepare("INSERT INTO example (col1, col2, col3) VALUES (0, 1, 2);")
Expand Down Expand Up @@ -77,14 +77,14 @@ fn test_insert() {
],
"SELECT rowid, * FROM example;",
&test_conn,
&mut conn,
&conn,
)
}

#[test]
fn test_insert_with_rowid() {
let file = create_sqlite_database(&["CREATE TABLE example(col);"]);
let mut conn = Connection::open(file.path()).unwrap();
let conn = Connection::open(file.path()).unwrap();

let mut stmt = conn
.prepare("INSERT INTO example (rowid, col) VALUES (-10, 2), (10, 5);")
Expand Down Expand Up @@ -115,7 +115,7 @@ fn test_insert_with_rowid() {
],
"SELECT rowid, * FROM example;",
&test_conn,
&mut conn,
&conn,
)
}

Expand All @@ -126,7 +126,7 @@ fn test_insert_into_existing_table() {
"INSERT INTO example(rowid, col) VALUES (1, 1);",
"INSERT INTO example(rowid, col) VALUES (10, 2);",
]);
let mut conn = Connection::open(file.path()).unwrap();
let conn = Connection::open(file.path()).unwrap();

let mut stmt = conn
.prepare("INSERT INTO example (rowid, col) VALUES (2, 3), (8, 4);")
Expand All @@ -149,14 +149,14 @@ fn test_insert_into_existing_table() {
],
"SELECT rowid, * FROM example;",
&test_conn,
&mut conn,
&conn,
)
}

#[test]
fn test_insert_rowid_conflict() {
let file = create_sqlite_database(&["CREATE TABLE example(col);"]);
let mut conn = Connection::open(file.path()).unwrap();
let conn = Connection::open(file.path()).unwrap();

let mut stmt = conn
.prepare("INSERT INTO example (col) VALUES (123);")
Expand All @@ -173,6 +173,41 @@ fn test_insert_rowid_conflict() {
&[&[Value::Integer(1), Value::Integer(123)]],
"SELECT rowid, * FROM example;",
&test_conn,
&mut conn,
&conn,
)
}

#[test]
fn test_insert_multiple_statements() {
let file = create_sqlite_database(&["CREATE TABLE example(col);"]);
let conn = Connection::open(file.path()).unwrap();

let mut stmt_i1 = conn
.prepare("INSERT INTO example (col) VALUES (123);")
.unwrap();
let stmt_s1 = conn.prepare("SELECT * FROM example;").unwrap();
let mut stmt_i2 = conn
.prepare("INSERT INTO example (col) VALUES (456);")
.unwrap();
let stmt_s2 = conn.prepare("SELECT * FROM example;").unwrap();

assert_eq!(stmt_i1.execute().unwrap(), 1);

let mut rows = stmt_s1.query().unwrap();
assert_same_result_prsqlite!(rows, [Value::Integer(123)], "");
assert!(rows.next_row().unwrap().is_none());
let mut rows = stmt_s2.query().unwrap();
assert_same_result_prsqlite!(rows, [Value::Integer(123)], "");
assert!(rows.next_row().unwrap().is_none());

assert_eq!(stmt_i2.execute().unwrap(), 1);

let mut rows = stmt_s1.query().unwrap();
assert_same_result_prsqlite!(rows, [Value::Integer(123)], "");
assert_same_result_prsqlite!(rows, [Value::Integer(456)], "");
assert!(rows.next_row().unwrap().is_none());
let mut rows = stmt_s2.query().unwrap();
assert_same_result_prsqlite!(rows, [Value::Integer(123)], "");
assert_same_result_prsqlite!(rows, [Value::Integer(456)], "");
assert!(rows.next_row().unwrap().is_none());
}
Loading

0 comments on commit ee35cbb

Please sign in to comment.