Skip to content

Commit

Permalink
Allow customizing logic for setting values in statements.
Browse files Browse the repository at this point in the history
  • Loading branch information
radeusgd committed Mar 28, 2023
1 parent a268b33 commit b5d374b
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import project.Internal.IR.Context.Context
import project.Internal.IR.SQL_Expression.SQL_Expression
import project.Internal.IR.Query.Query
import project.Internal.SQL_Type_Reference.SQL_Type_Reference

import project.Internal.Statement_Setter.Statement_Setter

from project.Internal.Result_Set import read_column, result_set_to_table
from project.Internal.JDBC_Connection import create_table_statement, handle_sql_errors
Expand Down Expand Up @@ -135,14 +135,16 @@ type Connection
False ->
Error.throw (Table_Not_Found.Error query sql_error treated_as_query=True)
SQL_Query.Raw_SQL raw_sql -> handle_sql_errors <|
columns = self.jdbc_connection.fetch_columns raw_sql
columns = self.jdbc_connection.fetch_columns raw_sql Statement_Setter.null
name = if alias == "" then (UUID.randomUUID.to_text) else alias
ctx = Context.for_query raw_sql name
Database_Table_Module.make_table self name columns ctx
SQL_Query.Table_Name name ->
result = handle_sql_errors <|
ctx = Context.for_table name (if alias == "" then name else alias)
columns = self.jdbc_connection.fetch_columns (self.dialect.generate_sql (Query.Select Nothing ctx))
statement = self.dialect.generate_sql (Query.Select Nothing ctx)
statement_setter = self.dialect.get_statement_setter
columns = self.jdbc_connection.fetch_columns statement statement_setter
Database_Table_Module.make_table self name columns ctx
result.catch SQL_Error sql_error->
Error.throw (Table_Not_Found.Error name sql_error treated_as_query=False)
Expand Down Expand Up @@ -171,7 +173,8 @@ type Connection
read_statement : SQL_Statement -> (Nothing | Vector SQL_Type_Reference) -> Materialized_Table
read_statement self statement column_type_suggestions=Nothing last_row_only=False =
type_overrides = self.dialect.get_type_mapping.prepare_type_overrides column_type_suggestions
self.jdbc_connection.with_prepared_statement statement stmt->
statement_setter = self.dialect.get_statement_setter
self.jdbc_connection.with_prepared_statement statement statement_setter stmt->
result_set_to_table stmt.executeQuery self.dialect.make_column_fetcher_for_type type_overrides last_row_only

## ADVANCED
Expand All @@ -185,7 +188,8 @@ type Connection
representing the query to execute.
execute_update : Text | SQL_Statement -> Integer
execute_update self query =
self.jdbc_connection.with_prepared_statement query stmt->
statement_setter = self.dialect.get_statement_setter
self.jdbc_connection.with_prepared_statement query statement_setter stmt->
Panic.catch UnsupportedOperationException stmt.executeLargeUpdate _->
stmt.executeUpdate

Expand Down Expand Up @@ -225,6 +229,7 @@ type Connection
pairs = db_table.internal_columns.map col->[col.name, SQL_Expression.Constant Nothing]
insert_query = self.dialect.generate_sql <| Query.Insert name pairs
insert_template = insert_query.prepare.first
self.jdbc_connection.load_table insert_template table batch_size
statement_setter = self.dialect.get_statement_setter
self.jdbc_connection.load_table insert_template statement_setter table batch_size

db_table
11 changes: 11 additions & 0 deletions distribution/lib/Standard/Database/0.0.0-dev/src/Data/Dialect.enso
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import project.Internal.Postgres.Postgres_Dialect
import project.Internal.Redshift.Redshift_Dialect
import project.Internal.SQLite.SQLite_Dialect
import project.Internal.SQL_Type_Mapping.SQL_Type_Mapping
import project.Internal.Statement_Setter.Statement_Setter
from project.Errors import Unsupported_Database_Operation

## PRIVATE
Expand Down Expand Up @@ -102,6 +103,16 @@ type Dialect
_ = sql_type
Unimplemented.throw "This is an interface only."

## PRIVATE
Returns a helper object that handles the logic of setting values in a
prepared statement.

This object may provide custom logic for handling dialect-specific
handling of some types.
get_statement_setter : Statement_Setter
get_statement_setter self =
Unimplemented.throw "This is an interface only."

## PRIVATE
Checks if the given aggregate is supported.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ polyglot java import java.sql.SQLTimeoutException
polyglot java import java.sql.Types as Java_Types

polyglot java import org.enso.database.JDBCProxy
polyglot java import org.enso.database.JDBCUtils

type JDBC_Connection
## PRIVATE
Expand Down Expand Up @@ -65,11 +64,11 @@ type JDBC_Connection

