Skip to content

Commit

Permalink
Merge 429de85 into 8020b2e
Browse files Browse the repository at this point in the history
  • Loading branch information
eyalleshem committed Jul 28, 2020
2 parents 8020b2e + 429de85 commit b7d58c7
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 49 deletions.
7 changes: 6 additions & 1 deletion src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ pub mod keywords;
mod mssql;
mod mysql;
mod postgresql;

mod snowflake;
use std::fmt::Debug;

pub use self::ansi::AnsiDialect;
pub use self::generic::GenericDialect;
pub use self::mssql::MsSqlDialect;
pub use self::mysql::MySqlDialect;
pub use self::postgresql::PostgreSqlDialect;
pub use self::snowflake::SnowflakeDialect;

pub trait Dialect: Debug {
/// Determine if a character starts a quoted identifier. The default
Expand All @@ -38,4 +39,8 @@ pub trait Dialect: Debug {
fn is_identifier_start(&self, ch: char) -> bool;
/// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool;

fn alllow_single_table_in_parenthesis(&self) -> bool {
false
}
}
26 changes: 26 additions & 0 deletions src/dialect/snowflake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::dialect::Dialect;

#[derive(Debug, Default)]
pub struct SnowflakeDialect;

impl Dialect for SnowflakeDialect {
//Revisit: currently copied from Genric dialect
fn is_identifier_start(&self, ch: char) -> bool {
(ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' || ch == '#' || ch == '@'
}

//Revisit: currently copied from Genric dialect
fn is_identifier_part(&self, ch: char) -> bool {
(ch >= 'a' && ch <= 'z')
|| (ch >= 'A' && ch <= 'Z')
|| (ch >= '0' && ch <= '9')
|| ch == '@'
|| ch == '$'
|| ch == '#'
|| ch == '_'
}

fn alllow_single_table_in_parenthesis(&self) -> bool {
true
}
}
125 changes: 115 additions & 10 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,28 @@ impl fmt::Display for ParserError {
impl Error for ParserError {}

/// SQL Parser
pub struct Parser {
pub struct Parser<'a> {
tokens: Vec<Token>,
/// The index of the first unprocessed token in `self.tokens`
index: usize,
dialect: &'a dyn Dialect,
}

impl Parser {
impl<'a> Parser<'a> {
/// Parse the specified tokens
pub fn new(tokens: Vec<Token>) -> Self {
Parser { tokens, index: 0 }
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
Parser {
tokens,
index: 0,
dialect,
}
}

/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
let mut tokenizer = Tokenizer::new(dialect, &sql);
let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens);
let mut parser = Parser::new(tokens, dialect);
let mut stmts = Vec::new();
let mut expecting_statement_delimiter = false;
debug!("Parsing sql '{}'...", sql);
Expand Down Expand Up @@ -950,7 +955,7 @@ impl Parser {
/// Parse a comma-separated list of 1+ items accepted by `F`
pub fn parse_comma_separated<T, F>(&mut self, mut f: F) -> Result<Vec<T>, ParserError>
where
F: FnMut(&mut Parser) -> Result<T, ParserError>,
F: FnMut(&mut Parser<'a>) -> Result<T, ParserError>,
{
let mut values = vec![];
loop {
Expand Down Expand Up @@ -2048,9 +2053,102 @@ impl Parser {
};
joins.push(join);
}

Ok(TableWithJoins { relation, joins })
}

fn add_alias_to_single_table_in_parenthesis(
&self,
table_and_joins: TableWithJoins,
consumed_alias: TableAlias,
) -> Result<TableWithJoins, ParserError> {
// alias not allowed on joins(at least in snowflakeDB)
if !table_and_joins.joins.is_empty() {
return Err(ParserError::ParserError(
"alias not allowed on multiple table join".to_owned(),
));
}

match table_and_joins.relation {
TableFactor::NestedJoin(table_and_joins_box) => Ok(TableWithJoins {
relation: TableFactor::NestedJoin(Box::new(
self.add_alias_to_single_table_in_parenthesis(
*table_and_joins_box,
consumed_alias,
)?,
)),
joins: Vec::new(),
}),
TableFactor::Derived {
lateral,
subquery,
alias,
} => match alias {
None => Ok(TableWithJoins {
relation: TableFactor::Derived {
lateral,
subquery,
alias: Some(consumed_alias),
},
joins: Vec::new(),
}),
Some(alias) => Err(ParserError::ParserError(format!(
"duplicate alias {}",
alias
))),
},
TableFactor::Table {
name,
alias,
args,
with_hints,
} => match alias {
None => Ok(TableWithJoins {
relation: TableFactor::Table {
name,
alias: Some(consumed_alias),
args,
with_hints,
},
joins: Vec::new(),
}),
Some(alias) => Err(ParserError::ParserError(format!(
"duplicate alias {}",
alias
))),
},
}
}

fn check_for_alias_after_parenthesis(
&mut self,
table_and_joins: TableWithJoins,
) -> Result<TableWithJoins, ParserError> {
let alias = match self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)? {
None => {
return Ok(table_and_joins);
}
Some(alias) => alias,
};

self.add_alias_to_single_table_in_parenthesis(table_and_joins, alias)
}

fn validate_nested_join(&self, table_and_joins: &TableWithJoins) -> Result<(), ParserError> {
match table_and_joins.relation {
TableFactor::NestedJoin { .. } => (),
_ => {
if table_and_joins.joins.is_empty() {
// The SQL spec prohibits derived tables and bare
// tables from appearing alone in parentheses.
self.expected("joined table", self.peek_token())?
}
}
}

Ok(())
}

/// A table name or a parenthesized subquery, followed by optional `[AS] alias`
pub fn parse_table_factor(&mut self) -> Result<TableFactor, ParserError> {
if self.parse_keyword(Keyword::LATERAL) {
Expand Down Expand Up @@ -2094,10 +2192,17 @@ impl Parser {
// followed by some joins or another level of nesting.
let table_and_joins = self.parse_table_and_joins()?;
self.expect_token(&Token::RParen)?;
// The SQL spec prohibits derived and bare tables from appearing
// alone in parentheses. We don't enforce this as some databases
// (e.g. Snowflake) allow such syntax.
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))

if self.dialect.alllow_single_table_in_parenthesis() {
let table_and_joins = self.check_for_alias_after_parenthesis(table_and_joins)?;
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))
} else {
self.validate_nested_join(&table_and_joins)?;
// The SQL spec prohibits derived and bare tables from appearing
// alone in parentheses. We don't enforce this as some databases
// (e.g. Snowflake) allow such syntax.
Ok(TableFactor::NestedJoin(Box::new(table_and_joins)))
}
} else {
let name = self.parse_object_name()?;
// Postgres, MSSQL: table-valued functions:
Expand Down
6 changes: 4 additions & 2 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl TestedDialects {
self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize().unwrap();
f(&mut Parser::new(tokens))
f(&mut Parser::new(tokens, dialect))
})
}

Expand Down Expand Up @@ -104,7 +104,9 @@ impl TestedDialects {
/// Ensures that `sql` parses as an expression, and is not modified
/// after a serialization round-trip.
pub fn verified_expr(&self, sql: &str) -> Expr {
let ast = self.run_parser_method(sql, Parser::parse_expr).unwrap();
let ast = self
.run_parser_method(sql, |parser| parser.parse_expr())
.unwrap();
assert_eq!(sql, &ast.to_string(), "round-tripping without changes");
ast
}
Expand Down
46 changes: 10 additions & 36 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use matches::assert_matches;

use sqlparser::ast::*;
use sqlparser::dialect::keywords::ALL_KEYWORDS;
use sqlparser::parser::{Parser, ParserError};
use sqlparser::parser::ParserError;
use sqlparser::test_utils::{all_dialects, expr_from_projection, number, only};

#[test]
Expand Down Expand Up @@ -147,13 +147,14 @@ fn parse_update() {

#[test]
fn parse_invalid_table_name() {
let ast = all_dialects().run_parser_method("db.public..customer", Parser::parse_object_name);
let ast = all_dialects()
.run_parser_method("db.public..customer", |parser| parser.parse_object_name());
assert!(ast.is_err());
}

#[test]
fn parse_no_table_name() {
let ast = all_dialects().run_parser_method("", Parser::parse_object_name);
let ast = all_dialects().run_parser_method("", |parser| parser.parse_object_name());
assert!(ast.is_err());
}

Expand Down Expand Up @@ -2273,19 +2274,12 @@ fn parse_join_nesting() {
vec![join(nest!(nest!(nest!(table("b"), table("c")))))]
);

// Parenthesized table names are non-standard, but supported in Snowflake SQL
let sql = "SELECT * FROM (a NATURAL JOIN (b))";
let select = verified_only_select(sql);
let from = only(select.from);

assert_eq!(from.relation, nest!(table("a"), nest!(table("b"))));

// Double parentheses around table names are non-standard, but supported in Snowflake SQL
let sql = "SELECT * FROM (a NATURAL JOIN ((b)))";
let select = verified_only_select(sql);
let from = only(select.from);

assert_eq!(from.relation, nest!(table("a"), nest!(nest!(table("b")))));
// Nesting a subquery in parentheses is non-standard, but supported in Snowflake SQL
let res = parse_sql_statements("SELECT * FROM ((SELECT 1) AS t)");
assert_eq!(
ParserError::ParserError("Expected joined table, found: EOF".to_string()),
res.unwrap_err()
);
}

#[test]
Expand Down Expand Up @@ -2427,26 +2421,6 @@ fn parse_derived_tables() {
}],
}))
);

// Nesting a subquery in parentheses is non-standard, but supported in Snowflake SQL
let sql = "SELECT * FROM ((SELECT 1) AS t)";
let select = verified_only_select(sql);
let from = only(select.from);

assert_eq!(
from.relation,
TableFactor::NestedJoin(Box::new(TableWithJoins {
relation: TableFactor::Derived {
lateral: false,
subquery: Box::new(verified_query("SELECT 1")),
alias: Some(TableAlias {
name: "t".into(),
columns: vec![],
})
},
joins: Vec::new(),
}))
);
}

#[test]
Expand Down

0 comments on commit b7d58c7

Please sign in to comment.