Skip to content

Commit

Permalink
Refactor StatementBuilder and Database and introduce connection logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Brendan Maginnis committed May 18, 2015
1 parent 2b93ecb commit ed8bc8b
Show file tree
Hide file tree
Showing 13 changed files with 208 additions and 206 deletions.
8 changes: 8 additions & 0 deletions examples/src/main/scala/sqlest/examples/Extractor.scala
Expand Up @@ -6,6 +6,14 @@ case class Fruit(name: String, juiciness: Int)
case class Smoothy(description: String, fruits: List[Fruit])

object ExtractorExamples extends App with DatabaseExample {

import org.apache.log4j.{ ConsoleAppender, PropertyConfigurator, SimpleLayout }
val p = new java.util.Properties
p.setProperty("log4j.rootLogger", "DEBUG,mylog")
p.setProperty("log4j.appender.mylog", classOf[ConsoleAppender].getName)
p.setProperty("log4j.appender.mylog.layout", classOf[SimpleLayout].getName)
PropertyConfigurator.configure(p)

InsertExamples.insertAll

val fruits =
Expand Down
4 changes: 2 additions & 2 deletions examples/src/main/scala/sqlest/examples/Insert.scala
Expand Up @@ -27,7 +27,7 @@ object InsertExamples extends App with DatabaseExample {
)

// We can get the setters for a case class using the extractor
val ingredients = List(Ingredient(1, 1), Ingredient(1, 3))
lazy val ingredients = List(Ingredient(1, 1), Ingredient(1, 3))
lazy val ingredientsTableInsertStatement =
insert
.into(IngredientsTable)
Expand All @@ -36,7 +36,7 @@ object InsertExamples extends App with DatabaseExample {
.values(extractor.settersFor(ingredients(1)))

// This also works for lists of the case class
val moreIngredients = List(Ingredient(2, 4), Ingredient(2, 3))
lazy val moreIngredients = List(Ingredient(2, 4), Ingredient(2, 3))
lazy val moreIngredientsInsertStatement =
insert
.into(IngredientsTable)
Expand Down
159 changes: 118 additions & 41 deletions sqlest/src/main/scala/sqlest/executor/Database.scala
Expand Up @@ -17,83 +17,94 @@
package sqlest.executor

import sqlest.ast._
import sqlest.extractor._
import sqlest.sql.base._
import sqlest.util._
import sqlest.util.Logging

import java.sql.{ Connection, DriverManager, ResultSet, Statement, SQLException }
import java.sql.{ Connection, Date => JdbcDate, DriverManager, ResultSet, PreparedStatement, Statement, SQLException, Timestamp => JdbcTimestamp, Types => JdbcTypes }
import javax.sql.DataSource
import java.util.Properties
import org.joda.time.{ DateTime, LocalDate }
import scala.util.DynamicVariable

object Database {
def withDataSource(dataSource: DataSource, builder: StatementBuilder): Database = new Database {
def getConnection: Connection = dataSource.getConnection
val statementBuilder = builder
}

def withDataSource(dataSource: DataSource, builder: StatementBuilder, connectionDescription: Connection => String): Database = {
val inConnectionDescription = connectionDescription
new Database {
def getConnection: Connection = dataSource.getConnection
val statementBuilder = builder
override val connectionDescription = Some(inConnectionDescription)
}
}
}

trait Database extends Logging {
protected def getConnection: Connection
protected def statementBuilder: StatementBuilder
protected def connectionDescription: Option[Connection => String] = None

private val transactionConnection = new DynamicVariable[Option[Connection]](None)

def executeSelect[A](select: Select[_, _])(extractor: ResultSet => A): A =
executeWithConnection { connection =>
val preparedStatement = statementBuilder(connection, select)
val (preprocessedSelect, sql, argumentLists) = statementBuilder(select)
try {
logger.debug(s"Executing select")
val resultSet = preparedStatement.executeQuery

val startTime = new DateTime
val preparedStatement = prepareStatement(connection, preprocessedSelect, sql, argumentLists)
try {
logger.debug(s"Extracting results")
extractor(resultSet)
val resultSet = preparedStatement.executeQuery
try {
val result = extractor(resultSet)
val endTime = new DateTime
logger.info(s"Ran sql in ${endTime.getMillis - startTime.getMillis}ms: ${logDetails(connection, sql, argumentLists)}")
result
} finally {
try {
if (resultSet != null) resultSet.close
} catch {
case e: SQLException =>
}
}
} finally {
try {
if (resultSet != null) resultSet.close
if (preparedStatement != null) preparedStatement.close
} catch {
case e: SQLException =>
}
}

} finally {
try {
if (preparedStatement != null) preparedStatement.close
} catch {
case e: SQLException =>
}
} catch {
case e: Throwable =>
logger.error(s"Exception running sql: ${logDetails(connection, sql, argumentLists)}", e)
throw e
}
}

def executeInsert(insert: Insert): Int = {
logger.debug(s"Executing insert")
executeCommand(insert)
}

def executeUpdate(update: Update): Int = {
logger.debug(s"Executing update")
executeCommand(update)
}

def executeDelete(delete: Delete): Int = {
logger.debug(s"Executing delete")
executeCommand(delete)
}

private def executeCommand(command: Command): Int = {
def executeCommand(command: Command): Int = {
checkInTransaction
executeWithConnection { connection =>
val preparedStatement = statementBuilder(connection, command)
val (preprocessedCommand, sql, argumentLists) = statementBuilder(command)
val startTime = new DateTime
try {
logger.debug(s"Executing command")
preparedStatement.executeBatch.sum
} finally {
val preparedStatement = prepareStatement(connection, preprocessedCommand, sql, argumentLists)
try {
if (preparedStatement != null) preparedStatement.close
} catch {
case e: SQLException =>
val result = preparedStatement.executeBatch.sum
val endTime = new DateTime
logger.info(s"Ran sql in ${endTime.getMillis - startTime.getMillis}ms: ${logDetails(connection, sql, argumentLists)}")
result
} finally {
try {
if (preparedStatement != null) preparedStatement.close
} catch {
case e: SQLException =>
}
}
} catch {
case e: Throwable =>
logger.error(s"Exception running sql: ${logDetails(connection, sql, argumentLists)}", e)
throw e
}
}
}
Expand Down Expand Up @@ -173,4 +184,70 @@ trait Database extends Logging {
private def checkInTransaction =
if (transactionConnection.value.isEmpty)
throw new AssertionError("Must run write operations in a transaction")

def prepareStatement(connection: Connection, operation: Operation, sql: String, argumentLists: List[List[LiteralColumn[_]]]) = {
val statement = connection.prepareStatement(sql)
setArguments(operation, statement, argumentLists)
statement
}

private def setArguments(operation: Operation, statement: PreparedStatement, argumentLists: List[List[LiteralColumn[_]]]) = {
def innerSetArguments(argumentList: List[LiteralColumn[_]]) = {
var index = 0
argumentList foreach { argument =>
index = index + 1 // prepared statement argument indices are 1-based
setArgument(statement, index, argument.columnType, argument.value)
}
}

argumentLists.foreach {
argumentList =>
innerSetArguments(argumentList)
statement.addBatch
}
}

private def setArgument[A](statement: PreparedStatement, index: Int, columnType: ColumnType[A], value: Any): Unit = columnType match {
case BooleanColumnType => statement.setBoolean(index, value.asInstanceOf[Boolean])
case IntColumnType => statement.setInt(index, value.asInstanceOf[Int])
case LongColumnType => statement.setLong(index, value.asInstanceOf[Long])
case DoubleColumnType => statement.setDouble(index, value.asInstanceOf[Double])
case BigDecimalColumnType => statement.setBigDecimal(index, value.asInstanceOf[BigDecimal].bigDecimal)
case StringColumnType => statement.setString(index, value.asInstanceOf[String])
case ByteArrayColumnType => statement.setBytes(index, value.asInstanceOf[Array[Byte]])
case DateTimeColumnType => statement.setTimestamp(index, new JdbcTimestamp(value.asInstanceOf[DateTime].getMillis))
case LocalDateColumnType => statement.setDate(index, new JdbcDate(value.asInstanceOf[LocalDate].toDate.getTime))
case mappedType: MappedColumnType[A, _] => setArgument(statement, index, mappedType.baseColumnType, mappedType.write(value.asInstanceOf[A]))
case optionType: OptionColumnType[_, _] => value.asInstanceOf[Option[_]] match {
case setNullOpt if setNullOpt.isEmpty && optionType.nullValue == null =>
statement.setNull(index, jdbcType(optionType.baseColumnType))
case nullValueOpt if nullValueOpt.isEmpty =>
setArgument(statement, index, optionType.baseColumnType, optionType.nullValue)
case definedOpt =>
setArgument(statement, index, optionType.innerColumnType, definedOpt.get)
}
}

private def jdbcType[A](columnType: ColumnType[A]): Int = columnType match {
case BooleanColumnType => JdbcTypes.BOOLEAN
case IntColumnType => JdbcTypes.INTEGER
case LongColumnType => JdbcTypes.INTEGER
case DoubleColumnType => JdbcTypes.DOUBLE
case BigDecimalColumnType => JdbcTypes.DECIMAL
case StringColumnType => JdbcTypes.CHAR
case ByteArrayColumnType => JdbcTypes.BINARY
case DateTimeColumnType => JdbcTypes.TIMESTAMP
case LocalDateColumnType => JdbcTypes.DATE
case optionType: OptionColumnType[_, _] => jdbcType(optionType.baseColumnType)
case mappedType: MappedColumnType[_, _] => jdbcType(mappedType.baseColumnType)
}

def logDetails(connection: Connection, sql: String, argumentLists: List[List[LiteralColumn[_]]]) = {
val connectionLog = connectionDescription.map(connectionDescription => s", connection [${connectionDescription(connection)}]").getOrElse("")
val argumentsLog =
if (argumentLists.size == 1) argumentLists.head.map(_.value).mkString(", ")
else argumentLists.map(_.map(_.value).mkString("(", ", ", ")")).mkString(", ")

s"sql [$sql], arguments [$argumentsLog]${connectionLog}"
}
}
6 changes: 3 additions & 3 deletions sqlest/src/main/scala/sqlest/executor/Executor.scala
Expand Up @@ -45,15 +45,15 @@ trait ExecutorSyntax extends QuerySyntax {
}

implicit class InsertExecutorOps(insert: Insert) {
def execute(implicit database: Database): Int = database.executeInsert(insert)
def execute(implicit database: Database): Int = database.executeCommand(insert)
}

implicit class UpdateExecutorOps(update: Update) {
def execute(implicit database: Database): Int = database.executeUpdate(update)
def execute(implicit database: Database): Int = database.executeCommand(update)
}

implicit class DeleteExecutorOps(delete: Delete) {
def execute(implicit database: Database): Int = database.executeDelete(delete)
def execute(implicit database: Database): Int = database.executeCommand(delete)
}

implicit class BatchExecutorOps(batchCommands: Seq[Command]) {
Expand Down
Expand Up @@ -39,11 +39,8 @@ trait InsertStatementBuilder extends BaseStatementBuilder {

// -------------------------------------------------

def insertArgs(insert: Insert): List[LiteralColumn[_]] = insert match {
case InsertValues(_, setterLists) => insertValuesArgs(setterLists)
case InsertFromSelect(_, _, select) => selectStatementBuilder.selectArgs(select)
def insertArgs(insert: Insert): List[List[LiteralColumn[_]]] = insert match {
case InsertValues(_, setterLists) => setterLists.map(_.toList.flatMap(setterArgs(_))).toList
case InsertFromSelect(_, _, select) => List(selectStatementBuilder.selectArgs(select))
}

def insertValuesArgs(setterLists: Seq[Seq[Setter[_, _]]]): List[LiteralColumn[_]] =
setterLists.flatten.toList.flatMap(setterArgs(_))
}

0 comments on commit ed8bc8b

Please sign in to comment.