diff --git a/build.sbt b/build.sbt index cc5ed16..724b868 100644 --- a/build.sbt +++ b/build.sbt @@ -4,20 +4,21 @@ val gatlingVersion = "2.3.0" scalacOptions += "-target:jvm-1.8" -libraryDependencies += "com.datastax.dse" % "dse-java-driver-core" % "1.6.8" -libraryDependencies += "com.datastax.dse" % "dse-java-driver-graph" % "1.6.8" -libraryDependencies += "com.github.nscala-time" %% "nscala-time" % "2.18.0" -libraryDependencies += "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.9.1" -libraryDependencies += "org.hdrhistogram" % "HdrHistogram" % "2.1.10" +libraryDependencies += "com.datastax.oss" % "java-driver-core" % "4.5.0" +libraryDependencies += "com.github.nscala-time" %% "nscala-time" % "2.18.0" +libraryDependencies += "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.9.1" +libraryDependencies += "org.hdrhistogram" % "HdrHistogram" % "2.1.10" -libraryDependencies += "io.gatling.highcharts" % "gatling-charts-highcharts" % gatlingVersion % Provided +libraryDependencies += "io.gatling.highcharts" % "gatling-charts-highcharts" % gatlingVersion % Provided + +libraryDependencies += "org.fusesource" % "sigar" % "1.6.4" % Test +libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % Test +libraryDependencies += "org.easymock" % "easymock" % "3.5" % Test +libraryDependencies += "org.cassandraunit" % "cassandra-unit" % "4.3.1.0" % Test +libraryDependencies += "org.pegdown" % "pegdown" % "1.6.0" % Test +libraryDependencies += "com.typesafe.akka" %% "akka-testkit" % "2.5.11" % Test +libraryDependencies += "com.datastax.oss" % "java-driver-query-builder" % "4.4.0" % Test -libraryDependencies += "org.fusesource" % "sigar" % "1.6.4" % Test -libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % Test -libraryDependencies += "org.easymock" % "easymock" % "3.5" % Test -libraryDependencies += "org.cassandraunit" % "cassandra-unit" % "3.3.0.2" % Test -libraryDependencies += "org.pegdown" % "pegdown" % "1.6.0" % Test -libraryDependencies += "com.typesafe.akka" %% "akka-testkit" % "2.5.11" % Test resolvers += Resolver.mavenLocal resolvers += Resolver.mavenCentral diff --git a/src/main/scala/com/datastax/gatling/plugin/DseProtocol.scala b/src/main/scala/com/datastax/gatling/plugin/DseProtocol.scala index e1c8863..17c8f88 100644 --- a/src/main/scala/com/datastax/gatling/plugin/DseProtocol.scala +++ b/src/main/scala/com/datastax/gatling/plugin/DseProtocol.scala @@ -11,10 +11,10 @@ import java.util.concurrent.atomic.AtomicLong import akka.Done import akka.actor.ActorSystem -import com.datastax.driver.dse.DseSession import com.datastax.gatling.plugin.metrics.MetricsLogger import com.datastax.gatling.plugin.request.{CqlRequestActionBuilder, GraphRequestActionBuilder} import com.datastax.gatling.plugin.utils.GatlingTimingSource +import com.datastax.oss.driver.api.core.CqlSession import com.typesafe.scalalogging.StrictLogging import io.gatling.core.CoreComponents import io.gatling.core.config.GatlingConfiguration @@ -63,7 +63,7 @@ object DseProtocol extends StrictLogging { } } -case class DseProtocol(session: DseSession) extends Protocol +case class DseProtocol(session: CqlSession) extends Protocol object DseComponents { private val componentsCache = mutable.Map[ActorSystem, DseComponents]() @@ -126,10 +126,10 @@ case class DseComponents(dseProtocol: DseProtocol, object DseProtocolBuilder { - def session(session: DseSession) = DseProtocolBuilder(session) + def session(session: CqlSession) = DseProtocolBuilder(session) } -case class DseProtocolBuilder(session: DseSession) { +case class DseProtocolBuilder(session: CqlSession) { def build = DseProtocol(session) } diff --git a/src/main/scala/com/datastax/gatling/plugin/Predef.scala b/src/main/scala/com/datastax/gatling/plugin/Predef.scala index ac3f89f..eb1eeea 100644 --- a/src/main/scala/com/datastax/gatling/plugin/Predef.scala +++ b/src/main/scala/com/datastax/gatling/plugin/Predef.scala @@ -36,9 +36,9 @@ trait DsePredefBase extends DseCheckSupport { implicit def protocolBuilder2DseProtocol(builder: DseProtocolBuilder): DseProtocol = builder.build - implicit def cqlRequestAttributes2ActionBuilder(builder: DseCqlAttributesBuilder): ActionBuilder = builder.build() + implicit def cqlRequestAttributes2ActionBuilder(builder: DseCqlAttributesBuilder[_,_]): ActionBuilder = builder.build() - implicit def graphRequestAttributes2ActionBuilder(builder: DseGraphAttributesBuilder): ActionBuilder = builder.build() + implicit def graphRequestAttributes2ActionBuilder(builder: DseGraphAttributesBuilder[_, _]): ActionBuilder = builder.build() } /** diff --git a/src/main/scala/com/datastax/gatling/plugin/checks/CqlChecks.scala b/src/main/scala/com/datastax/gatling/plugin/checks/CqlChecks.scala index 567f533..2729904 100644 --- a/src/main/scala/com/datastax/gatling/plugin/checks/CqlChecks.scala +++ b/src/main/scala/com/datastax/gatling/plugin/checks/CqlChecks.scala @@ -6,16 +6,15 @@ package com.datastax.gatling.plugin.checks -import com.datastax.driver.core.{ResultSet, Row} +import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, Statement, StatementBuilder} import com.datastax.gatling.plugin.response.CqlResponse import io.gatling.commons.validation.{SuccessWrapper, Validation} import io.gatling.core.check._ -import io.gatling.core.check.extractor.{CountArity, CriterionExtractor, Extractor, FindAllArity, FindArity, SingleArity, _} +import io.gatling.core.check.extractor.{Extractor, SingleArity} import io.gatling.core.session.{Expression, ExpressionSuccessWrapper, Session} import scala.collection.mutable - /** * This class serves as model for the CQL-specific checks. * @@ -51,50 +50,10 @@ private class CqlResponseExtractor[X](val name: String, } } -private abstract class ColumnValueExtractor[X] extends CriterionExtractor[CqlResponse, Any, X] { - val criterionName = "columnValue" -} - -private class SingleColumnValueExtractor(val criterion: String, val occurrence: Int) extends ColumnValueExtractor[Any] with FindArity { - def extract(response: CqlResponse): Validation[Option[Any]] = - response.getCqlResultColumnValues(criterion).lift(occurrence).success -} - -private class MultipleColumnValueExtractor(val criterion: String) extends ColumnValueExtractor[Seq[Any]] with FindAllArity { - def extract(response: CqlResponse): Validation[Option[Seq[Any]]] = - response.getCqlResultColumnValues(criterion).liftSeqOption.success -} - -private class CountColumnValueExtractor(val criterion: String) extends ColumnValueExtractor[Int] with CountArity { - def extract(response: CqlResponse): Validation[Option[Int]] = - response.getCqlResultColumnValues(criterion).liftSeqOption.map(_.size).success -} - object CqlChecks { - val resultSet = - new CqlResponseExtractor[ResultSet]( + val resultSet:CqlCheckBuilder[AsyncResultSet] = + new CqlResponseExtractor[AsyncResultSet]( "resultSet", - r => r.getCqlResultSet) - .toCheckBuilder - - val allRows = - new CqlResponseExtractor[Seq[Row]]( - "allRows", - r => r.getAllRows) + r => r.resultSet) .toCheckBuilder - - val oneRow = - new CqlResponseExtractor[Row]( - "oneRow", - r => r.getOneRow) - .toCheckBuilder - - def columnValue(columnName: Expression[String]) = { - val cqlResponseExtender: Extender[DseCqlCheck, CqlResponse] = wrapped => DseCqlCheck(wrapped) - new DefaultMultipleFindCheckBuilder[DseCqlCheck, CqlResponse, CqlResponse, Any](cqlResponseExtender, x => x.success) { - def findExtractor(occurrence: Int) = columnName.map(new SingleColumnValueExtractor(_, occurrence)) - def findAllExtractor = columnName.map(new MultipleColumnValueExtractor(_)) - def countExtractor = columnName.map(new CountColumnValueExtractor(_)) - } - } } diff --git a/src/main/scala/com/datastax/gatling/plugin/checks/DseCheckSupport.scala b/src/main/scala/com/datastax/gatling/plugin/checks/DseCheckSupport.scala index fe0e808..bef0fb6 100644 --- a/src/main/scala/com/datastax/gatling/plugin/checks/DseCheckSupport.scala +++ b/src/main/scala/com/datastax/gatling/plugin/checks/DseCheckSupport.scala @@ -6,49 +6,75 @@ package com.datastax.gatling.plugin.checks -import io.gatling.core.session.ExpressionSuccessWrapper +import com.datastax.dse.driver.api.core.graph.AsyncGraphResultSet +import com.datastax.oss.driver.api.core.cql.AsyncResultSet +import com.datastax.gatling.plugin.utils.ResultSetUtils +/** + * Make both CQL and Graph checks available to the DSL. + * + * Note that as of 1.3.5 (and the upgrade to the unified OSS driver it brings along) the API here has changed. + * The old check API exposed a rich set of checks for various operations including row counts and validating + * data in individual rows. Several (most?) of these checks were built on the idea that all rows were immediately + * available in memory. This design has changed in 1.3.5, so maintaining the old API would've proven quite difficult + * (if not impossible). Additionally, the rich API shields the user from the intricacies of the driver API at the + * cost of limited flexibility; implementing new functionality requires modifications to the plugin itself (or at + * least an awareness of it's innards). + * + * With 1.3.5 this relationship has been inverted. The check API has been reduced to a single check which makes the + * underlying [[AsyncResultSet]] or [[AsyncGraphResultSet]] available. Simulations can then use transform() to + * extract values and evaluate them as necessary. So, for instance, something like this: + * + * {{{ + * .check(columnValue("counter_type") is 2) + * }}} + * + * now becomes: + * + * {{{ + * .check(resultSet.transform(rs => rs.one().getLong("counter_type")) is 2L) + * }}} + * + * Note that these transforms are now managed as Scala code within the simulations so they can be abstracted and + * built into libraries which can be re-used across simulations. Also note that this abstraction can be implemented + * without modifying the plugin itself. + * + * A similar pattern applies to checks based on metadata. So this: + * + * {{{ + * .check(exhausted is true) + * }}} + * + * now becomes: + * + * {{{ + * .check(resultSet.transform(_.hasMorePages) is false) + * .check(resultSet.transform(_.remaining) is 0) + * }}} + * + * At this point we should also note that checks are now explicitly executed in the order in which hey are declared + * in the simulation. This matters because iterating through rows will impact methods such as remaining(). So, for + * example, if you want to validate that a single row was returned and it contained a specific value you should do + * something like: + * + * {{{ + * .check(resultSet.transform(_.remaining) is 1) + * .check(resultSet.transform(rs => rs.one().getLong("counter_type")) is 2L) + * }}} + * + * and not: + * + * {{{ + * .check(resultSet.transform(rs => rs.one().getLong("counter_type")) is 2L) + * .check(resultSet.transform(_.remaining) is 1) + * }}} + * + * Finally, in general the expectation is that you won't need to realize all rows in a result set, but if for some + * reason you find this necessary this functionality is supplied in [[ResultSetUtils]]. This class also serves as + * an example of the kind of abstraction over common extraction operations discussed above. + */ trait DseCheckSupport { - // start global checks - lazy val exhausted = GenericChecks.exhausted - lazy val applied = GenericChecks.applied - lazy val rowCount = GenericChecks.rowCount - - // execution info and subsets - lazy val executionInfo = GenericChecks.executionInfo - lazy val achievedCL = GenericChecks.achievedConsistencyLevel - lazy val pagingState = GenericChecks.pagingState - lazy val queriedHost = GenericChecks.queriedHost - lazy val schemaAgreement = GenericChecks.schemaInAgreement - lazy val successfulExecutionIndex = GenericChecks.successfulExecutionIndex - lazy val triedHosts = GenericChecks.triedHosts - lazy val warnings = GenericChecks.warnings - - // start CQL only checks lazy val resultSet = CqlChecks.resultSet - lazy val allRows = CqlChecks.allRows - lazy val oneRow = CqlChecks.oneRow - - // start Graph only checks - lazy val graphResultSet = GraphChecks.graphResultSet - lazy val allNodes = GraphChecks.allNodes - lazy val oneNode = GraphChecks.oneNode - - def edges(columnName: String) = GraphChecks.edges(columnName) - - def vertexes(columnName: String) = GraphChecks.vertexes(columnName) - - def paths(columnName: String) = GraphChecks.paths(columnName) - - def properties(columnName: String) = GraphChecks.paths(columnName) - - def vertexProperties(columnName: String) = GraphChecks.vertexProperties(columnName) - - /** - * Get a column by name returned by the CQL statement. - * Note that this statement implicitly fetches all rows from the result set! - */ - def columnValue(columnName: String) = CqlChecks.columnValue(columnName.expressionSuccess) + lazy val graphResultSet = GraphChecks.resultSet } - diff --git a/src/main/scala/com/datastax/gatling/plugin/checks/GenericChecks.scala b/src/main/scala/com/datastax/gatling/plugin/checks/GenericChecks.scala deleted file mode 100644 index 13ea365..0000000 --- a/src/main/scala/com/datastax/gatling/plugin/checks/GenericChecks.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright (c) 2018 Datastax Inc. - * - * This software can be used solely with DataStax products. Please consult the file LICENSE.md. - */ - -package com.datastax.gatling.plugin.checks - -import com.datastax.driver.core._ -import com.datastax.gatling.plugin.response.DseResponse -import io.gatling.commons.validation.{SuccessWrapper, Validation} -import io.gatling.core.check.extractor.{Extractor, SingleArity} -import io.gatling.core.check._ -import io.gatling.core.session.{Expression, ExpressionSuccessWrapper, Session} - -import scala.collection.mutable - -/** - * This class allows to execute checks on either CQL or Graph responses. - * - * There is no class hierarchy between DseResponseCheck and DseCqlCheck or - * DseGraphCheck on purpose. Otherwise any method that accept DseResponseCheck - * would make it possible to execute CQL checks on Graph responses, or - * vice-versa. - */ -case class GenericCheck(wrapped: Check[DseResponse]) extends Check[DseResponse] { - override def check(response: DseResponse, session: Session)(implicit cache: mutable.Map[Any, Any]): Validation[CheckResult] = { - wrapped.check(response, session) - } -} - -class GenericCheckBuilder[X](extractor: Expression[Extractor[DseResponse, X]]) - extends FindCheckBuilder[GenericCheck, DseResponse, DseResponse, X] { - - private val dseResponseExtender: Extender[GenericCheck, DseResponse] = - wrapped => GenericCheck(wrapped) - - def find: ValidatorCheckBuilder[GenericCheck, DseResponse, DseResponse, X] = { - ValidatorCheckBuilder(dseResponseExtender, x => x.success, extractor) - } -} - -private class GenericResponseExtractor[X](val name: String, - val extractor: DseResponse => X) - extends Extractor[DseResponse, X] with SingleArity { - - override def apply(response: DseResponse): Validation[Option[X]] = { - Some(extractor.apply(response)).success - } - - def toCheckBuilder: GenericCheckBuilder[X] = { - new GenericCheckBuilder[X](this.expressionSuccess) - } -} - -object GenericChecks { - val executionInfo = - new GenericResponseExtractor[ExecutionInfo]( - "executionInfo", - r => r.executionInfo()) - .toCheckBuilder - - val queriedHost = - new GenericResponseExtractor[Host]( - "queriedHost", - r => r.queriedHost()) - .toCheckBuilder - - val achievedConsistencyLevel = - new GenericResponseExtractor[ConsistencyLevel]( - "achievedConsistencyLevel", - r => r.achievedConsistencyLevel()) - .toCheckBuilder - - val speculativeExecutionsExtractor = - new GenericResponseExtractor[Int]( - "speculativeExecutions", - r => r.speculativeExecutions()) - .toCheckBuilder - - val pagingState = - new GenericResponseExtractor[PagingState]( - "pagingState", - r => r.pagingState()) - .toCheckBuilder - - val triedHosts = - new GenericResponseExtractor[List[Host]]( - "triedHost", - r => r.triedHosts()) - .toCheckBuilder - - val warnings = - new GenericResponseExtractor[List[String]]( - "warnings", - r => r.warnings()) - .toCheckBuilder - - val successfulExecutionIndex = - new GenericResponseExtractor[Int]( - "successfulExecutionIndex", - r => r.successFullExecutionIndex()) - .toCheckBuilder - - val schemaInAgreement = - new GenericResponseExtractor[Boolean]( - "schemaInAgreement", - r => r.schemaInAgreement()) - .toCheckBuilder - - val rowCount = - new GenericResponseExtractor[Int]( - "rowCount", - r => r.rowCount()) - .toCheckBuilder - - val applied = - new GenericResponseExtractor[Boolean]( - "applied", - r => r.applied()) - .toCheckBuilder - - val exhausted = - new GenericResponseExtractor[Boolean]( - "exhausted", - r => r.exhausted()) - .toCheckBuilder -} - diff --git a/src/main/scala/com/datastax/gatling/plugin/checks/GraphChecks.scala b/src/main/scala/com/datastax/gatling/plugin/checks/GraphChecks.scala index d3ed9ac..a04396c 100644 --- a/src/main/scala/com/datastax/gatling/plugin/checks/GraphChecks.scala +++ b/src/main/scala/com/datastax/gatling/plugin/checks/GraphChecks.scala @@ -6,7 +6,7 @@ package com.datastax.gatling.plugin.checks -import com.datastax.driver.dse.graph._ +import com.datastax.dse.driver.api.core.graph._ import com.datastax.gatling.plugin.response.GraphResponse import io.gatling.commons.validation.{SuccessWrapper, Validation} import io.gatling.core.check.extractor.{Extractor, SingleArity} @@ -15,7 +15,6 @@ import io.gatling.core.session.{Expression, ExpressionSuccessWrapper, Session} import scala.collection.mutable - /** * This class serves as model for the Graph-specific checks. * @@ -52,51 +51,9 @@ private class GraphResponseExtractor[X](val name: String, } object GraphChecks { - val graphResultSet = - new GraphResponseExtractor[GraphResultSet]( - "graphResultSet", - r => r.getGraphResultSet) - .toCheckBuilder - - val allNodes = - new GraphResponseExtractor[Seq[GraphNode]]( - "allNodes", - r => r.getAllNodes) - .toCheckBuilder - - val oneNode = - new GraphResponseExtractor[GraphNode]( - "oneNode", - r => r.getOneNode) - .toCheckBuilder - - def edges(column: String) = - new GraphResponseExtractor[Seq[Edge]]( - "edges", - r => r.getEdges(column)) - .toCheckBuilder - - def vertexes(column: String) = - new GraphResponseExtractor[Seq[Vertex]]( - "vertices", - r => r.getVertexes(column)) - .toCheckBuilder - - def paths(column: String) = - new GraphResponseExtractor[Seq[Path]]( - "paths", - r => r.getPaths(column)) - .toCheckBuilder - - def properties(column: String) = - new GraphResponseExtractor[Seq[Property]]( - "properties", - r => r.getProperties(column)) - .toCheckBuilder - - def vertexProperties(column: String) = - new GraphResponseExtractor[Seq[Property]]( - "vertexProperties", - r => r.getVertexProperties(column)) - .toCheckBuilder + val resultSet:GraphCheckBuilder[AsyncGraphResultSet] = + new GraphResponseExtractor[AsyncGraphResultSet]( + "graphResultSet", + r => r.resultSet) + .toCheckBuilder } \ No newline at end of file diff --git a/src/main/scala/com/datastax/gatling/plugin/model/DseCqlAttributes.scala b/src/main/scala/com/datastax/gatling/plugin/model/DseCqlAttributes.scala index 455b820..d523c75 100644 --- a/src/main/scala/com/datastax/gatling/plugin/model/DseCqlAttributes.scala +++ b/src/main/scala/com/datastax/gatling/plugin/model/DseCqlAttributes.scala @@ -7,11 +7,14 @@ package com.datastax.gatling.plugin.model import java.nio.ByteBuffer +import java.time.Duration -import com.datastax.driver.core.policies.RetryPolicy -import com.datastax.driver.core.{ConsistencyLevel, PagingState, Statement} -import com.datastax.gatling.plugin.response.{CqlResponse, DseResponse} -import io.gatling.core.check.Check +import com.datastax.gatling.plugin.checks.DseCqlCheck +import com.datastax.oss.driver.api.core.ConsistencyLevel +import com.datastax.oss.driver.api.core.cql.{Statement, StatementBuilder} +import com.datastax.oss.driver.api.core.metadata.Node +import com.datastax.oss.driver.api.core.metadata.token.Token +import com.datastax.oss.driver.api.core.time.TimestampGenerator /** * CQL Query Attributes to be applied to the current query @@ -22,33 +25,41 @@ import io.gatling.core.check.Check * @param statement CQL Statement to be sent to Cluster * @param cl Consistency Level to be used * @param cqlChecks Data-level checks to be run after response is returned - * @param genericChecks Low-level checks to be run after response is returned - * @param userOrRole User or role to be used when proxy auth is enabled - * @param readTimeout Read timeout to be used * @param idempotent Set request to be idempotent i.e. whether it can be applied multiple times - * @param defaultTimestamp Set default timestamp on request, overriding current system time + * @param node Set the node that should handle this query + * @param userOrRole Set the user/role for this query if proxy authentication is used + * @param customPayload Custom payload for this request * @param enableTrace Whether tracing should be enabled - * @param outGoingPayload Set ByteBuffer custom outgoing Payload - * @param serialCl Serial Consistency Level to be used - * @param fetchSize Set fetchSize of request i.e. paging size - * @param retryPolicy Retry Policy to be used + * @param pageSize Set pageSize (formerly known as fetchSize) * @param pagingState Set paging State wanted + * @param queryTimestamp Set a timestamp to use for this query. If equal to Some(Long.MIN_VALUE) a timestamp + * will be generated by the configured [[TimestampGenerator]] + * @param routingKey Sets the key for token-aware routing + * @param routingKeyspace Sets the keyspace for token-aware routing + * @param routingToken Sets the token to use for token-aware routing + * @param serialCl Serial Consistency Level to be used + * @param timeout How long to wait for this request to complete * @param cqlStatements String version of the CQL statement that is sent * */ -case class DseCqlAttributes(tag: String, - statement: DseStatement[Statement], - cl: Option[ConsistencyLevel] = None, - cqlChecks: List[Check[CqlResponse]] = List.empty, - genericChecks: List[Check[DseResponse]] = List.empty, - userOrRole: Option[String] = None, - readTimeout: Option[Int] = None, - idempotent: Option[Boolean] = None, - defaultTimestamp: Option[Long] = None, - enableTrace: Option[Boolean] = None, - outGoingPayload: Option[Map[String, ByteBuffer]] = None, - serialCl: Option[ConsistencyLevel] = None, - fetchSize: Option[Int] = None, - retryPolicy: Option[RetryPolicy] = None, - pagingState: Option[PagingState] = None, - cqlStatements: Seq[String] = Seq.empty) +case class DseCqlAttributes[T <: Statement[T], B <: StatementBuilder[B,T]] + (tag: String, + statement: DseCqlStatement[T, B], + cqlChecks: List[DseCqlCheck] = List.empty, + cqlStatements: Seq[String] = Seq.empty, + /* General attributes */ + cl: Option[ConsistencyLevel] = None, + idempotent: Option[Boolean] = None, + node: Option[Node] = None, + userOrRole: Option[String] = None, + /* CQL-specific attributes */ + customPayload: Option[Map[String, ByteBuffer]] = None, + enableTrace: Option[Boolean] = None, + pageSize: Option[Int] = None, + pagingState: Option[ByteBuffer] = None, + queryTimestamp: Option[Long] = None, + routingKey: Option[ByteBuffer] = None, + routingKeyspace: Option[String] = None, + routingToken: Option[Token] = None, + serialCl: Option[ConsistencyLevel] = None, + timeout: Option[Duration] = None) diff --git a/src/main/scala/com/datastax/gatling/plugin/model/DseCqlAttributesBuilder.scala b/src/main/scala/com/datastax/gatling/plugin/model/DseCqlAttributesBuilder.scala index b8283ad..b9231a4 100644 --- a/src/main/scala/com/datastax/gatling/plugin/model/DseCqlAttributesBuilder.scala +++ b/src/main/scala/com/datastax/gatling/plugin/model/DseCqlAttributesBuilder.scala @@ -6,11 +6,15 @@ package com.datastax.gatling.plugin.model -import com.datastax.driver.core.policies.RetryPolicy -import com.datastax.driver.core.{ConsistencyLevel, PagingState} -import com.datastax.gatling.plugin.checks.{DseCqlCheck, GenericCheck} +import java.nio.ByteBuffer +import java.time.Duration + +import com.datastax.oss.driver.api.core.ConsistencyLevel +import com.datastax.oss.driver.api.core.cql.{Statement, StatementBuilder} +import com.datastax.gatling.plugin.checks.DseCqlCheck import com.datastax.gatling.plugin.request.CqlRequestActionBuilder -import io.gatling.core.action.builder.ActionBuilder +import com.datastax.oss.driver.api.core.metadata.Node +import com.datastax.oss.driver.api.core.metadata.token.Token /** @@ -18,13 +22,13 @@ import io.gatling.core.action.builder.ActionBuilder * * @param attr Addition Attributes */ -case class DseCqlAttributesBuilder(attr: DseCqlAttributes) { +case class DseCqlAttributesBuilder[T <: Statement[T], B <: StatementBuilder[B,T]](attr: DseCqlAttributes[T, B]) { /** * Builds to final action to run * * @return */ - def build(): CqlRequestActionBuilder = new CqlRequestActionBuilder(attr) + def build(): CqlRequestActionBuilder[T, B] = new CqlRequestActionBuilder(attr) /** * Set Consistency Level @@ -32,122 +36,136 @@ case class DseCqlAttributesBuilder(attr: DseCqlAttributes) { * @param level ConsistencyLevel * @return */ - def withConsistencyLevel(level: ConsistencyLevel) = DseCqlAttributesBuilder(attr.copy(cl = Some(level))) + def withConsistencyLevel(level: ConsistencyLevel):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(cl = Some(level))) /** - * Execute a query as another user or another role, provided the current logged in user has PROXY.EXECUTE permission. - * - * This permission MUST be granted to the currently logged in user using the CQL statement: `GRANT PROXY.EXECUTE ON - * ROLE someRole TO alice`. The user MUST be logged in with - * [[com.datastax.driver.dse.auth.DsePlainTextAuthProvider]] or - * [[com.datastax.driver.dse.auth.DseGSSAPIAuthProvider]] + * Set custom payload * - * @param userOrRole String + * @param k the key for this custom payload + * @param v the value for this custom payload * @return */ - def withUserOrRole(userOrRole: String) = DseCqlAttributesBuilder(attr.copy(userOrRole = Some(userOrRole))) + def addCustomPayload(k:String, v:ByteBuffer):DseCqlAttributesBuilder[T, B] = { + val newVal = + attr.customPayload + .orElse(Some(Map[String, ByteBuffer]())) + .map(m => m + (k -> v)) + DseCqlAttributesBuilder(attr.copy(customPayload = newVal)) + } /** - * Override the current system time for write time of query + * Set query to be idempotent i.e. run only once * - * @param epochTsInMs timestamp to use * @return */ - def withDefaultTimestamp(epochTsInMs: Long) = DseCqlAttributesBuilder(attr.copy(defaultTimestamp = Some(epochTsInMs))) - + def withIdempotency():DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(idempotent = Some(true))) /** * Set query to be idempotent i.e. run only once * * @return */ - def withIdempotency() = DseCqlAttributesBuilder(attr.copy(idempotent = Some(true))) - + def withIdempotency(idempotency:Boolean):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(idempotent = Some(idempotency))) /** - * Set Read timeout of the query - * - * @param readTimeoutInMs time in milliseconds + * Set the node that should handle this query + * @param node Node * @return */ - def withReadTimeout(readTimeoutInMs: Int) = DseCqlAttributesBuilder(attr.copy(readTimeout = Some(readTimeoutInMs))) + def withNode(node: Node):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(node = Some(node))) + /** + * Set the user or role to use for proxy auth + * @param userOrRole String + * @return + */ + def executeAs(userOrRole: String):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(userOrRole = Some(userOrRole))) /** - * Set Serial Consistency + * Enable CQL Tracing on the query * - * @param level ConsistencyLevel * @return */ - def withSerialConsistencyLevel(level: ConsistencyLevel) = DseCqlAttributesBuilder(attr.copy(serialCl = Some(level))) - + def withTracingEnabled():DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(enableTrace = Some(true))) /** - * Define the [[com.datastax.driver.core.policies.RetryPolicy]] to be used for query + * Set the page size * - * @param retryPolicy DataStax drivers retry policy + * @param pageSize CQL page size * @return */ - def withRetryPolicy(retryPolicy: RetryPolicy) = DseCqlAttributesBuilder(attr.copy(retryPolicy = Some(retryPolicy))) + def withPageSize(pageSize: Int):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(pageSize = Some(pageSize))) /** - * Set fetchSize of query for paging + * Set the paging state * - * @param rowCnt number of rows to fetch at one time + * @param pagingState CQL Paging state * @return */ - def withFetchSize(rowCnt: Int) = DseCqlAttributesBuilder(attr.copy(fetchSize = Some(rowCnt))) - + def withPagingState(pagingState: ByteBuffer):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(pagingState = Some(pagingState))) /** - * Enable CQL Tracing on the query + * Set the query timestamp * + * @param queryTimestamp CQL query timestamp * @return */ - def withTracingEnabled() = DseCqlAttributesBuilder(attr.copy(enableTrace = Some(true))) - + def withQueryTimestamp(queryTimestamp: Long):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(queryTimestamp = Some(queryTimestamp))) /** - * Set the paging state + * Set the routing key * - * @param pagingState CQL Paging state + * @param routingKey the routing key to use * @return */ - def withPagingState(pagingState: PagingState) = DseCqlAttributesBuilder(attr.copy(pagingState = Some(pagingState))) - + def withRoutingKey(routingKey: ByteBuffer):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(routingKey = Some(routingKey))) /** - * For backwards compatibility + * Set the routing keyspace * - * @param level + * @param routingKeyspace the routing keyspace to set * @return */ - @deprecated("Replaced by withSerialConsistencyLevel") - def serialConsistencyLevel(level: ConsistencyLevel) = withSerialConsistencyLevel(level) - + def withRoutingKeyspace(routingKeyspace: String):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(routingKeyspace = Some(routingKeyspace))) /** - * Backwards compatibility to set consistencyLevel + * Set the routing token * - * @see [[DseCqlAttributesBuilder.withConsistencyLevel]] - * @param level Consistency Level to use + * @param routingToken the routing token to set * @return */ - @deprecated("Replaced by withConsistencyLevel") - def consistencyLevel(level: ConsistencyLevel) = withConsistencyLevel(level) - + def withRoutingToken(routingToken: Token):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(routingToken = Some(routingToken))) /** - * For Backwards compatibility + * Set Serial Consistency * - * @see [[DseCqlAttributesBuilder.executeAs]] - * @param userOrRole User or role to use + * @param level ConsistencyLevel * @return */ - @deprecated("Replaced by withUserOrRole") - def executeAs(userOrRole: String) = withUserOrRole(userOrRole: String) + def withSerialConsistencyLevel(level: ConsistencyLevel):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(serialCl = Some(level))) - def check(check: DseCqlCheck) = DseCqlAttributesBuilder(attr.copy(cqlChecks = check :: attr.cqlChecks)) + /** + * Set timeout + * + * @param timeout the timeout to set + * @return + */ + def withTimeout(timeout: Duration):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(timeout = Some(timeout))) - def check(check: GenericCheck) = DseCqlAttributesBuilder(attr.copy(genericChecks = check :: attr.genericChecks)) + def check(check: DseCqlCheck):DseCqlAttributesBuilder[T, B] = + DseCqlAttributesBuilder(attr.copy(cqlChecks = (attr.cqlChecks :+ check))) } diff --git a/src/main/scala/com/datastax/gatling/plugin/model/DseCqlStatementBuilders.scala b/src/main/scala/com/datastax/gatling/plugin/model/DseCqlStatementBuilders.scala index 1812102..fa1aa1c 100644 --- a/src/main/scala/com/datastax/gatling/plugin/model/DseCqlStatementBuilders.scala +++ b/src/main/scala/com/datastax/gatling/plugin/model/DseCqlStatementBuilders.scala @@ -6,7 +6,7 @@ package com.datastax.gatling.plugin.model -import com.datastax.driver.core.{PreparedStatement, SimpleStatement} +import com.datastax.oss.driver.api.core.cql._ import com.datastax.gatling.plugin._ import com.datastax.gatling.plugin.utils.CqlPreparedStatementUtil import io.gatling.core.session.Expression @@ -22,6 +22,8 @@ import io.gatling.core.session.Expression */ case class DseCqlStatementBuilder(tag: String) { + implicit val defaultBuilderFn = (s:BoundStatement) => new BoundStatementBuilder(s) + /** * Execute a simple Statement built from a CQL string. * @@ -29,7 +31,7 @@ case class DseCqlStatementBuilder(tag: String) { * @return */ @deprecated("Replaced by executeStatement(String)") - def executeCql(query: String): DseCqlAttributesBuilder = + def executeCql(query: String): DseCqlAttributesBuilder[SimpleStatement, SimpleStatementBuilder] = executeStatement(query) /** @@ -38,8 +40,8 @@ case class DseCqlStatementBuilder(tag: String) { * @param query Simple string query * @return */ - def executeStatement(query: String): DseCqlAttributesBuilder = - executeStatement(new SimpleStatement(query)) + def executeStatement(query: String): DseCqlAttributesBuilder[SimpleStatement, SimpleStatementBuilder] = + executeStatement(new SimpleStatementBuilder(query).build) /** * Execute a Simple Statement @@ -47,12 +49,12 @@ case class DseCqlStatementBuilder(tag: String) { * @param statement SimpleStatement * @return */ - def executeStatement(statement: SimpleStatement): DseCqlAttributesBuilder = + def executeStatement(statement: SimpleStatement): DseCqlAttributesBuilder[SimpleStatement, SimpleStatementBuilder] = DseCqlAttributesBuilder( DseCqlAttributes( tag, DseCqlSimpleStatement(statement), - cqlStatements = Seq(statement.getQueryString)) + cqlStatements = Seq(statement.getQuery)) ) /** @@ -78,7 +80,7 @@ case class DseCqlStatementBuilder(tag: String) { * @param preparedStatement CQL Prepared Statement w/ anon ?'s * @return */ - def executeStatement(preparedStatement: PreparedStatement) = + def executeStatement(preparedStatement: PreparedStatement): DsePreparedCqlStatementBuilder = DsePreparedCqlStatementBuilder(tag, preparedStatement) /** @@ -90,7 +92,7 @@ case class DseCqlStatementBuilder(tag: String) { * * @param preparedStatement CQL Prepared statement with named parameters */ - def executeNamed(preparedStatement: PreparedStatement): DseCqlAttributesBuilder = + def executeNamed(preparedStatement: PreparedStatement): DseCqlAttributesBuilder[BoundStatement, BoundStatementBuilder] = DsePreparedCqlStatementBuilder(tag, preparedStatement).withSessionParams() /** @@ -98,11 +100,12 @@ case class DseCqlStatementBuilder(tag: String) { * * @param preparedStatements Array of prepared statements */ - def executePreparedBatch(preparedStatements: Array[PreparedStatement]) = DseCqlAttributesBuilder( - DseCqlAttributes( - tag, - DseCqlBoundBatchStatement(CqlPreparedStatementUtil, preparedStatements), - cqlStatements = preparedStatements.map(_.getQueryString) + def executePreparedBatch(preparedStatements: Array[PreparedStatement]): DseCqlAttributesBuilder[BatchStatement, BatchStatementBuilder] = + DseCqlAttributesBuilder( + DseCqlAttributes( + tag, + DseCqlBoundBatchStatement(CqlPreparedStatementUtil, preparedStatements), + cqlStatements = preparedStatements.map(_.getQuery) ) ) @@ -113,14 +116,14 @@ case class DseCqlStatementBuilder(tag: String) { * @param payloadSessionKey Session key of the payload from session/feed * @return */ - def executeCustomPayload(statement: SimpleStatement, payloadSessionKey: String): DseCqlAttributesBuilder = + def executeCustomPayload(statement: SimpleStatement, payloadSessionKey: String): DseCqlAttributesBuilder[SimpleStatement, SimpleStatementBuilder] = DseCqlAttributesBuilder( DseCqlAttributes( tag, DseCqlCustomPayloadStatement(statement, payloadSessionKey), - cqlStatements = Seq(statement.getQueryString))) + cqlStatements = Seq(statement.getQuery))) - def executePreparedFromSession(key: String): DseCqlAttributesBuilder = + def executePreparedFromSession(key: String): DseCqlAttributesBuilder[BoundStatement, BoundStatementBuilder] = DseCqlAttributesBuilder( DseCqlAttributes( tag, @@ -135,17 +138,19 @@ case class DseCqlStatementBuilder(tag: String) { */ case class DsePreparedCqlStatementBuilder(tag: String, prepared: PreparedStatement) { + implicit val defaultBuilderFn = (s:BoundStatement) => new BoundStatementBuilder(s) + /** * Alias for the behavior of executeNamed function * * @return */ - def withSessionParams(): DseCqlAttributesBuilder = + def withSessionParams(): DseCqlAttributesBuilder[BoundStatement, BoundStatementBuilder] = DseCqlAttributesBuilder( DseCqlAttributes( tag, DseCqlBoundStatementNamed(CqlPreparedStatementUtil, prepared), - cqlStatements = Seq(prepared.getQueryString))) + cqlStatements = Seq(prepared.getQuery))) /** * Bind Gatling Session Values to CQL Prepared Statement @@ -153,12 +158,12 @@ case class DsePreparedCqlStatementBuilder(tag: String, prepared: PreparedStateme * @param params Gatling Session variables * @return */ - def withParams(params: Expression[AnyRef]*): DseCqlAttributesBuilder = + def withParams(params: Expression[AnyRef]*): DseCqlAttributesBuilder[BoundStatement, BoundStatementBuilder] = DseCqlAttributesBuilder( DseCqlAttributes( tag, DseCqlBoundStatementWithPassedParams(CqlPreparedStatementUtil, prepared, params: _*), - cqlStatements = Seq(prepared.getQueryString)) + cqlStatements = Seq(prepared.getQuery)) ) /** @@ -167,11 +172,11 @@ case class DsePreparedCqlStatementBuilder(tag: String, prepared: PreparedStateme * @param sessionKeys Gatling Session Keys * @return */ - def withParams(sessionKeys: List[String]) = + def withParams(sessionKeys: List[String]): DseCqlAttributesBuilder[BoundStatement, BoundStatementBuilder] = DseCqlAttributesBuilder( DseCqlAttributes( tag, DseCqlBoundStatementWithParamList(CqlPreparedStatementUtil, prepared, sessionKeys), - cqlStatements = Seq(prepared.getQueryString)) + cqlStatements = Seq(prepared.getQuery)) ) } diff --git a/src/main/scala/com/datastax/gatling/plugin/model/DseCqlStatements.scala b/src/main/scala/com/datastax/gatling/plugin/model/DseCqlStatements.scala index 1561482..9fddc5f 100644 --- a/src/main/scala/com/datastax/gatling/plugin/model/DseCqlStatements.scala +++ b/src/main/scala/com/datastax/gatling/plugin/model/DseCqlStatements.scala @@ -8,29 +8,28 @@ package com.datastax.gatling.plugin.model import java.nio.ByteBuffer -import com.datastax.driver.core._ +import com.datastax.oss.driver.api.core.cql._ import com.datastax.gatling.plugin.exceptions.DseCqlStatementException import com.datastax.gatling.plugin.utils.CqlPreparedStatementUtil +import com.datastax.oss.driver.api.core.`type`.DataType import io.gatling.commons.validation._ -import io.gatling.core.session.{Session, _} +import io.gatling.core.session._ import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.util.{Try, Failure => TryFailure, Success => TrySuccess} - -trait DseCqlStatement extends DseStatement[Statement] { - def buildFromSession(session: Session): Validation[Statement] -} +trait DseCqlStatement[T <: Statement[T], B <: StatementBuilder[B,T]] extends DseStatement[B] /** * Simple CQL Statement from the java driver * * @param statement the statement to execute */ -case class DseCqlSimpleStatement(statement: SimpleStatement) extends DseCqlStatement { - def buildFromSession(gatlingSession: Session): Validation[SimpleStatement] = { - statement.success +case class DseCqlSimpleStatement(statement: SimpleStatement) + extends DseCqlStatement[SimpleStatement, SimpleStatementBuilder] { + + def buildFromSession(gatlingSession: Session): Validation[SimpleStatementBuilder] = { + SimpleStatement.builder(statement).success } } @@ -39,30 +38,38 @@ case class DseCqlSimpleStatement(statement: SimpleStatement) extends DseCqlState * * @param preparedStatement the prepared statement on which to bind parameters */ -case class DseCqlBoundStatementNamed(cqlTypes: CqlPreparedStatementUtil, preparedStatement: PreparedStatement) - extends DseCqlStatement { +case class DseCqlBoundStatementNamed(cqlTypes: CqlPreparedStatementUtil, + preparedStatement: PreparedStatement) + (implicit builderFn: (BoundStatement) => BoundStatementBuilder) + extends DseCqlStatement[BoundStatement, BoundStatementBuilder] { - def buildFromSession(gatlingSession: Session): Validation[BoundStatement] = - bindParams( + def buildFromSession(gatlingSession: Session): Validation[BoundStatementBuilder] = { + val template:BoundStatement = bindParams( gatlingSession, preparedStatement.bind(), - cqlTypes.getParamsMap(preparedStatement)).success + cqlTypes.getParamsMap(preparedStatement)) + builderFn(template).success + } /** * Bind Gatling Session Params to CQL Statement by Name and Type * * @param gatlingSession Gatling Session - * @param boundStatement CQL Prepared Statement + * @param template CQL Prepared Statement * @param queryParams CQL Query Named Params and Types * @return */ - protected def bindParams(gatlingSession: Session, boundStatement: BoundStatement, - queryParams: Map[String, DataType.Name]): BoundStatement = { - queryParams.foreach { - case (gatlingSessionKey, valType) => - cqlTypes.bindParamByName(gatlingSession, boundStatement, valType, gatlingSessionKey) - } - boundStatement + protected def bindParams(gatlingSession: Session, template: BoundStatement, + queryParams: Map[String, DataType]): BoundStatement = { + val completedBuilder = + queryParams.foldLeft(builderFn(template)) { + (builder, kv) => + kv match { + case (gatlingSessionKey, valType) => + cqlTypes.bindParamByName(gatlingSession, builder, valType, gatlingSessionKey) + } + } + completedBuilder.build() } } @@ -74,9 +81,11 @@ case class DseCqlBoundStatementNamed(cqlTypes: CqlPreparedStatementUtil, prepare */ case class DseCqlBoundStatementWithPassedParams(cqlTypes: CqlPreparedStatementUtil, preparedStatement: PreparedStatement, - params: Expression[AnyRef]*) extends DseCqlStatement { + params: Expression[AnyRef]*) + (implicit builderFn: (BoundStatement) => BoundStatementBuilder) + extends DseCqlStatement[BoundStatement, BoundStatementBuilder] { - def buildFromSession(gatlingSession: Session): Validation[BoundStatement] = { + def buildFromSession(gatlingSession: Session): Validation[BoundStatementBuilder] = { val parsedParams: Seq[Validation[AnyRef]] = params.map(param => param(gatlingSession)) if (parsedParams.exists(_.isInstanceOf[Failure])) { val firstError = StringBuilder.newBuilder @@ -86,7 +95,8 @@ case class DseCqlBoundStatementWithPassedParams(cqlTypes: CqlPreparedStatementUt .onFailure(msg => firstError.append(msg)) firstError.toString().failure } else { - preparedStatement.bind(parsedParams.map(_.get): _*).success + val template:BoundStatement = preparedStatement.bind(parsedParams.map(_.get): _*) + builderFn(template).success } } } @@ -98,109 +108,103 @@ case class DseCqlBoundStatementWithPassedParams(cqlTypes: CqlPreparedStatementUt */ case class DseCqlBoundStatementWithParamList(cqlTypes: CqlPreparedStatementUtil, preparedStatement: PreparedStatement, - sessionKeys: Seq[String]) extends DseCqlStatement { + sessionKeys: Seq[String]) + (implicit builderFn: (BoundStatement) => BoundStatementBuilder) + extends DseCqlStatement[BoundStatement, BoundStatementBuilder] { /** * Apply the Gatling session params to the Prepared statement * - * @param gatlingSession DseSession + * @param gatlingSession current Gatling session * @return */ - def buildFromSession(gatlingSession: Session): Validation[BoundStatement] = { - bindParams(gatlingSession, preparedStatement.bind(), sessionKeys).success + def buildFromSession(gatlingSession: Session): Validation[BoundStatementBuilder] = { + val template:BoundStatement = bindParams(gatlingSession, preparedStatement.bind(), sessionKeys) + builderFn(template).success } /** * Bind the Gatling session params to the CQL Prepared Statement * * @param gatlingSession Gatling Session - * @param boundStatement CQL Prepared Statement + * @param template CQL Prepared Statement * @param sessionKeys List of session params to apply, put in order of query ?'s * @return */ - protected def bindParams(gatlingSession: Session, boundStatement: BoundStatement, + protected def bindParams(gatlingSession: Session, template: BoundStatement, sessionKeys: Seq[String]): BoundStatement = { - val params = cqlTypes.getParamsList(preparedStatement) - var cnt = 0 - - sessionKeys.foreach { gatlingSessionKey => - cqlTypes.bindParamByOrder(gatlingSession, boundStatement, params(cnt), gatlingSessionKey, cnt) - cnt += 1 - } - - boundStatement + val completedBuilder = + sessionKeys.zip(Iterable.range(0,sessionKeys.size)).foldLeft(builderFn(template)) { + (builder, kv) => + kv match { + case (sessionKey, cnt) => cqlTypes.bindParamByOrder(gatlingSession, builder, params(cnt), sessionKey, cnt) + } + } + completedBuilder.build() } - } - /** * Bound CQL Prepared Statement from Named Params * * @param statements CQL Prepared Statements */ -case class DseCqlBoundBatchStatement(cqlTypes: CqlPreparedStatementUtil, statements: Seq[PreparedStatement]) - extends DseCqlStatement { - - def buildFromSession(gatlingSession: Session): Validation[BatchStatement] = { - - val batch = new BatchStatement() - - statements.foreach(s => - batch.add(bindParams(gatlingSession, s, cqlTypes.getParamsMap(s)))) - - batch.success +case class DseCqlBoundBatchStatement(cqlTypes: CqlPreparedStatementUtil, + statements: Seq[PreparedStatement]) + (implicit builderFn: (BoundStatement) => BoundStatementBuilder) + extends DseCqlStatement[BatchStatement, BatchStatementBuilder] { + + def buildFromSession(gatlingSession: Session): Validation[BatchStatementBuilder] = { + val builder:BatchStatementBuilder = BatchStatement.builder(DefaultBatchType.LOGGED) + val batchables = statements.map(bindParams(gatlingSession)) + builder.addStatements(batchables:_*).success } - /** * Bind Gatling Session Params to CQL Statement by Name and Type * * @param gatlingSession Gatling Session * @param statement CQL Prepared Statement - * @param queryParams CQL Query Named Params and Types * @return */ - protected def bindParams(gatlingSession: Session, statement: PreparedStatement, - queryParams: Map[String, DataType.Name]): BoundStatement = { - - val boundStatement = statement.bind() - - if (queryParams.nonEmpty) { - queryParams.foreach { - case (gatlingSessionKey, valType) => - cqlTypes.bindParamByName(gatlingSession, boundStatement, valType, gatlingSessionKey) + def bindParams(gatlingSession: Session)(statement: PreparedStatement): BoundStatement = { + val queryParams: Map[String, DataType] = cqlTypes.getParamsMap(statement) + val initBuilder:BoundStatementBuilder = builderFn(statement.bind()) + val completedBuilder = + queryParams.foldLeft(initBuilder) { + (builder, kv) => + kv match { + case (gatlingSessionKey, valType) => + cqlTypes.bindParamByName(gatlingSession, builder, valType, gatlingSessionKey) + } } - } - boundStatement + completedBuilder.build() } } - /** * Set a custom payload on the statement * * @param statement SimpleStaten * @param payloadRef session variable for custom payload */ -case class DseCqlCustomPayloadStatement(statement: SimpleStatement, payloadRef: String) extends DseCqlStatement { - - def buildFromSession(gatlingSession: Session): Validation[Statement] = { +case class DseCqlCustomPayloadStatement(statement: SimpleStatement, payloadRef: String) + extends DseCqlStatement[SimpleStatement, SimpleStatementBuilder] { + def buildFromSession(gatlingSession: Session): Validation[SimpleStatementBuilder] = { if (!gatlingSession.contains(payloadRef)) { throw new DseCqlStatementException(s"Passed sessionKey: {$payloadRef} does not exist in Session.") } Try { val payload = gatlingSession(payloadRef).as[Map[String, ByteBuffer]].asJava - statement.setOutgoingPayload(payload) + statement.setCustomPayload(payload) } match { - case TrySuccess(stmt) => stmt.success + case TrySuccess(stmt) => SimpleStatement.builder(stmt).success case TryFailure(error) => error.getMessage.failure } } - } /** @@ -209,13 +213,16 @@ case class DseCqlCustomPayloadStatement(statement: SimpleStatement, payloadRef: * * @param sessionKey the session key which is associated to a PreparedStatement */ -case class DseCqlBoundStatementNamedFromSession(cqlTypes: CqlPreparedStatementUtil, sessionKey: String) extends DseCqlStatement { +case class DseCqlBoundStatementNamedFromSession(cqlTypes: CqlPreparedStatementUtil, + sessionKey: String) + (implicit builderFn: (BoundStatement) => BoundStatementBuilder) + extends DseCqlStatement[BoundStatement, BoundStatementBuilder] { - def buildFromSession(gatlingSession: Session): Validation[BoundStatement] = { + def buildFromSession(gatlingSession: Session): Validation[BoundStatementBuilder] = { if (!gatlingSession.contains(sessionKey)) { throw new DseCqlStatementException(s"Passed sessionKey: {$sessionKey} does not exist in Session.") } val preparedStatement = gatlingSession(sessionKey).as[PreparedStatement] - DseCqlBoundStatementNamed(cqlTypes, preparedStatement).buildFromSession(gatlingSession) + DseCqlBoundStatementNamed(cqlTypes, preparedStatement)(builderFn).buildFromSession(gatlingSession) } } diff --git a/src/main/scala/com/datastax/gatling/plugin/model/DseGraphAttributes.scala b/src/main/scala/com/datastax/gatling/plugin/model/DseGraphAttributes.scala index ad1e45f..4afbe1d 100644 --- a/src/main/scala/com/datastax/gatling/plugin/model/DseGraphAttributes.scala +++ b/src/main/scala/com/datastax/gatling/plugin/model/DseGraphAttributes.scala @@ -6,9 +6,12 @@ package com.datastax.gatling.plugin.model -import com.datastax.driver.core.{ConsistencyLevel, Row} -import com.datastax.driver.dse.graph.{GraphNode, GraphStatement} -import com.datastax.gatling.plugin.checks.{DseGraphCheck, GenericCheck} +import java.time.Duration + +import com.datastax.oss.driver.api.core.ConsistencyLevel +import com.datastax.dse.driver.api.core.graph.{GraphStatement, GraphStatementBuilderBase} +import com.datastax.gatling.plugin.checks.DseGraphCheck +import com.datastax.oss.driver.api.core.metadata.Node /** * Graph Query Attributes to be applied to the current query @@ -19,34 +22,31 @@ import com.datastax.gatling.plugin.checks.{DseGraphCheck, GenericCheck} * @param statement Graph Statement to be sent to Cluster * @param cl Consistency Level to be used * @param graphChecks Data-level checks to be run after response is returned - * @param genericChecks Low-level checks to be run after response is returned - * @param userOrRole User or role to be used when proxy auth is enabled - * @param readTimeout Read timeout to be used * @param idempotent Set request to be idempotent i.e. whether it can be applied multiple times - * @param defaultTimestamp Set default timestamp on request, overriding current system time - * @param readCL Consistency level to use for the read part of the query - * @param writeCL Consistency level to use for the write part of the query - * @param graphName Name of the graph to use if different from the one used when connecting - * @param graphLanguage Language used in the query - * @param graphSource Graph source to use if different from the one used when connecting - * @param isSystemQuery Whether the query is a system one and should be used without any graph name - * @param graphInternalOptions Query-specific options not available in the driver public API - * @param graphTransformResults Function to use in order to transform a row into a Graph node + * @param node Set the node that should handle this query + * @param userOrRole Set the user/role for this query if proxy authentication is used + * @param graphName Name of the graph to use if different from the one used when connecting + * @param readCL Consistency level to use for the read part of the query + * @param subProtocol Name of the graph protocol to use for encoding/decoding + * @param timeout Timeout to use for this request + * @param timestamp Timestamp to use for this request + * @param traversalSource The traversal source for this request + * @param writeCL Consistency level to use for the write part of the query */ -case class DseGraphAttributes(tag: String, - statement: DseStatement[GraphStatement], - cl: Option[ConsistencyLevel] = None, - graphChecks: List[DseGraphCheck] = List.empty, - genericChecks: List[GenericCheck] = List.empty, - userOrRole: Option[String] = None, - readTimeout: Option[Int] = None, - idempotent: Option[Boolean] = None, - defaultTimestamp: Option[Long] = None, - readCL: Option[ConsistencyLevel] = None, - writeCL: Option[ConsistencyLevel] = None, - graphName: Option[String] = None, - graphLanguage: Option[String] = None, - graphSource: Option[String] = None, - isSystemQuery: Option[Boolean] = None, - graphInternalOptions: Option[Seq[(String, String)]] = None, - graphTransformResults: Option[com.google.common.base.Function[Row, GraphNode]] = None) +case class DseGraphAttributes[T <: GraphStatement[T], B <: GraphStatementBuilderBase[B,T]] + (tag: String, + statement: DseGraphStatement[T, B], + graphChecks: List[DseGraphCheck] = List.empty, + /* General attributes */ + cl: Option[ConsistencyLevel] = None, + idempotent: Option[Boolean] = None, + node: Option[Node] = None, + userOrRole: Option[String] = None, + /* Graph-specific attributes */ + graphName: Option[String] = None, + readCL: Option[ConsistencyLevel] = None, + subProtocol: Option[String] = None, + timeout: Option[Duration] = None, + timestamp: Option[Long] = None, + traversalSource: Option[String] = None, + writeCL: Option[ConsistencyLevel] = None) diff --git a/src/main/scala/com/datastax/gatling/plugin/model/DseGraphAttributesBuilder.scala b/src/main/scala/com/datastax/gatling/plugin/model/DseGraphAttributesBuilder.scala index 427b7e9..e13f698 100644 --- a/src/main/scala/com/datastax/gatling/plugin/model/DseGraphAttributesBuilder.scala +++ b/src/main/scala/com/datastax/gatling/plugin/model/DseGraphAttributesBuilder.scala @@ -6,25 +6,26 @@ package com.datastax.gatling.plugin.model -import com.datastax.driver.core.{ConsistencyLevel, Row} -import com.datastax.driver.dse.graph.GraphNode -import com.datastax.gatling.plugin.checks.{DseGraphCheck, GenericCheck} -import com.datastax.gatling.plugin.request.GraphRequestActionBuilder -import io.gatling.core.action.builder.ActionBuilder +import java.time.Duration +import com.datastax.oss.driver.api.core.ConsistencyLevel +import com.datastax.dse.driver.api.core.graph.{GraphStatement, GraphStatementBuilderBase} +import com.datastax.gatling.plugin.checks.DseGraphCheck +import com.datastax.gatling.plugin.request.GraphRequestActionBuilder +import com.datastax.oss.driver.api.core.metadata.Node /** * Request Builder for Graph Requests * * @param attr Addition Attributes */ -case class DseGraphAttributesBuilder(attr: DseGraphAttributes) { +case class DseGraphAttributesBuilder[T <: GraphStatement[T], B <: GraphStatementBuilderBase[B,T]](attr: DseGraphAttributes[T, B]) { /** * Builds to final action to run * * @return */ - def build(): ActionBuilder = new GraphRequestActionBuilder(attr) + def build(): GraphRequestActionBuilder[T, B] = new GraphRequestActionBuilder(attr) /** * Set Consistency Level @@ -32,54 +33,33 @@ case class DseGraphAttributesBuilder(attr: DseGraphAttributes) { * @param level ConsistencyLevel * @return */ - def withConsistencyLevel(level: ConsistencyLevel) = DseGraphAttributesBuilder(attr.copy(cl = Some(level))) - - /** - * Execute a query as another user or another role, provided the current logged in user has PROXY.EXECUTE permission. - * - * This permission MUST be granted to the currently logged in user using the CQL statement: `GRANT PROXY.EXECUTE ON - * ROLE someRole TO alice`. The user MUST be logged in with - * [[com.datastax.driver.dse.auth.DsePlainTextAuthProvider]] or - * [[com.datastax.driver.dse.auth.DseGSSAPIAuthProvider]] - * - * @param userOrRole String - * @return - */ - def withUserOrRole(userOrRole: String) = DseGraphAttributesBuilder(attr.copy(userOrRole = Some(userOrRole))) - - /** - * Override the current system time for write time of query - * - * @param epochTsInMs timestamp to use - * @return - */ - def withDefaultTimestamp(epochTsInMs: Long) = DseGraphAttributesBuilder(attr.copy(defaultTimestamp = Some(epochTsInMs))) - + def withConsistencyLevel(level: ConsistencyLevel):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(cl = Some(level))) /** * Set query to be idempotent i.e. run only once * * @return */ - def withIdempotency() = DseGraphAttributesBuilder(attr.copy(idempotent = Some(true))) - + def withIdempotency():DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(idempotent = Some(true))) /** - * Set Read timeout of the query - * - * @param readTimeoutInMs time in milliseconds + * Set the node that should handle this query + * @param node Node * @return */ - def withReadTimeout(readTimeoutInMs: Int) = DseGraphAttributesBuilder(attr.copy(readTimeout = Some(readTimeoutInMs))) - + def withNode(node: Node):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(node = Some(node))) /** - * Sets the graph language - * - * @param language graph language to use + * Set the user or role to use for proxy auth + * @param userOrRole String * @return */ - def withLanguage(language: String) = DseGraphAttributesBuilder(attr.copy(graphLanguage = Some(language))) + def executeAs(userOrRole: String):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(userOrRole = Some(userOrRole))) + /** * Sets the graph name to use @@ -87,59 +67,53 @@ case class DseGraphAttributesBuilder(attr: DseGraphAttributes) { * @param name Graph name * @return */ - def withName(name: String) = DseGraphAttributesBuilder(attr.copy(graphName = Some(name))) - + def withName(name: String):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(graphName = Some(name))) /** - * Set the source of the graph - * - * @param source graph source - * @return - */ - def withSource(source: String) = DseGraphAttributesBuilder(attr.copy(graphSource = Some(source))) - - /** - * Set the query to be system level + * Define [[ConsistencyLevel]] to be used for read queries * + * @param readCL Consistency Level to use * @return */ - def withSystemQuery() = DseGraphAttributesBuilder(attr.copy(isSystemQuery = Some(true))) + def withReadConsistency(readCL: ConsistencyLevel):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(readCL = Some(readCL))) /** - * Set Options on graph + * Set the sub-protocol * - * @param options options in key/value par to set against the query + * @param subProtocol the sub-protocol to use * @return */ - def withOptions(options: (String, String)*) = DseGraphAttributesBuilder(attr.copy(graphInternalOptions = Some(options))) - + def withSubProtocol(subProtocol: String):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(subProtocol = Some(subProtocol))) /** - * Set Option on graph + * Set the timeout * - * @param option options in key/value par to set against the query + * @param timeout the timeout to use * @return */ - def withOption(option: (String, String)) = withOptions(option) - + def withTimeout(timeout: Duration):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(timeout = Some(timeout))) /** - * Transform results function + * Set the timestamp * - * @param transform Transform Function + * @param timestamp the timestamp to use * @return */ - def withTransformResults(transform: com.google.common.base.Function[Row, GraphNode]) = { - DseGraphAttributesBuilder(attr.copy(graphTransformResults = Some(transform))) - } + def withTimestamp(timestamp: Long):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(timestamp = Some(timestamp))) /** - * Define [[ConsistencyLevel]] to be used for read queries + * Set the sub-protocol * - * @param readCL Consistency Level to use + * @param traversalSource the traversal source to use * @return */ - def withReadConsistency(readCL: ConsistencyLevel) = DseGraphAttributesBuilder(attr.copy(readCL = Some(readCL))) + def withTraversalSource(traversalSource: String):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(traversalSource = Some(traversalSource))) /** * Define [[ConsistencyLevel]] to be used for write queries @@ -147,30 +121,9 @@ case class DseGraphAttributesBuilder(attr: DseGraphAttributes) { * @param writeCL Consistency Level to use * @return */ - def withWriteConsistency(writeCL: ConsistencyLevel) = DseGraphAttributesBuilder(attr.copy(writeCL = Some(writeCL))) - - - /** - * Backwards compatibility to set consistencyLevel - * - * @see [[DseGraphAttributesBuilder.withConsistencyLevel]] - * @param level Consistency Level to use - * @return - */ - @deprecated("use withConsistencyLevel() instead, will be removed in future version") - def consistencyLevel(level: ConsistencyLevel) = withConsistencyLevel(level) - - - /** - * For Backwards compatibility - * - * @see [[DseGraphAttributesBuilder.executeAs]] - * @param userOrRole User or role to use - * @return - */ - @deprecated("use withUserOrRole() instead, will be removed in future version") - def executeAs(userOrRole: String) = withUserOrRole(userOrRole: String) + def withWriteConsistency(writeCL: ConsistencyLevel):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(writeCL = Some(writeCL))) - def check(check: DseGraphCheck) = DseGraphAttributesBuilder(attr.copy(graphChecks = check :: attr.graphChecks)) - def check(check: GenericCheck) = DseGraphAttributesBuilder(attr.copy(genericChecks = check :: attr.genericChecks)) + def check(check: DseGraphCheck):DseGraphAttributesBuilder[T, B] = + DseGraphAttributesBuilder(attr.copy(graphChecks = check :: attr.graphChecks)) } diff --git a/src/main/scala/com/datastax/gatling/plugin/model/DseGraphStatementBuilders.scala b/src/main/scala/com/datastax/gatling/plugin/model/DseGraphStatementBuilders.scala index f559562..ce365f5 100644 --- a/src/main/scala/com/datastax/gatling/plugin/model/DseGraphStatementBuilders.scala +++ b/src/main/scala/com/datastax/gatling/plugin/model/DseGraphStatementBuilders.scala @@ -6,7 +6,7 @@ package com.datastax.gatling.plugin.model -import com.datastax.driver.dse.graph.{GraphStatement, SimpleGraphStatement} +import com.datastax.dse.driver.api.core.graph.{FluentGraphStatement, FluentGraphStatementBuilder, ScriptGraphStatement, ScriptGraphStatementBuilder} import io.gatling.core.session.{Expression, Session} /** @@ -22,8 +22,11 @@ case class DseGraphStatementBuilder(tag: String) { * @param strStatement Graph Query String * @return */ - def executeGraph(strStatement: Expression[String]) = { - DseGraphAttributesBuilder(DseGraphAttributes(tag, GraphStringStatement(strStatement))) + def executeGraph(strStatement: Expression[String]): DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = { + DseGraphAttributesBuilder( + DseGraphAttributes( + tag, + GraphStringStatement(strStatement))) } /** @@ -33,8 +36,8 @@ case class DseGraphStatementBuilder(tag: String) { * @param gStatement Simple Graph Statement * @return */ - @deprecated("Replaced by executeGraph(SimpleGraphStatement)") - def executeGraphStatement(gStatement: SimpleGraphStatement) = + @deprecated("Replaced by executeGraph(ScriptGraphStatement)") + def executeGraphStatement(gStatement: ScriptGraphStatement): DseGraphParametrizedStatementBuilder = executeGraph(gStatement) /** @@ -44,8 +47,8 @@ case class DseGraphStatementBuilder(tag: String) { * @param gStatement Simple Graph Statement * @return */ - def executeGraph(gStatement: SimpleGraphStatement) = { - DseGraphParametrizedStatementBuilder(tag, gStatement) + def executeGraph(gStatement: ScriptGraphStatement):DseGraphParametrizedStatementBuilder = { + DseGraphParametrizedStatementBuilder(tag, new ScriptGraphStatementBuilder(gStatement)) } /** @@ -54,8 +57,11 @@ case class DseGraphStatementBuilder(tag: String) { * @param gStatement Graph Statement from a Fluent API builder * @return */ - def executeGraphFluent(gStatement: GraphStatement) = { - DseGraphAttributesBuilder(DseGraphAttributes(tag, GraphFluentStatement(gStatement))) + def executeGraphFluent(gStatement: FluentGraphStatement): DseGraphAttributesBuilder[FluentGraphStatement, FluentGraphStatementBuilder] = { + DseGraphAttributesBuilder( + DseGraphAttributes( + tag, + GraphFluentStatement(gStatement))) } /** @@ -71,8 +77,11 @@ case class DseGraphStatementBuilder(tag: String) { * @param gLambda The lambda * @return */ - def executeGraphFluent(gLambda: Session => GraphStatement) = { - DseGraphAttributesBuilder(DseGraphAttributes(tag, GraphFluentStatementFromScalaLambda(gLambda))) + def executeGraphFluent(gLambda: Session => FluentGraphStatement): DseGraphAttributesBuilder[FluentGraphStatement, FluentGraphStatementBuilder] = { + DseGraphAttributesBuilder( + DseGraphAttributes( + tag, + GraphFluentStatementFromScalaLambda(gLambda))) } /** @@ -82,8 +91,11 @@ case class DseGraphStatementBuilder(tag: String) { * @return */ @deprecated("Replaced by executeGraphFluent{session => session(feederKey)}") - def executeGraphFeederTraversal(feederKey: String): DseGraphAttributesBuilder = { - DseGraphAttributesBuilder(DseGraphAttributes(tag, GraphFluentSessionKey(feederKey))) + def executeGraphFeederTraversal(feederKey: String): DseGraphAttributesBuilder[FluentGraphStatement, FluentGraphStatementBuilder] = { + DseGraphAttributesBuilder( + DseGraphAttributes( + tag, + GraphFluentSessionKey(feederKey))) } } @@ -91,9 +103,9 @@ case class DseGraphStatementBuilder(tag: String) { * Builder for Graph queries that do not have bound parameters yet. * * @param tag Query tag - * @param gStatement Simple Graph Staetment + * @param builder Simple Graph Staetment */ -case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleGraphStatement) { +case class DseGraphParametrizedStatementBuilder(tag: String, builder: ScriptGraphStatementBuilder) { /** * Included for compatibility @@ -102,7 +114,8 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG * @return */ @deprecated("Replaced by withParams") - def withSetParams(paramNames: Array[String]): DseGraphAttributesBuilder = withParams(paramNames.toList) + def withSetParams(paramNames: Array[String]): DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = + withParams(paramNames.toList) /** * Params to set from strings @@ -110,7 +123,7 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG * @param paramNames List of strings to use * @return */ - def withParams(paramNames: String*): DseGraphAttributesBuilder = + def withParams(paramNames: String*): DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = withParams(paramNames.toList) /** @@ -119,8 +132,13 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG * @param paramNames List of strings to use * @return */ - def withParams(paramNames: List[String]): DseGraphAttributesBuilder = DseGraphAttributesBuilder( - DseGraphAttributes(tag, GraphBoundStatement(gStatement, paramNames.map(key => key -> key).toMap)) + def withParams(paramNames: List[String]): DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = + DseGraphAttributesBuilder( + DseGraphAttributes( + tag, + GraphBoundStatement( + builder, + paramNames.map(key => key -> key).toMap)) ) /** @@ -130,7 +148,7 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG * @return */ @deprecated("Replaced with withParams") - def withSetParams(paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder = + def withSetParams(paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = withParams(paramNamesAndOverrides) @@ -142,7 +160,7 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG * @return */ @deprecated("Replaced with withParams") - def withParamOverrides(paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder = + def withParamOverrides(paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = withParams(paramNamesAndOverrides) /** @@ -152,8 +170,11 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG * @param paramNamesAndOverrides a Map of Session parameter names to their GraphStatement parameter names * @return */ - def withParams(paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder = { - DseGraphAttributesBuilder(DseGraphAttributes(tag, GraphBoundStatement(gStatement, paramNamesAndOverrides))) + def withParams(paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = { + DseGraphAttributesBuilder( + DseGraphAttributes( + tag, + GraphBoundStatement(builder, paramNamesAndOverrides))) } /** @@ -165,9 +186,9 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG * @return */ @deprecated("Replaced by withRepeatedParams") - def withRepeatedSetParams(batchSize: Int, paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder = { + def withRepeatedSetParams(batchSize: Int, + paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = withRepeatedParams(batchSize, paramNamesAndOverrides) - } /** * Repeat the parameters given by suffixing their names and overridden names by a number picked from 1 to batchSize. @@ -177,7 +198,8 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG * @param paramNamesAndOverrides a Map of Session parameter names to their GraphStatement parameter names * @return */ - def withRepeatedParams(batchSize: Int, paramNamesAndOverrides: Map[String, String]): DseGraphAttributesBuilder = { + def withRepeatedParams(batchSize: Int, paramNamesAndOverrides: Map[String, String]): + DseGraphAttributesBuilder[ScriptGraphStatement, ScriptGraphStatementBuilder] = { def repeatParameters(params: Map[String, String]): Map[String, String] = batchSize match { // Gatling has a weird behavior when feeding multiple values // Feeding 1 value gives non-suffixed variables whereas feeding more gives suffixed variables starting by the @@ -190,7 +212,11 @@ case class DseGraphParametrizedStatementBuilder(tag: String, gStatement: SimpleG } DseGraphAttributesBuilder( - DseGraphAttributes(tag, GraphBoundStatement(gStatement, repeatParameters(paramNamesAndOverrides))) + DseGraphAttributes( + tag, + GraphBoundStatement( + builder, + repeatParameters(paramNamesAndOverrides))) ) } } diff --git a/src/main/scala/com/datastax/gatling/plugin/model/DseGraphStatements.scala b/src/main/scala/com/datastax/gatling/plugin/model/DseGraphStatements.scala index af2be96..d007af3 100644 --- a/src/main/scala/com/datastax/gatling/plugin/model/DseGraphStatements.scala +++ b/src/main/scala/com/datastax/gatling/plugin/model/DseGraphStatements.scala @@ -6,8 +6,7 @@ package com.datastax.gatling.plugin.model -import com.datastax.driver.dse.graph.{GraphStatement, SimpleGraphStatement} -import com.datastax.dse.graph.api.DseGraph +import com.datastax.dse.driver.api.core.graph._ import com.datastax.gatling.plugin.exceptions.DseGraphStatementException import io.gatling.commons.validation._ import io.gatling.core.session.{Expression, Session} @@ -15,19 +14,18 @@ import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversal import scala.util.{Try, Failure => TryFailure, Success => TrySuccess} - -trait DseGraphStatement extends DseStatement[GraphStatement] { - def buildFromSession(session: Session): Validation[GraphStatement] -} +trait DseGraphStatement[T <: GraphStatement[T], B <: GraphStatementBuilderBase[B,T]] extends DseStatement[B] /** * Simple DSE Graph Statement from a String * * @param statement the Gremlin String to execute */ -case class GraphStringStatement(statement: Expression[String]) extends DseGraphStatement { - def buildFromSession(gatlingSession: Session): Validation[GraphStatement] = { - statement(gatlingSession).flatMap(stmt => new SimpleGraphStatement(stmt).success) +case class GraphStringStatement(statement: Expression[String]) + extends DseGraphStatement[ScriptGraphStatement, ScriptGraphStatementBuilder] { + + def buildFromSession(gatlingSession: Session): Validation[ScriptGraphStatementBuilder] = { + statement(gatlingSession).flatMap(stmt => ScriptGraphStatement.builder(stmt).success) } } @@ -36,9 +34,11 @@ case class GraphStringStatement(statement: Expression[String]) extends DseGraphS * * @param statement the Fluent Statement */ -case class GraphFluentStatement(statement: GraphStatement) extends DseGraphStatement { - def buildFromSession(gatlingSession: Session): Validation[GraphStatement] = { - statement.success +case class GraphFluentStatement(statement: FluentGraphStatement) + extends DseGraphStatement[FluentGraphStatement, FluentGraphStatementBuilder] { + + def buildFromSession(gatlingSession: Session): Validation[FluentGraphStatementBuilder] = { + FluentGraphStatement.builder(statement).success } } @@ -49,9 +49,11 @@ case class GraphFluentStatement(statement: GraphStatement) extends DseGraphState * @param lambda Scala lambda that takes a Gatling User Session (from which it can retrieve parameters) * and returns a fluent Graph Statement */ -case class GraphFluentStatementFromScalaLambda(lambda: Session => GraphStatement) extends DseGraphStatement { - def buildFromSession(gatlingSession: Session): Validation[GraphStatement] = { - lambda(gatlingSession).success +case class GraphFluentStatementFromScalaLambda(lambda: Session => FluentGraphStatement) + extends DseGraphStatement[FluentGraphStatement, FluentGraphStatementBuilder] { + + def buildFromSession(gatlingSession: Session): Validation[FluentGraphStatementBuilder] = { + FluentGraphStatement.builder(lambda(gatlingSession)).success } } @@ -61,31 +63,32 @@ case class GraphFluentStatementFromScalaLambda(lambda: Session => GraphStatement * * @param sessionKey Place a GraphTraversal in your session with this key name */ -case class GraphFluentSessionKey(sessionKey: String) extends DseGraphStatement { +case class GraphFluentSessionKey(sessionKey: String) + extends DseGraphStatement[FluentGraphStatement, FluentGraphStatementBuilder] { - def buildFromSession(gatlingSession: Session): Validation[GraphStatement] = { + def buildFromSession(gatlingSession: Session): Validation[FluentGraphStatementBuilder] = { if (!gatlingSession.contains(sessionKey)) { throw new DseGraphStatementException(s"Passed sessionKey: {$sessionKey} does not exist in Session.") } Try { - DseGraph.statementFromTraversal(gatlingSession(sessionKey).as[GraphTraversal[_, _]]) + FluentGraphStatement.builder(gatlingSession(sessionKey).as[GraphTraversal[_, _]]) } match { - case TrySuccess(stmt) => stmt.success + case TrySuccess(builder) => builder.success case TryFailure(error) => error.getMessage.failure } } - } /** * Set/Bind Gatling Session key/vals to GraphStatement * - * @param statement SimpleGraphStatement + * @param builder SimpleGraphStatementBuilder * @param sessionKeys Gatling session param keys mapped to their bind name, to allow name override */ -case class GraphBoundStatement(statement: SimpleGraphStatement, sessionKeys: Map[String, String]) extends DseGraphStatement { +case class GraphBoundStatement(builder: ScriptGraphStatementBuilder, sessionKeys: Map[String, String]) + extends DseGraphStatement[ScriptGraphStatement, ScriptGraphStatementBuilder] { /** * Apply the Gatling session params passed to the GraphStatement @@ -93,28 +96,19 @@ case class GraphBoundStatement(statement: SimpleGraphStatement, sessionKeys: Map * @param gatlingSession Gatling Session * @return */ - def buildFromSession(gatlingSession: Session): Validation[GraphStatement] = { + + def buildFromSession(gatlingSession: Session): Validation[ScriptGraphStatementBuilder] = { Try { - sessionKeys.foreach((tuple: (String, String)) => setParam(gatlingSession, tuple._1, tuple._2)).success - statement + sessionKeys foreach { k => + k match { + case (k, v) => builder.setQueryParam(v, gatlingSession(k).as[Object]) + case _ => throw new RuntimeException(s"Observed ${k} instead of expected key-value pair") + } + } + builder } match { - case TrySuccess(stmt) => stmt.success + case TrySuccess(builder) => builder.success case TryFailure(error) => error.getMessage.failure } } - - /** - * Set a parameter to the current gStatement. - * - * The value of this parameter is retrieved by looking up paramName from the gatlingSession. - * It is then bound to overriddenParamName in gStatement. - * - * @param gatlingSession Gatling session - * @param paramName Parameter name as accessible in Gatling session - * @param overriddenParamName Parameter name used to bind the statement - * @return gStatement - */ - private def setParam(gatlingSession: Session, paramName: String, overriddenParamName: String): GraphStatement = { - statement.set(overriddenParamName, gatlingSession(paramName).as[Object]) - } } \ No newline at end of file diff --git a/src/main/scala/com/datastax/gatling/plugin/request/CqlRequestAction.scala b/src/main/scala/com/datastax/gatling/plugin/request/CqlRequestAction.scala index bbf630a..fe98e93 100644 --- a/src/main/scala/com/datastax/gatling/plugin/request/CqlRequestAction.scala +++ b/src/main/scala/com/datastax/gatling/plugin/request/CqlRequestAction.scala @@ -12,18 +12,20 @@ import java.util.concurrent.ExecutorService import java.util.concurrent.TimeUnit.MICROSECONDS import akka.actor.ActorSystem +import com.datastax.dse.driver.api.core.auth.ProxyAuthentication import com.datastax.gatling.plugin.DseProtocol import com.datastax.gatling.plugin.metrics.MetricsLogger import com.datastax.gatling.plugin.model.DseCqlAttributes import com.datastax.gatling.plugin.response.CqlResponseHandler import com.datastax.gatling.plugin.utils._ +import com.datastax.oss.driver.api.core.cql.{Statement, StatementBuilder} import io.gatling.commons.stats.KO import io.gatling.commons.validation.safely import io.gatling.core.action.{Action, ExitableAction} import io.gatling.core.session.Session import io.gatling.core.stats.StatsEngine -import scala.collection.JavaConverters._ +import scala.compat.java8.FutureConverters import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} /** @@ -34,22 +36,22 @@ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} * free the Gatling `injector` actor as fast as possible. * * The plugin router (and its actors) execute the driver code in order to locate the best replica that should receive - * each query, through `DseSession.executeAsync()` or `DseSession.executeGraphAsync()`. A driver I/O thread encodes + * each query, through `CqlSession.executeAsync()` or `CqlSession.executeGraphAsync()`. A driver I/O thread encodes * the request and sends it over the wire. * * Once the response is received, a driver I/O thread (Netty) decodes the response into a Java Object. It then - * completes the `Future` that was returned by `DseSession.executeAsync()`. + * completes the `Future` that was returned by `CqlSession.executeAsync()`. * * Completing that future results in immediately delegating the latency recording work to the plugin router. That * work includes recording it in HDR histograms through non-blocking data structures, and forwarding the result to * other Gatling data writers, like the console reporter. */ -class CqlRequestAction(val name: String, +class CqlRequestAction[T <: Statement[T], B <: StatementBuilder[B,T]](val name: String, val next: Action, val system: ActorSystem, val statsEngine: StatsEngine, val protocol: DseProtocol, - val dseAttributes: DseCqlAttributes, + val dseAttributes: DseCqlAttributes[T, B], val metricsLogger: MetricsLogger, val dseExecutorService: ExecutorService, val gatlingTimingSource: GatlingTimingSource) @@ -61,48 +63,57 @@ class CqlRequestAction(val name: String, }) } - def sendQuery(session: Session): Unit = { - val enableCO = Boolean.getBoolean("gatling.dse.plugin.measure_service_time") - val responseTimeBuilder: ResponseTimeBuilder = if (enableCO) { - // The throughput checker is useless in CO affected scenarios since throughput is not known in advance - COAffectedResponseTime.startingNow(gatlingTimingSource) - } else { - ThroughputVerifier.checkForGatlingOverloading(session, gatlingTimingSource) - GatlingResponseTime.startedByGatling(session, gatlingTimingSource) - } - val stmt = safely()(dseAttributes.statement.buildFromSession(session)) - - stmt.onFailure(err => { - handleFailure(session, responseTimeBuilder, err) - }) + private def buildStatement(builder:B):T = { - stmt.onSuccess({ stmt => - // global options - dseAttributes.cl.map(stmt.setConsistencyLevel) - dseAttributes.userOrRole.map(stmt.executingAs) - dseAttributes.readTimeout.map(stmt.setReadTimeoutMillis) - dseAttributes.idempotent.map(stmt.setIdempotent) - dseAttributes.defaultTimestamp.map(stmt.setDefaultTimestamp) + // global options + dseAttributes.cl.foreach(builder.setConsistencyLevel) + dseAttributes.idempotent.foreach(builder.setIdempotence(_)) + dseAttributes.node.foreach(builder.setNode) - // CQL Only Options - dseAttributes.outGoingPayload.map(x => stmt.setOutgoingPayload(x.asJava)) - dseAttributes.serialCl.map(stmt.setSerialConsistencyLevel) - dseAttributes.retryPolicy.map(stmt.setRetryPolicy) - dseAttributes.fetchSize.map(stmt.setFetchSize) - dseAttributes.pagingState.map(stmt.setPagingState) - if (dseAttributes.enableTrace.isDefined && dseAttributes.enableTrace.get) { - stmt.enableTracing + // CQL Only Options + dseAttributes.customPayload.foreach { + _.foreach { + _ match { + case (k, v) => builder.addCustomPayload(k, v) + } } + } + dseAttributes.enableTrace.foreach(builder.setTracing) + dseAttributes.pageSize.foreach(builder.setPageSize) + dseAttributes.pagingState.foreach(builder.setPagingState) + dseAttributes.queryTimestamp.foreach(builder.setQueryTimestamp) + dseAttributes.routingKey.foreach(builder.setRoutingKey) + dseAttributes.routingKeyspace.foreach(builder.setRoutingKeyspace) + dseAttributes.routingToken.foreach(builder.setRoutingToken) + dseAttributes.serialCl.foreach(builder.setSerialConsistencyLevel) + dseAttributes.timeout.foreach(builder.setTimeout) + builder.build + } - val responseHandler = new CqlResponseHandler(next, session, system, statsEngine, responseTimeBuilder, stmt, dseAttributes, metricsLogger) - implicit val sameThreadExecutionContext: ExecutionContextExecutor = ExecutionContext.fromExecutorService(dseExecutorService) - FutureUtils - .toScalaFuture(protocol.session.executeAsync(stmt)) - .onComplete(t => DseRequestActor.recordResult(RecordResult(t, responseHandler))) - }) + private def handleSuccess(session: Session, responseTimeBuilder: ResponseTimeBuilder)(builder:B): Unit = { + + val baseStmt:T = buildStatement(builder) + val stmt:T = + dseAttributes.userOrRole + .map(ProxyAuthentication.executeAs(_,baseStmt)) + .getOrElse(baseStmt) + val responseHandler = + new CqlResponseHandler[T, B]( + next, + session, + system, + statsEngine, + responseTimeBuilder, + stmt, + dseAttributes, + metricsLogger) + implicit val sameThreadExecutionContext: ExecutionContextExecutor = ExecutionContext.fromExecutorService(dseExecutorService) + FutureConverters + .toScala(protocol.session.executeAsync(stmt)) + .onComplete(t => DseRequestActor.recordResult(RecordResult(t, responseHandler))) } - private def handleFailure(session: Session, responseTimeBuilder: ResponseTimeBuilder, err: String) = { + private def handleFailure(session: Session, responseTimeBuilder: ResponseTimeBuilder)(err: String) = { val responseTime: ResponseTime = responseTimeBuilder.build() val logUuid = UUID.randomUUID.toString val tagString = if (session.groupHierarchy.nonEmpty) session.groupHierarchy.mkString("/") + "/" + dseAttributes.tag else dseAttributes.tag @@ -113,4 +124,18 @@ class CqlRequestAction(val name: String, logger.error("[{}] {} - Err: {} - Attrs: {}", logUuid, tagString, err, session.attributes.mkString(",")) next ! session.markAsFailed } + + def sendQuery(session: Session): Unit = { + val enableCO = Boolean.getBoolean("gatling.dse.plugin.measure_service_time") + val responseTimeBuilder: ResponseTimeBuilder = if (enableCO) { + // The throughput checker is useless in CO affected scenarios since throughput is not known in advance + COAffectedResponseTime.startingNow(gatlingTimingSource) + } else { + ThroughputVerifier.checkForGatlingOverloading(session, gatlingTimingSource) + GatlingResponseTime.startedByGatling(session, gatlingTimingSource) + } + val stmtBuilder = safely()(dseAttributes.statement.buildFromSession(session)) + stmtBuilder.onFailure(handleFailure(session,responseTimeBuilder)) + stmtBuilder.onSuccess(handleSuccess(session,responseTimeBuilder)) + } } diff --git a/src/main/scala/com/datastax/gatling/plugin/request/CqlRequestActionBuilder.scala b/src/main/scala/com/datastax/gatling/plugin/request/CqlRequestActionBuilder.scala index 455d1bd..2816be0 100644 --- a/src/main/scala/com/datastax/gatling/plugin/request/CqlRequestActionBuilder.scala +++ b/src/main/scala/com/datastax/gatling/plugin/request/CqlRequestActionBuilder.scala @@ -8,17 +8,19 @@ package com.datastax.gatling.plugin.request import com.datastax.gatling.plugin.DseProtocol import com.datastax.gatling.plugin.model.DseCqlAttributes +import com.datastax.oss.driver.api.core.cql.{Statement, StatementBuilder} import io.gatling.core.action.Action import io.gatling.core.action.builder.ActionBuilder import io.gatling.core.structure.ScenarioContext import io.gatling.core.util.NameGen -class CqlRequestActionBuilder(val dseAttributes: DseCqlAttributes) extends ActionBuilder with - NameGen { +class CqlRequestActionBuilder[T <: Statement[T], B <: StatementBuilder[B,T]](val dseAttributes: DseCqlAttributes[T, B]) + extends ActionBuilder + with NameGen { def build(ctx: ScenarioContext, next: Action): Action = { val dseComponents = ctx.protocolComponentsRegistry.components(DseProtocol.DseProtocolKey) - new CqlRequestAction( + new CqlRequestAction[T, B]( dseAttributes.tag, next, ctx.system, @@ -30,4 +32,3 @@ class CqlRequestActionBuilder(val dseAttributes: DseCqlAttributes) extends Actio dseComponents.gatlingTimingSource) } } - diff --git a/src/main/scala/com/datastax/gatling/plugin/request/DseRequestActor.scala b/src/main/scala/com/datastax/gatling/plugin/request/DseRequestActor.scala index c55c996..789d3bc 100644 --- a/src/main/scala/com/datastax/gatling/plugin/request/DseRequestActor.scala +++ b/src/main/scala/com/datastax/gatling/plugin/request/DseRequestActor.scala @@ -8,19 +8,19 @@ package com.datastax.gatling.plugin.request import akka.actor.Actor -import com.datastax.driver.core.ResultSet -import com.datastax.driver.dse.graph.GraphResultSet -import com.google.common.util.concurrent.FutureCallback +import com.datastax.dse.driver.api.core.graph.{GraphStatement, GraphStatementBuilderBase} +import com.datastax.gatling.plugin.response.DseResponseCallback +import com.datastax.oss.driver.api.core.cql.{Statement, StatementBuilder} import com.typesafe.scalalogging.StrictLogging import io.gatling.core.session.Session import scala.concurrent.ExecutionException import scala.util.{Failure, Success, Try} -case class SendCqlQuery(dseRequestAction: CqlRequestAction, session: Session) -case class SendGraphQuery(dseRequestAction: GraphRequestAction, session: Session) +case class SendCqlQuery[T <: Statement[T], B <: StatementBuilder[B,T]](dseRequestAction: CqlRequestAction[T,B], session: Session) +case class SendGraphQuery[T <: GraphStatement[T], B <: GraphStatementBuilderBase[B,T]](dseRequestAction: GraphRequestAction[T,B], session: Session) -case class RecordResult[T](t: Try[T], callback: FutureCallback[T]) +case class RecordResult[T](t: Try[T], callback: DseResponseCallback[T]) class DseRequestActor extends Actor with StrictLogging { override def receive: Actor.Receive = { diff --git a/src/main/scala/com/datastax/gatling/plugin/request/GraphRequestAction.scala b/src/main/scala/com/datastax/gatling/plugin/request/GraphRequestAction.scala index e4abe66..1bfca0e 100644 --- a/src/main/scala/com/datastax/gatling/plugin/request/GraphRequestAction.scala +++ b/src/main/scala/com/datastax/gatling/plugin/request/GraphRequestAction.scala @@ -12,6 +12,8 @@ import java.util.concurrent.ExecutorService import java.util.concurrent.TimeUnit.MICROSECONDS import akka.actor.ActorSystem +import com.datastax.dse.driver.api.core.auth.ProxyAuthentication +import com.datastax.dse.driver.api.core.graph.{GraphStatement, GraphStatementBuilderBase} import com.datastax.gatling.plugin.DseProtocol import com.datastax.gatling.plugin.metrics.MetricsLogger import com.datastax.gatling.plugin.model.DseGraphAttributes @@ -22,6 +24,7 @@ import io.gatling.core.action.{Action, ExitableAction} import io.gatling.core.session.Session import io.gatling.core.stats.StatsEngine +import scala.compat.java8.FutureConverters import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} /** @@ -32,22 +35,22 @@ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} * free the Gatling `injector` actor as fast as possible. * * The plugin router (and its actors) execute the driver code in order to locate the best replica that should receive - * each query, through `DseSession.executeAsync()` or `DseSession.executeGraphAsync()`. A driver I/O thread encodes + * each query, through `CqlSession.executeAsync()` or `CqlSession.executeGraphAsync()`. A driver I/O thread encodes * the request and sends it over the wire. * * Once the response is received, a driver I/O thread (Netty) decodes the response into a Java Object. It then - * completes the `Future` that was returned by `DseSession.executeAsync()`. + * completes the `Future` that was returned by `CqlSession.executeAsync()`. * * Completing that future results in immediately delegating the latency recording work to the plugin router. That * work includes recording it in HDR histograms through non-blocking data structures, and forwarding the result to * other Gatling data writers, like the console reporter. */ -class GraphRequestAction(val name: String, +class GraphRequestAction[T <: GraphStatement[T], B <: GraphStatementBuilderBase[B,T]](val name: String, val next: Action, val system: ActorSystem, val statsEngine: StatsEngine, val protocol: DseProtocol, - val dseAttributes: DseGraphAttributes, + val dseAttributes: DseGraphAttributes[T, B], val metricsLogger: MetricsLogger, val dseExecutorService: ExecutorService, val gatlingTimingSource: GatlingTimingSource) @@ -59,6 +62,60 @@ class GraphRequestAction(val name: String, }) } + private def buildStatement(builder:B):T = { + + // global options + dseAttributes.cl.foreach(builder.setConsistencyLevel) + dseAttributes.idempotent.foreach(builder.setIdempotence(_)) + dseAttributes.node.foreach(builder.setNode) + + // Graph only Options + dseAttributes.graphName.foreach(builder.setGraphName) + dseAttributes.readCL.foreach(builder.setReadConsistencyLevel) + dseAttributes.subProtocol.foreach(builder.setSubProtocol) + dseAttributes.timeout.foreach(builder.setTimeout) + dseAttributes.timestamp.foreach(builder.setTimestamp) + dseAttributes.traversalSource.foreach(builder.setTraversalSource) + dseAttributes.writeCL.foreach(builder.setWriteConsistencyLevel) + + builder.build + } + + private def handleSuccess(session: Session, responseTimeBuilder: ResponseTimeBuilder)(builder:B): Unit = { + + val baseStmt:T = buildStatement(builder) + val stmt:T = + dseAttributes.userOrRole + .map(ProxyAuthentication.executeAs(_,baseStmt)) + .getOrElse(baseStmt) + val responseHandler = + new GraphResponseHandler[T, B]( + next, + session, + system, + statsEngine, + responseTimeBuilder, + stmt, + dseAttributes, + metricsLogger) + implicit val sameThreadExecutionContext: ExecutionContextExecutor = ExecutionContext.fromExecutorService(dseExecutorService) + FutureConverters + .toScala(protocol.session.executeAsync(stmt)) + .onComplete(t => DseRequestActor.recordResult(RecordResult(t, responseHandler))) + } + + private def handleFailure(session: Session, responseTimeBuilder: ResponseTimeBuilder)(err: String) = { + + val responseTime = responseTimeBuilder.build() + val logUuid = UUID.randomUUID.toString + val tagString = if (session.groupHierarchy.nonEmpty) session.groupHierarchy.mkString("/") + "/" + dseAttributes.tag else dseAttributes.tag + + statsEngine.logResponse(session, name, responseTime.toGatlingResponseTimings, KO, None, + Some(s"$tagString - Preparing: ${err.take(50)}"), List(responseTime.latencyIn(MICROSECONDS), "PRE", logUuid)) + + logger.error("[{}] {} - Preparing: {} - Attrs: {}", logUuid, tagString, err, session.attributes.mkString(",")) + next ! session.markAsFailed + } def sendQuery(session: Session): Unit = { val enableCO = Boolean.getBoolean("gatling.dse.plugin.measure_service_time") @@ -69,51 +126,8 @@ class GraphRequestAction(val name: String, ThroughputVerifier.checkForGatlingOverloading(session, gatlingTimingSource) GatlingResponseTime.startedByGatling(session, gatlingTimingSource) } - val stmt = dseAttributes.statement.buildFromSession(session) - - stmt.onFailure(err => { - val responseTime = responseTimeBuilder.build() - val logUuid = UUID.randomUUID.toString - val tagString = if (session.groupHierarchy.nonEmpty) session.groupHierarchy.mkString("/") + "/" + dseAttributes.tag else dseAttributes.tag - - statsEngine.logResponse(session, name, responseTime.toGatlingResponseTimings, KO, None, - Some(s"$tagString - Preparing: ${err.take(50)}"), List(responseTime.latencyIn(MICROSECONDS), "PRE", logUuid)) - - logger.error("[{}] {} - Preparing: {} - Attrs: {}", logUuid, tagString, err, session.attributes.mkString(",")) - next ! session.markAsFailed - }) - - stmt.onSuccess({ gStmt => - // global options - dseAttributes.cl.map(gStmt.setConsistencyLevel) - dseAttributes.defaultTimestamp.map(gStmt.setDefaultTimestamp) - dseAttributes.userOrRole.map(gStmt.executingAs) - dseAttributes.readTimeout.map(gStmt.setReadTimeoutMillis) - dseAttributes.idempotent.map(gStmt.setIdempotent) - - // Graph only Options - dseAttributes.readCL.map(gStmt.setGraphReadConsistencyLevel) - dseAttributes.writeCL.map(gStmt.setGraphWriteConsistencyLevel) - dseAttributes.graphLanguage.map(gStmt.setGraphLanguage) - dseAttributes.graphName.map(gStmt.setGraphName) - dseAttributes.graphSource.map(gStmt.setGraphSource) - dseAttributes.graphTransformResults.map(gStmt.setTransformResultFunction) - - if (dseAttributes.graphInternalOptions.isDefined) { - dseAttributes.graphInternalOptions.get.foreach { t => - gStmt.setGraphInternalOption(t._1, t._2) - } - } - - if (dseAttributes.isSystemQuery.isDefined && dseAttributes.isSystemQuery.get) { - gStmt.setSystemQuery() - } - - val responseHandler = new GraphResponseHandler(next, session, system, statsEngine, responseTimeBuilder, gStmt, dseAttributes, metricsLogger) - implicit val sameThreadExecutionContext: ExecutionContextExecutor = ExecutionContext.fromExecutorService(dseExecutorService) - FutureUtils - .toScalaFuture(protocol.session.executeGraphAsync(gStmt)) - .onComplete(t => DseRequestActor.recordResult(RecordResult(t, responseHandler))) - }) + val stmtBuilder = dseAttributes.statement.buildFromSession(session) + stmtBuilder.onFailure(handleFailure(session, responseTimeBuilder)) + stmtBuilder.onSuccess(handleSuccess(session, responseTimeBuilder)) } } diff --git a/src/main/scala/com/datastax/gatling/plugin/request/GraphRequestActionBuilder.scala b/src/main/scala/com/datastax/gatling/plugin/request/GraphRequestActionBuilder.scala index eb2d367..6f6e7f3 100644 --- a/src/main/scala/com/datastax/gatling/plugin/request/GraphRequestActionBuilder.scala +++ b/src/main/scala/com/datastax/gatling/plugin/request/GraphRequestActionBuilder.scala @@ -6,6 +6,7 @@ package com.datastax.gatling.plugin.request +import com.datastax.dse.driver.api.core.graph.{GraphStatement, GraphStatementBuilderBase} import com.datastax.gatling.plugin.DseProtocol import com.datastax.gatling.plugin.model.DseGraphAttributes import io.gatling.core.action.Action @@ -13,12 +14,13 @@ import io.gatling.core.action.builder.ActionBuilder import io.gatling.core.structure.ScenarioContext import io.gatling.core.util.NameGen -class GraphRequestActionBuilder(dseAttributes: DseGraphAttributes) extends ActionBuilder with - NameGen { +class GraphRequestActionBuilder[T <: GraphStatement[T], B <: GraphStatementBuilderBase[B,T]](dseAttributes: DseGraphAttributes[T, B]) + extends ActionBuilder + with NameGen { def build(ctx: ScenarioContext, next: Action): Action = { val dseComponents = ctx.protocolComponentsRegistry.components(DseProtocol.DseProtocolKey) - new GraphRequestAction( + new GraphRequestAction[T, B]( dseAttributes.tag, next, ctx.system, diff --git a/src/main/scala/com/datastax/gatling/plugin/response/DseResponse.scala b/src/main/scala/com/datastax/gatling/plugin/response/DseResponse.scala index db03c27..b3bcd21 100644 --- a/src/main/scala/com/datastax/gatling/plugin/response/DseResponse.scala +++ b/src/main/scala/com/datastax/gatling/plugin/response/DseResponse.scala @@ -6,147 +6,17 @@ package com.datastax.gatling.plugin.response -import com.datastax.driver.core._ -import com.datastax.driver.dse.graph._ -import com.datastax.gatling.plugin.model.{DseCqlAttributes, DseGraphAttributes} +import com.datastax.gatling.plugin.model.{ DseCqlAttributes, DseGraphAttributes } +import com.datastax.oss.driver.api.core.cql._ +import com.datastax.dse.driver.api.core.graph._ import com.typesafe.scalalogging.LazyLogging -import scala.collection.JavaConverters._ -import scala.util.Try - -abstract class DseResponse { - def executionInfo(): ExecutionInfo - def rowCount(): Int - def applied(): Boolean - def exhausted(): Boolean - def queriedHost(): Host = executionInfo().getQueriedHost - def achievedConsistencyLevel(): ConsistencyLevel = executionInfo().getAchievedConsistencyLevel - def speculativeExecutions(): Int = executionInfo().getSpeculativeExecutions - def pagingState(): PagingState = executionInfo().getPagingState - def triedHosts(): List[Host] = executionInfo().getTriedHosts.asScala.toList - def warnings(): List[String] = executionInfo().getWarnings.asScala.toList - def successFullExecutionIndex(): Int = executionInfo().getSuccessfulExecutionIndex - def schemaInAgreement(): Boolean = executionInfo().isSchemaInAgreement -} - - -class GraphResponse(graphResultSet: GraphResultSet, dseAttributes: DseGraphAttributes) extends DseResponse with LazyLogging { - private lazy val allGraphNodes: Seq[GraphNode] = collection.JavaConverters.asScalaBuffer(graphResultSet.all()) - - override def executionInfo(): ExecutionInfo = graphResultSet.getExecutionInfo() - override def applied(): Boolean = false // graph doesn't support LWTs so always return false - override def exhausted(): Boolean = graphResultSet.isExhausted() - - /** - * Get the number of all rows returned by the query. - * Note: Calling this function fetches all rows from the result set! - */ - override def rowCount(): Int = allGraphNodes.size - - def getGraphResultSet: GraphResultSet = { - graphResultSet - } - - def getAllNodes: Seq[GraphNode] = allGraphNodes - - def getOneNode: GraphNode = { - graphResultSet.one() - } - - - def getGraphResultColumnValues(column: String): Seq[Any] = { - if (allGraphNodes.isEmpty) return Seq.empty - - allGraphNodes.map(node => if (node.get(column) != null) node.get(column)) - } - - - def getEdges(name: String): Seq[Edge] = { - val columnValues = collection.mutable.Buffer[Edge]() - graphResultSet.forEach { n => - Try( - if (n.get(name).isEdge) { - columnValues.append(n.get(name).asEdge()) - } - ) - } - columnValues - } - - - def getVertexes(name: String): Seq[Vertex] = { - val columnValues = collection.mutable.Buffer[Vertex]() - graphResultSet.forEach { n => Try(if (n.get(name).isVertex) columnValues.append(n.get(name).asVertex())) } - columnValues - } - - - def getPaths(name: String): Seq[Path] = { - val columnValues = collection.mutable.Buffer[Path]() - graphResultSet.forEach { n => Try(columnValues.append(n.get(name).asPath())) } - columnValues - } - - - def getProperties(name: String): Seq[Property] = { - val columnValues = collection.mutable.Buffer[Property]() - graphResultSet.forEach { n => Try(columnValues.append(n.get(name).asProperty())) } - columnValues - } - - - def getVertexProperties(name: String): Seq[VertexProperty] = { - val columnValues = collection.mutable.Buffer[VertexProperty]() - graphResultSet.forEach { n => Try(columnValues.append(n.get(name).asVertexProperty())) } - columnValues - } - - def getDseAttributes: DseGraphAttributes = dseAttributes +class CqlResponse(cqlResultSet: AsyncResultSet, dseAttributes: DseCqlAttributes[_, _]) extends LazyLogging { + def attributes:DseCqlAttributes[_,_] = dseAttributes + def resultSet:AsyncResultSet = cqlResultSet } -class CqlResponse(cqlResultSet: ResultSet, dseAttributes: DseCqlAttributes) extends DseResponse with LazyLogging { - private lazy val allCqlRows: Seq[Row] = collection.JavaConverters.asScalaBuffer(cqlResultSet.all()) - - override def executionInfo(): ExecutionInfo = cqlResultSet.getExecutionInfo() - override def applied(): Boolean = cqlResultSet.wasApplied() - override def exhausted(): Boolean = cqlResultSet.isExhausted() - - /** - * Get the number of all rows returned by the query. - * Note: Calling this function fetches all rows from the result set! - */ - def rowCount: Int = allCqlRows.size - - /** - * Get CQL ResultSet - * - * @return - */ - def getCqlResultSet: ResultSet = { - cqlResultSet - } - - - /** - * Get all the rows in a Seq[Row] format - * - * @return - */ - def getAllRows: Seq[Row] = allCqlRows - - - def getOneRow: Row = { - cqlResultSet.one() - } - - - def getCqlResultColumnValues(column: String): Seq[Any] = { - if (allCqlRows.isEmpty || !allCqlRows.head.getColumnDefinitions.contains(column)) { - return Seq.empty - } - - allCqlRows.map(row => row.getObject(column)) - } - - def getDseAttributes: DseCqlAttributes = dseAttributes +class GraphResponse(graphResultSet: AsyncGraphResultSet, dseAttributes: DseGraphAttributes[_,_]) extends LazyLogging { + def attributes:DseGraphAttributes[_,_] = dseAttributes + def resultSet:AsyncGraphResultSet = graphResultSet } diff --git a/src/main/scala/com/datastax/gatling/plugin/response/DseResponseHandler.scala b/src/main/scala/com/datastax/gatling/plugin/response/DseResponseHandler.scala index ebd589e..4b0e145 100644 --- a/src/main/scala/com/datastax/gatling/plugin/response/DseResponseHandler.scala +++ b/src/main/scala/com/datastax/gatling/plugin/response/DseResponseHandler.scala @@ -10,12 +10,13 @@ import java.util.UUID import java.util.concurrent.TimeUnit.MICROSECONDS import akka.actor.ActorSystem -import com.datastax.driver.core._ -import com.datastax.driver.dse.graph.{GraphProtocol, GraphResultSet, GraphStatement} +import com.datastax.oss.driver.api.core.cql._ +import com.datastax.dse.driver.api.core.graph.{AsyncGraphResultSet, GraphStatement, GraphStatementBuilderBase} +import com.datastax.gatling.plugin.checks.{DseCqlCheck, DseGraphCheck} import com.datastax.gatling.plugin.metrics.MetricsLogger import com.datastax.gatling.plugin.model.{DseCqlAttributes, DseGraphAttributes} import com.datastax.gatling.plugin.utils.{ResponseTime, ResponseTimeBuilder} -import com.google.common.util.concurrent.FutureCallback +import com.datastax.oss.driver.api.core.metadata.Node import com.typesafe.scalalogging.StrictLogging import io.gatling.commons.stats._ import io.gatling.commons.validation.Failure @@ -25,7 +26,7 @@ import io.gatling.core.session.Session import io.gatling.core.stats.StatsEngine import io.gatling.core.stats.message.ResponseTimings -import scala.util.Try +import collection.JavaConverters._ object DseResponseHandler { def sanitize(s: String): String = s.replaceAll("""(\r|\n)""", " ") @@ -35,20 +36,25 @@ object DseResponseHandler { .mkString(",") } -abstract class DseResponseHandler[RS, Response <: DseResponse] extends StrictLogging with FutureCallback[RS] { +trait DseResponseCallback[RS] { + def onFailure(t: Throwable): Unit + + def onSuccess(result: RS): Unit +} + +abstract class DseResponseHandler[S, RS, R] extends StrictLogging with DseResponseCallback[RS] { protected def responseTimeBuilder: ResponseTimeBuilder protected def system: ActorSystem protected def statsEngine: StatsEngine protected def metricsLogger: MetricsLogger protected def next: Action protected def session: Session - protected def stmt: Any + protected def stmt: S protected def tag: String protected def queries: Seq[String] - protected def specificChecks: List[Check[Response]] - protected def genericChecks: List[Check[DseResponse]] - protected def newResponse(rs: RS): Response - protected def queriedHost(rs: RS): String + protected def specificChecks: List[Check[R]] + protected def newResponse(rs: RS): R + protected def coordinator(rs: RS): Node private def writeGatlingLog(status: Status, respTimings: ResponseTimings, message: Option[String], extraInfo: List[Any]): Unit = statsEngine.logResponse(session, tag, respTimings, status, None, message, extraInfo) @@ -70,8 +76,8 @@ abstract class DseResponseHandler[RS, Response <: DseResponse] extends StrictLog List(responseTime.latencyIn(MICROSECONDS), "CHK", logUuid) ) - logger.warn("[{}] {} - Check: {}, Query: {}, Host: {}", - logUuid, tagString, checkRes._2.get.message, DseResponseHandler.sanitizeAndJoin(queries), queriedHost(resultSet) + logger.warn("[{}] {} - Check: {}, Query: {}, Coordinator: {}", + logUuid, tagString, checkRes._2.get.message, DseResponseHandler.sanitizeAndJoin(queries), coordinator(resultSet).toString ) } @@ -88,13 +94,9 @@ abstract class DseResponseHandler[RS, Response <: DseResponse] extends StrictLog ) stmt match { - case Some(gs: GraphStatement) => - val unwrap = Try( - DseResponseHandler.sanitize(gs.unwrap(GraphProtocol.GRAPHSON_2_0).toString) - ).getOrElse(DseResponseHandler.sanitize(gs.unwrap(GraphProtocol.GRAPHSON_1_0).toString)) - + case Some(gs: GraphStatement[_]) => logger.warn("[{}] {} - Execute: {} - Attrs: {}", - logUuid, tagString, unwrap, session.attributes.mkString(","), t + logUuid, tagString, gs, session.attributes.mkString(","), t ) case _ => logger.warn("[{}] {} - Execute: {}, Query: {}", @@ -114,58 +116,65 @@ abstract class DseResponseHandler[RS, Response <: DseResponse] extends StrictLog val responseTime = responseTimeBuilder.build() val response = newResponse(result) - val genericResult: (Session => Session, Option[Failure]) = Check.check(response, session, genericChecks) - val genericChecksPassed = genericResult._2.isEmpty - val sessionAfterGenericChecks = genericResult._1(session) - if (genericChecksPassed) { - val specificResult: (Session => Session, Option[Failure]) = Check.check(response, sessionAfterGenericChecks, specificChecks) - val sessionAfterSpecificChecks = genericResult._1(session) - val specificChecksPassed = specificResult._2.isEmpty - if (specificChecksPassed) { - writeSuccess(responseTime) - next ! sessionAfterSpecificChecks.markAsSucceeded - } else { - writeCheckFailure(specificResult, result, responseTime) - next ! sessionAfterSpecificChecks.markAsFailed - } + val specificResult: (Session => Session, Option[Failure]) = Check.check(response, session, specificChecks) + val sessionAfterSpecificChecks = specificResult._1(session) + val specificChecksPassed = specificResult._2.isEmpty + if (specificChecksPassed) { + writeSuccess(responseTime) + next ! sessionAfterSpecificChecks.markAsSucceeded } else { - // Do not run specific checks as the response is already error'ed - writeCheckFailure(genericResult, result, responseTime) - next ! sessionAfterGenericChecks.markAsFailed + writeCheckFailure(specificResult, result, responseTime) + next ! sessionAfterSpecificChecks.markAsFailed } } } -class GraphResponseHandler(val next: Action, +class GraphResponseHandler[T <: GraphStatement[T], B <: GraphStatementBuilderBase[B,T]](val next: Action, val session: Session, val system: ActorSystem, val statsEngine: StatsEngine, val responseTimeBuilder: ResponseTimeBuilder, - val stmt: GraphStatement, - val dseAttributes: DseGraphAttributes, + val stmt: T, + val dseAttributes: DseGraphAttributes[T, B], val metricsLogger: MetricsLogger) - extends DseResponseHandler[GraphResultSet, GraphResponse] { + extends DseResponseHandler[T, AsyncGraphResultSet, GraphResponse] { override protected def tag: String = dseAttributes.tag override protected def queries: Seq[String] = Seq.empty - override protected def specificChecks: List[Check[GraphResponse]] = dseAttributes.graphChecks - override protected def genericChecks: List[Check[DseResponse]] = dseAttributes.genericChecks - override protected def newResponse(rs: GraphResultSet): GraphResponse = new GraphResponse(rs, dseAttributes) - override protected def queriedHost(rs: GraphResultSet): String = rs.getExecutionInfo.getQueriedHost.toString + override protected def specificChecks: List[DseGraphCheck] = dseAttributes.graphChecks + override protected def newResponse(rs: AsyncGraphResultSet): GraphResponse = new GraphResponse(rs, dseAttributes) + override protected def coordinator(rs: AsyncGraphResultSet): Node = rs.getExecutionInfo.getCoordinator } -class CqlResponseHandler(val next: Action, +class CqlResponseHandler[T <: Statement[T], B <: StatementBuilder[B,T]](val next: Action, val session: Session, val system: ActorSystem, val statsEngine: StatsEngine, val responseTimeBuilder: ResponseTimeBuilder, - val stmt: Statement, - val dseAttributes: DseCqlAttributes, + val stmt: T, + val dseAttributes: DseCqlAttributes[T, B], val metricsLogger: MetricsLogger) - extends DseResponseHandler[ResultSet, CqlResponse] { + extends DseResponseHandler[T, AsyncResultSet, CqlResponse] { override protected def tag: String = dseAttributes.tag - override protected def queries: Seq[String] = Seq.empty - override protected def specificChecks: List[Check[CqlResponse]] = dseAttributes.cqlChecks - override protected def genericChecks: List[Check[DseResponse]] = dseAttributes.genericChecks - override protected def newResponse(rs: ResultSet): CqlResponse = new CqlResponse(rs, dseAttributes) - override protected def queriedHost(rs: ResultSet): String = rs.getExecutionInfo.getQueriedHost.toString + override protected def queries: Seq[String] = getQueryStrings(stmt) + override protected def specificChecks: List[DseCqlCheck] = dseAttributes.cqlChecks + override protected def newResponse(rs: AsyncResultSet): CqlResponse = new CqlResponse(rs, dseAttributes) + override protected def coordinator(rs: AsyncResultSet): Node = rs.getExecutionInfo.getCoordinator + + def getQueryString(s:SimpleStatement):String = s.getQuery + + def getQueryString(s:BoundStatement):String = s.getPreparedStatement.getQuery + + def getQueryStrings(stmt:Statement[T]):Seq[String] = { + + stmt match { + case s:SimpleStatement => Seq(getQueryString(s)) + case s:BoundStatement => Seq(getQueryString(s)) + case s:BatchStatement => s.iterator.asScala.map((stmt) => { + stmt match { + case s:SimpleStatement => getQueryString(s) + case s:BoundStatement => getQueryString(s) + } + }).toSeq + } + } } diff --git a/src/main/scala/com/datastax/gatling/plugin/utils/CqlPreparedStatementUtil.scala b/src/main/scala/com/datastax/gatling/plugin/utils/CqlPreparedStatementUtil.scala index 9596de1..6000ccf 100644 --- a/src/main/scala/com/datastax/gatling/plugin/utils/CqlPreparedStatementUtil.scala +++ b/src/main/scala/com/datastax/gatling/plugin/utils/CqlPreparedStatementUtil.scala @@ -9,37 +9,40 @@ package com.datastax.gatling.plugin.utils import java.math.BigInteger import java.net.InetAddress import java.nio.ByteBuffer +import java.time.{Duration, Instant, LocalDate, LocalTime} import java.util -import java.util.Date -import java.util.concurrent.TimeUnit -import com.datastax.driver.core.DataType.Name._ -import com.datastax.driver.core._ -import com.datastax.driver.dse.geometry._ -import com.datastax.driver.dse.geometry.codecs.PointCodec +import com.datastax.dse.driver.api.core.data.geometry._ +import com.datastax.oss.driver.api.core.cql._ +import com.datastax.oss.driver.api.core.`type`._ +import com.datastax.oss.protocol.internal.ProtocolConstants.DataType._ import com.datastax.gatling.plugin.exceptions.CqlTypeException +import com.datastax.oss.driver.api.core.data.{TupleValue, UdtValue} import com.github.nscala_time.time.Imports.DateTime import io.gatling.core.session.Session import scala.collection.JavaConverters._ import scala.util.matching.Regex - trait CqlPreparedStatementUtil { protected val hourMinSecRegEx: Regex = """(\d+):(\d+):(\d+)""".r protected val hourMinSecNanoRegEx: Regex = """(\d+):(\d+):(\d+).(\d+{1,9})""".r - def bindParamByOrder(gatlingSession: Session, - boundStatement: BoundStatement, paramType: DataType.Name, - paramName: String, key: Int): BoundStatement + def bindParamByOrder[T <: Bindable[T]](gatlingSession: Session, + bindable: T, + paramType: DataType, + paramName: String, + key: Int): T - def bindParamByName(gatlingSession: Session, boundStatement: BoundStatement, paramType: DataType.Name, - paramName: String): BoundStatement + def bindParamByName[T <: Bindable[T]](gatlingSession: Session, + bindable: T, + paramType: DataType, + paramName: String): T - def getParamsMap(preparedStatement: PreparedStatement): Map[String, DataType.Name] + def getParamsMap(preparedStatement: PreparedStatement): Map[String, DataType] - def getParamsList(preparedStatement: PreparedStatement): List[DataType.Name] + def getParamsList(preparedStatement: PreparedStatement): List[DataType] } /** @@ -47,89 +50,111 @@ trait CqlPreparedStatementUtil { */ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { - /** * Bind CQL Prepared statement params by key order * * @param gatlingSession Gatling Session - * @param boundStatement CQL BoundStatement + * @param bindable CQL Bindable impl * @param paramType Type of param ie String, int, boolean * @param paramName Gatling Session Attribute Name * @param key Key/Order of param */ - def bindParamByOrder(gatlingSession: Session, boundStatement: BoundStatement, paramType: DataType.Name, - paramName: String, key: Int): BoundStatement = { + def bindParamByOrder[T <: Bindable[T]](gatlingSession: Session, bindable: T, paramType: DataType, + paramName: String, key: Int): T = { if (!gatlingSession.attributes.contains(paramName)) { - if (boundStatement.isSet(paramName)) { - boundStatement.unset(paramName) + return if (bindable.isSet(paramName)) { + bindable.unset(paramName) + } else { + bindable } - return boundStatement } gatlingSession.attributes.get(paramName) match { case Some(null) => - boundStatement.setToNull(paramName) - boundStatement + bindable.setToNull(paramName) case Some(None) => - if (boundStatement.isSet(paramName)) { - boundStatement.unset(paramName) + if (bindable.isSet(paramName)) { + bindable.unset(paramName) + } else { + bindable } - boundStatement case _ => - paramType match { - case (VARCHAR | TEXT | ASCII) => - boundStatement.setString(key, asString(gatlingSession, paramName)) + paramType.getProtocolCode match { + case (VARCHAR | ASCII) => + bindable.setString(key, asString(gatlingSession, paramName)) case INT => - boundStatement.setInt(key, asInteger(gatlingSession, paramName)) + bindable.setInt(key, asInteger(gatlingSession, paramName)) case BOOLEAN => - boundStatement.setBool(key, asBoolean(gatlingSession, paramName)) + bindable.setBoolean(key, asBoolean(gatlingSession, paramName)) case (UUID | TIMEUUID) => - boundStatement.setUUID(key, asUuid(gatlingSession, paramName)) + bindable.setUuid(key, asUuid(gatlingSession, paramName)) case FLOAT => - boundStatement.setFloat(key, asFloat(gatlingSession, paramName)) + bindable.setFloat(key, asFloat(gatlingSession, paramName)) case DOUBLE => - boundStatement.setDouble(key, asDouble(gatlingSession, paramName)) + bindable.setDouble(key, asDouble(gatlingSession, paramName)) case DECIMAL => - boundStatement.setDecimal(key, asDecimal(gatlingSession, paramName)) + bindable.setBigDecimal(key, asDecimal(gatlingSession, paramName)) case INET => - boundStatement.setInet(key, asInet(gatlingSession, paramName)) + bindable.setInetAddress(key, asInet(gatlingSession, paramName)) case TIMESTAMP => - boundStatement.setTimestamp(key, asTimestamp(gatlingSession, paramName)) + bindable.setInstant(key, asInstant(gatlingSession, paramName)) case COUNTER => - boundStatement.setLong(key, asCounter(gatlingSession, paramName)) + bindable.setLong(key, asCounter(gatlingSession, paramName)) case BIGINT => - boundStatement.setLong(key, asBigInt(gatlingSession, paramName)) + bindable.setLong(key, asBigInt(gatlingSession, paramName)) case BLOB => - boundStatement.setBytes(key, asByte(gatlingSession, paramName)) + bindable.setByteBuffer(key, asByte(gatlingSession, paramName)) case VARINT => - boundStatement.setVarint(key, asVarInt(gatlingSession, paramName)) + bindable.setBigInteger(key, asVarInt(gatlingSession, paramName)) case LIST => - boundStatement.setList(key, asList(gatlingSession, paramName)) + val dataType = bindable.getType(key) + dataType match { + case l: ListType => { + val memberClz = clzFromCodec(bindable, l.getElementType) + bindable.setList(key, asList(gatlingSession, paramName, memberClz), memberClz) + } + case _ => throw new IllegalStateException("Observed something other than ListType for a LIST param") + } case SET => - boundStatement.setSet(key, asSet(gatlingSession, paramName)) + val dataType = bindable.getType(key) + dataType match { + case s: SetType => { + val memberClz = clzFromCodec(bindable, s.getElementType) + bindable.setSet(key, asSet(gatlingSession, paramName, memberClz), memberClz) + } + case _ => throw new IllegalStateException("Observed something other than SetType for a SET param") + } case MAP => - boundStatement.setMap(key, asMap(gatlingSession, paramName)) + val dataType = bindable.getType(key) + dataType match { + case m: MapType => { + val keyClz = clzFromCodec(bindable, m.getKeyType) + val valClz = clzFromCodec(bindable, m.getValueType) + bindable.setMap(key, asMap(gatlingSession, paramName, keyClz, valClz), keyClz, valClz) + } + case _ => throw new IllegalStateException("Observed something other than MapType for a MAP param") + } case UDT => - boundStatement.setUDTValue(key, asUdt(gatlingSession, paramName)) + bindable.setUdtValue(key, asUdt(gatlingSession, paramName)) case TUPLE => - boundStatement.setTupleValue(key, asTuple(gatlingSession, paramName)) + bindable.setTupleValue(key, asTuple(gatlingSession, paramName)) case DATE => - boundStatement.setDate(key, asDate(gatlingSession, paramName)) + bindable.setLocalDate(key, asLocalDate(gatlingSession, paramName)) case SMALLINT => - boundStatement.setShort(key, asSmallInt(gatlingSession, paramName)) + bindable.setShort(key, asSmallInt(gatlingSession, paramName)) case TINYINT => - boundStatement.setByte(key, asTinyInt(gatlingSession, paramName)) + bindable.setByte(key, asTinyInt(gatlingSession, paramName)) case TIME => - boundStatement.setTime(key, asTime(gatlingSession, paramName)) + bindable.setLocalTime(key, asTime(gatlingSession, paramName)) case CUSTOM => gatlingSession.attributes.get(paramName) match { case Some(p: Point) => - boundStatement.set(key, asPoint(gatlingSession, paramName), classOf[Point]) + bindable.set(key, asPoint(gatlingSession, paramName), classOf[Point]) case Some(p: LineString) => - boundStatement.set(key, asLineString(gatlingSession, paramName), classOf[LineString]) + bindable.set(key, asLineString(gatlingSession, paramName), classOf[LineString]) case Some(p: Polygon) => - boundStatement.set(key, asPolygon(gatlingSession, paramName), classOf[Polygon]) + bindable.set(key, asPolygon(gatlingSession, paramName), classOf[Polygon]) case _ => throw new UnsupportedOperationException(s"$paramName on unknown CUSTOM type") } @@ -143,83 +168,106 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * Bind CQL Prepared statement params by anem * * @param gatlingSession Gatling Session - * @param boundStatement CQL BoundStatement + * @param bindable CQL Bindable impl * @param paramType Type of param ie String, int, boolean * @param paramName Gatling Session Attribute Value */ - def bindParamByName(gatlingSession: Session, boundStatement: BoundStatement, paramType: DataType.Name, - paramName: String): BoundStatement = { + def bindParamByName[T <: Bindable[T]](gatlingSession: Session, bindable: T, paramType: DataType, + paramName: String): T = { if (!gatlingSession.attributes.contains(paramName)) { - if (boundStatement.isSet(paramName)) { - boundStatement.unset(paramName) + return if (bindable.isSet(paramName)) { + bindable.unset(paramName) + } else { + bindable } - return boundStatement } gatlingSession.attributes.get(paramName) match { case Some(null) => - boundStatement.setToNull(paramName) - boundStatement + bindable.setToNull(paramName) case Some(None) => - if (boundStatement.isSet(paramName)) { - boundStatement.unset(paramName) + if (bindable.isSet(paramName)) { + bindable.unset(paramName) + } else { + bindable } - boundStatement case _ => - paramType match { - case (VARCHAR | TEXT | ASCII) => - boundStatement.setString(paramName, asString(gatlingSession, paramName)) + paramType.getProtocolCode match { + case (VARCHAR | ASCII) => + bindable.setString(paramName, asString(gatlingSession, paramName)) case INT => - boundStatement.setInt(paramName, asInteger(gatlingSession, paramName)) + bindable.setInt(paramName, asInteger(gatlingSession, paramName)) case BOOLEAN => - boundStatement.setBool(paramName, asBoolean(gatlingSession, paramName)) + bindable.setBoolean(paramName, asBoolean(gatlingSession, paramName)) case (UUID | TIMEUUID) => - boundStatement.setUUID(paramName, asUuid(gatlingSession, paramName)) + bindable.setUuid(paramName, asUuid(gatlingSession, paramName)) case FLOAT => - boundStatement.setFloat(paramName, asFloat(gatlingSession, paramName)) + bindable.setFloat(paramName, asFloat(gatlingSession, paramName)) case DOUBLE => - boundStatement.setDouble(paramName, asDouble(gatlingSession, paramName)) + bindable.setDouble(paramName, asDouble(gatlingSession, paramName)) case DECIMAL => - boundStatement.setDecimal(paramName, asDecimal(gatlingSession, paramName)) + bindable.setBigDecimal(paramName, asDecimal(gatlingSession, paramName)) case INET => - boundStatement.setInet(paramName, asInet(gatlingSession, paramName)) + bindable.setInetAddress(paramName, asInet(gatlingSession, paramName)) case TIMESTAMP => - boundStatement.setTimestamp(paramName, asTimestamp(gatlingSession, paramName)) + bindable.setInstant(paramName, asInstant(gatlingSession, paramName)) case BIGINT => - boundStatement.setLong(paramName, asBigInt(gatlingSession, paramName)) + bindable.setLong(paramName, asBigInt(gatlingSession, paramName)) case COUNTER => - boundStatement.setLong(paramName, asCounter(gatlingSession, paramName)) + bindable.setLong(paramName, asCounter(gatlingSession, paramName)) case BLOB => - boundStatement.setBytes(paramName, asByte(gatlingSession, paramName)) + bindable.setByteBuffer(paramName, asByte(gatlingSession, paramName)) case VARINT => - boundStatement.setVarint(paramName, asVarInt(gatlingSession, paramName)) + bindable.setBigInteger(paramName, asVarInt(gatlingSession, paramName)) case LIST => - boundStatement.setList(paramName, asList(gatlingSession, paramName)) + val dataType = bindable.getType(paramName) + dataType match { + case l: ListType => { + val memberClz = clzFromCodec(bindable, l.getElementType) + bindable.setList(paramName, asList(gatlingSession, paramName, memberClz), memberClz) + } + case _ => throw new IllegalStateException("Observed something other than ListType for a LIST param") + } case SET => - boundStatement.setSet(paramName, asSet(gatlingSession, paramName)) + val dataType = bindable.getType(paramName) + dataType match { + case s: SetType => { + val memberClz = clzFromCodec(bindable, s.getElementType) + bindable.setSet(paramName, asSet(gatlingSession, paramName, memberClz), memberClz) + } + case _ => throw new IllegalStateException("Observed something other than SetType for a SET param") + } case MAP => - boundStatement.setMap(paramName, asMap(gatlingSession, paramName)) + val dataType = bindable.getType(paramName) + dataType match { + case m: MapType => { + val keyClz = bindable.codecRegistry().codecFor(m.getKeyType).getJavaType.getRawType.asInstanceOf[Class[Any]] + val valClz = bindable.codecRegistry().codecFor(m.getValueType).getJavaType.getRawType.asInstanceOf[Class[Any]] + bindable.setMap(paramName, asMap(gatlingSession, paramName, keyClz, valClz), keyClz, valClz) + } + case _ => throw new IllegalStateException("Observed something other than MapType for a MAP param") + } case UDT => - boundStatement.setUDTValue(paramName, asUdt(gatlingSession, paramName)) + bindable.setUdtValue(paramName, asUdt(gatlingSession, paramName)) case TUPLE => - boundStatement.setTupleValue(paramName, asTuple(gatlingSession, paramName)) + bindable.setTupleValue(paramName, asTuple(gatlingSession, paramName)) case DATE => - boundStatement.setDate(paramName, asDate(gatlingSession, paramName)) + bindable.setLocalDate(paramName, asLocalDate(gatlingSession, paramName)) case SMALLINT => - boundStatement.setShort(paramName, asSmallInt(gatlingSession, paramName)) + bindable.setShort(paramName, asSmallInt(gatlingSession, paramName)) case TINYINT => - boundStatement.setByte(paramName, asTinyInt(gatlingSession, paramName)) + bindable.setByte(paramName, asTinyInt(gatlingSession, paramName)) case TIME => - boundStatement.setTime(paramName, asTime(gatlingSession, paramName)) + bindable.setLocalTime(paramName, asTime(gatlingSession, paramName)) case CUSTOM => gatlingSession.attributes.get(paramName) match { case Some(p: Point) => - boundStatement.set(paramName, asPoint(gatlingSession, paramName), classOf[Point]) + bindable.set(paramName, asPoint(gatlingSession, paramName), classOf[Point]) case Some(p: LineString) => - boundStatement.set(paramName, asLineString(gatlingSession, paramName), classOf[LineString]) + bindable.set(paramName, asLineString(gatlingSession, paramName), classOf[LineString]) case Some(p: Polygon) => - boundStatement.set(paramName, asPolygon(gatlingSession, paramName), classOf[Polygon]) + bindable.set(paramName, asPolygon(gatlingSession, paramName), classOf[Polygon]) case _ => throw new UnsupportedOperationException(s"$paramName on unknown CUSTOM type") } @@ -236,10 +284,10 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param preparedStatement CQL Prepared Stated * @return */ - def getParamsMap(preparedStatement: PreparedStatement): Map[String, DataType.Name] = { - val paramVariables = preparedStatement.getVariables + def getParamsMap(preparedStatement: PreparedStatement): Map[String, DataType] = { + val paramVariables = preparedStatement.getVariableDefinitions val paramIterator = paramVariables.iterator.asScala - paramIterator.map(p => (p.getName, p.getType.getName)).toMap + paramIterator.map(p => (p.getName.asCql(true), p.getType)).toMap } @@ -249,9 +297,9 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param preparedStatement CQL Prepared Stated * @return */ - def getParamsList(preparedStatement: PreparedStatement): List[DataType.Name] = { - val paramVariables = preparedStatement.getVariables - paramVariables.iterator.asScala.map(p => p.getType.getName).toList + def getParamsList(preparedStatement: PreparedStatement): List[DataType] = { + val paramVariables = preparedStatement.getVariableDefinitions + paramVariables.iterator.asScala.map(_.getType).toList } @@ -484,16 +532,16 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param paramName CQL prepared statement parameter name * @return */ - def asTimestamp(gatlingSession: Session, paramName: String): java.util.Date = { + def asInstant(gatlingSession: Session, paramName: String): Instant = { gatlingSession.attributes.get(paramName).flatMap(Option(_)) match { case Some(l: Long) => - new Date(l) + Instant.ofEpochMilli(l) case Some(s: String) => - DateTime.parse(s).toDate - case Some(d: Date) => - d + Instant.ofEpochMilli(DateTime.parse(s).getMillis) + case Some(i: Instant) => + i case _ => - throw new CqlTypeException(s"$paramName expected to be type of Long, String or java.util.Date") + throw new CqlTypeException(s"$paramName expected to be type of Long, String or java.time.Instant") } } @@ -507,13 +555,13 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param paramName CQL prepared statement parameter name * @return */ - def asSet(gatlingSession: Session, paramName: String): util.Set[Any] = { + def asSet[T](gatlingSession: Session, paramName: String, elementType: Class[T]): util.Set[T] = { gatlingSession.attributes.get(paramName).flatMap(Option(_)) match { - case Some(m: Set[Any]@unchecked) => + case Some(m: Set[T]@unchecked) => m.asJava - case Some(s: Seq[Any]@unchecked) => + case Some(s: Seq[T]@unchecked) => s.toSet.asJava - case Some(s: util.Set[Any]@unchecked) => + case Some(s: util.Set[T]@unchecked) => s case _ => throw new CqlTypeException(s"$paramName expected to be type of Set") @@ -530,13 +578,13 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param paramName CQL prepared statement parameter name * @return */ - def asList(gatlingSession: Session, paramName: String): util.List[Any] = { + def asList[T](gatlingSession: Session, paramName: String, elementType: Class[T]): util.List[T] = { gatlingSession.attributes.get(paramName).flatMap(Option(_)) match { - case Some(m: List[Any]@unchecked) => + case Some(m: List[T]@unchecked) => m.asJava - case Some(s: Seq[Any]@unchecked) => + case Some(s: Seq[T]@unchecked) => s.toList.asJava - case Some(l: util.List[Any]@unchecked) => + case Some(l: util.List[T]@unchecked) => l case _ => throw new CqlTypeException(s"$paramName expected to be type of List") @@ -553,11 +601,11 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param paramName CQL prepared statement parameter name * @return */ - def asMap(gatlingSession: Session, paramName: String): util.Map[Any, Any] = { + def asMap[K, V](gatlingSession: Session, paramName: String, keyType: Class[K], valType: Class[V]): util.Map[K, V] = { gatlingSession.attributes.get(paramName).flatMap(Option(_)) match { - case Some(m: Map[Any, Any]@unchecked) => + case Some(m: Map[K, V]@unchecked) => m.asJava - case Some(mj: util.Map[Any, Any]@unchecked) => + case Some(mj: util.Map[K, V]@unchecked) => mj case _ => throw new CqlTypeException(s"$paramName expected to be type of Set") @@ -595,7 +643,7 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { throw new CqlTypeException(s"$paramName expected to be type of LineString") } } - + /** * Returns CQL compatible Polygon type * @@ -695,10 +743,12 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param paramName CQL prepared statement parameter name * @return */ - def asTime(gatlingSession: Session, paramName: String): Long = { + def asTime(gatlingSession: Session, paramName: String): LocalTime = { gatlingSession.attributes.get(paramName).flatMap(Option(_)) match { - case Some(l: Long) => + case Some(l: LocalTime) => l + case Some(l: Long) => + LocalTime.ofNanoOfDay(l) case Some(s: String) => s.trim match { case hourMinSecNanoRegEx(hour, min, second, nano) => parseTime(paramName, hour, min, second, nano) @@ -711,7 +761,7 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { } private def parseTime(paramName: String, hourStr: String, minStr: String, - secStr: String, nanoStr: String = null) = { + secStr: String, nanoStr: String = null): LocalTime = { val hour = Integer.parseInt(hourStr) val min = Integer.parseInt(minStr) @@ -740,13 +790,7 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { throw new CqlTypeException(s"$paramName Seconds out of bounds.") } - var rawTime: Long = 0 - rawTime += TimeUnit.HOURS.toNanos(hour) - rawTime += TimeUnit.MINUTES.toNanos(min) - rawTime += TimeUnit.SECONDS.toNanos(sec) - rawTime += nanos - - rawTime + LocalTime.of(hour, min, sec, nanos) } @@ -761,16 +805,16 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param paramName CQL prepared statement parameter name * @return */ - def asDate(gatlingSession: Session, paramName: String): com.datastax.driver.core.LocalDate = { + def asLocalDate(gatlingSession: Session, paramName: String): LocalDate = { gatlingSession.attributes.get(paramName).flatMap(Option(_)) match { case Some(s: String) => val dateSplit = s.split("-").toList - LocalDate.fromYearMonthDay(dateSplit.head.toInt, dateSplit(1).toInt, dateSplit(2).toInt) + LocalDate.of(dateSplit.head.toInt, dateSplit(1).toInt, dateSplit(2).toInt) case Some(l: Long) => - LocalDate.fromMillisSinceEpoch(l) + toLocalDate(l) case Some(i: Int) => - LocalDate.fromDaysSinceEpoch(i) - case Some(ld: com.datastax.driver.core.LocalDate) => + toLocalDate(i) + case Some(ld: LocalDate) => ld case _ => throw new CqlTypeException(s"$paramName expected to be type of String, Long, Int or LocalDate") @@ -807,11 +851,12 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { * @param paramName CQL prepared statement parameter name * @return */ - def asUdt(gatlingSession: Session, paramName: String): UDTValue = { + def asUdt(gatlingSession: Session, paramName: String): UdtValue = { gatlingSession.attributes.get(paramName).flatMap(Option(_)) match { - case Some(udt: UDTValue) => + case Some(udt: UdtValue) => udt case _ => + throw new CqlTypeException(s"$paramName expected to be type of UDTValue") } } @@ -857,4 +902,14 @@ object CqlPreparedStatementUtil extends CqlPreparedStatementUtil { throw new CqlTypeException(s"$paramName expected to be type of ByteBuffer, Array[Byte] or Byte") } } + + def toLocalDate(epochMillis: Long): LocalDate = { + val end = Instant.ofEpochMilli(epochMillis) + val d = Duration.between(Instant.EPOCH, end) + LocalDate.ofEpochDay(d.toDays) + } + + def clzFromCodec(bindable: Bindable[_], genType: DataType): Class[Any] = { + bindable.codecRegistry().codecFor(genType).getJavaType.getRawType.asInstanceOf[Class[Any]] + } } diff --git a/src/main/scala/com/datastax/gatling/plugin/utils/FutureUtils.scala b/src/main/scala/com/datastax/gatling/plugin/utils/FutureUtils.scala deleted file mode 100644 index 8111056..0000000 --- a/src/main/scala/com/datastax/gatling/plugin/utils/FutureUtils.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2018 Datastax Inc. - * - * This software can be used solely with DataStax products. Please consult the file LICENSE.md. - */ - -package com.datastax.gatling.plugin.utils - -import com.google.common.util.concurrent.{FutureCallback, Futures, ListenableFuture} - -import scala.concurrent.{Future, Promise} - -object FutureUtils { - /** - * Converts a Guava future into a Scala future containing the exact same - * result. - * - * The conversion is operated by the same thread that will eventually - * complete the guava future. It is not an expensive operation, though it - * may involve CAS operations under the hood. - * - * @param guavaFuture the Guava future to convert - * @tparam T the type of object returned by the Guava future - * @return a Scala [[Future]] that will be completed when the Guava future - * completes - */ - def toScalaFuture[T](guavaFuture: ListenableFuture[T]): Future[T] = { - val scalaPromise = Promise[T]() - Futures.addCallback( - guavaFuture, - new FutureCallback[T] { - def onSuccess(result: T): Unit = scalaPromise.success(result) - - def onFailure(exception: Throwable): Unit = scalaPromise.failure(exception) - } - ) - scalaPromise.future - } -} diff --git a/src/main/scala/com/datastax/gatling/plugin/utils/ResultSetUtils.scala b/src/main/scala/com/datastax/gatling/plugin/utils/ResultSetUtils.scala new file mode 100644 index 0000000..264eaf1 --- /dev/null +++ b/src/main/scala/com/datastax/gatling/plugin/utils/ResultSetUtils.scala @@ -0,0 +1,73 @@ +package com.datastax.gatling.plugin.utils + +import java.util.concurrent.CompletionStage + +import com.datastax.dse.driver.api.core.graph.{ AsyncGraphResultSet, GraphNode } +import com.datastax.oss.driver.api.core.cql.{ AsyncResultSet, Row } +import org.apache.tinkerpop.gremlin.process.traversal.Path +import org.apache.tinkerpop.gremlin.structure.{ Edge, Property, Vertex, VertexProperty } + +import collection.JavaConverters._ + +object ResultSetUtils { + + def asyncResultSetToIterator(resultSet: AsyncResultSet): Iterator[Row] = + new IteratorFromAsyncResultSet(new CqlAsyncRS(resultSet)) + + def asyncGraphResultSetToIterator(resultSet: AsyncGraphResultSet): Iterator[GraphNode] = + new IteratorFromAsyncResultSet(new GraphAsyncRS(resultSet)) +} + +object GraphResultSetUtils { + + private def buildFilterAndMapFn[T](filterFn: GraphNode => Boolean, mapFn: GraphNode => T)(resultSet: AsyncGraphResultSet, key: String): Iterator[T] = + ResultSetUtils + .asyncGraphResultSetToIterator(resultSet) + .map(graphNode => graphNode.getByKey(key)) + .filter(filterFn) + .map(mapFn) + + def edges: (AsyncGraphResultSet, String) => Iterator[Edge] = buildFilterAndMapFn(_.isEdge, _.asEdge) + + def vertexes: (AsyncGraphResultSet, String) => Iterator[Vertex] = buildFilterAndMapFn(_.isVertex, _.asVertex) + + def paths: (AsyncGraphResultSet, String) => Iterator[Path] = buildFilterAndMapFn(_.isPath, _.asPath) + + def properties: (AsyncGraphResultSet, String) => Iterator[Property[_]] = + buildFilterAndMapFn(_.isProperty, _.asProperty.asInstanceOf[Property[_]]) + + def vertexProperties: (AsyncGraphResultSet, String) => Iterator[VertexProperty[_]] = + buildFilterAndMapFn(_.isVertexProperty, _.asVertexProperty.asInstanceOf[VertexProperty[_]]) +} + +trait AsyncRS[T] { + def hasMorePages: Boolean + def currentPage: Iterable[T] + def fetchNextPage: CompletionStage[AsyncRS[T]] +} + +class CqlAsyncRS(rs: AsyncResultSet) extends AsyncRS[Row] { + override def hasMorePages: Boolean = rs.hasMorePages + override def currentPage: Iterable[Row] = rs.currentPage.asScala + override def fetchNextPage: CompletionStage[AsyncRS[Row]] = rs.fetchNextPage.thenApplyAsync(new CqlAsyncRS(_)) +} + +class GraphAsyncRS(rs: AsyncGraphResultSet) extends AsyncRS[GraphNode] { + override def hasMorePages: Boolean = rs.hasMorePages + override def currentPage: Iterable[GraphNode] = rs.currentPage.asScala + override def fetchNextPage: CompletionStage[AsyncRS[GraphNode]] = rs.fetchNextPage.thenApplyAsync(new GraphAsyncRS(_)) +} + +/* Note that this iterator isn't thread-safe */ +class IteratorFromAsyncResultSet[T](rs: AsyncRS[T]) extends Iterator[T] { + var working = rs + var iter = rs.currentPage.iterator + override def hasNext: Boolean = iter.hasNext || working.hasMorePages + override def next: T = { + if (!iter.hasNext && working.hasMorePages) { + working = working.fetchNextPage.toCompletableFuture.get + iter = working.currentPage.iterator + } + iter.next + } +} diff --git a/src/test/scala/com/datastax/gatling/plugin/DseCqlStatementSpec.scala b/src/test/scala/com/datastax/gatling/plugin/DseCqlStatementSpec.scala index 632b445..9b58156 100644 --- a/src/test/scala/com/datastax/gatling/plugin/DseCqlStatementSpec.scala +++ b/src/test/scala/com/datastax/gatling/plugin/DseCqlStatementSpec.scala @@ -2,12 +2,12 @@ package com.datastax.gatling.plugin import java.nio.ByteBuffer -import com.datastax.driver.core.ColumnDefinitions.Definition -import com.datastax.driver.core._ import com.datastax.gatling.plugin.base.BaseSpec import com.datastax.gatling.plugin.exceptions.DseCqlStatementException import com.datastax.gatling.plugin.model._ import com.datastax.gatling.plugin.utils.CqlPreparedStatementUtil +import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes} +import com.datastax.oss.driver.api.core.cql._ import io.gatling.commons.validation._ import io.gatling.core.session.Session import io.gatling.core.session.el.ElCompiler @@ -15,13 +15,10 @@ import org.easymock.EasyMock._ import scala.collection.JavaConverters._ - class DseCqlStatementSpec extends BaseSpec { val prepared = mock[PreparedStatement] val mockColDefinitions = mock[ColumnDefinitions] - val mockDefinitions = mock[Definition] - val mockDefinitionId = mock[Definition] val mockBoundStatement = mock[BoundStatement] val mockCqlTypes = mock[CqlPreparedStatementUtil] @@ -41,47 +38,46 @@ class DseCqlStatementSpec extends BaseSpec { reset(prepared, mockBoundStatement, mockCqlTypes) } - describe("DseCqlSimpleStatement") { it("should succeed with a passed SimpleStatement", CqlTest) { - val stmt = new SimpleStatement("select * from keyspace.table where id = 5") + val stmt = SimpleStatement.builder("select * from keyspace.table where id = 5").build() val result = DseCqlSimpleStatement(stmt).buildFromSession(validGatlingSession) result shouldBe a[Success[_]] - result.get.toString shouldBe stmt.toString + result.get.build.getQuery shouldBe stmt.getQuery } - } - describe("DseCqlBoundStatementWithPassedParams") { val e1 = ElCompiler.compile[AnyRef]("${foo}") val e2 = ElCompiler.compile[AnyRef]("${bar}") + val mockBuilder = mock[BoundStatementBuilder] + it("correctly bind values to a prepared statement") { expecting { prepared.bind(fooValue, barValue).andReturn(mockBoundStatement) + mockBuilder.build().andReturn(mockBoundStatement) } whenExecuting(prepared, mockCqlTypes, mockBoundStatement) { - DseCqlBoundStatementWithPassedParams(mockCqlTypes, prepared, e1, e2) + DseCqlBoundStatementWithPassedParams(mockCqlTypes, prepared, e1, e2)((_) => mockBuilder) .buildFromSession(validGatlingSession) shouldBe a[Success[_]] } } - it("should fail if the expression is wrong and return the 1st error") { expecting { - prepared.getVariables.andStubReturn(mockColDefinitions) + prepared.getVariableDefinitions.andStubReturn(mockColDefinitions) } whenExecuting(prepared, mockCqlTypes, mockBoundStatement) { - val r = DseCqlBoundStatementWithPassedParams(mockCqlTypes, prepared, e1, e2) + val r = DseCqlBoundStatementWithPassedParams(mockCqlTypes, prepared, e1, e2)((_) => mockBuilder) .buildFromSession(invalidGatlingSession) r shouldBe a[Failure] r shouldBe "No attribute named 'foo' is defined".failure @@ -89,64 +85,69 @@ class DseCqlStatementSpec extends BaseSpec { } } - describe("DseCqlBoundStatementWithParamList") { val validParamList = Seq("foo", "bar") - val paramsList = List[DataType.Name](DataType.Name.TEXT, DataType.Name.INT) + val paramsList = List(DataTypes.TEXT, DataTypes.INT) + + val mockBuilder = mock[BoundStatementBuilder] it("correctly bind values to a prepared statement") { expecting { prepared.bind().andReturn(mockBoundStatement) mockCqlTypes.getParamsList(prepared).andReturn(paramsList) - mockCqlTypes.bindParamByOrder(validGatlingSession, mockBoundStatement, DataType.Name.TEXT, "foo", 0) - .andReturn(mockBoundStatement) - mockCqlTypes.bindParamByOrder(validGatlingSession, mockBoundStatement, DataType.Name.INT, "bar", 1) - .andReturn(mockBoundStatement) + mockCqlTypes.bindParamByOrder(validGatlingSession, mockBuilder, DataTypes.TEXT, "foo", 0) + .andReturn(mockBuilder) + mockCqlTypes.bindParamByOrder(validGatlingSession, mockBuilder, DataTypes.INT, "bar", 1) + .andReturn(mockBuilder) + mockBuilder.build().andReturn(mockBoundStatement) } whenExecuting(prepared, mockCqlTypes, mockBoundStatement) { - DseCqlBoundStatementWithParamList(mockCqlTypes, prepared, validParamList) + DseCqlBoundStatementWithParamList(mockCqlTypes, prepared, validParamList)((_) => mockBuilder) .buildFromSession(validGatlingSession) shouldBe a[Success[_]] } } } - - describe("DseCqlBoundStatementNamed") { + val mockBuilder = mock[BoundStatementBuilder] + it("correctly bind values to a prepared statement") { expecting { prepared.bind().andReturn(mockBoundStatement) - mockCqlTypes.getParamsMap(prepared).andReturn(Map(fooKey -> DataType.Name.INT)) - mockCqlTypes.bindParamByName(validGatlingSession, mockBoundStatement, DataType.Name.INT, "foo") - .andReturn(mockBoundStatement) + mockCqlTypes.getParamsMap(prepared).andReturn(Map(fooKey -> DataTypes.INT)) + mockCqlTypes.bindParamByName(validGatlingSession, mockBuilder, DataTypes.INT, "foo") + .andReturn(mockBuilder) + mockBuilder.build().andReturn(mockBoundStatement) } whenExecuting(prepared, mockCqlTypes, mockBoundStatement) { - DseCqlBoundStatementNamed(mockCqlTypes, prepared) + DseCqlBoundStatementNamed(mockCqlTypes, prepared)((_) => mockBuilder) .buildFromSession(validGatlingSession) shouldBe a[Success[_]] } } } + describe("DseCqlBoundStatementNamedFromSession") { + val mockBuilder = mock[BoundStatementBuilder] - - describe("DseCqlBoundStatementNamedFromSession") { it("correctly bind values to a prepared statement in session") { val sessionWithStatement: Session = validGatlingSession.set("statementKey", prepared) expecting { prepared.bind().andReturn(mockBoundStatement) - mockCqlTypes.getParamsMap(prepared).andReturn(Map(fooKey -> DataType.Name.INT)) - mockCqlTypes.bindParamByName(sessionWithStatement, mockBoundStatement, DataType.Name.INT, "foo") - .andReturn(mockBoundStatement) + mockCqlTypes.getParamsMap(prepared).andReturn(Map(fooKey -> DataTypes.INT)) + mockCqlTypes.bindParamByName(sessionWithStatement, mockBuilder, DataTypes.INT, fooKey) + .andReturn(mockBuilder) + mockBuilder.build().andReturn(mockBoundStatement) } + whenExecuting(prepared, mockCqlTypes, mockBoundStatement) { - DseCqlBoundStatementNamedFromSession(mockCqlTypes, "statementKey") + DseCqlBoundStatementNamedFromSession(mockCqlTypes, "statementKey")((_) => mockBuilder) .buildFromSession(sessionWithStatement) shouldBe a[Success[_]] } } @@ -156,7 +157,7 @@ class DseCqlStatementSpec extends BaseSpec { } whenExecuting(prepared, mockCqlTypes, mockBoundStatement) { val thrown = intercept[DseCqlStatementException] { - DseCqlBoundStatementNamedFromSession(mockCqlTypes, "statementKey") + DseCqlBoundStatementNamedFromSession(mockCqlTypes, "statementKey")((_) => mockBuilder) .buildFromSession(validGatlingSession) shouldBe a[Failure] } thrown.getMessage shouldBe "Passed sessionKey: {statementKey} does not exist in Session." @@ -164,45 +165,45 @@ class DseCqlStatementSpec extends BaseSpec { } } - - - describe("DseCqlBoundBatchStatement") { + val mockBuilder = mock[BoundStatementBuilder] + it("correctly bind values to a prepared statement") { expecting { - mockBoundStatement.getOutgoingPayload.andReturn(Map("test" -> ByteBuffer.wrap(Array(12.toByte))).asJava) - mockBoundStatement.getOutgoingPayload.andReturn(Map("test" -> ByteBuffer.wrap(Array(12.toByte))).asJava) prepared.bind().andReturn(mockBoundStatement) - mockCqlTypes.getParamsMap(prepared).andReturn(Map(fooKey -> DataType.Name.INT)) - mockCqlTypes.bindParamByName(validGatlingSession, mockBoundStatement, DataType.Name.INT, "foo") - .andReturn(mockBoundStatement) + mockCqlTypes.getParamsMap(prepared).andReturn(Map(fooKey -> DataTypes.INT)) + mockCqlTypes.bindParamByName(validGatlingSession, mockBuilder, DataTypes.INT, fooKey) + .andReturn(mockBuilder).anyTimes() + mockBuilder.build().andReturn(mockBoundStatement) } - whenExecuting(prepared, mockCqlTypes, mockBoundStatement) { - DseCqlBoundBatchStatement(mockCqlTypes, Seq(prepared)) + whenExecuting(prepared, mockCqlTypes, mockBoundStatement, mockBuilder) { + DseCqlBoundBatchStatement(mockCqlTypes, Seq(prepared))((_) => mockBuilder) .buildFromSession(validGatlingSession) shouldBe a[Success[_]] } } } - describe("DseCqlCustomPayloadStatement") { - val stmt = new SimpleStatement("select * from keyspace.table where id = 5") + val stmt = SimpleStatement.builder("select * from keyspace.table where id = 5").build() it("should succeed with a passed SimpleStatement", CqlTest) { + val expectedCustomPayload = Map("test" -> ByteBuffer.wrap(Array(12.toByte))) val payloadGatlingSession = new Session("name", 1, Map( - "payload" -> Map("test" -> ByteBuffer.wrap(Array(12.toByte)))) + "payload" -> expectedCustomPayload) ) val result = DseCqlCustomPayloadStatement(stmt, "payload") .buildFromSession(payloadGatlingSession) result shouldBe a[Success[_]] - result.get.toString shouldBe stmt.toString + val resultStmt = result.get.build + resultStmt.getCustomPayload shouldBe expectedCustomPayload.asJava + resultStmt.getQuery shouldBe stmt.getQuery } it("should fail with non existent sessionKey", CqlTest) { @@ -220,8 +221,6 @@ class DseCqlStatementSpec extends BaseSpec { DseCqlCustomPayloadStatement(stmt, "payload") .buildFromSession(payloadGatlingSession) shouldBe a[Failure] } - } - } diff --git a/src/test/scala/com/datastax/gatling/plugin/DseGraphStatementSpec.scala b/src/test/scala/com/datastax/gatling/plugin/DseGraphStatementSpec.scala index 69c6bd1..63371d4 100644 --- a/src/test/scala/com/datastax/gatling/plugin/DseGraphStatementSpec.scala +++ b/src/test/scala/com/datastax/gatling/plugin/DseGraphStatementSpec.scala @@ -1,24 +1,23 @@ package com.datastax.gatling.plugin -import com.datastax.driver.dse.DseSession -import com.datastax.driver.dse.graph.SimpleGraphStatement -import com.datastax.dse.graph.api.DseGraph +import com.datastax.dse.driver.api.core.graph.{FluentGraphStatement, ScriptGraphStatement} +import com.datastax.dse.driver.api.core.graph.DseGraph.g import com.datastax.gatling.plugin.base.BaseSpec import com.datastax.gatling.plugin.model.{GraphBoundStatement, GraphFluentStatement, GraphStringStatement} +import com.datastax.oss.driver.api.core.CqlSession import io.gatling.commons.validation.{Failure, Success} import io.gatling.core.session.Session import io.gatling.core.session.el.ElCompiler import org.easymock.EasyMock.reset - class DseGraphStatementSpec extends BaseSpec { - val mockDseSession = mock[DseSession] + val mockCqlSession = mock[CqlSession] val validGatlingSession = new Session("name", 1, Map("test" -> "5")) val invalidGatlingSession = new Session("name", 1, Map("buzz" -> Map("test" -> "this"))) before { - reset(mockDseSession) + reset(mockCqlSession) } @@ -42,8 +41,7 @@ class DseGraphStatementSpec extends BaseSpec { describe("FluentStatement") { - val g = DseGraph.traversal(mockDseSession) - val gStatement = DseGraph.statementFromTraversal(g.V().limit(5)) + val gStatement = FluentGraphStatement.newInstance(g.V().limit(5)) val target = GraphFluentStatement(gStatement) it("should correctly return StringStatement for a valid expression") { @@ -54,7 +52,7 @@ class DseGraphStatementSpec extends BaseSpec { describe("GraphBoundStatement") { - val graphStatement = new SimpleGraphStatement("g.addV(label, vertexLabel).property('type', myType)") + val graphStatement = ScriptGraphStatement.builder("g.addV(label, vertexLabel).property('type', myType)") val target = GraphBoundStatement(graphStatement, Map("test" -> "type")) it("should suceeed with a valid session") { diff --git a/src/test/scala/com/datastax/gatling/plugin/base/BaseCassandraServerSpec.scala b/src/test/scala/com/datastax/gatling/plugin/base/BaseCassandraServerSpec.scala index aee98db..7bcc96e 100644 --- a/src/test/scala/com/datastax/gatling/plugin/base/BaseCassandraServerSpec.scala +++ b/src/test/scala/com/datastax/gatling/plugin/base/BaseCassandraServerSpec.scala @@ -1,15 +1,16 @@ package com.datastax.gatling.plugin.base import java.nio.file.Files +import java.util.concurrent.atomic.AtomicBoolean -import com.datastax.driver.dse.DseSession +import com.datastax.oss.driver.api.core.CqlSession import org.cassandraunit.utils.EmbeddedCassandraServerHelper /** * Used for Specs that require a running Cassandra instance to run */ class BaseCassandraServerSpec extends BaseSpec { - protected val dseSession: DseSession = BaseCassandraServerSpec.dseSession + protected val cqlSession: CqlSession = BaseCassandraServerSpec.dseSession protected def cleanCassandra(keyspace: String = ""): Unit = { if (keyspace.isEmpty) { @@ -20,14 +21,14 @@ class BaseCassandraServerSpec extends BaseSpec { } protected def createKeyspace(keyspace: String): Boolean = { - dseSession.execute( + cqlSession.execute( s"CREATE KEYSPACE IF NOT EXISTS $keyspace WITH replication = " + "{ 'class' : 'SimpleStrategy', 'replication_factor': '1'}") .wasApplied() } protected def createTable(keyspace: String, name: String, columns: String): Boolean = { - dseSession.execute( + cqlSession.execute( s""" CREATE TABLE IF NOT EXISTS $keyspace.$name ( $columns @@ -36,7 +37,7 @@ class BaseCassandraServerSpec extends BaseSpec { protected def createType(keyspace: String, name: String, columns: String): Boolean = { - dseSession.execute( + cqlSession.execute( s""" CREATE TYPE IF NOT EXISTS $keyspace.$name ( $columns @@ -45,12 +46,11 @@ class BaseCassandraServerSpec extends BaseSpec { } object BaseCassandraServerSpec { - if (EmbeddedCassandraServerHelper.getSession == null) { + EmbeddedCassandraServerHelper.startEmbeddedCassandra( "cassandra.yaml", Files.createTempDirectory("gatling-dse-plugin.").toString, 30000L) - } - private val dseSession: DseSession = GatlingDseSession.getSession + private val dseSession: CqlSession = GatlingCqlSession.getSession } diff --git a/src/test/scala/com/datastax/gatling/plugin/base/BaseCqlSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/base/BaseCqlSimulation.scala index bd6b95e..dfe9d9d 100644 --- a/src/test/scala/com/datastax/gatling/plugin/base/BaseCqlSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/base/BaseCqlSimulation.scala @@ -6,7 +6,7 @@ abstract class BaseCqlSimulation extends Simulation { val testKeyspace = "gatling_cql_unittests" - val session = GatlingDseSession.createDseSession() + val session = GatlingCqlSession.createCqlSession() def createTestKeyspace = { session.execute( diff --git a/src/test/scala/com/datastax/gatling/plugin/base/BaseGraphSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/base/BaseGraphSimulation.scala index fd0563d..0891137 100644 --- a/src/test/scala/com/datastax/gatling/plugin/base/BaseGraphSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/base/BaseGraphSimulation.scala @@ -6,7 +6,7 @@ abstract class BaseGraphSimulation extends Simulation { val testKeyspace = "gatling_cql_unittests" - val session = GatlingDseSession.createDseSession("10.10.10.2", 9042) + val session = GatlingCqlSession.createCqlSession("10.10.10.2", 9042) def createTestKeyspace = { session.execute( diff --git a/src/test/scala/com/datastax/gatling/plugin/base/GatlingCqlSession.scala b/src/test/scala/com/datastax/gatling/plugin/base/GatlingCqlSession.scala new file mode 100644 index 0000000..93c8c7b --- /dev/null +++ b/src/test/scala/com/datastax/gatling/plugin/base/GatlingCqlSession.scala @@ -0,0 +1,58 @@ +package com.datastax.gatling.plugin.base + +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicReference + +import com.datastax.oss.driver.api.core.CqlSession +import org.cassandraunit.utils.EmbeddedCassandraServerHelper + +trait GatlingCqlSession { + + // TODO: Probably should eventually be an actor, using AtomicRef for now to cheat + private val session: AtomicReference[CqlSession] = new AtomicReference[CqlSession](null) + + /** + * Create new Dse Session to either the embedded C* instance or a remote instance + * + * Note: This includes a hack to get around issue with DseCluster builder and Scala + * + * @param cassandraHost Cassandra Server IP + * @param cassandraPort Cassandra Port, default will use Embedded Cassandra's port + * @return + */ + def createCqlSession(cassandraHost: String = "127.0.0.1", + cassandraPort: Int = EmbeddedCassandraServerHelper.getNativeTransportPort, + localDc:String = "datacenter1"): CqlSession = { + + session.updateAndGet((v) => { + if (v != null) { + v + } + else { + val addr = new InetSocketAddress(cassandraHost, cassandraPort) + CqlSession.builder().addContactPoint(addr).withLocalDatacenter(localDc).build() + } + }) + } + + + /** + * Get the session of the current DSE Session + * + * @return + */ + def getSession: CqlSession = { + createCqlSession() + } + + + /** + * Close the current session + */ + def closeSession(): Unit = { + session.getAndSet(null).close() + } + +} + +object GatlingCqlSession extends GatlingCqlSession diff --git a/src/test/scala/com/datastax/gatling/plugin/base/GatlingDseSession.scala b/src/test/scala/com/datastax/gatling/plugin/base/GatlingDseSession.scala deleted file mode 100644 index 8af4d84..0000000 --- a/src/test/scala/com/datastax/gatling/plugin/base/GatlingDseSession.scala +++ /dev/null @@ -1,67 +0,0 @@ -package com.datastax.gatling.plugin.base - -import com.datastax.driver.dse.{DseCluster, DseSession} -import org.cassandraunit.utils.EmbeddedCassandraServerHelper - -trait GatlingDseSession { - - private var dseCluster: DseCluster = _ - - private var session: DseSession = _ - - /** - * Create new Dse Session to either the embedded C* instance or a remote instance - * - * Note: This includes a hack to get around issue with DseCluster builder and Scala - * - * @param cassandraHost Cassandra Server IP - * @param cassandraPort Cassandra Port, default will use Embedded Cassandra's port - * @return - */ - def createDseSession(cassandraHost: String = "127.0.0.1", cassandraPort: Int = -1): DseSession = { - - if (session != null) { - return session - } - - var cPort = cassandraPort - if (cPort == -1) { - cPort = EmbeddedCassandraServerHelper.getNativeTransportPort - } - - dseCluster = - try { - DseCluster.builder().addContactPoint(cassandraHost).withPort(cPort).build() - } - catch { - case _: Exception => DseCluster.builder().addContactPoint(cassandraHost).withPort(cPort).build() - } - - session = dseCluster.connect() - session - } - - - /** - * Get the session of the current DSE Session - * - * @return - */ - def getSession: DseSession = { - if (session == null) { - createDseSession() - } - session - } - - - /** - * Close the current session - */ - def closeSession(): Unit = { - session.close() - } - -} - -object GatlingDseSession extends GatlingDseSession diff --git a/src/test/scala/com/datastax/gatling/plugin/model/CqlStatementBuildersSpec.scala b/src/test/scala/com/datastax/gatling/plugin/model/CqlStatementBuildersSpec.scala index de40a8f..b3b6189 100644 --- a/src/test/scala/com/datastax/gatling/plugin/model/CqlStatementBuildersSpec.scala +++ b/src/test/scala/com/datastax/gatling/plugin/model/CqlStatementBuildersSpec.scala @@ -1,11 +1,14 @@ package com.datastax.gatling.plugin.model -import com.datastax.driver.core.ConsistencyLevel.{EACH_QUORUM, THREE} -import com.datastax.driver.core._ -import com.datastax.driver.core.policies.FallthroughRetryPolicy +import java.nio.ByteBuffer +import java.time.Duration + import com.datastax.gatling.plugin.DsePredef._ -import com.datastax.gatling.plugin.checks.{CqlChecks, DseCqlCheck, GenericCheck, GenericChecks} -import io.gatling.core.Predef._ +import com.datastax.gatling.plugin.checks.CqlChecks +import com.datastax.oss.driver.api.core.ConsistencyLevel.{EACH_QUORUM, THREE} +import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, SimpleStatement, SimpleStatementBuilder} +import com.datastax.oss.driver.api.core.metadata.Node +import com.datastax.oss.driver.api.core.metadata.token.Token import io.gatling.core.session.{ExpressionSuccessWrapper, Session} import org.scalatest.easymock.EasyMockSugar import org.scalatest.{FlatSpec, Matchers} @@ -15,63 +18,77 @@ class CqlStatementBuildersSpec extends FlatSpec with Matchers with EasyMockSugar it should "build statements from a CQL String" in { - val statementAttributes: DseCqlAttributes = cql("the-tag") - .executeCql("SELECT foo FROM bar.baz LIMIT 1") + val statementAttributes: DseCqlAttributes[SimpleStatement,SimpleStatementBuilder] = cql("the-tag") + .executeStatement("SELECT foo FROM bar.baz LIMIT 1") .build() .dseAttributes - val statement: SimpleStatement = statementAttributes.statement + val statement = statementAttributes.statement .buildFromSession(Session("the-tag", 42)) - .get.asInstanceOf[SimpleStatement] + .get.build statementAttributes.cqlStatements should contain only "SELECT foo FROM bar.baz LIMIT 1" - statement.getQueryString() should be("SELECT foo FROM bar.baz LIMIT 1") + statement.getQuery should be("SELECT foo FROM bar.baz LIMIT 1") } it should "forward all attributs to DseCqlAttributes" in { - val pagingState = mock[PagingState] - val genericCheck = GenericCheck(GenericChecks.exhausted.is(true)) - val cqlCheck = DseCqlCheck(CqlChecks.oneRow.is(mock[Row].expressionSuccess)) - val statementAttributes: DseCqlAttributes = cql("the-session-tag") - .executeCql("FOO") + val node = mock[Node] + val userOrRole = "userOrRole" + val customPayloadKey = "key" + val customPayloadVal = mock[ByteBuffer] + val pagingState = mock[ByteBuffer] + val queryTimestamp = 123L + val routingKey = mock[ByteBuffer] + val routingKeyspace = "some_keyspace" + val routingToken = mock[Token] + val timeout = Duration.ofHours(1) + val cqlCheck = CqlChecks.resultSet.find.is(mock[AsyncResultSet].expressionSuccess).build + val statementAttributes: DseCqlAttributes[_,_] = cql("the-session-tag") + .executeStatement("FOO") .withConsistencyLevel(EACH_QUORUM) - .withUserOrRole("User or role") - .withDefaultTimestamp(-76) + .addCustomPayload(customPayloadKey, customPayloadVal) .withIdempotency() - .withReadTimeout(99) - .withSerialConsistencyLevel(THREE) - .withRetryPolicy(FallthroughRetryPolicy.INSTANCE) - .withFetchSize(3) + .withNode(node) + .executeAs(userOrRole) .withTracingEnabled() + .withPageSize(3) .withPagingState(pagingState) - .check(genericCheck) + .withQueryTimestamp(queryTimestamp) + .withRoutingKey(routingKey) + .withRoutingKeyspace(routingKeyspace) + .withRoutingToken(routingToken) + .withSerialConsistencyLevel(THREE) + .withTimeout(timeout) .check(cqlCheck) .build() .dseAttributes statementAttributes.tag should be("the-session-tag") statementAttributes.cl should be(Some(EACH_QUORUM)) statementAttributes.cqlChecks should contain only cqlCheck - statementAttributes.genericChecks should contain only genericCheck - statementAttributes.userOrRole should be(Some("User or role")) - statementAttributes.readTimeout should be(Some(99)) + statementAttributes.customPayload should be(Some(Map(customPayloadKey -> customPayloadVal))) statementAttributes.idempotent should be(Some(true)) - statementAttributes.defaultTimestamp should be(Some(-76)) + statementAttributes.node should be(Some(node)) + statementAttributes.userOrRole should be(Some(userOrRole)) statementAttributes.enableTrace should be(Some(true)) - statementAttributes.serialCl should be(Some(THREE)) - statementAttributes.fetchSize should be(Some(3)) - statementAttributes.retryPolicy should be(Some(FallthroughRetryPolicy.INSTANCE)) + statementAttributes.pageSize should be(Some(3)) statementAttributes.pagingState should be(Some(pagingState)) + statementAttributes.queryTimestamp should be(Some(queryTimestamp)) + statementAttributes.routingKey should be(Some(routingKey)) + statementAttributes.routingKeyspace should be(Some(routingKeyspace)) + statementAttributes.routingToken should be(Some(routingToken)) + statementAttributes.serialCl should be(Some(THREE)) + statementAttributes.timeout should be(Some(timeout)) statementAttributes.cqlStatements should contain only "FOO" } it should "build statements from a SimpleStatement" in { - val statementAttributes: DseCqlAttributes = cql("the-tag") - .executeStatement(new SimpleStatement("Some CQL")) + val statementAttributes: DseCqlAttributes[SimpleStatement,SimpleStatementBuilder] = cql("the-tag") + .executeStatement(SimpleStatement.newInstance("Some CQL")) .build() .dseAttributes - val statement: SimpleStatement = statementAttributes.statement + val statement = statementAttributes.statement .buildFromSession(Session("the-tag", 42)) - .get.asInstanceOf[SimpleStatement] + .get.build statementAttributes.cqlStatements should contain only "Some CQL" - statement.getQueryString() should be("Some CQL") + statement.getQuery should be("Some CQL") } // it should "build statements from a PreparedStatement" in { diff --git a/src/test/scala/com/datastax/gatling/plugin/request/CqlRequestActionSpec.scala b/src/test/scala/com/datastax/gatling/plugin/request/CqlRequestActionSpec.scala index 553e778..73a7e74 100644 --- a/src/test/scala/com/datastax/gatling/plugin/request/CqlRequestActionSpec.scala +++ b/src/test/scala/com/datastax/gatling/plugin/request/CqlRequestActionSpec.scala @@ -1,21 +1,23 @@ package com.datastax.gatling.plugin.request -import java.util.concurrent.{Executor, TimeUnit} +import java.nio.ByteBuffer +import java.time.Duration +import java.util.concurrent.{CompletableFuture, CompletionStage} import akka.actor.{ActorSystem, Props} import akka.testkit.TestKitBase import ch.qos.logback.classic.{Level, Logger} import ch.qos.logback.classic.spi.ILoggingEvent import ch.qos.logback.core.read.ListAppender -import com.datastax.driver.core._ -import com.datastax.driver.core.policies.FallthroughRetryPolicy -import com.datastax.driver.dse.DseSession import com.datastax.gatling.plugin.base.BaseSpec import com.datastax.gatling.plugin.metrics.NoopMetricsLogger import com.datastax.gatling.plugin.utils.GatlingTimingSource import com.datastax.gatling.plugin.DseProtocol import com.datastax.gatling.plugin.model.{DseCqlAttributes, DseCqlStatement} -import com.google.common.util.concurrent.{Futures, ListenableFuture} +import com.datastax.oss.driver.api.core.{ConsistencyLevel, CqlSession} +import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, SimpleStatement, SimpleStatementBuilder} +import com.datastax.oss.driver.api.core.metadata.Node +import com.datastax.oss.driver.api.core.metadata.token.Token import io.gatling.commons.validation.SuccessWrapper import io.gatling.core.action.Exit import io.gatling.core.config.GatlingConfiguration @@ -28,39 +30,36 @@ import org.slf4j.LoggerFactory class CqlRequestActionSpec extends BaseSpec with TestKitBase { implicit lazy val system:ActorSystem = ActorSystem() val gatlingTestConfig: GatlingConfiguration = GatlingConfiguration.loadForTest() - val dseSession: DseSession = mock[DseSession] - val dseCqlStatement: DseCqlStatement = mock[DseCqlStatement] - val pagingState: PagingState = mock[PagingState] + val cqlSession: CqlSession = mock[CqlSession] + val dseCqlStatement: DseCqlStatement[SimpleStatement,SimpleStatementBuilder] = mock[DseCqlStatement[SimpleStatement,SimpleStatementBuilder]] + val node:Node = mock[Node] + val pageSize = 3 + val pagingState: ByteBuffer = mock[ByteBuffer] + val queryTimestamp = 123L + val routingKey:ByteBuffer = mock[ByteBuffer] + val routingKeyspace = "some_keyspace" + val routingToken:Token = mock[Token] + val timeout:Duration = Duration.ofHours(1) val statsEngine: StatsEngine = mock[StatsEngine] val gatlingSession = Session("scenario", 1) - def getTarget(dseAttributes: DseCqlAttributes): CqlRequestAction = { + def getTarget(dseAttributes: DseCqlAttributes[SimpleStatement,SimpleStatementBuilder]): CqlRequestAction[SimpleStatement,SimpleStatementBuilder] = { new CqlRequestAction( "sample-dse-request", new Exit(system.actorOf(Props[DseRequestActor]), statsEngine), system, statsEngine, - DseProtocol(dseSession), + DseProtocol(cqlSession), dseAttributes, NoopMetricsLogger(), executorServiceForTests(), GatlingTimingSource()) } - private def mockResultSetFuture(): ResultSetFuture = new ResultSetFuture { - val delegate: ListenableFuture[ResultSet] = Futures.immediateFuture(mock[ResultSet]) - override def cancel(b: Boolean): Boolean = false - override def getUninterruptibly: ResultSet = delegate.get() - override def getUninterruptibly(duration: Long, timeUnit: TimeUnit): ResultSet = delegate.get(duration, timeUnit) - override def addListener(listener: Runnable, executor: Executor): Unit = delegate.addListener(listener, executor) - override def isCancelled: Boolean = delegate.isCancelled - override def isDone: Boolean = delegate.isDone - override def get(): ResultSet = delegate.get() - override def get(timeout: Long, unit: TimeUnit): ResultSet = delegate.get(timeout, unit) - } + private def mockAsyncResultSetFuture(): CompletionStage[AsyncResultSet] = CompletableFuture.completedFuture(mock[AsyncResultSet]) before { - reset(dseCqlStatement, dseSession, pagingState, statsEngine) + reset(dseCqlStatement, cqlSession, pagingState, statsEngine) } override protected def afterAll(): Unit = { @@ -68,19 +67,19 @@ class CqlRequestActionSpec extends BaseSpec with TestKitBase { } describe("CQL") { - val statementCapture = EasyMock.newCapture[RegularStatement] + val statementCapture = EasyMock.newCapture[SimpleStatement] it("should have default CQL attributes set if nothing passed") { val cqlAttributesWithDefaults = DseCqlAttributes( "test", dseCqlStatement) expecting { - dseCqlStatement.buildFromSession(gatlingSession).andReturn(new SimpleStatement("select * from test") + dseCqlStatement.buildFromSession(gatlingSession) andReturn(SimpleStatement.builder("select * from test") .success) - dseSession.executeAsync(capture(statementCapture)) andReturn mockResultSetFuture() + cqlSession.executeAsync(capture(statementCapture)) andReturn mockAsyncResultSetFuture() } - whenExecuting(dseCqlStatement, dseSession) { + whenExecuting(dseCqlStatement, cqlSession) { getTarget(cqlAttributesWithDefaults).sendQuery(gatlingSession) } @@ -88,50 +87,53 @@ class CqlRequestActionSpec extends BaseSpec with TestKitBase { capturedStatement shouldBe a[SimpleStatement] capturedStatement.getConsistencyLevel shouldBe null capturedStatement.getSerialConsistencyLevel shouldBe null - capturedStatement.getFetchSize shouldBe 0 - capturedStatement.getDefaultTimestamp shouldBe -9223372036854775808L - capturedStatement.getReadTimeoutMillis shouldBe -2147483648 - capturedStatement.getRetryPolicy shouldBe null + capturedStatement.getPageSize should be <= 0 capturedStatement.isIdempotent shouldBe null capturedStatement.isTracing shouldBe false - capturedStatement.getQueryString should be("select * from test") + capturedStatement.getQuery should be("select * from test") } it("should enable all the CQL Attributes in DseAttributes") { - val cqlAttributes = DseCqlAttributes( + val cqlAttributes = new DseCqlAttributes[SimpleStatement,SimpleStatementBuilder]( "test", dseCqlStatement, cl = Some(ConsistencyLevel.ANY), - userOrRole = Some("test_user"), - readTimeout = Some(12), - defaultTimestamp = Some(1498167845000L), idempotent = Some(true), - fetchSize = Some(50), + node = Some(node), + enableTrace = Some(true), + pagingState = Some(pagingState), + pageSize = Some(pageSize), + queryTimestamp = Some(queryTimestamp), + routingKey = Some(routingKey), + routingKeyspace = Some(routingKeyspace), + routingToken = Some(routingToken), serialCl = Some(ConsistencyLevel.LOCAL_SERIAL), - retryPolicy = Some(FallthroughRetryPolicy.INSTANCE), - enableTrace = Some(true)) + timeout = Some(timeout)) expecting { - dseCqlStatement.buildFromSession(gatlingSession).andReturn(new SimpleStatement("select * from test") + dseCqlStatement.buildFromSession(gatlingSession) andReturn(SimpleStatement.builder("select * from test") .success) - dseSession.executeAsync(capture(statementCapture)) andReturn mockResultSetFuture() + cqlSession.executeAsync(capture(statementCapture)) andReturn mockAsyncResultSetFuture() } - whenExecuting(dseCqlStatement, dseSession) { + whenExecuting(dseCqlStatement, cqlSession) { getTarget(cqlAttributes).sendQuery(gatlingSession) } val capturedStatement = statementCapture.getValue capturedStatement shouldBe a[SimpleStatement] capturedStatement.getConsistencyLevel shouldBe ConsistencyLevel.ANY - capturedStatement.getDefaultTimestamp shouldBe 1498167845000L - capturedStatement.getReadTimeoutMillis shouldBe 12 capturedStatement.isIdempotent shouldBe true - capturedStatement.getFetchSize shouldBe 50 - capturedStatement.getSerialConsistencyLevel shouldBe ConsistencyLevel.LOCAL_SERIAL - capturedStatement.getRetryPolicy shouldBe FallthroughRetryPolicy.INSTANCE - capturedStatement.getQueryString should be("select * from test") + capturedStatement.getNode shouldBe node capturedStatement.isTracing shouldBe true + capturedStatement.getPageSize shouldBe pageSize + capturedStatement.getPagingState shouldBe pagingState + capturedStatement.getQueryTimestamp shouldBe queryTimestamp + capturedStatement.getRoutingKey shouldBe routingKey + capturedStatement.getRoutingKeyspace.toString shouldBe routingKeyspace + capturedStatement.getRoutingToken shouldBe routingToken + capturedStatement.getSerialConsistencyLevel shouldBe ConsistencyLevel.LOCAL_SERIAL + capturedStatement.getTimeout shouldBe timeout } it("should log exceptions encountered") { @@ -147,12 +149,12 @@ class CqlRequestActionSpec extends BaseSpec with TestKitBase { val cqlRequestAction = getTarget(cqlAttributesWithDefaults) - val classLogger = LoggerFactory.getLogger(classOf[CqlRequestAction]).asInstanceOf[Logger] + val classLogger = LoggerFactory.getLogger(classOf[CqlRequestAction[SimpleStatement,SimpleStatementBuilder]]).asInstanceOf[Logger] val listAppender: ListAppender[ILoggingEvent] = new ListAppender[ILoggingEvent] listAppender.start() classLogger.addAppender(listAppender) - whenExecuting(dseCqlStatement, dseSession) { + whenExecuting(dseCqlStatement, cqlSession) { cqlRequestAction.sendQuery(gatlingSession) } diff --git a/src/test/scala/com/datastax/gatling/plugin/request/GraphRequestActionSpec.scala b/src/test/scala/com/datastax/gatling/plugin/request/GraphRequestActionSpec.scala index 2c61336..5d02568 100644 --- a/src/test/scala/com/datastax/gatling/plugin/request/GraphRequestActionSpec.scala +++ b/src/test/scala/com/datastax/gatling/plugin/request/GraphRequestActionSpec.scala @@ -1,18 +1,20 @@ package com.datastax.gatling.plugin.request -import java.util.concurrent.{Executor, TimeUnit} +import java.nio.ByteBuffer +import java.time.Duration +import java.util.concurrent.{CompletableFuture, CompletionStage} import akka.actor.{ActorSystem, Props} import akka.testkit.TestKitBase -import com.datastax.driver.core._ -import com.datastax.driver.dse.DseSession -import com.datastax.driver.dse.graph.{GraphResultSet, RegularGraphStatement, SimpleGraphStatement} +import com.datastax.dse.driver.api.core.DseSession +import com.datastax.dse.driver.api.core.graph.{AsyncGraphResultSet, ScriptGraphStatement, ScriptGraphStatementBuilder} +import com.datastax.gatling.plugin.DseProtocol import com.datastax.gatling.plugin.base.BaseSpec import com.datastax.gatling.plugin.metrics.NoopMetricsLogger import com.datastax.gatling.plugin.utils.GatlingTimingSource -import com.datastax.gatling.plugin.DseProtocol -import com.datastax.gatling.plugin.model.{DseGraphStatement, DseGraphAttributes} -import com.google.common.util.concurrent.{Futures, ListenableFuture} +import com.datastax.gatling.plugin.model.{DseGraphAttributes, DseGraphStatement} +import com.datastax.oss.driver.api.core.ConsistencyLevel +import com.datastax.oss.driver.api.core.metadata.Node import io.gatling.commons.validation.SuccessWrapper import io.gatling.core.action.Exit import io.gatling.core.config.GatlingConfiguration @@ -24,108 +26,82 @@ import org.easymock.EasyMock._ class GraphRequestActionSpec extends BaseSpec with TestKitBase { implicit lazy val system = ActorSystem() val gatlingTestConfig = GatlingConfiguration.loadForTest() - val dseSession = mock[DseSession] - val dseGraphStatement = mock[DseGraphStatement] - val pagingState = mock[PagingState] + val cqlSession = mock[DseSession] + val dseGraphStatement = mock[DseGraphStatement[ScriptGraphStatement,ScriptGraphStatementBuilder]] + val node:Node = mock[Node] + val readConsistencyLevel = ConsistencyLevel.LOCAL_QUORUM + val subProtocol = "graph-binary-3.0" + val timeout:Duration = Duration.ofHours(1) + val timestamp = 123L + val traversalSource = "g.V()" + val writeConsistencyLevel = ConsistencyLevel.LOCAL_QUORUM + + val pagingState:ByteBuffer = mock[ByteBuffer] val statsEngine: StatsEngine = mock[StatsEngine] val gatlingSession = Session("scenario", 1) - def getTarget(dseAttributes: DseGraphAttributes): GraphRequestAction = { + def getTarget(dseAttributes: DseGraphAttributes[ScriptGraphStatement,ScriptGraphStatementBuilder]): + GraphRequestAction[ScriptGraphStatement,ScriptGraphStatementBuilder] = { new GraphRequestAction( "sample-dse-request", new Exit(system.actorOf(Props[DseRequestActor]), statsEngine), system, statsEngine, - DseProtocol(dseSession), + DseProtocol(cqlSession), dseAttributes, NoopMetricsLogger(), executorServiceForTests(), GatlingTimingSource()) } - private def mockResultSetFuture(): ResultSetFuture = new ResultSetFuture { - val delegate: ListenableFuture[ResultSet] = Futures.immediateFuture(mock[ResultSet]) - override def cancel(b: Boolean): Boolean = false - override def getUninterruptibly: ResultSet = delegate.get() - override def getUninterruptibly(duration: Long, timeUnit: TimeUnit): ResultSet = delegate.get(duration, timeUnit) - override def addListener(listener: Runnable, executor: Executor): Unit = delegate.addListener(listener, executor) - override def isCancelled: Boolean = delegate.isCancelled - override def isDone: Boolean = delegate.isDone - override def get(): ResultSet = delegate.get() - override def get(timeout: Long, unit: TimeUnit): ResultSet = delegate.get(timeout, unit) - } + private def mockAsyncGraphResultSetFuture(): CompletionStage[AsyncGraphResultSet] = + CompletableFuture.completedFuture(mock[AsyncGraphResultSet]) before { - reset(dseGraphStatement, dseSession, pagingState, statsEngine) + reset(dseGraphStatement, cqlSession, pagingState, statsEngine) } override protected def afterAll(): Unit = { shutdown(system) } - + describe("Graph") { - val statementCapture = EasyMock.newCapture[RegularGraphStatement] + val statementCapture = EasyMock.newCapture[ScriptGraphStatement] it("should enable all the Graph Attributes in DseAttributes") { - val graphAttributes = DseGraphAttributes("test", dseGraphStatement, + val graphAttributes = new DseGraphAttributes("test", dseGraphStatement, cl = Some(ConsistencyLevel.ANY), - userOrRole = Some("test_user"), - readTimeout = Some(12), - defaultTimestamp = Some(1498167845000L), idempotent = Some(true), - readCL = Some(ConsistencyLevel.LOCAL_QUORUM), - writeCL = Some(ConsistencyLevel.LOCAL_QUORUM), + node = Some(node), graphName = Some("MyGraph"), - graphLanguage = Some("english"), - graphSource = Some("mysource"), - graphInternalOptions = Some(Seq(("get", "this"))), - graphTransformResults = None + readCL = Some(readConsistencyLevel), + subProtocol = Some(subProtocol), + timeout = Some(timeout), + timestamp = Some(timestamp), + traversalSource = Some(traversalSource), + writeCL = Some(writeConsistencyLevel) ) expecting { - dseGraphStatement.buildFromSession(gatlingSession).andReturn(new SimpleGraphStatement("g.V()").success) - dseSession.executeGraphAsync(capture(statementCapture)) andReturn Futures.immediateFuture(mock[GraphResultSet]) + dseGraphStatement.buildFromSession(gatlingSession) andReturn(ScriptGraphStatement.builder("g.V()").success) + cqlSession.executeAsync(capture(statementCapture)) andReturn mockAsyncGraphResultSetFuture() } - whenExecuting(dseGraphStatement, dseSession) { + whenExecuting(dseGraphStatement, cqlSession) { getTarget(graphAttributes).sendQuery(gatlingSession) } val capturedStatement = statementCapture.getValue - capturedStatement shouldBe a[SimpleGraphStatement] + capturedStatement shouldBe a[ScriptGraphStatement] capturedStatement.getConsistencyLevel shouldBe ConsistencyLevel.ANY - capturedStatement.getDefaultTimestamp shouldBe 1498167845000L - capturedStatement.getReadTimeoutMillis shouldBe 12 capturedStatement.isIdempotent shouldBe true + capturedStatement.getNode shouldBe node capturedStatement.getGraphName shouldBe "MyGraph" - capturedStatement.getGraphLanguage shouldBe "english" - capturedStatement.getGraphReadConsistencyLevel shouldBe ConsistencyLevel.LOCAL_QUORUM - capturedStatement.getGraphWriteConsistencyLevel shouldBe ConsistencyLevel.LOCAL_QUORUM - capturedStatement.getGraphSource shouldBe "mysource" - capturedStatement.isSystemQuery shouldBe false - capturedStatement.getGraphInternalOption("get") shouldBe "this" - } - - it("should override the graph name if system") { - val graphAttributes = DseGraphAttributes( - "test", - dseGraphStatement, - graphName = Some("MyGraph"), - isSystemQuery = Some(true), - ) - - expecting { - dseGraphStatement.buildFromSession(gatlingSession).andReturn(new SimpleGraphStatement("g.V()").success) - dseSession.executeGraphAsync(capture(statementCapture)) andReturn Futures.immediateFuture(mock[GraphResultSet]) - } - - whenExecuting(dseGraphStatement, dseSession) { - getTarget(graphAttributes).sendQuery(gatlingSession) - } - - val capturedStatement = statementCapture.getValue - capturedStatement shouldBe a[SimpleGraphStatement] - capturedStatement.getGraphName shouldBe null - capturedStatement.isSystemQuery shouldBe true + capturedStatement.getReadConsistencyLevel shouldBe readConsistencyLevel + capturedStatement.getSubProtocol shouldBe subProtocol + capturedStatement.getTimeout shouldBe timeout + capturedStatement.getTimestamp shouldBe timestamp + capturedStatement.getTraversalSource shouldBe traversalSource + capturedStatement.getWriteConsistencyLevel shouldBe writeConsistencyLevel } } } diff --git a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/BatchStatementSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/BatchStatementSimulation.scala index e4a7bb9..c9a1d67 100644 --- a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/BatchStatementSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/BatchStatementSimulation.scala @@ -1,8 +1,8 @@ package com.datastax.gatling.plugin.simulations.cql -import com.datastax.driver.core.ResultSet import com.datastax.gatling.plugin.DsePredef._ import com.datastax.gatling.plugin.base.BaseCqlSimulation +import com.datastax.oss.driver.api.core.`type`.UserDefinedType import io.gatling.core.Predef._ import scala.concurrent.duration.DurationInt @@ -17,7 +17,7 @@ class BatchStatementSimulation extends BaseCqlSimulation { val cqlConfig = cql.session(session) //Initialize Gatling DSL with your session - val addressType = session.getCluster.getMetadata.getKeyspace(testKeyspace).getUserType("fullname") + val addressType:UserDefinedType = session.getMetadata.getKeyspace(testKeyspace).flatMap(_.getUserDefinedType(udt_name)).get val simpleId = 1 val preparedId = 2 @@ -42,8 +42,8 @@ class BatchStatementSimulation extends BaseCqlSimulation { val scn = scenario("BatchStatement") .feed(preparedFeed) .exec(insertPreparedCql - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining) is 0) // "normal" INSERTs don't return anything ) setUp( @@ -53,7 +53,7 @@ class BatchStatementSimulation extends BaseCqlSimulation { ) - def createTable: ResultSet = { + private def createTable = { val udt = s""" diff --git a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/BoundCqlTypesSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/BoundCqlTypesSimulation.scala index 89dcf8a..c742a38 100644 --- a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/BoundCqlTypesSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/BoundCqlTypesSimulation.scala @@ -2,11 +2,14 @@ package com.datastax.gatling.plugin.simulations.cql import java.nio.ByteBuffer import java.sql.Timestamp +import java.time.Instant -import com.datastax.driver.core.utils.UUIDs -import com.datastax.driver.core.{DataType, ResultSet} import com.datastax.gatling.plugin.DsePredef._ import com.datastax.gatling.plugin.base.BaseCqlSimulation +import com.datastax.oss.driver.api.core.`type`.{DataTypes, UserDefinedType} +import com.datastax.oss.driver.api.core.cql.Row +import com.datastax.oss.driver.api.core.uuid.Uuids +import com.datastax.oss.driver.api.querybuilder.QueryBuilder import io.gatling.core.Predef._ import scala.concurrent.duration.DurationInt @@ -20,16 +23,16 @@ class BoundCqlTypesSimulation extends BaseCqlSimulation { createTable val cqlConfig = cql.session(session) - val udtType = session.getCluster.getMetadata.getKeyspace(testKeyspace).getUserType("fullname") + val addressType:UserDefinedType = session.getMetadata.getKeyspace(testKeyspace).flatMap(_.getUserDefinedType("fullname")).get - val insertFullName = udtType.newValue() + val insertFullName = addressType.newValue() .setString("firstname", "John") .setString("lastname", "Smith") - val tupleType = session.getCluster.getMetadata.newTupleType(DataType.text(), DataType.text()) + val tupleType = DataTypes.tupleOf(DataTypes.TEXT, DataTypes.TEXT) val insertTuple = tupleType.newValue("one", "two") - val uuid = UUIDs.random() + val uuid = Uuids.random() val preparedStatementInsert = s"""INSERT INTO $testKeyspace.$table_name ( @@ -44,7 +47,10 @@ class BoundCqlTypesSimulation extends BaseCqlSimulation { |:tinyint_type, :time_type, :null_type, :udt_type, :tuple_type, :frozen_set_type, :set_string_type |)""".stripMargin - val preparedStatementSelect = s"""SELECT * FROM $testKeyspace.$table_name WHERE uuid_type = ?""" + val preparedStatementSelect = QueryBuilder.selectFrom(testKeyspace, table_name) + .all() + .whereColumn("uuid_type").isEqualTo(QueryBuilder.bindMarker()) + .build() val preparedInsert = session.prepare(preparedStatementInsert) val preparedSelect = session.prepare(preparedStatementSelect) @@ -58,7 +64,7 @@ class BoundCqlTypesSimulation extends BaseCqlSimulation { val preparedFeed = Iterator.continually( Map( "uuid_type" -> uuid, - "timeuuid_type" -> UUIDs.timeBased(), + "timeuuid_type" -> Uuids.timeBased(), "int_type" -> 1, "text_type" -> "text", "float_type" -> 4.50, @@ -110,37 +116,38 @@ class BoundCqlTypesSimulation extends BaseCqlSimulation { ) // End counter details + // A predicate function can be useful when we want to do multiple comparisons on data in a single row + private def preparedCqlPredicate(row:Row):Boolean = + row.getBoolean("boolean_type") && row.isNull("null_type") val scn = scenario("BoundCqlStatement") .feed(preparedFeed) .exec(insertPreparedCql - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining) is 0) // "normal" INSERTs don't return anything ) .pause(100.millis) .exec(selectPreparedCql .withParams(List("uuid_type")) - .check(rowCount is 1) - .check(columnValue("name") not "") - .check(columnValue("null_type") not "test") - .check(columnValue("boolean_type") is true) + .check(resultSet.transform(_.remaining) is 1) + .check(resultSet.transform(rs => preparedCqlPredicate(rs.one)) is true) ) .pause(100.millis) .feed(counterFeed) .exec(insertCounterPreparedCql - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining) is 0) // "normal" INSERTs don't return anything ) .pause(100.millis) .exec(selectCounterPreparedCql .withParams(List("uuid_type")) - .check(rowCount is 1) - .check(columnValue("counter_type") is 2) + .check(resultSet.transform(_.remaining) is 1) + .check(resultSet.transform(rs => rs.one().getLong("counter_type")) is 2L) ) .pause(100.millis) @@ -151,8 +158,7 @@ class BoundCqlTypesSimulation extends BaseCqlSimulation { global.failedRequests.count.is(0) ) - - def createTable: ResultSet = { + private def createTable = { val udt = s""" @@ -206,12 +212,11 @@ class BoundCqlTypesSimulation extends BaseCqlSimulation { } - def getRandomEpoch: Timestamp = { - val offset: Long = Timestamp.valueOf("2012-01-01 00:00:00").getTime + def getRandomEpoch: Instant = { + val offset = Timestamp.valueOf("2012-01-01 00:00:00").getTime val end = Timestamp.valueOf("2017-01-01 00:00:00").getTime val diff = end - offset + 1 - val time: Long = (offset + (Math.random() * diff)).toLong - new Timestamp(time) + val time = (offset + (Math.random() * diff)).toLong + Instant.ofEpochMilli(time) } - } diff --git a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/NamedStatementSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/NamedStatementSimulation.scala index cf88192..9000064 100644 --- a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/NamedStatementSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/NamedStatementSimulation.scala @@ -2,6 +2,7 @@ package com.datastax.gatling.plugin.simulations.cql import com.datastax.gatling.plugin.DsePredef._ import com.datastax.gatling.plugin.base.BaseCqlSimulation +import com.datastax.oss.driver.api.core.cql.Row import io.gatling.core.Predef._ import scala.concurrent.duration.DurationInt @@ -39,17 +40,20 @@ class NamedStatementSimulation extends BaseCqlSimulation { val selectCql = cql("NamedParam Select Statement") .executeNamed(preparedSelect) + private def selectCqlExtract(row:Row):String = + row.getString("str") + val scn = scenario("NamedStatement") .feed(feeder) .exec(insertCql - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining) is 0) // "normal" INSERTs don't return anything ) .pause(1.seconds) .exec(selectCql - .check(rowCount is 1) - .check(columnValue("str") is insertStr) + .check(resultSet.transform(_.remaining) is 1) + .check(resultSet.transform(rs => selectCqlExtract(rs.one)) is insertStr) ) diff --git a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/PreparedStatementSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/PreparedStatementSimulation.scala index f634abd..45103e3 100644 --- a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/PreparedStatementSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/PreparedStatementSimulation.scala @@ -2,6 +2,8 @@ package com.datastax.gatling.plugin.simulations.cql import com.datastax.gatling.plugin.DsePredef._ import com.datastax.gatling.plugin.base.BaseCqlSimulation +import com.datastax.oss.driver.api.core.cql.Row +import com.datastax.oss.driver.api.querybuilder.QueryBuilder import io.gatling.core.Predef._ import scala.concurrent.duration.DurationInt @@ -19,8 +21,16 @@ class PreparedStatementSimulation extends BaseCqlSimulation { val insertStr = "two" val insertName = "test" - val statementInsert = s"""INSERT INTO $testKeyspace.$table_name (id, str, name) VALUES (?, ?, ?)""" - val statementSelect = s"""SELECT * FROM $testKeyspace.$table_name WHERE id = ? AND str = ?""" + val statementInsert = QueryBuilder.insertInto(testKeyspace, table_name) + .value("id", QueryBuilder.bindMarker()) + .value("str", QueryBuilder.bindMarker()) + .value("name", QueryBuilder.bindMarker()) + .build() + val statementSelect = QueryBuilder.selectFrom(testKeyspace, table_name) + .all() + .whereColumn("id").isEqualTo(QueryBuilder.bindMarker()) + .whereColumn("str").isEqualTo(QueryBuilder.bindMarker()) + .build() val preparedInsert = session.prepare(statementInsert) val preparedSelect = session.prepare(statementSelect) @@ -36,35 +46,37 @@ class PreparedStatementSimulation extends BaseCqlSimulation { ) val insertCql = cql("Insert_Statement") - .executePrepared(preparedInsert) + .executeStatement(preparedInsert) .withParams("${id}", "${str}", "${name}") val selectCql = cql("Select_Statement") - .executePrepared(preparedSelect) + .executeStatement(preparedSelect) .withParams("${id}", "${str}") val selectCqlSessionParam = cql("Select_Statement_Array") - .executePrepared(preparedSelect) + .executeStatement(preparedSelect) .withParams(List("id", "str")) + private def selectCqlExtract(row:Row):String = + row.getString("name") val scnPassed = scenario("ABCPreparedStatement") .feed(feeder) .exec(insertCql - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining) is 0) // "normal" INSERTs don't return anything ) .pause(1.seconds) .exec(selectCql - .check(rowCount is 1) - .check(columnValue("name") is insertName) + .check(resultSet.transform(_.remaining) is 1) + .check(resultSet.transform(rs => selectCqlExtract(rs.one)) is insertName) ) .pause(1.seconds) .exec(selectCqlSessionParam - .check(rowCount is 1) - .check(columnValue("name") is insertName) + .check(resultSet.transform(_.remaining) is 1) + .check(resultSet.transform(rs => selectCqlExtract(rs.one)) is insertName) ) diff --git a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/SimpleStatementSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/SimpleStatementSimulation.scala index a314aad..51fea11 100644 --- a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/SimpleStatementSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/SimpleStatementSimulation.scala @@ -2,6 +2,9 @@ package com.datastax.gatling.plugin.simulations.cql import com.datastax.gatling.plugin.DsePredef._ import com.datastax.gatling.plugin.base.BaseCqlSimulation +import com.datastax.oss.driver.api.core.`type`.DataTypes +import com.datastax.oss.driver.api.core.cql.Row +import com.datastax.oss.driver.api.querybuilder.{QueryBuilder, SchemaBuilder} import io.gatling.core.Predef._ import scala.concurrent.duration.DurationInt @@ -11,33 +14,46 @@ class SimpleStatementSimulation extends BaseCqlSimulation { val table_name = "test_table_simple" createTestKeyspace - createTable + + val createTable = SchemaBuilder.createTable(testKeyspace, table_name) + .ifNotExists + .withPartitionKey("id",DataTypes.INT) + .withClusteringColumn("str", DataTypes.TEXT) + .build + session.execute(createTable) val dseProtocol = dseProtocolBuilder.session(session) //Initialize Gatling DSL with your session val insertId = 1 val insertStr = "one" - val simpleStatementInsert = s"""INSERT INTO $testKeyspace.$table_name (id, str) VALUES ($insertId, '$insertStr')""" - val simpleStatementSelect = s"""SELECT * FROM $testKeyspace.$table_name WHERE id = $insertId""" - - val insertCql = cql("Insert_Statement") - .executeCql(simpleStatementInsert) + val simpleStatementInsert = QueryBuilder.insertInto(testKeyspace, table_name) + .value("id", QueryBuilder.literal(insertId)) + .value("str", QueryBuilder.literal(insertStr)) + .build() + val simpleStatementSelect = QueryBuilder.selectFrom(testKeyspace, table_name) + .all() + .whereColumn("id").isEqualTo(QueryBuilder.literal(insertId)) + .build() - val selectCql = cql("Select_Statement") - .executeCql(simpleStatementSelect) + private def selectCqlExtract(row:Row):String = + row.getString("str") val scn = scenario("SimpleStatement") - .exec(insertCql - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .exec( + cql("Insert_Statement") + .executeStatement(simpleStatementInsert) + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining()) is 0) // "normal" INSERTs don't return anything ) .pause(1.seconds) .group("TestGroup") { - exec(selectCql - .check(rowCount is 1) - .check(columnValue("str") is insertStr) - ) + exec( + cql("Select_Statement") + .executeStatement(simpleStatementSelect) + .check(resultSet.transform(_.getExecutionInfo.getWarnings.size).is(0)) + .check(resultSet.transform(_.remaining()) is 1) + .check(resultSet.transform(rs => selectCqlExtract(rs.one)) is insertStr)) } setUp( @@ -45,16 +61,4 @@ class SimpleStatementSimulation extends BaseCqlSimulation { ).assertions( global.failedRequests.count.is(0) ) - - def createTable = { - val table = - s""" - CREATE TABLE IF NOT EXISTS $testKeyspace.$table_name ( - id int, - str text, - PRIMARY KEY (id) - );""" - - session.execute(table) - } } diff --git a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/UdtStatementSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/UdtStatementSimulation.scala index 1406973..87b7e36 100644 --- a/src/test/scala/com/datastax/gatling/plugin/simulations/cql/UdtStatementSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/simulations/cql/UdtStatementSimulation.scala @@ -1,8 +1,10 @@ package com.datastax.gatling.plugin.simulations.cql -import com.datastax.driver.core.ResultSet import com.datastax.gatling.plugin.DsePredef._ import com.datastax.gatling.plugin.base.BaseCqlSimulation +import com.datastax.oss.driver.api.core.`type`.UserDefinedType +import com.datastax.oss.driver.api.core.cql.Row +import com.datastax.oss.driver.api.querybuilder.QueryBuilder import io.gatling.core.Predef._ import scala.concurrent.duration.DurationInt @@ -17,7 +19,7 @@ class UdtStatementSimulation extends BaseCqlSimulation { val cqlConfig = cql.session(session) //Initialize Gatling DSL with your session - val addressType = session.getCluster.getMetadata.getKeyspace(testKeyspace).getUserType("fullname") + val addressType:UserDefinedType = session.getMetadata.getKeyspace(testKeyspace).flatMap(_.getUserDefinedType("fullname")).get val simpleId = 1 val preparedId = 2 @@ -26,15 +28,20 @@ class UdtStatementSimulation extends BaseCqlSimulation { .setString("firstname", "John") .setString("lastname", "Smith") - - val simpleStatementInsert = s"""INSERT INTO $testKeyspace.$table_name (id, name) VALUES ($simpleId, $insertFullName)""" - val simpleStatementSelect = s"""SELECT * FROM $testKeyspace.$table_name WHERE id = $simpleId""" + val simpleStatementInsert = QueryBuilder.insertInto(testKeyspace, table_name) + .value("id", QueryBuilder.literal(simpleId)) + .value("name", QueryBuilder.literal(insertFullName)) + .build() + val simpleStatementSelect = QueryBuilder.selectFrom(testKeyspace, table_name) + .all() + .whereColumn("id").isEqualTo(QueryBuilder.literal(simpleId)) + .build() val insertCql = cql("Insert Simple Statements") - .executeCql(simpleStatementInsert) + .executeStatement(simpleStatementInsert) val selectCql = cql("Select Simple Statement") - .executeCql(simpleStatementSelect) + .executeStatement(simpleStatementSelect) val preparedStatementInsert = s"""INSERT INTO $testKeyspace.$table_name (id, name) VALUES (?, ?)""" val preparedStatementSelect = s"""SELECT * FROM $testKeyspace.$table_name WHERE id = ?""" @@ -74,44 +81,47 @@ class UdtStatementSimulation extends BaseCqlSimulation { ) ) + private def extractFirstNameFromRow(row:Row):String = + row.getUdtValue("name").getString(0) + val scn = scenario("SimpleStatement") .exec(insertCql - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining) is 0) // "normal" INSERTs don't return anything ) .pause(100.millis) .exec(selectCql - .check(rowCount is 1) - .check(columnValue("name").find(0) not "") + .check(resultSet.transform(_.remaining) is 1) + .check(resultSet.transform(rs => extractFirstNameFromRow(rs.one)) not "") ) .pause(100.millis) .feed(preparedFeed) .exec(insertPreparedCql .withParams(List("id", "fullname")) - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining) is 0) // "normal" INSERTs don't return anything ) .pause(100.millis) .exec(selectPreparedCql .withParams(List("id")) - .check(rowCount is 1) - .check(columnValue("name").find(0) not "") + .check(resultSet.transform(_.remaining) is 1) + .check(resultSet.transform(rs => extractFirstNameFromRow(rs.one)) not "") ) .pause(100.millis) .feed(namedFeed) .exec(insertNamedCql - .check(exhausted is true) - .check(rowCount is 0) // "normal" INSERTs don't return anything + .check(resultSet.transform(_.hasMorePages) is false) + .check(resultSet.transform(_.remaining) is 0) // "normal" INSERTs don't return anything ) .pause(100.millis) .exec(selectNamedCql - .check(rowCount is 1) - .check(columnValue("name").find(0) not "") + .check(resultSet.transform(_.remaining) is 1) + .check(resultSet.transform(rs => extractFirstNameFromRow(rs.one)) not "") ) setUp( @@ -120,8 +130,7 @@ class UdtStatementSimulation extends BaseCqlSimulation { global.failedRequests.count.is(0) ) - - def createTable: ResultSet = { + private def createTable = { val udt = s""" diff --git a/src/test/scala/com/datastax/gatling/plugin/simulations/graph/GraphStatementSimulation.scala b/src/test/scala/com/datastax/gatling/plugin/simulations/graph/GraphStatementSimulation.scala index 8f2fce2..8476488 100644 --- a/src/test/scala/com/datastax/gatling/plugin/simulations/graph/GraphStatementSimulation.scala +++ b/src/test/scala/com/datastax/gatling/plugin/simulations/graph/GraphStatementSimulation.scala @@ -1,12 +1,10 @@ package com.datastax.gatling.plugin.simulations.graph -import com.datastax.driver.core.ConsistencyLevel -import com.datastax.driver.dse.graph.{GraphStatement, SimpleGraphStatement} -import com.datastax.dse.graph.api.DseGraph +import com.datastax.dse.driver.api.core.graph.{DseGraph, FluentGraphStatement, ScriptGraphStatement} import com.datastax.gatling.plugin.DsePredef._ import com.datastax.gatling.plugin.base.BaseGraphSimulation +import com.datastax.oss.driver.api.core.ConsistencyLevel import io.gatling.core.Predef._ -import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversal import org.scalatest.Ignore import scala.concurrent.duration.DurationInt @@ -15,42 +13,38 @@ import scala.concurrent.duration.DurationInt class GraphStatementSimulation extends BaseGraphSimulation { val table_name = "test_table" - - session.getCluster.getConfiguration.getGraphOptions.setGraphName("demo") + val graph_name = "demo" val graphConfig = graph.session(session) //Initialize Gatling DSL with your session val r = scala.util.Random - val graphStatement = new SimpleGraphStatement("g.addV(label, vertexLabel).property('type', myType)") - def getInt: String = { "test_" + r.nextInt(100).toString } + val insertStatement = ScriptGraphStatement.newInstance("g.addV(label, vertexLabel).property('type', myType)") val insertGraph = graph("Graph Statement") - .executeGraphStatement(graphStatement) - .withSetParams(Array("vertexLabel", "myType")) - .consistencyLevel(ConsistencyLevel.LOCAL_ONE) + .executeGraph(insertStatement) + .withParams("vertexLabel", "myType") + .withConsistencyLevel(ConsistencyLevel.LOCAL_ONE) + .withTraversalSource("g") + .withName(graph_name) val queryGraph = graph("Graph Query") .executeGraph("g.V().limit(5)") + .withTraversalSource("g") + .withName(graph_name) - val g = DseGraph.traversal(session) - val t: GraphTraversal[_,_] = g.V().limit(5) - val st: GraphStatement = DseGraph.statementFromTraversal(t) - + val queryStatement: FluentGraphStatement = FluentGraphStatement.newInstance(DseGraph.g.V().limit(5)) val queryGraphNative = graph("Graph Fluent") - .executeGraphFluent(st) - - val queryGraphFeederTraversal = graph("Graph Feeder") - .executeGraphFeederTraversal("traversal") + .executeGraphFluent(queryStatement) + .withName(graph_name) val feeder = Iterator.continually( Map[String, Any]( "vertexLabel" -> getInt, - "myType" -> r.nextInt(100), - "traversal" -> t + "myType" -> r.nextInt(100) ) ) @@ -60,12 +54,10 @@ class GraphStatementSimulation extends BaseGraphSimulation { .pause(1.seconds) .exec(queryGraph - .check(rowCount greaterThan 1) + .check(graphResultSet.transform(_.remaining) greaterThan 1) ) .pause(1.seconds) .exec(queryGraphNative) - .pause(1.seconds) - .exec(queryGraphFeederTraversal) .exec(session => { // println(session("test").asOption[String].toString) session diff --git a/src/test/scala/com/datastax/gatling/plugin/utils/CqlPreparedStatementUtilSpec.scala b/src/test/scala/com/datastax/gatling/plugin/utils/CqlPreparedStatementUtilSpec.scala index d373c2c..23435c9 100644 --- a/src/test/scala/com/datastax/gatling/plugin/utils/CqlPreparedStatementUtilSpec.scala +++ b/src/test/scala/com/datastax/gatling/plugin/utils/CqlPreparedStatementUtilSpec.scala @@ -3,12 +3,16 @@ package com.datastax.gatling.plugin.utils import java.math.BigInteger import java.net.InetAddress import java.nio.ByteBuffer +import java.time.{Instant, LocalDate, LocalTime} +import java.util.Optional -import com.datastax.driver.core.utils.UUIDs -import com.datastax.driver.core.{DataType, _} -import com.datastax.driver.dse.geometry._ +import com.datastax.dse.driver.api.core.data.geometry._ import com.datastax.gatling.plugin.base.BaseCassandraServerSpec import com.datastax.gatling.plugin.exceptions.CqlTypeException +import com.datastax.oss.driver.api.core.`type`.{DataTypes, UserDefinedType} +import com.datastax.oss.driver.api.core.cql.BoundStatement +import com.datastax.oss.driver.api.core.data.{TupleValue, UdtValue} +import com.datastax.oss.driver.api.core.uuid.Uuids import com.github.nscala_time.time.Imports.DateTime import io.gatling.core.session.Session @@ -39,12 +43,13 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { "floatStr" -> "12.4", "double" -> 12.0, "epoch" -> 1483299340813L, + "epochInstant" -> Instant.ofEpochMilli(1483299340813L), "number" -> 12.asInstanceOf[Number], "inetStr" -> "127.0.0.1", "inet" -> InetAddress.getByName("127.0.0.1"), - "localDate" -> LocalDate.fromMillisSinceEpoch(1483299340813L), + "localDate" -> CqlPreparedStatementUtil.toLocalDate(1483299340813L), "stringDate" -> "2016-10-05", "set" -> Set(1), @@ -68,12 +73,13 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { "javaNumber" -> 12.asInstanceOf[Number], "javaDate" -> new java.util.Date(), + "javaInstant" -> Instant.now(), "isoDateString" -> "2008-03-01T13:00:00Z", "dateString" -> "2016-10-05", - "uuid" -> UUIDs.random(), + "uuid" -> Uuids.random(), "uuidString" -> "252a3806-b8be-42d3-929d-4cbb380a433e", - "timeUuid" -> UUIDs.timeBased(), + "timeUuid" -> Uuids.timeBased(), "byte" -> 12.toByte, "short" -> 12.toShort, @@ -88,9 +94,9 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { "null_type" -> null, - "point_type" -> new Point(1.0, 1.0), - "linestring_type" -> new LineString(new Point(1.0, 1.0), new Point(2.0, 2.0)), - "polygon_type" -> new Polygon(new Point(1.0, 1.0), new Point(2.0, 2.0), new Point(3.0, 3.0), new Point(4.0, 4.0)) + "point_type" -> Point.fromCoordinates(1.0, 1.0), + "linestring_type" -> LineString.fromPoints(Point.fromCoordinates(1.0, 1.0), Point.fromCoordinates(2.0, 2.0)), + "polygon_type" -> Polygon.fromPoints(Point.fromCoordinates(1.0, 1.0), Point.fromCoordinates(2.0, 2.0), Point.fromCoordinates(3.0, 3.0), Point.fromCoordinates(4.0, 4.0)) ) val defaultGatlingSession: Session = gatlingSession.setAll(defaultSessionVars) @@ -105,24 +111,22 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { it("should return a list of types") { - val preparedStatement = dseSession.prepare(s"SELECT * FROM $keyspace.$table where id = ?") + val preparedStatement = cqlSession.prepare(s"SELECT * FROM $keyspace.$table where id = ?") val paramList = CqlPreparedStatementUtil.getParamsList(preparedStatement) - paramList should contain(DataType.Name.INT) + paramList should contain(DataTypes.INT) } - } describe("getParamsMap") { it("should return a map of types") { - val preparedStatement = dseSession.prepare(s"SELECT * FROM $keyspace.$table where id = :id") + val preparedStatement = cqlSession.prepare(s"SELECT * FROM $keyspace.$table where id = :id") val paramsMap = CqlPreparedStatementUtil.getParamsMap(preparedStatement) - paramsMap("id") shouldBe DataType.Name.INT + paramsMap("id") shouldBe DataTypes.INT } - } } @@ -387,28 +391,28 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { describe("asDate") { it("should accept a date string") { - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "stringDate") shouldBe a[LocalDate] - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "stringDate").getDay.equals(5) + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "stringDate") shouldBe a[LocalDate] + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "stringDate").getDayOfMonth.equals(5) } it("should accept a long") { - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "long") shouldBe a[LocalDate] - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "long").getDay.equals(1) + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "long") shouldBe a[LocalDate] + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "long").getDayOfMonth.equals(1) } it("should accept an int") { - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "int") shouldBe a[LocalDate] - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "int").getDay.equals(13) + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "int") shouldBe a[LocalDate] + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "int").getDayOfMonth.equals(13) } it("should accept an native localDate") { - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "localDate") shouldBe a[LocalDate] - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "localDate").getDay.equals(1) + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "localDate") shouldBe a[LocalDate] + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "localDate").getDayOfMonth.equals(1) } it("should not accept a float and produce a CqlTypeException") { intercept[CqlTypeException] { - CqlPreparedStatementUtil.asDate(defaultGatlingSession, "float") shouldBe a[LocalDate] + CqlPreparedStatementUtil.asLocalDate(defaultGatlingSession, "float") shouldBe a[LocalDate] } } @@ -418,83 +422,78 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { describe("asSet") { it("should accept a scala set") { - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "set") shouldBe a[java.util.Set[_]] - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "set") shouldBe + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "set", classOf[Int]) shouldBe a[java.util.Set[_]] + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "set", classOf[Int]) shouldBe defaultSessionVars("set").asInstanceOf[Set[Int]].asJava } it("should accept a java set") { - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "setJava") shouldBe a[java.util.Set[_]] - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "setJava") shouldBe + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "setJava", classOf[Int]) shouldBe a[java.util.Set[_]] + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "setJava", classOf[Int]) shouldBe defaultSessionVars("set").asInstanceOf[Set[Int]].asJava } it("should accept a scala seq of ints") { - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "seq") shouldBe a[java.util.Set[_]] - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "seq") shouldBe + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "seq", classOf[Int]) shouldBe a[java.util.Set[_]] + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "seq", classOf[Int]) shouldBe defaultSessionVars("set").asInstanceOf[Set[Int]].asJava } it("should accept a scala seq of strings") { - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "seqString") shouldBe a[java.util.Set[_]] - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "seqString") shouldBe + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "seqString", classOf[String]) shouldBe a[java.util.Set[_]] + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "seqString", classOf[String]) shouldBe defaultSessionVars("setString").asInstanceOf[Set[String]].asJava } it("should not accept a float and produce a CqlTypeException") { intercept[CqlTypeException] { - CqlPreparedStatementUtil.asSet(defaultGatlingSession, "float") shouldBe a[java.util.Set[_]] + CqlPreparedStatementUtil.asSet(defaultGatlingSession, "float", classOf[Float]) shouldBe a[java.util.Set[_]] } } - } - describe("asList") { it("should accept a scala set") { - CqlPreparedStatementUtil.asList(defaultGatlingSession, "list") shouldBe a[java.util.List[_]] - CqlPreparedStatementUtil.asList(defaultGatlingSession, "list") shouldBe + CqlPreparedStatementUtil.asList(defaultGatlingSession, "list", classOf[Int]) shouldBe a[java.util.List[_]] + CqlPreparedStatementUtil.asList(defaultGatlingSession, "list", classOf[Int]) shouldBe defaultSessionVars("list").asInstanceOf[List[Int]].asJava } it("should accept a java set") { - CqlPreparedStatementUtil.asList(defaultGatlingSession, "listJava") shouldBe a[java.util.List[_]] + CqlPreparedStatementUtil.asList(defaultGatlingSession, "listJava", classOf[Int]) shouldBe a[java.util.List[_]] } it("should accept a scala seq") { - CqlPreparedStatementUtil.asList(defaultGatlingSession, "seq") shouldBe a[java.util.List[_]] - CqlPreparedStatementUtil.asList(defaultGatlingSession, "seq") shouldBe + CqlPreparedStatementUtil.asList(defaultGatlingSession, "seq", classOf[Int]) shouldBe a[java.util.List[_]] + CqlPreparedStatementUtil.asList(defaultGatlingSession, "seq", classOf[Int]) shouldBe defaultSessionVars("list").asInstanceOf[List[Int]].asJava } it("should not accept a float and produce a CqlTypeException") { intercept[CqlTypeException] { - CqlPreparedStatementUtil.asList(defaultGatlingSession, "float") shouldBe a[java.util.List[_]] + CqlPreparedStatementUtil.asList(defaultGatlingSession, "float", classOf[Float]) shouldBe a[java.util.List[_]] } } - } - describe("asMap") { - it("should accept a scala set") { - CqlPreparedStatementUtil.asMap(defaultGatlingSession, "map") shouldBe a[java.util.Map[_, _]] - CqlPreparedStatementUtil.asMap(defaultGatlingSession, "map") shouldBe + it("should accept a scala map") { + CqlPreparedStatementUtil.asMap(defaultGatlingSession, "map", classOf[Int], classOf[Int]) shouldBe a[java.util.Map[_,_]] + CqlPreparedStatementUtil.asMap(defaultGatlingSession, "map", classOf[Int], classOf[Int]) shouldBe defaultSessionVars("map").asInstanceOf[Map[Int, Int]].asJava } - it("should accept a java set") { - CqlPreparedStatementUtil.asMap(defaultGatlingSession, "mapJava") shouldBe a[java.util.Map[_, _]] + it("should accept a java map") { + CqlPreparedStatementUtil.asMap(defaultGatlingSession, "mapJava", classOf[Int], classOf[Int]) shouldBe a[java.util.Map[_,_]] } it("should not accept a float and produce a CqlTypeException") { intercept[CqlTypeException] { - CqlPreparedStatementUtil.asMap(defaultGatlingSession, "float").get(1) shouldBe 1 + CqlPreparedStatementUtil.asMap(defaultGatlingSession, "float", classOf[Int], classOf[Int]).get(1) shouldBe 1 } } - } describe("asInet") { @@ -520,21 +519,21 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { describe("asTime") { it("should accept a Long") { - CqlPreparedStatementUtil.asTime(defaultGatlingSession, "long") shouldBe a[java.lang.Long] + CqlPreparedStatementUtil.asTime(defaultGatlingSession, "long") shouldBe a[LocalTime] } describe("should accept a String") { it("should accept a String time w/o nanoseconds") { val validHour = CqlPreparedStatementUtil.asTime(defaultGatlingSession, "hourTime") - validHour shouldBe a[java.lang.Long] - validHour shouldBe 3661000000000L + validHour shouldBe a[LocalTime] + validHour.toNanoOfDay shouldBe 3661000000000L } it("should accept a String time w/ nanoseconds") { val validNano = CqlPreparedStatementUtil.asTime(defaultGatlingSession, "nanoTime") - validNano shouldBe a[java.lang.Long] - validNano shouldBe 3661343000000L + validNano shouldBe a[LocalTime] + validNano.toNanoOfDay shouldBe 3661343000000L } it("should not accept invalid hour and produce a CqlTypeException") { @@ -704,38 +703,37 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { } - describe("asTimestamp") { + describe("asInstant") { it("should accept an epoch long") { - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "epoch") shouldBe a[java.util.Date] - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "epoch") shouldBe - new java.util.Date(defaultSessionVars("epoch").asInstanceOf[Long]) + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "epoch") shouldBe a[Instant] + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "epoch") shouldBe + defaultSessionVars("epochInstant") } - it("should accept a java Date") { - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "javaDate") shouldBe a[java.util.Date] - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "javaDate") shouldBe - defaultSessionVars("javaDate").asInstanceOf[java.util.Date] + it("should accept a java Instant") { + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "javaInstant") shouldBe a[Instant] + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "javaInstant") shouldBe + defaultSessionVars("javaInstant") } it("should accept a date string") { - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "dateString") shouldBe a[java.util.Date] - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "dateString") shouldBe - DateTime.parse(defaultSessionVars("dateString").asInstanceOf[String]).toDate + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "dateString") shouldBe a[Instant] + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "dateString") shouldBe + Instant.ofEpochMilli(DateTime.parse(defaultSessionVars("dateString").toString).getMillis) } it("should accept a isoDateString string") { - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "isoDateString") shouldBe a[java.util.Date] - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "isoDateString") shouldBe - DateTime.parse(defaultSessionVars("isoDateString").asInstanceOf[String]).toDate + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "isoDateString") shouldBe a[Instant] + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "isoDateString") shouldBe + Instant.ofEpochMilli(DateTime.parse(defaultSessionVars("isoDateString").toString).getMillis) } it("should not accept a float and produce a CqlTypeException") { intercept[CqlTypeException] { - CqlPreparedStatementUtil.asTimestamp(defaultGatlingSession, "float") shouldBe a[java.util.Date] + CqlPreparedStatementUtil.asInstant(defaultGatlingSession, "float") shouldBe a[Instant] } } - } describe("asByte") { @@ -770,9 +768,10 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { createType(keyspace, typeName, "firstname text, lastname text") createTable(keyspace, table, "id int, name frozen, PRIMARY KEY(id)") - val addressType = dseSession.getCluster.getMetadata.getKeyspace(keyspace).getUserType(typeName) + val addressType:Optional[UserDefinedType] = cqlSession.getMetadata.getKeyspace(keyspace).flatMap(_.getUserDefinedType(typeName)) + addressType should not be Optional.empty - val insertFullName = addressType.newValue() + val insertFullName = addressType.get.newValue() .setString("firstname", "John") .setString("lastname", "Smith") @@ -781,23 +780,19 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { val udtSession: Session = gatlingSession.setAll(newSessionVars) it("should accept a UDTValue") { - CqlPreparedStatementUtil.asUdt(udtSession, "fullname") shouldBe a[UDTValue] + CqlPreparedStatementUtil.asUdt(udtSession, "fullname") shouldBe a[UdtValue] } it("should not accept a float and produce a CqlTypeException") { intercept[CqlTypeException] { - CqlPreparedStatementUtil.asUdt(udtSession, "invalid") shouldBe a[UDTValue] + CqlPreparedStatementUtil.asUdt(udtSession, "invalid") shouldBe a[UdtValue] } } - } describe("asTuple") { - val table = "tuple_test" - createTable(keyspace, table, "id int, tuple_type tuple, PRIMARY KEY (id)") - - val tupleType = dseSession.getCluster.getMetadata.newTupleType(DataType.varchar(), DataType.varchar()) + val tupleType = DataTypes.tupleOf(DataTypes.TEXT, DataTypes.TEXT) val insertTuple = tupleType.newValue("test", "test2") val newSessionVars = Map("tuple_type" -> insertTuple, "invalid" -> "string") val tupleSession: Session = gatlingSession.setAll(newSessionVars) @@ -811,21 +806,17 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { CqlPreparedStatementUtil.asTuple(tupleSession, "invalid") shouldBe a[TupleValue] } } - } - } - describe("boundStatementFunctions") { - val typeName = "fullname2" createType(keyspace, typeName, "firstname text, lastname text") val tableName = "type_table" createTable(keyspace, tableName, "uuid_type uuid, timeuuid_type timeuuid, int_type int, text_type text, " + - "varchar_type varchar, ascii_type ascii, float_type float, double_type double, decimal_type decimal, " + + "ascii_type ascii, float_type float, double_type double, decimal_type decimal, " + "boolean_type boolean, inet_type inet, timestamp_type timestamp, bigint_type bigint, blob_type blob, " + "varint_type varint, list_type list, set_type set, map_type map, date_type date, " + "smallint_type smallint, tinyint_type tinyint, time_type time, tuple_type tuple, " + @@ -843,7 +834,6 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { |timeuuid_type, |int_type, |text_type, - |varchar_type, |ascii_type, |float_type, |double_type, @@ -866,168 +856,165 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { |null_type, |none_type) |VALUES - |(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""".stripMargin + |(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""".stripMargin - val boundStatementKeys = dseSession.prepare(preparedStatementInsert).bind() + val boundStatementKeys = cqlSession.prepare(preparedStatementInsert).bind() it("should bind with a UUID") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.UUID, "uuid", 0) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.UUID, "uuid", 0) result shouldBe a[BoundStatement] result.isSet(0) shouldBe true } it("should bind with a timeUuid") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.TIMEUUID, "timeUuid", 1) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.TIMEUUID, "timeUuid", 1) result shouldBe a[BoundStatement] result.isSet(1) shouldBe true } it("should bind with a int") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.INT, "int", 2) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.INT, "int", 2) result shouldBe a[BoundStatement] result.isSet(2) shouldBe true } it("should bind with a text") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.TEXT, "string", 3) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.TEXT, "string", 3) result shouldBe a[BoundStatement] result.isSet(3) shouldBe true } - it("should bind with a varchar") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.VARCHAR, "string", 4) - result shouldBe a[BoundStatement] - result.isSet(4) shouldBe true - } - it("should bind with a ascii") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.ASCII, "string", 5) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.ASCII, "string", 4) result shouldBe a[BoundStatement] - result.isSet(5) shouldBe true + result.isSet(4) shouldBe true } it("should bind with a float") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.FLOAT, "float", 6) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.FLOAT, "float", 5) result shouldBe a[BoundStatement] - result.isSet(6) shouldBe true + result.isSet(5) shouldBe true } it("should bind with a double") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.DOUBLE, "double", 7) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.DOUBLE, "double", 6) result shouldBe a[BoundStatement] - result.isSet(7) shouldBe true + result.isSet(6) shouldBe true } it("should bind with a decimal") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.DECIMAL, "double", 8) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.DECIMAL, "double", 7) result shouldBe a[BoundStatement] - result.isSet(8) shouldBe true + result.isSet(7) shouldBe true } it("should bind with a boolean") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.BOOLEAN, "boolean", 9) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.BOOLEAN, "boolean", 8) result shouldBe a[BoundStatement] - result.isSet(9) shouldBe true + result.isSet(8) shouldBe true } it("should bind with a inetAddress") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.INET, "inetStr", 10) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.INET, "inetStr", 9) result shouldBe a[BoundStatement] - result.isSet(10) shouldBe true + result.isSet(9) shouldBe true } it("should bind with a timestamp") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.TIMESTAMP, "epoch", 11) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.TIMESTAMP, "epochInstant", 10) result shouldBe a[BoundStatement] - result.isSet(11) shouldBe true + result.isSet(10) shouldBe true } it("should bind with a bigInt") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.BIGINT, "bigInteger", 12) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.BIGINT, "bigInteger", 11) result shouldBe a[BoundStatement] - result.isSet(12) shouldBe true + result.isSet(11) shouldBe true } it("should bind with a blob_type") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.BLOB, "byteArray", 13) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.BLOB, "byteArray", 12) result shouldBe a[BoundStatement] - result.isSet(13) shouldBe true + result.isSet(12) shouldBe true } it("should bind with a varint") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.VARINT, "int", 14) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.VARINT, "int", 13) result shouldBe a[BoundStatement] - result.isSet(14) shouldBe true + result.isSet(13) shouldBe true } it("should bind with a list") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.LIST, "list", 15) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.listOf(DataTypes.INT), "list", 14) result shouldBe a[BoundStatement] - result.isSet(15) shouldBe true + result.isSet(14) shouldBe true } it("should bind with a set") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.SET, "set", 16) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.setOf(DataTypes.INT), "set", 15) result shouldBe a[BoundStatement] - result.isSet(16) shouldBe true + result.isSet(15) shouldBe true } it("should bind with a map") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.MAP, "map", 17) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.mapOf(DataTypes.INT, DataTypes.INT), "map", 16) result shouldBe a[BoundStatement] - result.isSet(17) shouldBe true + result.isSet(16) shouldBe true } it("should bind with a date") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.DATE, "epoch", 18) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.DATE, "epoch", 17) result shouldBe a[BoundStatement] - result.isSet(18) shouldBe true + result.isSet(17) shouldBe true } it("should bind with a smallInt") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.SMALLINT, "int", 19) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.SMALLINT, "int", 18) result shouldBe a[BoundStatement] - result.isSet(19) shouldBe true + result.isSet(18) shouldBe true } it("should bind with a tinyint") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.TINYINT, "int", 20) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.TINYINT, "int", 19) result shouldBe a[BoundStatement] - result.isSet(20) shouldBe true + result.isSet(19) shouldBe true } it("should bind with a time") { - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.TIME, "epoch", 21) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.TIME, "epoch", 20) result shouldBe a[BoundStatement] - result.isSet(21) shouldBe true + result.isSet(20) shouldBe true } it("should bind with a tuple") { - val tupleType = dseSession.getCluster.getMetadata.newTupleType(DataType.varchar(), DataType.varchar()) + val tupleType = DataTypes.tupleOf(DataTypes.TEXT, DataTypes.TEXT) val insertTuple = tupleType.newValue("test", "test2") val newSessionVars = Map("tuple_type" -> insertTuple, "invalid" -> "string") val tupleSession: Session = gatlingSession.setAll(newSessionVars) - val result = CqlPreparedStatementUtil.bindParamByOrder(tupleSession, boundStatementKeys, DataType.Name.TUPLE, "tuple_type", 22) + val result = CqlPreparedStatementUtil.bindParamByOrder(tupleSession, boundStatementKeys, tupleType, "tuple_type", 21) result shouldBe a[BoundStatement] - result.isSet(22) shouldBe true + result.isSet(21) shouldBe true } it("should bind with a udt") { - val addressType = dseSession.getCluster.getMetadata.getKeyspace(keyspace).getUserType("fullname2") - val insertFullName = addressType.newValue() - .setString("firstname", "John") - .setString("lastname", "Smith") + val addressType:Optional[UserDefinedType] = cqlSession.getMetadata.getKeyspace(keyspace).flatMap(_.getUserDefinedType("fullname2")) + addressType should not be Optional.empty + + val insertFullName = addressType.get.newValue() + .setString("firstname", "John") + .setString("lastname", "Smith") + val newSessionVars = Map("fullname2" -> insertFullName, "invalid" -> "string") val udtSession: Session = gatlingSession.setAll(newSessionVars) - val result = CqlPreparedStatementUtil.bindParamByOrder(udtSession, boundStatementKeys, DataType.Name.UDT, "fullname2", 23) + val result = CqlPreparedStatementUtil.bindParamByOrder(udtSession, boundStatementKeys, addressType.get, "fullname2", 22) result shouldBe a[BoundStatement] - result.isSet(23) shouldBe true + result.isSet(22) shouldBe true } @@ -1036,9 +1023,9 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { val preparedStatementInsertCounter = s"""UPDATE $keyspace.$counterTableName SET counter_type = counter_type + ? WHERE uuid_type = ?""" - val boundStatementCounter = dseSession.prepare(preparedStatementInsertCounter).bind() + val boundStatementCounter = cqlSession.prepare(preparedStatementInsertCounter).bind() - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementCounter, DataType.Name.COUNTER, "int", 0) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementCounter, DataTypes.COUNTER, "int", 0) result shouldBe a[BoundStatement] result.isSet(0) shouldBe true } @@ -1047,29 +1034,33 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { boundStatementKeys.isSet("null_type") shouldBe false - val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.TINYINT, "null_type", 24) + val result = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.TINYINT, "null_type", 23) result shouldBe a[BoundStatement] - boundStatementKeys.isSet("null_type") shouldBe true - boundStatementKeys.isNull("null_type") shouldBe true + result.isSet("null_type") shouldBe true + result.isNull("null_type") shouldBe true } it("should not set and unset a None value") { val field = "none_type" - - CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.TEXT, field, 25) shouldBe a[BoundStatement] boundStatementKeys.isSet(field) shouldBe false + val result1 = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataTypes.TEXT, field, 24) + result1 shouldBe a[BoundStatement] + result1.isSet(field) shouldBe false + val newSessionVars = Map(field -> "test") val newSession: Session = gatlingSession.setAll(newSessionVars) - CqlPreparedStatementUtil.bindParamByOrder(newSession, boundStatementKeys, DataType.Name.TEXT, field, 25) shouldBe a[BoundStatement] - boundStatementKeys.isSet(field) shouldBe true - boundStatementKeys.getString(field) shouldBe "test" + val result2 = CqlPreparedStatementUtil.bindParamByOrder(newSession, result1, DataTypes.TEXT, field, 24) + result2 shouldBe a[BoundStatement] + result2.isSet(field) shouldBe true + result2.getString(field) shouldBe "test" - CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, boundStatementKeys, DataType.Name.TEXT, field, 25) shouldBe a[BoundStatement] - boundStatementKeys.isSet(field) shouldBe false + val result3 = CqlPreparedStatementUtil.bindParamByOrder(defaultGatlingSession, result2, DataTypes.TEXT, field, 24) + result3 shouldBe a[BoundStatement] + result3.isSet(field) shouldBe false } it("should not set a missing session value") { @@ -1078,7 +1069,7 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { val newSessionVars = Map("missing" -> "test") val newSession: Session = gatlingSession.setAll(newSessionVars) - CqlPreparedStatementUtil.bindParamByOrder(newSession, boundStatementKeys, DataType.Name.TEXT, field, 25) shouldBe a[BoundStatement] + CqlPreparedStatementUtil.bindParamByOrder(newSession, boundStatementKeys, DataTypes.TEXT, field, 25) shouldBe a[BoundStatement] boundStatementKeys.isSet(field) shouldBe false } } @@ -1092,7 +1083,6 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { |timeuuid_type, |int_type, |text_type, - |varchar_type, |ascii_type, |float_type, |double_type, @@ -1119,7 +1109,6 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { |:timeuuid_type, |:int_type, |:text_type, - |:varchar_type, |:ascii_type, |:float_type, |:double_type, @@ -1143,14 +1132,13 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { |:none_type |)""".stripMargin - val boundStatementNames = dseSession.prepare(preparedStatementInsert).bind() + val boundStatementNames = cqlSession.prepare(preparedStatementInsert).bind() val defaultSessionVars = Map( "uuid_type" -> java.util.UUID.randomUUID(), - "timeuuid_type" -> UUIDs.timeBased(), + "timeuuid_type" -> Uuids.timeBased(), "int_type" -> 12, "text_type" -> "string", - "varchar_type" -> "string", "ascii_type" -> "string", "float_type" -> 12.0.toFloat, "double_type" -> 12.0, @@ -1177,184 +1165,179 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { it("should bind with a UUID") { val paramName = "uuid_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.UUID, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.UUID, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a timeUuid") { val paramName = "timeuuid_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.TIMEUUID, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.TIMEUUID, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a int") { val paramName = "int_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.INT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.INT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a text") { val paramName = "text_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.TEXT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.TEXT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true - } - - it("should bind with a varchar") { - val paramName = "varchar_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.VARCHAR, paramName) - result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a ascii") { val paramName = "ascii_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.ASCII, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.ASCII, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a float") { val paramName = "float_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.FLOAT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.FLOAT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a double") { val paramName = "double_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.DOUBLE, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.DOUBLE, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a decimal") { val paramName = "decimal_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.DECIMAL, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.DECIMAL, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a boolean") { val paramName = "boolean_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.BOOLEAN, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.BOOLEAN, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a inetAddress") { val paramName = "inet_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.INET, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.INET, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a timestamp") { val paramName = "timestamp_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.TIMESTAMP, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.TIMESTAMP, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true - + result.isSet(paramName) shouldBe true } it("should bind with a bigInt") { val paramName = "bigint_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.BIGINT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.BIGINT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a blob_type") { val paramName = "blob_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.BLOB, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.BLOB, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a varint") { val paramName = "varint_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.VARINT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.VARINT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a list") { val paramName = "list_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.LIST, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.listOf(DataTypes.INT), paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a set") { val paramName = "set_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.SET, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.setOf(DataTypes.INT), paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a map") { val paramName = "map_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.MAP, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.mapOf(DataTypes.INT, DataTypes.INT), paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a date") { val paramName = "date_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.DATE, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.DATE, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a smallInt") { val paramName = "smallint_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.SMALLINT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.SMALLINT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a tinyint") { val paramName = "tinyint_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.TINYINT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.TINYINT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a time") { val paramName = "time_type" - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.TIME, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.TIME, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } it("should bind with a tuple") { val paramName = "tuple_type" - val tupleType = dseSession.getCluster.getMetadata.newTupleType(DataType.varchar(), DataType.varchar()) + val tupleType = DataTypes.tupleOf(DataTypes.TEXT, DataTypes.TEXT) val insertTuple = tupleType.newValue("test", "test2") - val newSessionVars = Map("tuple_type" -> insertTuple, "invalid" -> "string") + val newSessionVars = Map(paramName -> insertTuple, "invalid" -> "string") val tupleSession: Session = gatlingSession.setAll(newSessionVars) - val result = CqlPreparedStatementUtil.bindParamByName(tupleSession, boundStatementNames, DataType.Name.TUPLE, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(tupleSession, boundStatementNames, tupleType, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } - it("should bind with a udt") { + val paramName = "udt_type" - val addressType = dseSession.getCluster.getMetadata.getKeyspace(keyspace).getUserType("fullname2") - val insertFullName = addressType.newValue() - .setString("firstname", "John") - .setString("lastname", "Smith") - val newSessionVars = Map("udt_type" -> insertFullName, "invalid" -> "string") + + val addressType:Optional[UserDefinedType] = cqlSession.getMetadata.getKeyspace(keyspace).flatMap(_.getUserDefinedType("fullname2")) + addressType should not be Optional.empty + + val insertFullName = addressType.get.newValue() + .setString("firstname", "John") + .setString("lastname", "Smith") + val newSessionVars = Map(paramName -> insertFullName, "invalid" -> "string") val udtSession: Session = gatlingSession.setAll(newSessionVars) - val result = CqlPreparedStatementUtil.bindParamByName(udtSession, boundStatementNames, DataType.Name.UDT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(udtSession, boundStatementNames, addressType.get, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true + result.isSet(paramName) shouldBe true } @@ -1363,9 +1346,9 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { val preparedStatementInsertCounter = s"""UPDATE $keyspace.$counterTableName SET counter_type = counter_type + :counter_type WHERE uuid_type = :uuid_type""" - val boundNamedStatementCounter = dseSession.prepare(preparedStatementInsertCounter).bind() + val boundNamedStatementCounter = cqlSession.prepare(preparedStatementInsertCounter).bind() - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundNamedStatementCounter, DataType.Name.COUNTER, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundNamedStatementCounter, DataTypes.COUNTER, paramName) result shouldBe a[BoundStatement] } @@ -1374,32 +1357,33 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { val paramName = "null_type" boundStatementNames.isSet(paramName) shouldBe false - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.TINYINT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.TINYINT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true - boundStatementNames.isNull(paramName) shouldBe true + result.isSet(paramName) shouldBe true + result.isNull(paramName) shouldBe true } it("should not set and unset a None value") { val paramName = "none_type" + boundStatementNames.isSet(paramName) shouldBe false - val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.TEXT, paramName) + val result = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataTypes.TEXT, paramName) result shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe false + result.isSet(paramName) shouldBe false val newSessionVars = Map(paramName -> "test") val newSession: Session = gatlingSession.setAll(newSessionVars) - val result2 = CqlPreparedStatementUtil.bindParamByName(newSession, boundStatementNames, DataType.Name.TEXT, paramName) + val result2 = CqlPreparedStatementUtil.bindParamByName(newSession, result, DataTypes.TEXT, paramName) result2 shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe true - boundStatementNames.getString(paramName) shouldBe "test" + result2.isSet(paramName) shouldBe true + result2.getString(paramName) shouldBe "test" - val result3 = CqlPreparedStatementUtil.bindParamByName(typeSession, boundStatementNames, DataType.Name.TEXT, paramName) + val result3 = CqlPreparedStatementUtil.bindParamByName(typeSession, result2, DataTypes.TEXT, paramName) result3 shouldBe a[BoundStatement] - boundStatementNames.isSet(paramName) shouldBe false + result3.isSet(paramName) shouldBe false } it("should not set a missing session value") { @@ -1408,11 +1392,10 @@ class CqlPreparedStatementUtilSpec extends BaseCassandraServerSpec { val newSessionVars = Map("missing" -> "test") val newSession: Session = gatlingSession.setAll(newSessionVars) - val result = CqlPreparedStatementUtil.bindParamByName(newSession, boundStatementNames, DataType.Name.TEXT, field) + val result = CqlPreparedStatementUtil.bindParamByName(newSession, boundStatementNames, DataTypes.TEXT, field) result shouldBe a[BoundStatement] - boundStatementNames.isSet(field) shouldBe false + result.isSet(field) shouldBe false } - } } }