diff --git a/.travis.yml b/.travis.yml index f2cb4adc..90444c32 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,9 @@ language: java +env: + - GRADLE_OPTS="-Xms2048m -Xmx2048m" + services: - mysql diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Aggregation.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Aggregation.kt index 21c840cd..236efb3d 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Aggregation.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Aggregation.kt @@ -1,8 +1,8 @@ package me.liuwj.ktorm.dsl -import me.liuwj.ktorm.database.Database -import me.liuwj.ktorm.database.prepareStatement -import me.liuwj.ktorm.expression.* +import me.liuwj.ktorm.entity.* +import me.liuwj.ktorm.expression.AggregateExpression +import me.liuwj.ktorm.expression.AggregateType import me.liuwj.ktorm.schema.ColumnDeclaring import me.liuwj.ktorm.schema.DoubleSqlType import me.liuwj.ktorm.schema.IntSqlType @@ -40,142 +40,87 @@ fun sumDistinct(column: ColumnDeclaring): AggregateExpression return AggregateExpression(AggregateType.SUM, column.asExpression(), true, column.sqlType) } -fun count(column: ColumnDeclaring): AggregateExpression { - return AggregateExpression(AggregateType.COUNT, column.asExpression(), false, IntSqlType) +fun count(column: ColumnDeclaring<*>? = null): AggregateExpression { + return AggregateExpression(AggregateType.COUNT, column?.asExpression(), false, IntSqlType) } -fun countDistinct(column: ColumnDeclaring): AggregateExpression { - return AggregateExpression(AggregateType.COUNT, column.asExpression(), true, IntSqlType) +fun countDistinct(column: ColumnDeclaring<*>? = null): AggregateExpression { + return AggregateExpression(AggregateType.COUNT, column?.asExpression(), true, IntSqlType) } /** * 如果表中的所有行都符合指定条件,返回 true,否则 false */ -fun > T.all(block: (T) -> ScalarExpression): Boolean { - return none { !block(this) } +inline fun , T : Table> T.all(predicate: (T) -> ColumnDeclaring): Boolean { + return asSequence().all(predicate) } /** * 如果表中有数据,返回 true,否则 false */ -fun Table<*>.any(): Boolean { - return count() > 0 +fun , T : Table> T.any(): Boolean { + return asSequence().any() } /** * 如果表中存在任何一条记录满足指定条件,返回 true,否则 false */ -fun > T.any(block: (T) -> ScalarExpression): Boolean { - return count(block) > 0 +inline fun , T : Table> T.any(predicate: (T) -> ColumnDeclaring): Boolean { + return asSequence().any(predicate) } /** * 如果表中没有数据,返回 true,否则 false */ -fun Table<*>.none(): Boolean { - return count() == 0 +fun , T : Table> T.none(): Boolean { + return asSequence().none() } /** * 如果表中所有记录都不满足指定条件,返回 true,否则 false */ -fun > T.none(block: (T) -> ScalarExpression): Boolean { - return count(block) == 0 +inline fun , T : Table> T.none(predicate: (T) -> ColumnDeclaring): Boolean { + return asSequence().none(predicate) } /** * 返回表中的记录数 */ -fun Table<*>.count(): Int { - return doCount(null) +fun , T : Table> T.count(): Int { + return asSequence().count() } /** * 返回表中满足指定条件的记录数 */ -fun > T.count(block: (T) -> ScalarExpression): Int { - return doCount(block) -} - -private fun > T.doCount(block: ((T) -> ScalarExpression)?): Int { - val expression = SelectExpression( - columns = listOf( - ColumnDeclaringExpression( - expression = AggregateExpression( - type = AggregateType.COUNT, - argument = null, - isDistinct = false, - sqlType = IntSqlType - ) - ) - ), - from = this.asExpression(), - where = block?.invoke(this) - ) - - expression.prepareStatement { statement, logger -> - statement.executeQuery().use { rs -> - if (rs.next()) { - return rs.getInt(1).also { logger.debug("Count: {}", it) } - } else { - val (sql, _) = Database.global.formatExpression(expression, beautifySql = true) - throw IllegalStateException("No result return for sql: $sql") - } - } - } +inline fun , T : Table> T.count(predicate: (T) -> ColumnDeclaring): Int { + return asSequence().count(predicate) } /** * 返回表中指定字段的和,若表中没有数据,返回 null */ -fun , C : Number> T.sumBy(block: (T) -> ColumnDeclaring): C? { - return doAggregation(sum(block(this))) +inline fun , T : Table, C : Number> T.sumBy(selector: (T) -> ColumnDeclaring): C? { + return asSequence().sumBy(selector) } /** * 返回表中指定字段的最大值,若表中没有数据,返回 null */ -fun , C : Number> T.maxBy(block: (T) -> ColumnDeclaring): C? { - return doAggregation(max(block(this))) +inline fun , T : Table, C : Number> T.maxBy(selector: (T) -> ColumnDeclaring): C? { + return asSequence().maxBy(selector) } /** * 返回表中指定字段的最小值,若表中没有数据,返回 null */ -fun , C : Number> T.minBy(block: (T) -> ColumnDeclaring): C? { - return doAggregation(min(block(this))) +inline fun , T : Table, C : Number> T.minBy(selector: (T) -> ColumnDeclaring): C? { + return asSequence().minBy(selector) } /** * 返回表中指定字段的平均值,若表中没有数据,返回 null */ -fun > T.avgBy(block: (T) -> ColumnDeclaring): Double? { - return doAggregation(avg(block(this))) -} - -private fun Table<*>.doAggregation(aggregation: AggregateExpression): R? { - val expression = SelectExpression( - columns = listOf( - ColumnDeclaringExpression( - expression = aggregation.asExpression() - ) - ), - from = this.asExpression() - ) - - expression.prepareStatement { statement, logger -> - statement.executeQuery().use { rs -> - if (rs.next()) { - val result = aggregation.sqlType.getResult(rs, 1) - - if (logger.isDebugEnabled) { - logger.debug("{}: {}", aggregation.type.toString().capitalize(), result) - } - - return result - } else { - return null - } - } - } +inline fun , T : Table> T.averageBy(selector: (T) -> ColumnDeclaring): Double? { + return asSequence().averageBy(selector) } \ No newline at end of file diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/CountExpression.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/CountExpression.kt index 68b79a95..1f9f15ea 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/CountExpression.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/CountExpression.kt @@ -1,21 +1,10 @@ package me.liuwj.ktorm.dsl import me.liuwj.ktorm.expression.* -import me.liuwj.ktorm.schema.IntSqlType internal fun QueryExpression.toCountExpression(): SelectExpression { val expression = OrderByRemover.visit(this) as QueryExpression - - val countColumns = listOf( - ColumnDeclaringExpression( - expression = AggregateExpression( - type = AggregateType.COUNT, - argument = null, - isDistinct = false, - sqlType = IntSqlType - ) - ) - ) + val countColumns = listOf(count().asDeclaringExpression()) if (expression is SelectExpression && expression.isSimpleSelect()) { return expression.copy(columns = countColumns, offset = null, limit = null) diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Dml.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Dml.kt index b3207011..42c9bff8 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Dml.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Dml.kt @@ -17,7 +17,7 @@ fun > T.update(block: UpdateStatementBuilder.(T) -> Unit): Int { val assignments = ArrayList>() val builder = UpdateStatementBuilder(assignments).apply { block(this@update) } - val expression = AliasRemover.visit(UpdateExpression(asExpression(), assignments, builder.where)) + val expression = AliasRemover.visit(UpdateExpression(asExpression(), assignments, builder.where?.asExpression())) expression.prepareStatement { statement, logger -> return statement.executeUpdate().also { logger.debug("Effects: {}", it) } @@ -145,8 +145,8 @@ fun Query.insertTo(table: Table<*>, vararg columns: Column<*>): Int { /** * 根据条件删除表中的记录,返回受影响的记录数 */ -fun > T.delete(block: (T) -> ScalarExpression): Int { - val expression = AliasRemover.visit(DeleteExpression(asExpression(), block(this))) +fun > T.delete(block: (T) -> ColumnDeclaring): Int { + val expression = AliasRemover.visit(DeleteExpression(asExpression(), block(this).asExpression())) expression.prepareStatement { statement, logger -> return statement.executeUpdate().also { logger.debug("Effects: {}", it) } @@ -191,9 +191,9 @@ open class AssignmentsBuilder(private val assignments: MutableList>) : AssignmentsBuilder(assignments) { - internal var where: ScalarExpression? = null + internal var where: ColumnDeclaring? = null - fun where(block: () -> ScalarExpression) { + fun where(block: () -> ColumnDeclaring) { this.where = block() } } @@ -208,7 +208,7 @@ class BatchUpdateStatementBuilder>(internal val table: T) { val builder = UpdateStatementBuilder(assignments) builder.block(table) - val expr = UpdateExpression(table.asExpression(), assignments, builder.where) + val expr = UpdateExpression(table.asExpression(), assignments, builder.where?.asExpression()) val (sql, _) = Database.global.formatExpression(expr, beautifySql = true) diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Join.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Join.kt index 52c3375c..84ee3c92 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Join.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Join.kt @@ -4,68 +4,69 @@ package me.liuwj.ktorm.dsl import me.liuwj.ktorm.expression.* +import me.liuwj.ktorm.schema.ColumnDeclaring import me.liuwj.ktorm.schema.Table -fun QuerySourceExpression.crossJoin(right: QuerySourceExpression, on: ScalarExpression? = null): JoinExpression { - return JoinExpression(type = JoinType.CROSS_JOIN, left = this, right = right, condition = on) +fun QuerySourceExpression.crossJoin(right: QuerySourceExpression, on: ColumnDeclaring? = null): JoinExpression { + return JoinExpression(type = JoinType.CROSS_JOIN, left = this, right = right, condition = on?.asExpression()) } -fun QuerySourceExpression.crossJoin(right: Table<*>, on: ScalarExpression? = null): JoinExpression { +fun QuerySourceExpression.crossJoin(right: Table<*>, on: ColumnDeclaring? = null): JoinExpression { return crossJoin(right.asExpression(), on) } -fun Table<*>.crossJoin(right: QuerySourceExpression, on: ScalarExpression? = null): JoinExpression { +fun Table<*>.crossJoin(right: QuerySourceExpression, on: ColumnDeclaring? = null): JoinExpression { return asExpression().crossJoin(right, on) } -fun Table<*>.crossJoin(right: Table<*>, on: ScalarExpression? = null): JoinExpression { +fun Table<*>.crossJoin(right: Table<*>, on: ColumnDeclaring? = null): JoinExpression { return crossJoin(right.asExpression(), on) } -fun QuerySourceExpression.innerJoin(right: QuerySourceExpression, on: ScalarExpression? = null): JoinExpression { - return JoinExpression(type = JoinType.INNER_JOIN, left = this, right = right, condition = on) +fun QuerySourceExpression.innerJoin(right: QuerySourceExpression, on: ColumnDeclaring? = null): JoinExpression { + return JoinExpression(type = JoinType.INNER_JOIN, left = this, right = right, condition = on?.asExpression()) } -fun QuerySourceExpression.innerJoin(right: Table<*>, on: ScalarExpression? = null): JoinExpression { +fun QuerySourceExpression.innerJoin(right: Table<*>, on: ColumnDeclaring? = null): JoinExpression { return innerJoin(right.asExpression(), on) } -fun Table<*>.innerJoin(right: QuerySourceExpression, on: ScalarExpression? = null): JoinExpression { +fun Table<*>.innerJoin(right: QuerySourceExpression, on: ColumnDeclaring? = null): JoinExpression { return asExpression().innerJoin(right, on) } -fun Table<*>.innerJoin(right: Table<*>, on: ScalarExpression? = null): JoinExpression { +fun Table<*>.innerJoin(right: Table<*>, on: ColumnDeclaring? = null): JoinExpression { return innerJoin(right.asExpression(), on) } -fun QuerySourceExpression.leftJoin(right: QuerySourceExpression, on: ScalarExpression? = null): JoinExpression { - return JoinExpression(type = JoinType.LEFT_JOIN, left = this, right = right, condition = on) +fun QuerySourceExpression.leftJoin(right: QuerySourceExpression, on: ColumnDeclaring? = null): JoinExpression { + return JoinExpression(type = JoinType.LEFT_JOIN, left = this, right = right, condition = on?.asExpression()) } -fun QuerySourceExpression.leftJoin(right: Table<*>, on: ScalarExpression? = null): JoinExpression { +fun QuerySourceExpression.leftJoin(right: Table<*>, on: ColumnDeclaring? = null): JoinExpression { return leftJoin(right.asExpression(), on) } -fun Table<*>.leftJoin(right: QuerySourceExpression, on: ScalarExpression? = null): JoinExpression { +fun Table<*>.leftJoin(right: QuerySourceExpression, on: ColumnDeclaring? = null): JoinExpression { return asExpression().leftJoin(right, on) } -fun Table<*>.leftJoin(right: Table<*>, on: ScalarExpression? = null): JoinExpression { +fun Table<*>.leftJoin(right: Table<*>, on: ColumnDeclaring? = null): JoinExpression { return leftJoin(right.asExpression(), on) } -fun QuerySourceExpression.rightJoin(right: QuerySourceExpression, on: ScalarExpression? = null): JoinExpression { - return JoinExpression(type = JoinType.RIGHT_JOIN, left = this, right = right, condition = on) +fun QuerySourceExpression.rightJoin(right: QuerySourceExpression, on: ColumnDeclaring? = null): JoinExpression { + return JoinExpression(type = JoinType.RIGHT_JOIN, left = this, right = right, condition = on?.asExpression()) } -fun QuerySourceExpression.rightJoin(right: Table<*>, on: ScalarExpression? = null): JoinExpression { +fun QuerySourceExpression.rightJoin(right: Table<*>, on: ColumnDeclaring? = null): JoinExpression { return rightJoin(right.asExpression(), on) } -fun Table<*>.rightJoin(right: QuerySourceExpression, on: ScalarExpression? = null): JoinExpression { +fun Table<*>.rightJoin(right: QuerySourceExpression, on: ColumnDeclaring? = null): JoinExpression { return asExpression().rightJoin(right, on) } -fun Table<*>.rightJoin(right: Table<*>, on: ScalarExpression? = null): JoinExpression { +fun Table<*>.rightJoin(right: Table<*>, on: ColumnDeclaring? = null): JoinExpression { return rightJoin(right.asExpression(), on) } diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Query.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Query.kt index 2227796f..d9dd76ce 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Query.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/dsl/Query.kt @@ -123,17 +123,17 @@ fun Table<*>.selectDistinct(vararg columns: ColumnDeclaring<*>): Query { return asExpression().selectDistinct(columns.asList()) } -inline fun Query.where(block: () -> ScalarExpression): Query { +inline fun Query.where(block: () -> ColumnDeclaring): Query { return this.copy( expression = when (expression) { - is SelectExpression -> expression.copy(where = block()) + is SelectExpression -> expression.copy(where = block().asExpression()) is UnionExpression -> throw IllegalStateException("Where clause is not supported in a union expression.") } ) } -inline fun Query.whereWithConditions(block: (MutableList>) -> Unit): Query { - val conditions = ArrayList>().apply(block) +inline fun Query.whereWithConditions(block: (MutableList>) -> Unit): Query { + val conditions = ArrayList>().apply(block) if (conditions.isEmpty()) { return this @@ -142,8 +142,8 @@ inline fun Query.whereWithConditions(block: (MutableList>) -> Unit): Query { - val conditions = ArrayList>().apply(block) +inline fun Query.whereWithOrConditions(block: (MutableList>) -> Unit): Query { + val conditions = ArrayList>().apply(block) if (conditions.isEmpty()) { return this @@ -152,7 +152,7 @@ inline fun Query.whereWithOrConditions(block: (MutableList>.combineConditions(): ScalarExpression { +fun Iterable>.combineConditions(): ColumnDeclaring { if (this.any()) { return this.reduce { a, b -> a and b } } else { @@ -169,10 +169,10 @@ fun Query.groupBy(vararg columns: ColumnDeclaring<*>): Query { ) } -inline fun Query.having(block: () -> ScalarExpression): Query { +inline fun Query.having(block: () -> ColumnDeclaring): Query { return this.copy( expression = when (expression) { - is SelectExpression -> expression.copy(having = block()) + is SelectExpression -> expression.copy(having = block().asExpression()) is UnionExpression -> throw IllegalStateException("Having clause is not supported in a union expression.") } ) diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/Entity.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/Entity.kt index 0bab10ad..9f8af39a 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/Entity.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/Entity.kt @@ -56,6 +56,11 @@ interface Entity> : Serializable { */ operator fun set(name: String, value: Any?) + /** + * 复制一个当前实体对象的拷贝,返回的对象具有与当前对象完全相同的属性值和状态 + */ + fun copy(): E + companion object { /** diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityDml.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityDml.kt index 154f86e9..826784b3 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityDml.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityDml.kt @@ -10,6 +10,8 @@ import me.liuwj.ktorm.schema.* */ @Suppress("UNCHECKED_CAST") fun > Table.add(entity: E): Int { + entity.implementation.checkUnexpectedDiscarding(this) + val assignments = findInsertColumns(entity).takeIf { it.isNotEmpty() } ?: return 0 val expression = AliasRemover.visit( @@ -66,6 +68,8 @@ private fun Table<*>.findInsertColumns(entity: Entity<*>): Map, Any?> @Suppress("UNCHECKED_CAST") internal fun EntityImplementation.doFlushChanges(): Int { val fromTable = this.fromTable?.takeIf { this.parent == null } ?: error("The entity is not associated with any table yet.") + checkUnexpectedDiscarding(fromTable) + val primaryKey = fromTable.primaryKey ?: error("Table ${fromTable.tableName} doesn't have a primary key.") val assignments = findChangedColumns(fromTable).takeIf { it.isNotEmpty() } ?: return 0 @@ -111,25 +115,17 @@ private fun EntityImplementation.findChangedColumns(fromTable: Table<*>): Map) { curr = curr.implementation } check(curr is EntityImplementation?) - val changed = if (curr == null) false else prop.name in curr.changedProperties - - if (changed && i > 0) { - check(curr != null) - - if (curr.fromTable != null && curr.getRoot() != this) { - val propPath = binding.properties.subList(0, i + 1).joinToString(separator = ".", prefix = "this.") { it.name } - throw IllegalStateException("$propPath may be unexpectedly discarded after flushChanges, please save it to database first.") - } + if (curr != null && prop.name in curr.changedProperties) { + anyChanged = true } - anyChanged = anyChanged || changed curr = curr?.getProperty(prop.name) } @@ -143,15 +139,6 @@ private fun EntityImplementation.findChangedColumns(fromTable: Table<*>): Map) { + for (column in fromTable.columns) { + val binding = column.binding?.takeIf { column is SimpleColumn } ?: continue + + if (binding is NestedBinding) { + var curr: Any? = this + + for ((i, prop) in binding.properties.withIndex()) { + if (curr == null) { + break + } + if (curr is Entity<*>) { + curr = curr.implementation + } + + check(curr is EntityImplementation) + + if (i > 0 && prop.name in curr.changedProperties && curr.fromTable != null && curr.getRoot() != this) { + val propPath = binding.properties.subList(0, i + 1).joinToString(separator = ".", prefix = "this.") { it.name } + throw IllegalStateException("$propPath may be unexpectedly discarded, please save it to database first.") + } + + curr = curr.getProperty(prop.name) + } + } + } +} + +private tailrec fun EntityImplementation.getRoot(): EntityImplementation { + val parent = this.parent + if (parent == null) { + return this + } else { + return parent.getRoot() + } +} + +internal fun Entity<*>.clearChangesRecursively() { + implementation.changedProperties.clear() + + for ((_, value) in properties) { + if (value is Entity<*>) { + value.clearChangesRecursively() + } + } +} + @Suppress("UNCHECKED_CAST") internal fun EntityImplementation.doDelete(): Int { val fromTable = this.fromTable?.takeIf { this.parent == null } ?: error("The entity is not associated with any table yet.") diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityExtensions.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityExtensions.kt index 787d68ac..8c9a7250 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityExtensions.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityExtensions.kt @@ -30,12 +30,12 @@ internal fun EntityImplementation.getColumnValue(column: Column<*>): Any? { } } -internal fun EntityImplementation.setPrimaryKeyValue(fromTable: Table<*>, value: Any?) { +internal fun EntityImplementation.setPrimaryKeyValue(fromTable: Table<*>, value: Any?, forceSet: Boolean = false) { val primaryKey = fromTable.primaryKey ?: error("Table ${fromTable.tableName} doesn't have a primary key.") - setColumnValue(primaryKey, value) + setColumnValue(primaryKey, value, forceSet) } -internal fun EntityImplementation.setColumnValue(column: Column<*>, value: Any?) { +internal fun EntityImplementation.setColumnValue(column: Column<*>, value: Any?, forceSet: Boolean = false) { val binding = column.binding ?: error("Column $column has no bindings to any entity field.") when (binding) { @@ -43,10 +43,10 @@ internal fun EntityImplementation.setColumnValue(column: Column<*>, value: Any?) var child = this.getProperty(binding.onProperty.name) as Entity<*>? if (child == null) { child = Entity.create(binding.onProperty.returnType.classifier as KClass<*>, fromTable = binding.referenceTable) - this.setProperty(binding.onProperty.name, child) + this.setProperty(binding.onProperty.name, child, forceSet) } - child.implementation.setPrimaryKeyValue(binding.referenceTable, value) + child.implementation.setPrimaryKeyValue(binding.referenceTable, value, forceSet) } is NestedBinding -> { var curr: EntityImplementation = this @@ -55,51 +55,14 @@ internal fun EntityImplementation.setColumnValue(column: Column<*>, value: Any?) var child = curr.getProperty(prop.name) as Entity<*>? if (child == null) { child = Entity.create(prop.returnType.classifier as KClass<*>, parent = curr) - curr.setProperty(prop.name, child) + curr.setProperty(prop.name, child, forceSet) } curr = child.implementation } } - curr.setProperty(binding.properties.last().name, value) - } - } -} - -internal fun EntityImplementation.forceSetPrimaryKeyValue(fromTable: Table<*>, value: Any?) { - val primaryKey = fromTable.primaryKey ?: error("Table ${fromTable.tableName} doesn't have a primary key.") - forceSetColumnValue(primaryKey, value) -} - -internal fun EntityImplementation.forceSetColumnValue(column: Column<*>, value: Any?) { - val binding = column.binding ?: error("Column $column has no bindings to any entity field.") - - when (binding) { - is ReferenceBinding -> { - var child = this.getProperty(binding.onProperty.name) as Entity<*>? - if (child == null) { - child = Entity.create(binding.onProperty.returnType.classifier as KClass<*>, fromTable = binding.referenceTable) - this.setProperty(binding.onProperty.name, child, forceSet = true) - } - - child.implementation.forceSetPrimaryKeyValue(binding.referenceTable, value) - } - is NestedBinding -> { - var curr: EntityImplementation = this - for ((i, prop) in binding.properties.withIndex()) { - if (i != binding.properties.lastIndex) { - var child = curr.getProperty(prop.name) as Entity<*>? - if (child == null) { - child = Entity.create(prop.returnType.classifier as KClass<*>, parent = curr) - curr.setProperty(prop.name, child, forceSet = true) - } - - curr = child.implementation - } - } - - curr.setProperty(binding.properties.last().name, value, forceSet = true) + curr.setProperty(binding.properties.last().name, value, forceSet) } } } diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityFinding.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityFinding.kt index 91563cfb..4505b5c9 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityFinding.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityFinding.kt @@ -4,8 +4,8 @@ import me.liuwj.ktorm.dsl.* import me.liuwj.ktorm.expression.BinaryExpression import me.liuwj.ktorm.expression.BinaryExpressionType import me.liuwj.ktorm.expression.QuerySourceExpression -import me.liuwj.ktorm.expression.ScalarExpression import me.liuwj.ktorm.schema.* +import kotlin.reflect.KClass /** * 根据 ID 批量获取实体对象,会自动 left join 所有的引用表 @@ -40,8 +40,8 @@ fun > Table.findById(id: Any): E? { /** * 根据指定条件获取对象,会自动 left join 所有的引用表 */ -inline fun , T : Table> T.findOne(block: (T) -> ScalarExpression): E? { - val list = findList(block) +inline fun , T : Table> T.findOne(predicate: (T) -> ColumnDeclaring): E? { + val list = findList(predicate) when (list.size) { 0 -> return null 1 -> return list[0] @@ -53,6 +53,7 @@ inline fun , T : Table> T.findOne(block: (T) -> ScalarExpressio * 获取表中的所有记录,会自动 left join 所有的引用表 */ fun > Table.findAll(): List { + // return this.asSequence().toList() return this .joinReferencesAndSelect() .map { row -> this.createEntity(row) } @@ -61,10 +62,11 @@ fun > Table.findAll(): List { /** * 根据指定条件获取对象列表,会自动 left join 所有的引用表 */ -inline fun , T : Table> T.findList(block: (T) -> ScalarExpression): List { +inline fun , T : Table> T.findList(predicate: (T) -> ColumnDeclaring): List { + // return this.asSequence().filter(predicate).toList() return this .joinReferencesAndSelect() - .where { block(this) } + .where { predicate(this) } .map { row -> this.createEntity(row) } } @@ -111,45 +113,58 @@ private infix fun ColumnDeclaring<*>.eq(column: ColumnDeclaring<*>): BinaryExpre */ @Suppress("UNCHECKED_CAST") fun > Table.createEntity(row: QueryRowSet): E { - return doCreateEntity(row) as E + val entity = doCreateEntity(row, skipReferences = false) as E + return entity.apply { clearChangesRecursively() } } -private fun Table<*>.doCreateEntity(row: QueryRowSet, foreignKey: Column<*>? = null): Entity<*> { +/** + * 从结果集中创建实体对象,不会自动级联创建引用表的实体对象 + */ +@Suppress("UNCHECKED_CAST") +fun > Table.createEntityWithoutReferences(row: QueryRowSet): E { + val entity = doCreateEntity(row, skipReferences = true) as E + return entity.apply { clearChangesRecursively() } +} + +private fun Table<*>.doCreateEntity(row: QueryRowSet, skipReferences: Boolean = false): Entity<*> { val entityClass = this.entityClass ?: error("No entity class configured for table: $tableName") val entity = Entity.create(entityClass, fromTable = this) for (column in columns) { try { - row.retrieveColumn(column, intoEntity = entity) + row.retrieveColumn(column, intoEntity = entity, skipReferences = skipReferences) } catch (e: Throwable) { throw IllegalStateException("Error occur while retrieving column: $column, binding: ${column.binding}", e) } } - val foreignKeyValue = if (foreignKey != null && row.hasColumn(foreignKey)) row[foreignKey] else null - if (foreignKeyValue != null) { - entity.implementation.forceSetPrimaryKeyValue(this, foreignKeyValue) - } - - return entity.apply { discardChanges() } + return entity } -private fun QueryRowSet.retrieveColumn(column: Column<*>, intoEntity: Entity<*>) { +private fun QueryRowSet.retrieveColumn(column: Column<*>, intoEntity: Entity<*>, skipReferences: Boolean) { + val columnValue = (if (this.hasColumn(column)) this[column] else null) ?: return + val binding = column.binding ?: return when (binding) { is ReferenceBinding -> { val rightTable = binding.referenceTable val primaryKey = rightTable.primaryKey ?: error("Table ${rightTable.tableName} doesn't have a primary key.") - if (this.hasColumn(primaryKey) && this[primaryKey] != null) { - intoEntity[binding.onProperty.name] = rightTable.doCreateEntity(this, foreignKey = column) + when { + skipReferences -> { + val child = Entity.create(binding.onProperty.returnType.classifier as KClass<*>, fromTable = rightTable) + child.implementation.setColumnValue(primaryKey, columnValue) + intoEntity[binding.onProperty.name] = child + } + this.hasColumn(primaryKey) && this[primaryKey] != null -> { + val child = rightTable.doCreateEntity(this) + child.implementation.setColumnValue(primaryKey, columnValue, forceSet = true) + intoEntity[binding.onProperty.name] = child + } } } is NestedBinding -> { - val columnValue = if (this.hasColumn(column)) this[column] else null - if (columnValue != null) { - intoEntity.implementation.setColumnValue(column, columnValue) - } + intoEntity.implementation.setColumnValue(column, columnValue) } } } diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityGrouping.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityGrouping.kt new file mode 100644 index 00000000..192b29a1 --- /dev/null +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityGrouping.kt @@ -0,0 +1,180 @@ +package me.liuwj.ktorm.entity + +import me.liuwj.ktorm.dsl.* +import me.liuwj.ktorm.expression.ColumnDeclaringExpression +import me.liuwj.ktorm.schema.ColumnDeclaring +import me.liuwj.ktorm.schema.Table + +data class EntityGrouping, T : Table, K : Any>( + val sequence: EntitySequence, + val keySelector: (T) -> ColumnDeclaring +) { + fun asKotlinGrouping() = object : Grouping { + private val allEntities = LinkedHashMap() + + init { + val keyColumn = keySelector(sequence.sourceTable) + val expr = sequence.expression.copy(columns = sequence.expression.columns + keyColumn.asDeclaringExpression()) + + for (row in Query(expr)) { + val entity = sequence.sourceTable.createEntity(row) + val groupKey = keyColumn.sqlType.getResult(row, expr.columns.size) + allEntities[entity] = groupKey + } + } + + override fun sourceIterator(): Iterator { + return allEntities.keys.iterator() + } + + override fun keyOf(element: E): K? { + return allEntities[element] + } + } +} + +inline fun , T : Table, K : Any, C : Any> EntityGrouping.aggregate( + aggregationSelector: (T) -> ColumnDeclaring +): MutableMap { + return aggregateTo(LinkedHashMap(), aggregationSelector) +} + +inline fun , T : Table, K : Any, C : Any, M : MutableMap> EntityGrouping.aggregateTo( + destination: M, + aggregationSelector: (T) -> ColumnDeclaring +): M { + val keyColumn = keySelector(sequence.sourceTable).asExpression() + val aggregation = aggregationSelector(sequence.sourceTable) + + val expr = sequence.expression.copy( + columns = listOf(keyColumn, aggregation.asExpression()).map { ColumnDeclaringExpression(it) }, + groupBy = listOf(keyColumn) + ) + + for (row in Query(expr)) { + val key = keyColumn.sqlType.getResult(row, 1) + val value = aggregation.sqlType.getResult(row, 2) + destination[key] = value + } + + return destination +} + +inline fun , K : Any, R> EntityGrouping.aggregate( + operation: (key: K?, accumulator: R?, element: E, first: Boolean) -> R +): Map { + return asKotlinGrouping().aggregate(operation) +} + +inline fun , K : Any, R, M : MutableMap> EntityGrouping.aggregateTo( + destination: M, + operation: (key: K?, accumulator: R?, element: E, first: Boolean) -> R +): M { + return asKotlinGrouping().aggregateTo(destination, operation) +} + +inline fun , K : Any, R> EntityGrouping.fold( + initialValueSelector: (key: K?, element: E) -> R, + operation: (key: K?, accumulator: R, element: E) -> R +): Map { + return asKotlinGrouping().fold(initialValueSelector, operation) +} + +inline fun , K : Any, R, M : MutableMap> EntityGrouping.foldTo( + destination: M, + initialValueSelector: (key: K?, element: E) -> R, + operation: (key: K?, accumulator: R, element: E) -> R +): M { + return asKotlinGrouping().foldTo(destination, initialValueSelector, operation) +} + +inline fun , K : Any, R> EntityGrouping.fold( + initialValue: R, + operation: (accumulator: R, element: E) -> R +): Map { + return asKotlinGrouping().fold(initialValue, operation) +} + +inline fun , K : Any, R, M : MutableMap> EntityGrouping.foldTo( + destination: M, + initialValue: R, + operation: (accumulator: R, element: E) -> R +): M { + return asKotlinGrouping().foldTo(destination, initialValue, operation) +} + +inline fun , K : Any> EntityGrouping.reduce( + operation: (key: K?, accumulator: E, element: E) -> E +): Map { + return asKotlinGrouping().reduce(operation) +} + +inline fun , K : Any, M : MutableMap> EntityGrouping.reduceTo( + destination: M, + operation: (key: K?, accumulator: E, element: E) -> E +): M { + return asKotlinGrouping().reduceTo(destination, operation) +} + +fun , T : Table, K : Any> EntityGrouping.eachCount(): Map { + return eachCountTo(LinkedHashMap()) +} + +@Suppress("RedundantLambdaArrow", "UNCHECKED_CAST") +fun , T : Table, K : Any, M : MutableMap> EntityGrouping.eachCountTo( + destination: M +): M { + return aggregateTo(destination as MutableMap) { _ -> count() } as M +} + +inline fun , T : Table, K : Any, C : Number> EntityGrouping.eachSumBy( + columnSelector: (T) -> ColumnDeclaring +): Map { + return eachSumByTo(LinkedHashMap(), columnSelector) +} + +inline fun , T : Table, K : Any, C : Number, M : MutableMap> EntityGrouping.eachSumByTo( + destination: M, + columnSelector: (T) -> ColumnDeclaring +): M { + return aggregateTo(destination) { sum(columnSelector(it)) } +} + +inline fun , T : Table, K : Any, C : Number> EntityGrouping.eachMaxBy( + columnSelector: (T) -> ColumnDeclaring +): Map { + return eachMaxByTo(LinkedHashMap(), columnSelector) +} + +inline fun , T : Table, K : Any, C : Number, M : MutableMap> EntityGrouping.eachMaxByTo( + destination: M, + columnSelector: (T) -> ColumnDeclaring +): M { + return aggregateTo(destination) { max(columnSelector(it)) } +} + +inline fun , T : Table, K : Any, C : Number> EntityGrouping.eachMinBy( + columnSelector: (T) -> ColumnDeclaring +): Map { + return eachMinByTo(LinkedHashMap(), columnSelector) +} + +inline fun , T : Table, K : Any, C : Number, M : MutableMap> EntityGrouping.eachMinByTo( + destination: M, + columnSelector: (T) -> ColumnDeclaring +): M { + return aggregateTo(destination) { min(columnSelector(it)) } +} + +inline fun , T : Table, K : Any> EntityGrouping.eachAverageBy( + columnSelector: (T) -> ColumnDeclaring +): Map { + return eachAverageByTo(LinkedHashMap(), columnSelector) +} + +inline fun , T : Table, K : Any, M : MutableMap> EntityGrouping.eachAverageByTo( + destination: M, + columnSelector: (T) -> ColumnDeclaring +): M { + return aggregateTo(destination) { avg(columnSelector(it)) } +} \ No newline at end of file diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityImplementation.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityImplementation.kt index a2bfda0c..10b7a73e 100644 --- a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityImplementation.kt +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntityImplementation.kt @@ -52,6 +52,7 @@ internal class EntityImplementation( "delete" -> this.doDelete() "get" -> this.getProperty(args!![0] as String) "set" -> this.setProperty(args!![0] as String, args[1]) + "copy" -> this.copy() else -> throw IllegalStateException("Unrecognized method: $method") } } @@ -126,6 +127,13 @@ internal class EntityImplementation( changedProperties.add(name) } + private fun copy(): Entity<*> { + val entity = Entity.create(entityClass, parent, fromTable) + entity.implementation.values.putAll(values) + entity.implementation.changedProperties.addAll(changedProperties) + return entity + } + private fun writeObject(output: ObjectOutputStream) { output.writeUTF(entityClass.jvmName) output.writeObject(values) diff --git a/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntitySequence.kt b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntitySequence.kt new file mode 100644 index 00000000..5fbe66f6 --- /dev/null +++ b/ktorm-core/src/main/kotlin/me/liuwj/ktorm/entity/EntitySequence.kt @@ -0,0 +1,461 @@ +package me.liuwj.ktorm.entity + +import me.liuwj.ktorm.database.Database +import me.liuwj.ktorm.dsl.* +import me.liuwj.ktorm.expression.OrderByExpression +import me.liuwj.ktorm.expression.SelectExpression +import me.liuwj.ktorm.schema.Column +import me.liuwj.ktorm.schema.ColumnDeclaring +import me.liuwj.ktorm.schema.Table +import java.util.* +import kotlin.collections.ArrayList +import kotlin.math.min + +data class EntitySequence, T : Table>(val sourceTable: T, val expression: SelectExpression) { + + val query = Query(expression) + + val sql get() = query.sql + + val rowSet get() = query.rowSet + + val totalRecords get() = query.totalRecords + + fun asKotlinSequence() = Sequence { iterator() } + + operator fun iterator() = object : Iterator { + private val queryIterator = query.iterator() + + override fun hasNext(): Boolean { + return queryIterator.hasNext() + } + + override fun next(): E { + return sourceTable.createEntity(queryIterator.next()) + } + } +} + +fun , T : Table> T.asSequence(): EntitySequence { + val query = this.joinReferencesAndSelect() + return EntitySequence(this, query.expression as SelectExpression) +} + +fun , C : MutableCollection> EntitySequence.toCollection(destination: C): C { + return asKotlinSequence().toCollection(destination) +} + +fun > EntitySequence.toList(): List { + return asKotlinSequence().toList() +} + +fun > EntitySequence.toMutableList(): MutableList { + return asKotlinSequence().toMutableList() +} + +fun > EntitySequence.toSet(): Set { + return asKotlinSequence().toSet() +} + +fun > EntitySequence.toMutableSet(): MutableSet { + return asKotlinSequence().toMutableSet() +} + +fun > EntitySequence.toHashSet(): HashSet { + return asKotlinSequence().toHashSet() +} + +fun EntitySequence.toSortedSet(): SortedSet where E : Entity, E : Comparable { + return asKotlinSequence().toSortedSet() +} + +fun EntitySequence.toSortedSet( + comparator: Comparator +): SortedSet where E : Entity, E : Comparable { + return asKotlinSequence().toSortedSet(comparator) +} + +inline fun , T : Table> EntitySequence.filterColumns( + selector: (T) -> List> +): EntitySequence { + val columns = selector(sourceTable) + if (columns.isEmpty()) { + return this + } else { + return this.copy(expression = expression.copy(columns = columns.map { it.asDeclaringExpression() })) + } +} + +inline fun , T : Table> EntitySequence.filter( + predicate: (T) -> ColumnDeclaring +): EntitySequence { + if (expression.where == null) { + return this.copy(expression = expression.copy(where = predicate(sourceTable).asExpression())) + } else { + return this.copy(expression = expression.copy(where = expression.where and predicate(sourceTable))) + } +} + +inline fun , T : Table> EntitySequence.filterNot( + predicate: (T) -> ColumnDeclaring +): EntitySequence { + return filter { !predicate(it) } +} + +inline fun , T : Table, C : MutableCollection> EntitySequence.filterTo( + destination: C, + predicate: (T) -> ColumnDeclaring +): C { + return filter(predicate).toCollection(destination) +} + +inline fun , T : Table, C : MutableCollection> EntitySequence.filterNotTo( + destination: C, + predicate: (T) -> ColumnDeclaring +): C { + return filterNot(predicate).toCollection(destination) +} + +inline fun , R> EntitySequence.map(transform: (E) -> R): List { + return mapTo(ArrayList(), transform) +} + +inline fun , R, C : MutableCollection> EntitySequence.mapTo( + destination: C, + transform: (E) -> R +): C { + for (item in this) destination += transform(item) + return destination +} + +inline fun , R> EntitySequence.mapIndexed(transform: (index: Int, E) -> R): List { + return mapIndexedTo(ArrayList(), transform) +} + +inline fun , R, C : MutableCollection> EntitySequence.mapIndexedTo( + destination: C, + transform: (index: Int, E) -> R +): C { + var index = 0 + return mapTo(destination) { transform(index++, it) } +} + +inline fun , T : Table> EntitySequence.sorted( + selector: (T) -> List +): EntitySequence { + return this.copy(expression = expression.copy(orderBy = selector(sourceTable))) +} + +inline fun , T : Table> EntitySequence.sortedBy( + selector: (T) -> ColumnDeclaring<*> +): EntitySequence { + return sorted { listOf(selector(it).asc()) } +} + +inline fun , T : Table> EntitySequence.sortedByDescending( + selector: (T) -> ColumnDeclaring<*> +): EntitySequence { + return sorted { listOf(selector(it).desc()) } +} + +fun , T : Table> EntitySequence.drop(n: Int): EntitySequence { + if (n == 0) { + return this + } else { + val offset = expression.offset ?: 0 + return this.copy(expression = expression.copy(offset = offset + n)) + } +} + +fun , T : Table> EntitySequence.take(n: Int): EntitySequence { + val limit = expression.limit ?: Int.MAX_VALUE + return this.copy(expression = expression.copy(limit = min(limit, n))) +} + +inline fun , T : Table, C : Any> EntitySequence.aggregate( + aggregationSelector: (T) -> ColumnDeclaring +): C? { + val aggregation = aggregationSelector(sourceTable) + + val expr = expression.copy( + columns = listOf(aggregation.asDeclaringExpression()) + ) + + val rowSet = Query(expr).rowSet + + if (rowSet.size() == 1) { + assert(rowSet.next()) + return aggregation.sqlType.getResult(rowSet, 1) + } else { + val (sql, _) = Database.global.formatExpression(expr, beautifySql = true) + throw IllegalStateException("Expected 1 result but ${rowSet.size()} returned from sql: \n\n$sql") + } +} + +fun , T : Table> EntitySequence.count(): Int { + return aggregate { me.liuwj.ktorm.dsl.count() } ?: error("Count expression returns null, which never happens.") +} + +inline fun , T : Table> EntitySequence.count( + predicate: (T) -> ColumnDeclaring +): Int { + return filter(predicate).count() +} + +fun , T : Table> EntitySequence.none(): Boolean { + return count() == 0 +} + +inline fun , T : Table> EntitySequence.none( + predicate: (T) -> ColumnDeclaring +): Boolean { + return count(predicate) == 0 +} + +fun , T : Table> EntitySequence.any(): Boolean { + return count() > 0 +} + +inline fun , T : Table> EntitySequence.any( + predicate: (T) -> ColumnDeclaring +): Boolean { + return count(predicate) > 0 +} + +inline fun , T : Table> EntitySequence.all( + predicate: (T) -> ColumnDeclaring +): Boolean { + return none { !predicate(it) } +} + +inline fun , T : Table, C : Number> EntitySequence.sumBy( + selector: (T) -> ColumnDeclaring +): C? { + return aggregate { sum(selector(it)) } +} + +inline fun , T : Table, C : Number> EntitySequence.maxBy( + selector: (T) -> ColumnDeclaring +): C? { + return aggregate { max(selector(it)) } +} + +inline fun , T : Table, C : Number> EntitySequence.minBy( + selector: (T) -> ColumnDeclaring +): C? { + return aggregate { min(selector(it)) } +} + +inline fun , T : Table> EntitySequence.averageBy( + selector: (T) -> ColumnDeclaring +): Double? { + return aggregate { avg(selector(it)) } +} + +inline fun , K, V> EntitySequence.associate( + transform: (E) -> Pair +): Map { + return asKotlinSequence().associate(transform) +} + +inline fun , K> EntitySequence.associateBy( + keySelector: (E) -> K +): Map { + return asKotlinSequence().associateBy(keySelector) +} + +inline fun , K, V> EntitySequence.associateBy( + keySelector: (E) -> K, + valueTransform: (E) -> V +): Map { + return asKotlinSequence().associateBy(keySelector, valueTransform) +} + +inline fun , V> EntitySequence.associateWith( + valueTransform: (K) -> V +): Map { + return asKotlinSequence().associateWith(valueTransform) +} + +inline fun , K, V, M : MutableMap> EntitySequence.associateTo( + destination: M, + transform: (E) -> Pair +): M { + return asKotlinSequence().associateTo(destination, transform) +} + +inline fun , K, M : MutableMap> EntitySequence.associateByTo( + destination: M, + keySelector: (E) -> K +): M { + return asKotlinSequence().associateByTo(destination, keySelector) +} + +inline fun , K, V, M : MutableMap> EntitySequence.associateByTo( + destination: M, + keySelector: (E) -> K, + valueTransform: (E) -> V +): M { + return asKotlinSequence().associateByTo(destination, keySelector, valueTransform) +} + +inline fun , V, M : MutableMap> EntitySequence.associateWithTo( + destination: M, + valueTransform: (K) -> V +): M { + return asKotlinSequence().associateWithTo(destination, valueTransform) +} + +fun , T : Table> EntitySequence.elementAtOrNull(index: Int): E? { + try { + return drop(index).take(1).asKotlinSequence().firstOrNull() + } catch (e: UnsupportedOperationException) { + return asKotlinSequence().elementAtOrNull(index) + } +} + +inline fun , T : Table> EntitySequence.elementAtOrElse(index: Int, defaultValue: (Int) -> E): E { + return elementAtOrNull(index) ?: defaultValue(index) +} + +fun , T : Table> EntitySequence.elementAt(index: Int): E { + return elementAtOrNull(index) ?: throw IndexOutOfBoundsException("Sequence doesn't contain element at index $index.") +} + +fun , T : Table> EntitySequence.firstOrNull(): E? { + return elementAtOrNull(0) +} + +inline fun , T : Table> EntitySequence.firstOrNull(predicate: (T) -> ColumnDeclaring): E? { + return filter(predicate).elementAtOrNull(0) +} + +fun , T : Table> EntitySequence.first(): E { + return elementAt(0) +} + +inline fun , T : Table> EntitySequence.first(predicate: (T) -> ColumnDeclaring): E { + return filter(predicate).elementAt(0) +} + +fun > EntitySequence.lastOrNull(): E? { + return asKotlinSequence().lastOrNull() +} + +inline fun , T : Table> EntitySequence.lastOrNull(predicate: (T) -> ColumnDeclaring): E? { + return filter(predicate).lastOrNull() +} + +fun > EntitySequence.last(): E { + return lastOrNull() ?: throw NoSuchElementException("Sequence is empty.") +} + +inline fun , T : Table> EntitySequence.last(predicate: (T) -> ColumnDeclaring): E { + return filter(predicate).last() +} + +inline fun , T : Table> EntitySequence.find(predicate: (T) -> ColumnDeclaring): E? { + return firstOrNull(predicate) +} + +inline fun , T : Table> EntitySequence.findLast(predicate: (T) -> ColumnDeclaring): E? { + return lastOrNull(predicate) +} + +fun , T : Table> EntitySequence.singleOrNull(): E? { + return asKotlinSequence().singleOrNull() +} + +inline fun , T : Table> EntitySequence.singleOrNull(predicate: (T) -> ColumnDeclaring): E? { + return filter(predicate).singleOrNull() +} + +fun , T : Table> EntitySequence.single(): E { + return asKotlinSequence().single() +} + +inline fun , T : Table> EntitySequence.single(predicate: (T) -> ColumnDeclaring): E { + return filter(predicate).single() +} + +inline fun , R> EntitySequence.fold(initial: R, operation: (acc: R, E) -> R): R { + return asKotlinSequence().fold(initial, operation) +} + +inline fun , R> EntitySequence.foldIndexed(initial: R, operation: (index: Int, acc: R, E) -> R): R { + return asKotlinSequence().foldIndexed(initial, operation) +} + +inline fun > EntitySequence.reduce(operation: (acc: E, E) -> E): E { + return asKotlinSequence().reduce(operation) +} + +inline fun > EntitySequence.reduceIndexed(operation: (index: Int, acc: E, E) -> E): E { + return asKotlinSequence().reduceIndexed(operation) +} + +inline fun > EntitySequence.forEach(action: (E) -> Unit) { + for (item in this) action(item) +} + +inline fun > EntitySequence.forEachIndexed(action: (index: Int, E) -> Unit) { + var index = 0 + for (item in this) action(index++, item) +} + +inline fun , K> EntitySequence.groupBy( + keySelector: (E) -> K +): Map> { + return asKotlinSequence().groupBy(keySelector) +} + +inline fun , K, V> EntitySequence.groupBy( + keySelector: (E) -> K, + valueTransform: (E) -> V +): Map> { + return asKotlinSequence().groupBy(keySelector, valueTransform) +} + +inline fun , K, M : MutableMap>> EntitySequence.groupByTo( + destination: M, + keySelector: (E) -> K +): M { + return asKotlinSequence().groupByTo(destination, keySelector) +} + +inline fun , K, V, M : MutableMap>> EntitySequence.groupByTo( + destination: M, + keySelector: (E) -> K, + valueTransform: (E) -> V +): M { + return asKotlinSequence().groupByTo(destination, keySelector, valueTransform) +} + +fun , T : Table, K : Any> EntitySequence.groupingBy( + keySelector: (T) -> ColumnDeclaring +): EntityGrouping { + return EntityGrouping(this, keySelector) +} + +fun , A : Appendable> EntitySequence.joinTo( + buffer: A, + separator: CharSequence = ", ", + prefix: CharSequence = "", + postfix: CharSequence = "", + limit: Int = -1, + truncated: CharSequence = "...", + transform: ((E) -> CharSequence)? = null +): A { + return asKotlinSequence().joinTo(buffer, separator, prefix, postfix, limit, truncated, transform) +} + +fun > EntitySequence.joinToString( + separator: CharSequence = ", ", + prefix: CharSequence = "", + postfix: CharSequence = "", + limit: Int = -1, + truncated: CharSequence = "...", + transform: ((E) -> CharSequence)? = null +): String { + return asKotlinSequence().joinToString(separator, prefix, postfix, limit, truncated, transform) +} \ No newline at end of file diff --git a/ktorm-core/src/test/kotlin/me/liuwj/ktorm/BaseTest.kt b/ktorm-core/src/test/kotlin/me/liuwj/ktorm/BaseTest.kt index f0cc37cb..0a5c448f 100644 --- a/ktorm-core/src/test/kotlin/me/liuwj/ktorm/BaseTest.kt +++ b/ktorm-core/src/test/kotlin/me/liuwj/ktorm/BaseTest.kt @@ -83,5 +83,6 @@ open class BaseTest { val hireDate by date("hire_date").bindTo { it.hireDate } val salary by long("salary").bindTo { it.salary } val departmentId by int("department_id").references(Departments) { it.department } + val department get() = departmentId.referenceTable as Departments } } \ No newline at end of file diff --git a/ktorm-core/src/test/kotlin/me/liuwj/ktorm/dsl/AggregationTest.kt b/ktorm-core/src/test/kotlin/me/liuwj/ktorm/dsl/AggregationTest.kt index ed690cf1..b94f8964 100644 --- a/ktorm-core/src/test/kotlin/me/liuwj/ktorm/dsl/AggregationTest.kt +++ b/ktorm-core/src/test/kotlin/me/liuwj/ktorm/dsl/AggregationTest.kt @@ -1,6 +1,8 @@ package me.liuwj.ktorm.dsl import me.liuwj.ktorm.BaseTest +import me.liuwj.ktorm.entity.aggregate +import me.liuwj.ktorm.entity.asSequence import org.junit.Test /** @@ -40,7 +42,7 @@ class AggregationTest : BaseTest() { @Test fun testAvg() { - val avg = Employees.avgBy { it.salary } + val avg = Employees.averageBy { it.salary } println(avg) } @@ -58,4 +60,11 @@ class AggregationTest : BaseTest() { fun testAll() { assert(Employees.all { it.salary greater 0L }) } + + @Test + fun testAggregate() { + val result = Employees.asSequence().aggregate { max(it.salary) - min(it.salary) } + println(result) + assert(result == 150L) + } } \ No newline at end of file diff --git a/ktorm-core/src/test/kotlin/me/liuwj/ktorm/entity/EntitySequenceTest.kt b/ktorm-core/src/test/kotlin/me/liuwj/ktorm/entity/EntitySequenceTest.kt new file mode 100644 index 00000000..f20841c6 --- /dev/null +++ b/ktorm-core/src/test/kotlin/me/liuwj/ktorm/entity/EntitySequenceTest.kt @@ -0,0 +1,196 @@ +package me.liuwj.ktorm.entity + +import me.liuwj.ktorm.BaseTest +import me.liuwj.ktorm.dsl.* +import org.junit.Test + +/** + * Created by vince on Mar 22, 2019. + */ +class EntitySequenceTest : BaseTest() { + + @Test + fun testRealSequence() { + val sequence = listOf(1, 2, 3).asSequence() + sequence.toSet() + } + + @Test + fun testToList() { + val employees = Employees.asSequence().toList() + assert(employees.size == 4) + assert(employees[0].name == "vince") + assert(employees[0].department.name == "tech") + } + + @Test + fun testFilter() { + val names = Employees + .asSequence() + .filter { it.departmentId eq 1 } + .filterNot { it.managerId.isNull() } + .toList() + .map { it.name } + + assert(names.size == 1) + assert(names[0] == "marry") + } + + @Test + fun testFilterTo() { + val names = Employees + .asSequence() + .filter { it.departmentId eq 1 } + .filterTo(ArrayList()) { it.managerId.isNull() } + .map { it.name } + + assert(names.size == 1) + assert(names[0] == "vince") + } + + @Test + fun testCount() { + assert(Employees.asSequence().filter { it.departmentId eq 1 }.count() == 2) + assert(Employees.asSequence().count { it.departmentId eq 1 } == 2) + } + + @Test + fun testAll() { + assert(Employees.asSequence().filter { it.departmentId eq 1 }.all { it.salary greater 49L }) + } + + @Test + fun testAssociate() { + val employees = Employees.asSequence().filter { it.departmentId eq 1 }.associateBy { it.id } + assert(employees.size == 2) + assert(employees[1]!!.name == "vince") + } + + @Test + fun testDrop() { + try { + val employees = Employees.asSequence().drop(3).toList() + assert(employees.size == 1) + assert(employees[0].name == "penny") + } catch (e: UnsupportedOperationException) { + // Expected, pagination should be provided by dialects... + } + } + + @Test + fun testTake() { + try { + val employees = Employees.asSequence().take(1).toList() + assert(employees.size == 1) + assert(employees[0].name == "vince") + } catch (e: UnsupportedOperationException) { + // Expected, pagination should be provided by dialects... + } + } + + @Test + fun testFindLast() { + val employee = Employees + .asSequence() + .elementAt(3) + + assert(employee.name == "penny") + assert(Employees.asSequence().elementAtOrNull(4) == null) + } + + @Test + fun testFold() { + val totalSalary = Employees.asSequence().fold(0L) { acc, employee -> acc + employee.salary } + assert(totalSalary == 450L) + } + + @Test + fun testSorted() { + val employee = Employees.asSequence().sortedByDescending { it.salary }.first() + assert(employee.name == "tom") + } + + @Test + fun testFilterColumns() { + val employee = Employees + .asSequence() + .filterColumns { it.columns + it.department.columns - it.department.location } + .filter { it.department.id eq 1 } + .first() + + assert(employee.department.location.isEmpty()) + } + + @Test + fun testGroupBy() { + val employees = Employees + .asSequence() + .groupBy { it.department.id } + + println(employees) + assert(employees.size == 2) + assert(employees[1]!!.sumBy { it.salary.toInt() } == 150) + assert(employees[2]!!.sumBy { it.salary.toInt() } == 300) + } + + @Test + fun testGroupingBy() { + val salaries = Employees + .asSequence() + .groupingBy { it.departmentId * 2 } + .fold(0L) { acc, employee -> + acc + employee.salary + } + + println(salaries) + assert(salaries.size == 2) + assert(salaries[2] == 150L) + assert(salaries[4] == 300L) + } + + @Test + fun testEachCount() { + val counts = Employees + .asSequence() + .filter { it.salary less 100000L } + .groupingBy { it.departmentId } + .eachCount() + + println(counts) + assert(counts.size == 2) + assert(counts[1] == 2) + assert(counts[2] == 2) + } + + @Test + fun testEachSum() { + val sums = Employees + .asSequence() + .filter { it.salary lessEq 100000L } + .groupingBy { it.departmentId } + .eachSumBy { it.salary } + + println(sums) + assert(sums.size == 2) + assert(sums[1] == 150L) + assert(sums[2] == 300L) + } + + @Test + fun testJoinToString() { + val salaries = Employees.asSequence().joinToString { it.id.toString() } + assert(salaries == "1, 2, 3, 4") + } + + @Test + fun testReduce() { + val emp = Employees.asSequence().reduce { acc, employee -> acc.apply { salary += employee.salary } } + assert(emp.salary == 450L) + } + + @Test + fun testSingle() { + val employee = Employees.asSequence().singleOrNull { it.departmentId eq 1 } + assert(employee == null) + } +} \ No newline at end of file diff --git a/ktorm-core/src/test/kotlin/me/liuwj/ktorm/entity/EntityTest.kt b/ktorm-core/src/test/kotlin/me/liuwj/ktorm/entity/EntityTest.kt index a83c4fec..32703c55 100644 --- a/ktorm-core/src/test/kotlin/me/liuwj/ktorm/entity/EntityTest.kt +++ b/ktorm-core/src/test/kotlin/me/liuwj/ktorm/entity/EntityTest.kt @@ -2,9 +2,7 @@ package me.liuwj.ktorm.entity import me.liuwj.ktorm.BaseTest import me.liuwj.ktorm.dsl.* -import me.liuwj.ktorm.schema.Table -import me.liuwj.ktorm.schema.int -import me.liuwj.ktorm.schema.varchar +import me.liuwj.ktorm.schema.* import org.junit.Test import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream @@ -247,6 +245,22 @@ class EntityTest : BaseTest() { employees.forEach { println(it) } } + @Test + fun testCreateEntityWithoutReferences() { + val employees = Employees + .leftJoin(Departments, on = Employees.departmentId eq Departments.id) + .select(Employees.columns + Departments.columns) + .map { Employees.createEntityWithoutReferences(it) } + + employees.forEach { println(it) } + + assert(employees.size == 4) + assert(employees[0].department.id == 1) + assert(employees[1].department.id == 1) + assert(employees[2].department.id == 2) + assert(employees[3].department.id == 2) + } + @Test fun testAutoDiscardChanges() { var department = Departments.findById(2) ?: return @@ -270,9 +284,13 @@ class EntityTest : BaseTest() { } interface Emp : Entity { + companion object : Entity.Factory() val id: Int var employee: Employee var manager: Employee + var hireDate: LocalDate + var salary: Long + var departmentId: Int } object Emps : Table("t_employee") { @@ -280,6 +298,9 @@ class EntityTest : BaseTest() { val name by varchar("name").bindTo { it.employee.name } val job by varchar("job").bindTo { it.employee.job } val managerId by int("manager_id").bindTo { it.manager.id } + val hireDate by date("hire_date").bindTo { it.hireDate } + val salary by long("salary").bindTo { it.salary } + val departmentId by int("department_id").bindTo { it.departmentId } } @Test @@ -288,6 +309,28 @@ class EntityTest : BaseTest() { emp1.employee.name = "jerry" // emp1.flushChanges() + val emp2 = Emp { + employee = emp1.employee + hireDate = LocalDate.now() + salary = 100 + departmentId = 1 + } + + try { + Emps.add(emp2) + throw AssertionError("failed") + + } catch (e: IllegalStateException) { + assert(e.message == "this.employee.name may be unexpectedly discarded, please save it to database first.") + } + } + + @Test + fun testCheckUnexpectedFlush0() { + val emp1 = Emps.findById(1) ?: return + emp1.employee.name = "jerry" + // emp1.flushChanges() + val emp2 = Emps.findById(2) ?: return emp2.employee = emp1.employee @@ -296,7 +339,7 @@ class EntityTest : BaseTest() { throw AssertionError("failed") } catch (e: IllegalStateException) { - assert(e.message == "this.employee.name may be unexpectedly discarded after flushChanges, please save it to database first.") + assert(e.message == "this.employee.name may be unexpectedly discarded, please save it to database first.") } } @@ -314,7 +357,7 @@ class EntityTest : BaseTest() { throw AssertionError("failed") } catch (e: IllegalStateException) { - assert(e.message == "this.employee.name may be unexpectedly discarded after flushChanges, please save it to database first.") + assert(e.message == "this.employee.name may be unexpectedly discarded, please save it to database first.") } } @@ -327,4 +370,14 @@ class EntityTest : BaseTest() { emp = Emps.findById(1) ?: return assert(emp.manager.id == 2) } + + @Test + fun testCopy() { + var employee = Employees.findById(1)?.copy() ?: return + employee.name = "jerry" + employee.flushChanges() + + employee = Employees.findById(1) ?: return + assert(employee.name == "jerry") + } } \ No newline at end of file diff --git a/ktorm-support-mysql/src/test/kotlin/me/liuwj/ktorm/support/mysql/MySqlTest.kt b/ktorm-support-mysql/src/test/kotlin/me/liuwj/ktorm/support/mysql/MySqlTest.kt index e1b3f35b..a05128e8 100644 --- a/ktorm-support-mysql/src/test/kotlin/me/liuwj/ktorm/support/mysql/MySqlTest.kt +++ b/ktorm-support-mysql/src/test/kotlin/me/liuwj/ktorm/support/mysql/MySqlTest.kt @@ -4,7 +4,7 @@ import me.liuwj.ktorm.BaseTest import me.liuwj.ktorm.database.Database import me.liuwj.ktorm.database.useConnection import me.liuwj.ktorm.dsl.* -import me.liuwj.ktorm.entity.findById +import me.liuwj.ktorm.entity.* import org.junit.Test import java.time.LocalDate @@ -152,4 +152,29 @@ class MySqlTest : BaseTest() { assert(query.totalRecords == 4) } + + @Test + fun testDrop() { + val employees = Employees.asSequence().drop(1).drop(1).drop(1).toList() + assert(employees.size == 1) + assert(employees[0].name == "penny") + } + + @Test + fun testTake() { + val employees = Employees.asSequence().take(2).take(1).toList() + assert(employees.size == 1) + assert(employees[0].name == "vince") + } + + @Test + fun testElementAt() { + val employee = Employees + .asSequence() + .drop(2) + .elementAt(1) + + assert(employee.name == "penny") + assert(Employees.asSequence().elementAtOrNull(4) == null) + } } \ No newline at end of file