diff --git a/tests/sqllogictests/src/client/mod.rs b/tests/sqllogictests/src/client/mod.rs index dde2798a6ede..827133b91fb3 100644 --- a/tests/sqllogictests/src/client/mod.rs +++ b/tests/sqllogictests/src/client/mod.rs @@ -19,3 +19,45 @@ mod mysql_client; pub use clickhouse_client::ClickhouseHttpClient; pub use http_client::HttpClient; pub use mysql_client::MysqlClient; +use sqllogictest::DBOutput; + +use crate::error::Result; + +pub enum ClientType { + Mysql, + Http, + Clickhouse, +} + +pub enum Client { + Mysql(MysqlClient), + Http(HttpClient), + Clickhouse(ClickhouseHttpClient), +} + +impl Client { + pub async fn query(&mut self, sql: &str) -> Result { + match self { + Client::Mysql(client) => { + println!("Running sql with mysql client: [{}]", sql); + client.query(sql).await + } + Client::Http(client) => { + println!("Running sql with http client: [{}]", sql); + client.query(sql).await + } + Client::Clickhouse(client) => { + println!("Running sql with clickhouse client: [{}]", sql); + client.query(sql).await + } + } + } + + pub fn engine_name(&self) -> &str { + match self { + Client::Mysql(_) => "mysql", + Client::Http(_) => "http", + Client::Clickhouse(_) => "clickhouse", + } + } +} diff --git a/tests/sqllogictests/src/main.rs b/tests/sqllogictests/src/main.rs index 9328c7ca8496..7769ff2702e3 100644 --- a/tests/sqllogictests/src/main.rs +++ b/tests/sqllogictests/src/main.rs @@ -24,6 +24,8 @@ use walkdir::DirEntry; use walkdir::WalkDir; use crate::arg::SqlLogicTestArgs; +use crate::client::Client; +use crate::client::ClientType; use crate::client::HttpClient; use crate::client::MysqlClient; use crate::error::DSqlLogicTestError; @@ -36,22 +38,12 @@ mod error; mod util; pub struct Databend { - mysql_client: Option, - http_client: Option, - ck_client: Option, + client: Client, } impl Databend { - pub fn create( - mysql_client: Option, - http_client: Option, - ck_client: Option, - ) -> Self { - Databend { - mysql_client, - http_client, - ck_client, - } + pub fn create(client: Client) -> Self { + Databend { client } } } @@ -60,26 +52,11 @@ impl sqllogictest::AsyncDB for Databend { type Error = DSqlLogicTestError; async fn run(&mut self, sql: &str) -> Result { - if let Some(mysql_client) = &mut self.mysql_client { - println!("Running sql with mysql client: [{}]", sql); - return mysql_client.query(sql).await; - } - if let Some(http_client) = &mut self.http_client { - println!("Running sql with http client: [{}]", sql); - return http_client.query(sql).await; - } - println!("Running sql with clickhouse client: [{}]", sql); - self.ck_client.as_mut().unwrap().query(sql).await + self.client.query(sql).await } fn engine_name(&self) -> &str { - if self.mysql_client.is_some() { - return "mysql"; - } - if self.ck_client.is_some() { - return "clickhouse"; - } - "http" + self.client.engine_name() } } @@ -127,32 +104,43 @@ pub async fn main() -> Result<()> { async fn run_mysql_client() -> Result<()> { let suits = SqlLogicTestArgs::parse().suites; let suits = std::fs::read_dir(suits).unwrap(); - let mysql_client = MysqlClient::create().await?; - let databend = Databend::create(Some(mysql_client), None, None); - run_suits(suits, databend).await?; + run_suits(suits, ClientType::Mysql).await?; Ok(()) } async fn run_http_client() -> Result<()> { let suits = SqlLogicTestArgs::parse().suites; let suits = std::fs::read_dir(suits).unwrap(); - let http_client = HttpClient::create()?; - let databend = Databend::create(None, Some(http_client), None); - run_suits(suits, databend).await?; + run_suits(suits, ClientType::Http).await?; Ok(()) } async fn run_ck_http_client() -> Result<()> { let suits = SqlLogicTestArgs::parse().suites; let suits = std::fs::read_dir(suits).unwrap(); - let ck_client = ClickhouseHttpClient::create()?; - let databend = Databend::create(None, None, Some(ck_client)); - run_suits(suits, databend).await?; + run_suits(suits, ClientType::Clickhouse).await?; Ok(()) } -async fn run_suits(suits: ReadDir, databend: Databend) -> Result<()> { - let mut runner = sqllogictest::Runner::new(databend); +// Create new databend with client type +async fn create_databend(client_type: &ClientType) -> Result { + match client_type { + ClientType::Mysql => { + let mysql_client = MysqlClient::create().await?; + Ok(Databend::create(Client::Mysql(mysql_client))) + } + ClientType::Http => { + let http_client = HttpClient::create()?; + Ok(Databend::create(Client::Http(http_client))) + } + ClientType::Clickhouse => { + let ck_client = ClickhouseHttpClient::create()?; + Ok(Databend::create(Client::Clickhouse(ck_client))) + } + } +} + +async fn run_suits(suits: ReadDir, client_type: ClientType) -> Result<()> { // Todo: set validator to process regex let args = SqlLogicTestArgs::parse(); // Walk each suit dir and read all files in it @@ -163,6 +151,8 @@ async fn run_suits(suits: ReadDir, databend: Databend) -> Result<()> { // Parse the suit and find all slt files let files = get_files(suit)?; for file in files.into_iter() { + // For each file, create new client to run. + let mut runner = sqllogictest::Runner::new(create_databend(&client_type).await?); let file_name = file .as_ref() .unwrap()