Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merykitty's attempt #114

Merged
merged 10 commits into from Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 20 additions & 0 deletions calculate_average_merykitty.sh
@@ -0,0 +1,20 @@
#!/bin/sh
#
# Copyright 2023 The original authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector" # -XX:+UnlockDiagnosticVMOptions -XX:PrintAssemblyOptions=intel -XX:CompileCommand=print,*.CalculateAverage_merykitty::iterate"
time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_merykitty
368 changes: 368 additions & 0 deletions src/main/java/dev/morling/onebrc/CalculateAverage_merykitty.java
@@ -0,0 +1,368 @@
/*
* Copyright 2023 The original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package dev.morling.onebrc;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;

public class CalculateAverage_merykitty {
private static final String FILE = "./measurements.txt";
private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED;
private static final ValueLayout.OfLong JAVA_LONG_LT = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
private static final VarHandle PREFETCH_HANDLE = MethodHandles.memorySegmentViewVarHandle(ValueLayout.JAVA_BYTE);
private static final long KEY_MAX_SIZE = 100;

private static record ResultRow(double min, double mean, double max) {
public String toString() {
return round(min) + "/" + round(mean) + "/" + round(max);
}

private double round(double value) {
return Math.round(value * 10.0) / 10.0;
}
};

private static class Aggregator {
private int min = Integer.MAX_VALUE;
private int max = Integer.MIN_VALUE;
private int sum;
private long count;
}

// An open-address map that is specialized for this task
private static class PoorManMap {
static final int R_LOAD_FACTOR = 2;

private static class PoorManMapNode {
byte[] data;
long size;
int hash;
Aggregator aggr;

PoorManMapNode(MemorySegment data, long offset, long size, int hash) {
this.hash = hash;
this.size = size;
this.data = new byte[BYTE_SPECIES.vectorByteSize() + (int) KEY_MAX_SIZE];
this.aggr = new Aggregator();
MemorySegment.copy(data, offset, MemorySegment.ofArray(this.data), BYTE_SPECIES.vectorByteSize(), size);
}
}

MemorySegment data;
PoorManMapNode[] nodes;
int size;

PoorManMap(MemorySegment data) {
this.data = data;
this.nodes = new PoorManMapNode[1 << 10];
}

Aggregator indexSimple(long offset, long size, int hash) {
hash = rehash(hash);
int bucketMask = nodes.length - 1;
int bucket = hash & bucketMask;
for (;; bucket = (bucket + 1) & bucketMask) {
PoorManMapNode node = nodes[bucket];
if (node == null) {
this.size++;
if (this.size * R_LOAD_FACTOR > nodes.length) {
grow();
bucketMask = nodes.length - 1;
for (bucket = hash & bucketMask; nodes[bucket] != null; bucket = (bucket + 1) & bucketMask) {
}
}
node = new PoorManMapNode(this.data, offset, size, hash);
nodes[bucket] = node;
return node.aggr;
}
else if (keyEqualScalar(node, offset, size, hash)) {
return node.aggr;
}
}
}

void grow() {
var oldNodes = this.nodes;
var newNodes = new PoorManMapNode[oldNodes.length * 2];
int bucketMask = newNodes.length - 1;
for (var node : oldNodes) {
if (node == null) {
continue;
}
int bucket = node.hash & bucketMask;
for (; newNodes[bucket] != null; bucket = (bucket + 1) & bucketMask) {
}
newNodes[bucket] = node;
}
this.nodes = newNodes;
}

static int rehash(int x) {
x = ((x >>> 16) ^ x) * 0x45d9f3b;
x = ((x >>> 16) ^ x) * 0x45d9f3b;
x = (x >>> 16) ^ x;
return x;
}

private boolean keyEqualScalar(PoorManMapNode node, long offset, long size, int hash) {
if (node.hash != hash || node.size != size) {
return false;
}

// Be simple
for (int i = 0; i < size; i++) {
int c1 = node.data[BYTE_SPECIES.vectorByteSize() + i];
int c2 = data.get(ValueLayout.JAVA_BYTE, offset + i);
if (c1 != c2) {
return false;
}
}
return true;
}
}

private static long parseDataPoint(Aggregator aggr, MemorySegment data, long offset) {
long word = data.get(JAVA_LONG_LT, offset);
// This can be 12, 20, 28
int decimalSepPos = Long.numberOfTrailingZeros(~word & 0x10101000);
int shift = 28 - decimalSepPos;
// signed is -1 if negative, 0 otherwise
long signed = (~word << 59) >> 63;
long designMask = ~(signed & 0xFF);
long digits = ((word & designMask) << shift) & 0x0F000F0F00L;

// Now digits is in the form 0xUU00TTHH00
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is TT, HH, etc? As for the main loop, could you provide an overview of how this thing works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UU stands for the value of the units digit, TT is the tens digit and HH is the hundreds digit.

// 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) =
// 0x000000UU00TTHH00 +
// 0x00UU00TTHH000000 * 10 +
// 0xUU00TTHH00000000 * 100
// Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400
// This results in our value lies in the bit 32 to 41 of this product
// That was close :)
long value = ((digits * 0x640a0001) >>> 32) & 0x3FF;
int point = (int) ((value ^ signed) - signed);
aggr.min = Math.min(point, aggr.min);
aggr.max = Math.max(point, aggr.max);
aggr.sum += point;
aggr.count++;
return offset + (decimalSepPos >>> 3) + 3;
}

private static long parseDataPointTail(Aggregator aggr, MemorySegment data, long offset) {
int point = 0;
boolean negative = false;
if (data.get(ValueLayout.JAVA_BYTE, offset) == '-') {
negative = true;
offset++;
}
for (;; offset++) {
int c = data.get(ValueLayout.JAVA_BYTE, offset);
if (c == '.') {
c = data.get(ValueLayout.JAVA_BYTE, offset + 1);
point = point * 10 + (c - '0');
offset += 3;
break;
}

point = point * 10 + (c - '0');
}
point = negative ? -point : point;
aggr.min = Math.min(point, aggr.min);
aggr.max = Math.max(point, aggr.max);
aggr.sum += point;
aggr.count++;
return offset;
}

private static long iterate(PoorManMap aggrMap, MemorySegment data, long offset) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide a description of how that iteration loop works? I roughly get it, but it's quite dense code and having a high-level description would be great.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

var line = ByteVector.fromMemorySegment(BYTE_SPECIES, data, offset, ByteOrder.nativeOrder());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this would fail if station names are e.g. 100 double-byte characters, right? Not a problem per se, we agreed yesterday on enforcing a 100 bytes limit for station names (rather than 100 UTF-8 characters), I'm just curious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No if we cannot find the delimiter which means the key is longer than the segment we fall back to slow path which do so in a more straightforward way,

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, got it. Thanks for adding all the comments!

long semicolons = line.compare(VectorOperators.EQ, ';').toLong();
if (semicolons == 0) {
long semicolonPos = BYTE_SPECIES.vectorByteSize();
for (; data.get(ValueLayout.JAVA_BYTE, offset + semicolonPos) != ';'; semicolonPos++) {
}
int hash = line.reinterpretAsInts().lane(0);
var aggr = aggrMap.indexSimple(offset, semicolonPos, hash);
return parseDataPoint(aggr, data, offset + 1 + semicolonPos);
}

long currOffset = offset;
while (true) {
int localOffset = (int) (currOffset - offset);
long semicolonPos = Long.numberOfTrailingZeros(semicolons) - localOffset;
int hash = data.get(ValueLayout.JAVA_INT_UNALIGNED, currOffset);
if (semicolonPos < Integer.BYTES) {
hash = (byte) hash;
}

Aggregator aggr;
hash = PoorManMap.rehash(hash);
int bucketMask = aggrMap.nodes.length - 1;
int bucket = hash & bucketMask;
for (;; bucket = (bucket + 1) & bucketMask) {
PoorManMap.PoorManMapNode node = aggrMap.nodes[bucket];
if (node == null) {
aggrMap.size++;
if (aggrMap.size * PoorManMap.R_LOAD_FACTOR > aggrMap.nodes.length) {
aggrMap.grow();
bucketMask = aggrMap.nodes.length - 1;
for (bucket = hash & bucketMask; aggrMap.nodes[bucket] != null; bucket = (bucket + 1) & bucketMask) {
}
}
node = new PoorManMap.PoorManMapNode(data, currOffset, semicolonPos, hash);
aggrMap.nodes[bucket] = node;
aggr = node.aggr;
break;
}

if (node.hash != hash || node.size != semicolonPos) {
continue;
}

var nodeKey = ByteVector.fromArray(BYTE_SPECIES, node.data, BYTE_SPECIES.length() - localOffset);
var eqMask = line.compare(VectorOperators.EQ, nodeKey).toLong();
long validMask = (-1L >>> -semicolonPos) << localOffset;
if ((eqMask & validMask) == validMask) {
aggr = node.aggr;
break;
}
}

long nextOffset = parseDataPoint(aggr, data, currOffset + 1 + semicolonPos);
semicolons &= (semicolons - 1);
if (semicolons == 0) {
return nextOffset;
}
currOffset = nextOffset;
}
}

// Process all lines that start in [offset, limit)
private static PoorManMap processFile(MemorySegment data, long offset, long limit) {
var aggrMap = new PoorManMap(data);
if (offset != 0) {
offset--;
for (; offset < limit;) {
if (data.get(ValueLayout.JAVA_BYTE, offset++) == '\n') {
break;
}
}
}
if (offset == limit) {
return aggrMap;
}

while (offset < limit - Math.max(BYTE_SPECIES.vectorByteSize(),
Long.BYTES + 1 + KEY_MAX_SIZE)) {
offset = iterate(aggrMap, data, offset);
}

// Now we are at the tail, just be simple
while (offset < limit) {
long semicolonPos = 0;
for (; data.get(ValueLayout.JAVA_BYTE, offset + semicolonPos) != ';'; semicolonPos++) {
}
int hash;
if (semicolonPos >= Integer.BYTES) {
hash = data.get(ValueLayout.JAVA_INT_UNALIGNED, offset);
}
else {
hash = data.get(ValueLayout.JAVA_BYTE, offset);
}
var aggr = aggrMap.indexSimple(offset, semicolonPos, hash);
offset = parseDataPointTail(aggr, data, offset + 1 + semicolonPos);
}

return aggrMap;
}

public static void main(String[] args) throws InterruptedException, IOException {
int processorCnt = Runtime.getRuntime().availableProcessors();
var res = HashMap.<String, Aggregator> newHashMap(processorCnt);
try (var file = FileChannel.open(Path.of(FILE), StandardOpenOption.READ);
var arena = Arena.ofShared()) {
var data = file.map(MapMode.READ_ONLY, 0, file.size(), arena);
long chunkSize = Math.ceilDiv(data.byteSize(), processorCnt);
var threadList = new Thread[processorCnt];
var resultList = new PoorManMap[processorCnt];
for (int i = 0; i < processorCnt; i++) {
int index = i;
long offset = i * chunkSize;
long limit = Math.min((i + 1) * chunkSize, data.byteSize());
var prefetch = new Thread(() -> {
for (long o = offset; o < limit; o += 1024) {
byte x = (byte) PREFETCH_HANDLE.getOpaque(data, o);
}
});
prefetch.start();
var thread = new Thread(() -> {
resultList[index] = processFile(data, offset, limit);
});
threadList[index] = thread;
thread.start();
}
for (var thread : threadList) {
thread.join();
}

// Collect the results
for (var aggrMap : resultList) {
for (var node : aggrMap.nodes) {
if (node == null) {
continue;
}
byte[] keyData = Arrays.copyOfRange(node.data, BYTE_SPECIES.vectorByteSize(), BYTE_SPECIES.vectorByteSize() + (int) node.size);
String key = new String(keyData, StandardCharsets.UTF_8);
var aggr = node.aggr;
var resAggr = new Aggregator();
var existingAggr = res.putIfAbsent(key, resAggr);
if (existingAggr != null) {
resAggr = existingAggr;
}
resAggr.min = Math.min(resAggr.min, aggr.min);
resAggr.max = Math.max(resAggr.max, aggr.max);
resAggr.sum += aggr.sum;
resAggr.count += aggr.count;
}
}
}

Map<String, ResultRow> measurements = new TreeMap<>();
for (var entry : res.entrySet()) {
String key = entry.getKey();
var aggr = entry.getValue();
measurements.put(key, new ResultRow((double) aggr.min / 10, (double) aggr.sum / (aggr.count * 10), (double) aggr.max / 10));
}
System.out.println(measurements);
}
}