Skip to content

Commit

Permalink
feat(cubesql): Support prepared statements in MySQL protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr committed Jan 31, 2022
1 parent df33d89 commit ed2bb93
Showing 1 changed file with 95 additions and 9 deletions.
104 changes: 95 additions & 9 deletions rust/cubesql/cubesql/src/mysql/service.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::io;

use std::sync::Arc;
Expand Down Expand Up @@ -34,25 +35,40 @@ use super::server_manager::ServerManager;
use super::AuthContext;
use super::SqlAuthService;

struct Backend {
struct PreparedStatements {
id: u32,
statements: HashMap<u32, String>,
}

impl PreparedStatements {
pub fn new() -> Self {
Self {
id: 1,
statements: HashMap::new(),
}
}
}

struct Connection {
server: Arc<ServerManager>,
// Props for execution queries
props: QueryPlannerExecutionProps,
// Context for Transport
context: Option<AuthContext>,
// Prepared statements
statements: Arc<RwLock<PreparedStatements>>,
}

enum QueryResponse {
Ok(StatusFlags),
ResultSet(StatusFlags, Arc<dataframe::DataFrame>),
}

impl Backend {
impl Connection {
async fn execute_query<'a>(&'a mut self, query: &'a str) -> Result<QueryResponse, CubeError> {
let _start = SystemTime::now();

let query = query.replace("SELECT FROM", "SELECT * FROM");
debug!("QUERY: {}", query);

let query_lower = query.to_lowercase();
let query_lower = query_lower.replace("db.`", "");
Expand Down Expand Up @@ -351,7 +367,7 @@ impl Backend {
}

#[async_trait]
impl<W: io::Write + Send> AsyncMysqlShim<W> for Backend {
impl<W: io::Write + Send> AsyncMysqlShim<W> for Connection {
type Error = io::Error;

fn server_version(&self) -> &str {
Expand All @@ -364,32 +380,101 @@ impl<W: io::Write + Send> AsyncMysqlShim<W> for Backend {

async fn on_prepare<'a>(
&'a mut self,
_query: &'a str,
query: &'a str,
info: StatementMetaWriter<'a, W>,
) -> Result<(), Self::Error> {
info.reply(42, &[], &[])
debug!("on_execute: {}", query);

let mut state = self.statements.write().await;
state.id = state.id + 1;

let next_id = state.id;
state.statements.insert(next_id, query.to_string());

info.reply(state.id, &[], &[])
}

async fn on_execute<'a>(
&'a mut self,
_id: u32,
id: u32,
_params: ParamParser<'a>,
results: QueryResultWriter<'a, W>,
) -> Result<(), Self::Error> {
results.completed(0, 0, StatusFlags::empty())
debug!("on_execute: {}", id);

let mut state = self.statements.write().await;
let possible_statement = state.statements.remove(&id);

std::mem::drop(state);

let statement = if possible_statement.is_none() {
return results.error(ErrorKind::ER_INTERNAL_ERROR, b"Unknown statement");
} else {
possible_statement.unwrap()
};

let query = statement.as_str();
match self.execute_query(query).await {
Err(e) => {
error!("Error during processing {}: {}", query, e.to_string());
results.error(ErrorKind::ER_INTERNAL_ERROR, e.message.as_bytes())?;

Ok(())
}
Ok(QueryResponse::Ok(status)) => {
results.completed(0, 0, status)?;
Ok(())
}
Ok(QueryResponse::ResultSet(_, data_frame)) => {
let columns = data_frame
.get_columns()
.iter()
.map(|c| Column {
table: "result".to_string(), // TODO
column: c.get_name(),
coltype: c.get_type(),
colflags: c.get_flags(),
})
.collect::<Vec<_>>();

let mut rw = results.start(&columns)?;

for row in data_frame.get_rows().iter() {
for (_i, value) in row.values().iter().enumerate() {
match value {
dataframe::TableValue::String(s) => rw.write_col(s)?,
dataframe::TableValue::Timestamp(s) => rw.write_col(s.to_string())?,
dataframe::TableValue::Boolean(s) => rw.write_col(s.to_string())?,
dataframe::TableValue::Float64(s) => rw.write_col(s)?,
dataframe::TableValue::Int64(s) => rw.write_col(s)?,
dataframe::TableValue::Null => rw.write_col(Option::<String>::None)?,
}
}

rw.end_row()?;
}

rw.finish()?;

Ok(())
}
}
}

async fn on_close<'a>(&'a mut self, _stmt: u32)
where
W: 'async_trait,
{
trace!("on_close");
}

async fn on_query<'a>(
&'a mut self,
query: &'a str,
results: QueryResultWriter<'a, W>,
) -> Result<(), Self::Error> {
debug!("on_query: {}", query);

match self.execute_query(query).await {
Err(e) => {
error!("Error during processing {}: {}", query, e.to_string());
Expand Down Expand Up @@ -546,10 +631,11 @@ impl ProcessingLoop for MySqlServer {

tokio::spawn(async move {
if let Err(e) = AsyncMysqlIntermediary::run_on(
Backend {
Connection {
server,
props: QueryPlannerExecutionProps::new(connection_id, None, None),
context: None,
statements: Arc::new(RwLock::new(PreparedStatements::new())),
},
socket,
)
Expand Down

0 comments on commit ed2bb93

Please sign in to comment.