Skip to content

Commit

Permalink
Allow to implement a dynamic select clause feature in
Browse files Browse the repository at this point in the history
diesel_dynamic_query

This commit contains the following changes:
* Make SelectClause traits public to allow implementing a dynamic
select clause variant based on a vector instead of tuples like the
existing ones. This allows to construct dynamically sized select
clauses
* Add two methods to the `Row` traits to which allows to get more
information (current column name and column count) of the result
set. This is required to implement compatible result types for
dynamically sized select clauses
  • Loading branch information
weiznich committed Oct 4, 2019
1 parent bd13f24 commit bc8acce
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 1 deletion.
4 changes: 4 additions & 0 deletions diesel/src/mysql/connection/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ impl Binds {
pub fn field_data(&self, idx: usize) -> Option<MysqlValue<'_>> {
self.data[idx].bytes().map(MysqlValue::new)
}

pub fn len(&self) -> usize {
self.data.len()
}
}

struct BindData {
Expand Down
17 changes: 17 additions & 0 deletions diesel/src/mysql/connection/stmt/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::ffi::CStr;

use super::{ffi, libc, Binds, Statement, StatementMetadata};
use mysql::{Mysql, MysqlTypeMetadata, MysqlValue};
Expand Down Expand Up @@ -37,6 +38,7 @@ impl<'a> StatementIterator<'a> {
Ok(Some(())) => Some(Ok(MysqlRow {
col_idx: 0,
binds: &mut self.output_binds,
stmt: &self.stmt,
})),
Ok(None) => None,
Err(e) => Some(Err(e)),
Expand All @@ -47,6 +49,7 @@ impl<'a> StatementIterator<'a> {
pub struct MysqlRow<'a> {
col_idx: usize,
binds: &'a Binds,
stmt: &'a Statement,
}

impl<'a> Row<Mysql> for MysqlRow<'a> {
Expand All @@ -59,6 +62,20 @@ impl<'a> Row<Mysql> for MysqlRow<'a> {
fn next_is_null(&self, count: usize) -> bool {
(0..count).all(|i| self.binds.field_data(self.col_idx + i).is_none())
}

fn column_count(&self) -> usize {
self.binds.len()
}

fn column_name(&self) -> &str {
let metadata = self.stmt.metadata().expect("Failed to get metadata");
let field = if self.col_idx == 0 {
metadata.fields()[0]
} else {
metadata.fields()[self.col_idx - 1]
};
unsafe { CStr::from_ptr(field.name).to_str().expect("It's utf8") }
}
}

pub struct NamedStatementIterator<'a> {
Expand Down
15 changes: 15 additions & 0 deletions diesel/src/pg/connection/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ impl<'a> PgResult<'a> {
}
}

pub fn column_name(&self, col_idx: usize) -> &str {
unsafe {
CStr::from_ptr(PQfname(
self.internal_result.as_ptr(),
col_idx as libc::c_int,
))
.to_str()
.expect("Utf8")
}
}

pub fn column_count(&self) -> usize {
unsafe { PQnfields(self.internal_result.as_ptr()) as usize }
}

pub fn field_number(&self, column_name: &str) -> Option<usize> {
let cstr = CString::new(column_name).unwrap_or_default();
let fnum = unsafe { PQfnumber(self.internal_result.as_ptr(), cstr.as_ptr()) };
Expand Down
8 changes: 8 additions & 0 deletions diesel/src/pg/connection/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ impl<'a> Row<Pg> for PgRow<'a> {
fn next_is_null(&self, count: usize) -> bool {
(0..count).all(|i| self.db_result.is_null(self.row_idx, self.col_idx + i))
}

fn column_count(&self) -> usize {
self.db_result.column_count()
}

fn column_name(&self) -> &str {
self.db_result.column_name(self.col_idx)
}
}

pub struct PgNamedRow<'a> {
Expand Down
4 changes: 3 additions & 1 deletion diesel/src/query_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub mod nodes;
mod offset_clause;
mod order_clause;
mod returning_clause;
mod select_clause;
pub(crate) mod select_clause;
mod select_statement;
mod sql_query;
mod update_statement;
Expand All @@ -40,6 +40,8 @@ pub use self::insert_statement::{
IncompleteInsertStatement, InsertStatement, UndecoratedInsertRecord, ValuesClause,
};
pub use self::query_id::QueryId;
#[doc(inline)]
pub use self::select_clause::{SelectClauseExpression, SelectClauseQueryFragment};
#[doc(hidden)]
pub use self::select_statement::{BoxedSelectStatement, SelectStatement};
pub use self::sql_query::{BoxedSqlQuery, SqlQuery};
Expand Down
16 changes: 16 additions & 0 deletions diesel/src/query_builder/select_clause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ pub struct DefaultSelectClause;
#[derive(Debug, Clone, Copy, QueryId)]
pub struct SelectClause<T>(pub T);

/// Specialised variant of `Expression` for select clause types
///
/// The difference to the normal `Expression` trait is the query source (`QS`)
/// generic type parameter. This allows to access the query source in generic code.
pub trait SelectClauseExpression<QS> {
/// SQL type of the select clause
type SelectClauseSqlType;
}

Expand All @@ -26,7 +31,18 @@ where
type SelectClauseSqlType = <QS::DefaultSelection as Expression>::SqlType;
}

/// Specialised variant of `QueryFragment` for select clause types
///
/// The difference to the normal `QueryFragment` trait is the query source (`QS`)
/// generic type parameter.
pub trait SelectClauseQueryFragment<QS, DB: Backend> {
/// Walk over this `SelectClauseQueryFragment` for all passes.
///
/// This method is where the actual behavior of an select clause is implemented.
/// This method will contain the behavior required for all possible AST
/// passes. See [`AstPass`] for more details.
///
/// [`AstPass`]: struct.AstPass.html
fn walk_ast(&self, source: &QS, pass: AstPass<DB>) -> QueryResult<()>;
}

Expand Down
6 changes: 6 additions & 0 deletions diesel/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ pub trait Row<DB: Backend> {
self.take();
}
}

/// Number of columns in the current result set
fn column_count(&self) -> usize;

/// Name of the current column
fn column_name(&self) -> &str;
}

/// Represents a row of a SQL query, where the values are accessed by name
Expand Down
8 changes: 8 additions & 0 deletions diesel/src/sqlite/connection/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,12 @@ impl<'a> Row<Sqlite> for FunctionRow<'a> {
.iter()
.all(|&p| unsafe { SqliteValue::new(p) }.is_none())
}

fn column_count(&self) -> usize {
self.args.len()
}

fn column_name(&self) -> &str {
unimplemented!()
}
}
17 changes: 17 additions & 0 deletions diesel/src/sqlite/connection/sqlite_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ impl Row<Sqlite> for SqliteRow {
tpe == ffi::SQLITE_NULL
})
}

fn column_name(&self) -> &str {
unsafe {
let ptr = if self.next_col_index == 0 {
ffi::sqlite3_column_name(self.stmt.as_ptr(), 0)
} else {
ffi::sqlite3_column_name(self.stmt.as_ptr(), self.next_col_index - 1)
};
std::ffi::CStr::from_ptr(ptr)
.to_str()
.expect("Sqlite3 doc's say it's UTF8")
}
}

fn column_count(&self) -> usize {
unsafe { ffi::sqlite3_column_count(self.stmt.as_ptr()) as usize }
}
}

pub struct SqliteNamedRow<'a> {
Expand Down

0 comments on commit bc8acce

Please sign in to comment.