Runs the provided action with a prepared statement, adding contextual
information to any thrown SQL errors.
with_prepared_statement : Text | SQL_Statement -> (PreparedStatement -> Any) -> Any
with_prepared_statement self query action =
with_prepared_statement : Text | SQL_Statement -> Statement_Setter -> (PreparedStatement -> Any) -> Any
with_prepared_statement self query statement_setter action =
prepare template values = self.connection_resource.with java_connection->
stmt = java_connection.prepareStatement template
Panic.catch Any (set_statement_values stmt values) caught_panic->
Panic.catch Any (statement_setter.fill_values stmt values) caught_panic->
stmt.close
Panic.throw caught_panic
stmt
Expand All @@ -85,12 +84,11 @@ type JDBC_Connection
go compiled.first compiled.second

## PRIVATE

Given a prepared statement, gets the column names and types for the
result set.
fetch_columns : Text | SQL_Statement -> Any
fetch_columns self statement =
self.with_prepared_statement statement stmt->
fetch_columns : Text | SQL_Statement -> Statement_Setter -> Any
fetch_columns self statement statement_setter =
self.with_prepared_statement statement statement_setter stmt->
metadata = stmt.executeQuery.getMetaData

resolve_column ix =
Expand All @@ -104,8 +102,8 @@ type JDBC_Connection

Given an insert query template and the associated Database_Table, and a
Materialized_Table of data, load to the database.
load_table : Text -> Materialized_Table -> Integer -> Nothing
load_table self insert_template table batch_size =
load_table : Text -> Statement_Setter -> Materialized_Table -> Integer -> Nothing
load_table self insert_template statement_setter table batch_size =
self.with_connection java_connection->
default_autocommit = java_connection.getAutoCommit
java_connection.setAutoCommit False
Expand All @@ -121,7 +119,7 @@ type JDBC_Connection
Panic.throw <| Illegal_State.Error "A single update within the batch unexpectedly affected "+affected_rows.to_text+" rows."
0.up_to num_rows . each row_id->
values = columns.map col-> col.at row_id
set_statement_values stmt values
statement_setter.fill_values stmt values
stmt.addBatch
if (row_id+1 % batch_size) == 0 then check_rows stmt.executeBatch batch_size
if (num_rows % batch_size) != 0 then check_rows stmt.executeBatch (num_rows % batch_size)
Expand Down Expand Up @@ -172,18 +170,6 @@ handle_sql_errors ~action related_query=Nothing =
exc : SQLTimeoutException -> Error.throw (SQL_Timeout.Error exc related_query)
exc -> Error.throw (SQL_Error.Error exc related_query)

## PRIVATE
Sets values inside of a prepared statement.
set_statement_values : PreparedStatement -> Vector Any -> Nothing
set_statement_values stmt values =
values.map_with_index ix-> obj->
position = ix + 1
# TODO [RW] dialect specific logic!
case obj of
Nothing -> stmt.setNull position Java_Types.NULL
_ : Date_Time -> stmt.setTimestamp position (JDBCUtils.getTimestamp obj)
_ -> stmt.setObject position obj

## PRIVATE
Given a Materialized_Table, create a SQL statement to build the table.
create_table_statement : (Value_Type -> SQL_Type) -> Text -> Materialized_Table -> Boolean -> SQL_Statement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ import project.Internal.IR.Order_Descriptor.Order_Descriptor
import project.Internal.IR.Nulls_Order.Nulls_Order
import project.Internal.IR.SQL_Join_Kind.SQL_Join_Kind
import project.Internal.IR.Query.Query
import project.Internal.SQL_Type_Mapping.SQL_Type_Mapping
import project.Internal.Postgres.Postgres_Type_Mapping.Postgres_Type_Mapping
import project.Internal.SQL_Type_Mapping.SQL_Type_Mapping
import project.Internal.Statement_Setter.Statement_Setter
from project.Errors import Unsupported_Database_Operation

polyglot java import org.enso.database.JDBCUtils

## PRIVATE

The dialect of PostgreSQL databases.
Expand Down Expand Up @@ -127,6 +130,10 @@ type Postgres_Dialect
value_type = type_mapping.sql_type_to_value_type sql_type
Column_Fetcher_Module.default_fetcher_for_value_type value_type

## PRIVATE
get_statement_setter : Statement_Setter
get_statement_setter self = postgres_statement_setter

## PRIVATE
check_aggregate_support : Aggregate_Column -> Boolean ! Unsupported_Database_Operation
check_aggregate_support self aggregate =
Expand Down Expand Up @@ -315,3 +322,16 @@ decimal_div = Base_Generator.lift_binary_op "/" x-> y->
## PRIVATE
mod_op = Base_Generator.lift_binary_op "mod" x-> y->
x ++ " - FLOOR(CAST(" ++ x ++ " AS double precision) / CAST(" ++ y ++ " AS double precision)) * " ++ y

## PRIVATE
postgres_statement_setter : Statement_Setter
postgres_statement_setter =
default = Statement_Setter.default
fill_holes stmt i value = case value of
# TODO [RW] Postgres date handling
_ : Date_Time ->
stmt.setTimestamp position (JDBCUtils.getTimestamp obj)
# _ : Date ->
# _ : Time_Of_Day ->
_ -> default.fill_holes stmt i value
Statement_Setter.Value fill_holes
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ type Redshift_Dialect
value_type = type_mapping.sql_type_to_value_type sql_type
Column_Fetcher_Module.default_fetcher_for_value_type value_type

