From 839b5ab108c1592a471bc683328f68c1b46486f8 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 24 May 2018 14:45:21 +0200 Subject: [PATCH] Clean up after PR review --- .../internal/runtime/LongArrayHash.java | 8 +- .../internal/runtime/LongArrayHashMap.java | 28 +++---- .../runtime/LongArrayHashMultiMap.java | 20 ++--- .../internal/runtime/LongArrayHashTable.java | 46 +++++------ .../runtime/LongArrayHashMapTest.scala | 81 ++++++++++++++++--- .../runtime/LongArrayHashMultiMapTest.scala | 31 ++++++- .../runtime/LongArrayHashSetTest.scala | 60 ++++++-------- ...EagerAggregationSlottedPrimitivePipe.scala | 2 +- 8 files changed, 174 insertions(+), 102 deletions(-) diff --git a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHash.java b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHash.java index 77f75f00b88bb..6d425d1e6853c 100644 --- a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHash.java +++ b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHash.java @@ -19,8 +19,6 @@ */ package org.neo4j.cypher.internal.runtime; -import org.neo4j.helpers.collection.Pair; - public class LongArrayHash { static final long NOT_IN_USE = -2; @@ -28,6 +26,11 @@ public class LongArrayHash static final int VALUE_FOUND = 1; static final int CONTINUE_PROBING = -1; + // Static class only + private LongArrayHash() + { + } + public static int hashCode( long[] arr, int from, int numberOfElements ) { // This way of producing a hashcode for an array of longs is the @@ -58,5 +61,4 @@ static boolean validValue( long[] arr, int width ) } return true; } - } diff --git a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashMap.java b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashMap.java index f110134516ec3..90622a649f1d3 100644 --- a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashMap.java +++ b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashMap.java @@ -38,23 +38,23 @@ */ public class LongArrayHashMap { - private final int width; + private final int keySize; private LongArrayHashTable table; private Object[] values; - public LongArrayHashMap( int initialCapacity, int width ) + public LongArrayHashMap( int initialCapacity, int keySize ) { - assert (initialCapacity & (initialCapacity - 1)) == 0 : "Size must be a power of 2"; - assert width > 0 : "Number of elements must be larger than 0"; + assert (initialCapacity & (initialCapacity - 1)) == 0 : "Capacity must be a power of 2"; + assert keySize > 0 : "Number of elements must be larger than 0"; - this.width = width; - table = new LongArrayHashTable( initialCapacity, width ); + this.keySize = keySize; + table = new LongArrayHashTable( initialCapacity, keySize ); values = new Object[initialCapacity]; } - public VALUE getOrCreateAndAdd( long[] key, Supplier creator ) + public VALUE computeIfAbsent( long[] key, Supplier creator ) { - assert LongArrayHash.validValue( key, width ); + assert LongArrayHash.validValue( key, keySize ); int slotNr = slotFor( key ); while ( true ) { @@ -96,7 +96,7 @@ public VALUE getOrCreateAndAdd( long[] key, Supplier creator ) public VALUE get( long[] key ) { - assert LongArrayHash.validValue( key, width ); + assert LongArrayHash.validValue( key, keySize ); int slotNr = slotFor( key ); while ( true ) { @@ -135,7 +135,7 @@ private void resize() private int slotFor( long[] value ) { - return LongArrayHash.hashCode( value, 0, width ) & table.tableMask; + return LongArrayHash.hashCode( value, 0, keySize ) & table.tableMask; } public Iterator> iterator() @@ -148,7 +148,7 @@ public Iterator> iterator() protected Map.Entry fetchNextOrNull() { // First, find a good spot - while ( current < table.capacity && table.keys[current * width] == NOT_IN_USE ) + while ( current < table.capacity && table.keys[current * keySize] == NOT_IN_USE ) { current = current + 1; } @@ -160,8 +160,8 @@ protected Map.Entry fetchNextOrNull() } // Otherwise, let's create the return object. - long[] key = new long[width]; - System.arraycopy( table.keys, current * width, key, 0, width ); + long[] key = new long[keySize]; + System.arraycopy( table.keys, current * keySize, key, 0, keySize ); @SuppressWarnings( "unchecked" ) VALUE value = (VALUE) values[current]; @@ -201,7 +201,7 @@ public VALUE getValue() @Override public VALUE setValue( VALUE value ) { - return null; + throw new UnsupportedOperationException(); } } } diff --git a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashMultiMap.java b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashMultiMap.java index 53c5604802c14..c45fa25a140bc 100644 --- a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashMultiMap.java +++ b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashMultiMap.java @@ -35,23 +35,23 @@ */ public class LongArrayHashMultiMap { - private final int width; + private final int keySize; private LongArrayHashTable table; private Object[] values; - public LongArrayHashMultiMap( int initialCapacity, int width ) + public LongArrayHashMultiMap( int initialCapacity, int keySize ) { - assert (initialCapacity & (initialCapacity - 1)) == 0 : "Size must be a power of 2"; - assert width > 0 : "Number of elements must be larger than 0"; + assert (initialCapacity & (initialCapacity - 1)) == 0 : "Capacity must be a power of 2"; + assert keySize > 0 : "Number of elements must be larger than 0"; - this.width = width; - table = new LongArrayHashTable( initialCapacity, width ); + this.keySize = keySize; + table = new LongArrayHashTable( initialCapacity, keySize ); values = new Object[initialCapacity]; } public void add( long[] key, VALUE value ) { - assert LongArrayHash.validValue( key, width ); + assert LongArrayHash.validValue( key, keySize ); int slotNr = slotFor( key ); while ( true ) @@ -62,7 +62,7 @@ public void add( long[] key, VALUE value ) case SLOT_EMPTY: if ( table.timeToResize() ) { - // We know we need to add the value to the set, but there is no space left + // We know we need to add the value to the map, but there is no space left resize(); // Need to restart linear probe after resizing slotNr = slotFor( key ); @@ -95,7 +95,7 @@ public void add( long[] key, VALUE value ) public Iterator get( long[] key ) { - assert LongArrayHash.validValue( key, width ); + assert LongArrayHash.validValue( key, keySize ); int slot = slotFor( key ); // Here we'll spin while the slot is taken by a different value. @@ -124,7 +124,7 @@ private void resize() private int slotFor( long[] value ) { - return LongArrayHash.hashCode( value, 0, width ) & table.tableMask; + return LongArrayHash.hashCode( value, 0, keySize ) & table.tableMask; } class Node diff --git a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashTable.java b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashTable.java index 7189ea9875759..5c14ae7cb4e03 100644 --- a/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashTable.java +++ b/community/cypher/runtime-util/src/main/java/org/neo4j/cypher/internal/runtime/LongArrayHashTable.java @@ -59,7 +59,7 @@ class LongArrayHashTable */ boolean timeToResize() { - return numberOfEntries == resizeLimit; + return numberOfEntries >= resizeLimit; } /*** @@ -101,20 +101,14 @@ int checkSlot( int slot, long[] key ) void claimSlot( int slot, long[] key ) { int offset = slot * width; + assert keys[offset] == NOT_IN_USE : "Tried overwriting an already used slot"; System.arraycopy( key, 0, keys, offset, width ); numberOfEntries++; } public boolean isEmpty() { - for ( int i = 0; i < keys.length; i = i + width ) - { - if ( keys[i] != NOT_IN_USE ) - { - return false; - } - } - return true; + return numberOfEntries == 0; } /** @@ -136,39 +130,39 @@ private int findUnusedSlot( int fromSlot ) LongArrayHashTable doubleCapacity() { - LongArrayHashTable newTable = new LongArrayHashTable( capacity * 2, width ); - newTable.numberOfEntries = numberOfEntries; + LongArrayHashTable toTable = new LongArrayHashTable( capacity * 2, width ); + toTable.numberOfEntries = numberOfEntries; for ( int fromOffset = 0; fromOffset < capacity * width; fromOffset = fromOffset + width ) { if ( keys[fromOffset] != NOT_IN_USE ) { - int toSlot = LongArrayHash.hashCode( keys, fromOffset, width ) & newTable.tableMask; - toSlot = newTable.findUnusedSlot( toSlot ); - System.arraycopy( keys, fromOffset, newTable.keys, toSlot * width, width ); + int toSlot = LongArrayHash.hashCode( keys, fromOffset, width ) & toTable.tableMask; + toSlot = toTable.findUnusedSlot( toSlot ); + System.arraycopy( keys, fromOffset, toTable.keys, toSlot * width, width ); } } - return newTable; + return toTable; } - Pair doubleCapacity( Object[] srcValues ) + Pair doubleCapacity( Object[] fromValues ) { - LongArrayHashTable dstTable = new LongArrayHashTable( capacity * 2, width ); - Object[] dstValues = new Object[capacity * 2]; - long[] srcKeys = keys; - dstTable.numberOfEntries = numberOfEntries; + LongArrayHashTable toTable = new LongArrayHashTable( capacity * 2, width ); + Object[] toValues = new Object[capacity * 2]; + long[] fromKeys = keys; + toTable.numberOfEntries = numberOfEntries; for ( int fromSlot = 0; fromSlot < capacity; fromSlot = fromSlot + 1 ) { int fromOffset = fromSlot * width; - if ( srcKeys[fromOffset] != NOT_IN_USE ) + if ( fromKeys[fromOffset] != NOT_IN_USE ) { - int toSlot = LongArrayHash.hashCode( srcKeys, fromOffset, width ) & dstTable.tableMask; - toSlot = dstTable.findUnusedSlot( toSlot ); - System.arraycopy( srcKeys, fromOffset, dstTable.keys, toSlot * width, width ); - dstValues[toSlot] = srcValues[fromSlot]; + int toSlot = LongArrayHash.hashCode( fromKeys, fromOffset, width ) & toTable.tableMask; + toSlot = toTable.findUnusedSlot( toSlot ); + System.arraycopy( fromKeys, fromOffset, toTable.keys, toSlot * width, width ); + toValues[toSlot] = fromValues[fromSlot]; } } - return Pair.of( dstTable, dstValues ); + return Pair.of( toTable, toValues ); } } diff --git a/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashMapTest.scala b/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashMapTest.scala index 5447d497af0b6..7ffc9f2cb7ff8 100644 --- a/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashMapTest.scala +++ b/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashMapTest.scala @@ -24,12 +24,14 @@ import java.util.function.Supplier import org.scalatest.{FunSuite, Matchers} import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random -class LongArrayHashMapTest extends FunSuite with Matchers { +class LongArrayHashMapTest extends FunSuite with Matchers with RandomTester { test("basic") { val map = new LongArrayHashMap[String](32, 3) - map.getOrCreateAndAdd(Array(1L, 2L, 3L), () => "hello") should equal("hello") - map.getOrCreateAndAdd(Array(1L, 2L, 3L), () => "world") should equal("hello") + map.computeIfAbsent(Array(1L, 2L, 3L), () => "hello") should equal("hello") + map.computeIfAbsent(Array(1L, 2L, 3L), () => "world") should equal("hello") map.get(Array(1L, 2L, 3L)) should equal("hello") resultAsSet(map) should equal(Set(List(1L, 2L, 3L) -> "hello")) @@ -44,14 +46,14 @@ class LongArrayHashMapTest extends FunSuite with Matchers { test("fill and doubleCapacity") { val map = new LongArrayHashMap[String](8, 3) - map.getOrCreateAndAdd(Array(0L, 8L, 1L), () => "hello") - map.getOrCreateAndAdd(Array(0L, 7L, 2L), () => "is") - map.getOrCreateAndAdd(Array(0L, 6L, 3L), () => "it") - map.getOrCreateAndAdd(Array(0L, 5L, 4L), () => "me") - map.getOrCreateAndAdd(Array(0L, 4L, 5L), () => "you") - map.getOrCreateAndAdd(Array(0L, 3L, 6L), () => "are") - map.getOrCreateAndAdd(Array(0L, 2L, 7L), () => "looking") - map.getOrCreateAndAdd(Array(0L, 1L, 8L), () => "for") + map.computeIfAbsent(Array(0L, 8L, 1L), () => "hello") + map.computeIfAbsent(Array(0L, 7L, 2L), () => "is") + map.computeIfAbsent(Array(0L, 6L, 3L), () => "it") + map.computeIfAbsent(Array(0L, 5L, 4L), () => "me") + map.computeIfAbsent(Array(0L, 4L, 5L), () => "you") + map.computeIfAbsent(Array(0L, 3L, 6L), () => "are") + map.computeIfAbsent(Array(0L, 2L, 7L), () => "looking") + map.computeIfAbsent(Array(0L, 1L, 8L), () => "for") map.get(Array(0L, 8L, 1L)) should equal("hello") map.get(Array(0L, 7L, 2L)) should equal("is") @@ -85,4 +87,61 @@ class LongArrayHashMapTest extends FunSuite with Matchers { override def get(): T = f() } + randomTest { randomer => + val r = randomer.r + val width = r.nextInt(10) + 2 + val size = r.nextInt(10000) + val tested = new LongArrayHashMap[String](16, width) + val validator = new mutable.HashMap[Array[Long], String]() + (0 to size) foreach { _ => + val key = new Array[Long](width) + (0 until width) foreach { i => key(i) = randomer.randomLong() } + tested.computeIfAbsent(key, () => key.toString) + validator.getOrElseUpdate(key, key.toString) + } + + validator foreach { case (key: Array[Long], expectedValue: String) => + val v = tested.get(key) + v should equal(expectedValue) + } + + (0 to size) foreach { _ => + val tuple = new Array[Long](width) + (0 until width) foreach { i => tuple(i) = randomer.randomLong() } + val a = tested.get(tuple) + val b = validator.getOrElse(tuple, null) + a should equal(b) + } + + } } + +trait RandomTester { + self: FunSuite => + def randomTest(f: Randomer => Unit): Unit = { + val seed = System.nanoTime() + val rand = new Random(seed) + val input = new Randomer { + override val r: Random = rand + } + + (0 to 100) foreach { i => + test(s"random test with seed $seed uniquefier $i") { + f(input) + } + } + } + + trait Randomer { + val r: Random + + def randomLong(): Long = { + val x = r.nextLong() + if (x == -1 || x == -2) + randomLong() + else + x + } + } + +} \ No newline at end of file diff --git a/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashMultiMapTest.scala b/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashMultiMapTest.scala index e5c03d88d9acb..23608b6cad5a7 100644 --- a/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashMultiMapTest.scala +++ b/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashMultiMapTest.scala @@ -22,8 +22,9 @@ package org.neo4j.cypher.internal.runtime import org.scalatest.{FunSuite, Matchers} import scala.collection.JavaConverters._ +import scala.collection.{immutable, mutable} -class LongArrayHashMultiMapTest extends FunSuite with Matchers { +class LongArrayHashMultiMapTest extends FunSuite with Matchers with RandomTester { test("basic") { val map = new LongArrayHashMultiMap[String](32, 3) map.add(Array(1L, 2L, 3L), "hello") @@ -85,4 +86,32 @@ class LongArrayHashMultiMapTest extends FunSuite with Matchers { map.get(Array(0L, 0L)).asScala.toList should equal(List.empty) } + randomTest { randomer => + val r = randomer.r + val width = r.nextInt(10) + 2 + val size = r.nextInt(10000) + val tested = new LongArrayHashMultiMap[String](16, width) + val validator = new mutable.HashMap[Array[Long], mutable.ListBuffer[String]]() + (0 to size) foreach { _ => + val key = new Array[Long](width) + (0 until width) foreach { i => key(i) = randomer.randomLong() } + val value = System.nanoTime().toString + tested.add(key, value) + val values: mutable.ListBuffer[String] = validator.getOrElseUpdate(key, new mutable.ListBuffer[String]) + values.append(value) + } + + validator.foreach { case (key, expectedValues) => + val v = tested.get(key).asScala.toList + v should equal(expectedValues.toList) + } + + (0 to size) foreach { _ => + val tuple = new Array[Long](width) + (0 until width) foreach { i => tuple(i) = randomer.randomLong() } + val a = tested.get(tuple).asScala.toList + val b = validator.getOrElse(tuple, List.empty) + a should equal(b.toList) + } + } } diff --git a/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashSetTest.scala b/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashSetTest.scala index 4b39b94d97c0a..3a4f151d06241 100644 --- a/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashSetTest.scala +++ b/community/cypher/runtime-util/src/test/scala/org/neo4j/cypher/internal/runtime/LongArrayHashSetTest.scala @@ -24,50 +24,38 @@ import java.util import org.scalatest.{FunSuite, Matchers} import scala.collection.mutable -import scala.util.Random -class LongArrayHashSetTest extends FunSuite with Matchers { +class LongArrayHashSetTest extends FunSuite with Matchers with RandomTester { - val r = new Random() - - (0 to 100) foreach { i => - test(s"test #$i") { - val width = r.nextInt(10) + 2 - val size = r.nextInt(10000) - val tested = new LongArrayHashSet(16, width) - val validator = new mutable.HashSet[Array[Long]]() - (0 to size) foreach { _ => - val tuple = new Array[Long](width) - (0 until width) foreach { i => tuple(i) = randomLong() } - tested.add(tuple) - validator.add(tuple) - } + randomTest { randomer => + val r = randomer.r + val width = r.nextInt(10) + 2 + val size = r.nextInt(10000) + val tested = new LongArrayHashSet(16, width) + val validator = new mutable.HashSet[Array[Long]]() + (0 to size) foreach { _ => + val tuple = new Array[Long](width) + (0 until width) foreach { i => tuple(i) = randomer.randomLong() } + tested.add(tuple) + validator.add(tuple) + } - validator foreach { x => - if (!tested.contains(x)) - fail(s"Value was missing: ${util.Arrays.toString(x)}") - } + validator foreach { x => + if (!tested.contains(x)) + fail(s"Value was missing: ${util.Arrays.toString(x)}") + } - (0 to size) foreach { _ => - val tuple = new Array[Long](width) - (0 until width) foreach { i => tuple(i) = randomLong() } - val a = tested.contains(tuple) - val b = validator.contains(tuple) + (0 to size) foreach { _ => + val tuple = new Array[Long](width) + (0 until width) foreach { i => tuple(i) = randomer.randomLong() } + val a = tested.contains(tuple) + val b = validator.contains(tuple) - if(a != b) - fail(s"Value: ${util.Arrays.toString(tuple)} LongArrayHashSet $a mutable.HashSet") - } + if (a != b) + fail(s"Value: ${util.Arrays.toString(tuple)} LongArrayHashSet $a mutable.HashSet") } } - private def randomLong(): Long = { - val x = r.nextLong() - if (x == -1 || x == -2) - randomLong() - else - x - } - test("manual test to help with debugging") { val set = new LongArrayHashSet(8, 3) set.add(Array(1, 2, 3)) diff --git a/enterprise/cypher/slotted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/slotted/pipes/EagerAggregationSlottedPrimitivePipe.scala b/enterprise/cypher/slotted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/slotted/pipes/EagerAggregationSlottedPrimitivePipe.scala index 2db6cb34ed2a0..5d5abef360bd3 100644 --- a/enterprise/cypher/slotted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/slotted/pipes/EagerAggregationSlottedPrimitivePipe.scala +++ b/enterprise/cypher/slotted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/slotted/pipes/EagerAggregationSlottedPrimitivePipe.scala @@ -86,7 +86,7 @@ case class EagerAggregationSlottedPrimitivePipe(source: Pipe, // Consume all input and aggregate input.foreach(ctx => { setKeyFromCtx(ctx) - val aggregationFunctions = result.getOrCreateAndAdd(keys, supplier) + val aggregationFunctions = result.computeIfAbsent(keys, supplier) aggregationFunctions.foreach(func => func(ctx, state)) })