## PRIVATE
get_statement_setter : Statement_Setter
get_statement_setter self = Postgres_Dialect.postgres_statement_setter

## PRIVATE
check_aggregate_support : Aggregate_Column -> Boolean ! Unsupported_Database_Operation
check_aggregate_support self aggregate =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ type SQL_Type_Reference
new : Connection -> Context -> SQL_Expression -> SQL_Type_Reference
new connection context expression =
do_fetch =
# TODO [RW] remove type from here ?
empty_context = context.add_where_filters [SQL_Expression.Constant False]
statement = connection.dialect.generate_sql (Query.Select [["typed_column", expression]] empty_context)
columns = connection.jdbc_connection.fetch_columns statement
statement_setter = connection.dialect.get_statement_setter
columns = connection.jdbc_connection.fetch_columns statement statement_setter
only_column = columns.first
only_column.second
SQL_Type_Reference.Computed_By_Database (Lazy.new do_fetch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import project.Internal.Common.Database_Distinct_Helper
import project.Internal.Common.Database_Join_Helper
import project.Internal.SQL_Type_Mapping.SQL_Type_Mapping
import project.Internal.SQLite.SQLite_Type_Mapping.SQLite_Type_Mapping
import project.Internal.Statement_Setter.Statement_Setter
from project.Errors import Unsupported_Database_Operation

## PRIVATE
Expand Down Expand Up @@ -146,6 +147,10 @@ type SQLite_Dialect
value_type = type_mapping.sql_type_to_value_type sql_type
Column_Fetcher_Module.default_fetcher_for_value_type value_type

## PRIVATE
get_statement_setter : Statement_Setter
get_statement_setter self = Statement_Setter.default

## PRIVATE
check_aggregate_support : Aggregate_Column -> Boolean ! Unsupported_Database_Operation
check_aggregate_support self aggregate = case aggregate of
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from Standard.Base import all
import Standard.Base.Errors.Illegal_State.Illegal_State

polyglot java import java.sql.PreparedStatement

type Statement_Setter
## PRIVATE
Encapsulates the logic for filling a hole in a prepared statement.
Value (fill_hole : PreparedStatement -> Integer -> Any -> Nothing)

## PRIVATE
A helper that gets a list of values and fills their corresponding holes
in the prepared statement.

It assumes that the provided vector contains all values expected in this
prepared statement. It should not be called multiple times on the same
statement.
fill_values : PreparedStatement -> Vector Any -> Nothing
fill_values self stmt values =
values.each_with_index ix-> value->
self.fill_hole stmt (ix + 1) value

## PRIVATE
The default setter that is handling simple commonly supported types.
default : Statement_Setter
default = Statement_Setter.Value fill_hole_default

## PRIVATE
Used internally to mark statements that do not expect to have any values
to set.

It will panic if called.
null : Statement_Setter
null =
fill_hole_unexpected _ _ _ =
Panic.throw (Illegal_State.Error "The associated statement does not expect any values to be set. This is a bug in the Database library.")
Statement_Setter.Value fill_hole_unexpected

## PRIVATE
fill_hole_default stmt i value = case value of
Nothing -> stmt.setNull i Java_Types.NULL
_ : Boolean -> stmt.setBoolean i value
_ : Integer -> stmt.setLong i value
_ : Decimal -> stmt.setDouble i value
_ : Text -> stmt.setString i value
_ -> stmt.setObject i value
17 changes: 17 additions & 0 deletions test/Table_Tests/src/Database/Common_Spec.enso
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ spec prefix connection =
m2.at "b" . to_vector . should_equal [5]
m2.at "c" . should_fail_with No_Such_Column

Test.specify "should allow to access a Table by an SQL query" <|
t2 = connection.query (SQL_Query.Raw_SQL ('SELECT a, b FROM "' + name + '" WHERE a >= 3'))
m2 = t2.read
m2.column_names . should_equal ["a", "b"]
m2.at "a" . to_vector . should_equal [4]
m2.at "b" . to_vector . should_equal [5]
m2.at "c" . should_fail_with No_Such_Column

t3 = connection.query (SQL_Query.Raw_SQL ('SELECT 1+2'))
m3 = t3.read
m3.at 0 . to_vector . should_equal [3]

Test.specify "should use labels for column names" <|
t2 = connection.query (SQL_Query.Raw_SQL ('SELECT a AS c, b FROM "' + name + '" WHERE a >= 3'))
m2 = t2.read
Expand Down Expand Up @@ -93,6 +105,11 @@ spec prefix connection =
r3 = connection.query (SQL_Query.Raw_SQL "MALFORMED-QUERY")
r3.should_fail_with SQL_Error

Test.specify "should not allow interpolations in raw user-built queries" <|
r = connection.query (SQL_Query.Raw_SQL "SELECT 1 + ?")
IO.println r
r.should_fail_with SQL_Error

Test.specify "should make a best-effort attempt at returning a reasonable error for the short-hand" <|
r2 = connection.query "NONEXISTENT-TABLE"
r2.should_fail_with Table_Not_Found
Expand Down

0 comments on commit b5d374b

Please sign in to comment.