Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,21 @@ SimilarityResult jaccard(double similarityCutoff, CategoricalInput e2) {
if (jaccard < similarityCutoff) return null;
return new SimilarityResult(id, e2.id, count1, count2, intersection, jaccard);
}

SimilarityResult overlap(double similarityCutoff, CategoricalInput e2) {
long intersection = Intersections.intersection3(targets, e2.targets);
if (similarityCutoff >= 0d && intersection == 0) return null;
int count1 = targets.length;
int count2 = e2.targets.length;
long denominator = Math.min(count1, count2);
double overlap = denominator == 0 ? 0 : (double)intersection / denominator;
if (overlap < similarityCutoff) return null;

if(count1 <= count2) {
return new SimilarityResult(id, e2.id, count1, count2, intersection, overlap, false, false);
} else {
return new SimilarityResult(e2.id, id, count2, count1, intersection, overlap, false, true);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public Stream<SimilaritySummaryResult> cosine(


boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
return writeAndAggregateResults(configuration, stream, inputs.length, write);
return writeAndAggregateResults(configuration, stream, inputs.length, write, "SIMILAR");
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public Stream<SimilaritySummaryResult> euclidean(
.map(SimilarityResult::squareRooted);

boolean write = configuration.isWriteFlag(false); // && similarityCutoff != 0.0;
return writeAndAggregateResults(configuration, stream, inputs.length, write);
return writeAndAggregateResults(configuration, stream, inputs.length, write, "SIMILAR");
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public Stream<SimilaritySummaryResult> jaccard(
Stream<SimilarityResult> stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, getTopK(configuration)), getTopN(configuration));

boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
return writeAndAggregateResults(configuration, stream, inputs.length, write);
return writeAndAggregateResults(configuration, stream, inputs.length, write, "SIMILAR");
}


Expand Down
70 changes: 70 additions & 0 deletions algo/src/main/java/org/neo4j/graphalgo/similarity/OverlapProc.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/**
* Copyright (c) 2017 "Neo4j, Inc." <http://neo4j.com>
*
* This file is part of Neo4j Graph Algorithms <http://github.com/neo4j-contrib/neo4j-graph-algorithms>.
*
* Neo4j Graph Algorithms is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.graphalgo.similarity;

import org.neo4j.graphalgo.core.ProcedureConfiguration;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

public class OverlapProc extends SimilarityProc {

@Procedure(name = "algo.similarity.overlap.stream", mode = Mode.READ)
@Description("CALL algo.similarity.overlap.stream([{source:id, targets:[ids]}], {similarityCutoff:-1,degreeCutoff:0}) " +
"YIELD item1, item2, count1, count2, intersection, similarity - computes jaccard similarities")
public Stream<SimilarityResult> similarityStream(
@Name(value = "data", defaultValue = "null") List<Map<String,Object>> data,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

SimilarityComputer<CategoricalInput> computer = (s, t, cutoff) -> s.overlap(cutoff, t);

ProcedureConfiguration configuration = ProcedureConfiguration.create(config);

CategoricalInput[] inputs = prepareCategories(data, getDegreeCutoff(configuration));

return topN(similarityStream(inputs, computer, configuration, getSimilarityCutoff(configuration), getTopK(configuration)), getTopN(configuration));
}

@Procedure(name = "algo.similarity.overlap", mode = Mode.WRITE)
@Description("CALL algo.similarity.overlap([{source:id, targets:[ids]}], {similarityCutoff:-1,degreeCutoff:0}) " +
"YIELD p50, p75, p90, p99, p999, p100 - computes jaccard similarities")
public Stream<SimilaritySummaryResult> overlap(
@Name(value = "data", defaultValue = "null") List<Map<String, Object>> data,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

SimilarityComputer<CategoricalInput> computer = (s,t,cutoff) -> s.overlap(cutoff, t);

ProcedureConfiguration configuration = ProcedureConfiguration.create(config);

CategoricalInput[] inputs = prepareCategories(data, getDegreeCutoff(configuration));

double similarityCutoff = getSimilarityCutoff(configuration);
Stream<SimilarityResult> stream = topN(similarityStream(inputs, computer, configuration, similarityCutoff, getTopK(configuration)), getTopN(configuration));

boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
return writeAndAggregateResults(configuration, stream, inputs.length, write, "NARROWER_THAN");
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static Stream<SimilarityResult> topN(Stream<SimilarityResult> stream, int topN)
if (topN > 10000) {
return stream.sorted(comparator).limit(topN);
}
return topK(stream,topN, comparator);
return topK(stream, topN, comparator);
}

private static <T> void put(BlockingQueue<T> queue, T items) {
Expand All @@ -66,8 +66,8 @@ Long getDegreeCutoff(ProcedureConfiguration configuration) {
return configuration.get("degreeCutoff", 0L);
}

Stream<SimilaritySummaryResult> writeAndAggregateResults(ProcedureConfiguration configuration, Stream<SimilarityResult> stream, int length, boolean write) {
String writeRelationshipType = configuration.get("writeRelationshipType", "SIMILAR");
Stream<SimilaritySummaryResult> writeAndAggregateResults(ProcedureConfiguration configuration, Stream<SimilarityResult> stream, int length, boolean write, String defaultWriteProperty) {
String writeRelationshipType = configuration.get("writeRelationshipType", defaultWriteProperty);
String writeProperty = configuration.getWriteProperty("score");

AtomicLong similarityPairs = new AtomicLong();
Expand All @@ -77,7 +77,7 @@ Stream<SimilaritySummaryResult> writeAndAggregateResults(ProcedureConfiguration
similarityPairs.getAndIncrement();
};

if(write) {
if (write) {
SimilarityExporter similarityExporter = new SimilarityExporter(api, writeRelationshipType, writeProperty);
similarityExporter.export(stream.peek(recorder));
} else {
Expand Down Expand Up @@ -114,17 +114,15 @@ <T> Stream<SimilarityResult> similarityStream(T[] inputs, SimilarityComputer<T>
private <T> Stream<SimilarityResult> similarityStream(T[] inputs, int length, double similiarityCutoff, SimilarityComputer<T> computer) {
return IntStream.range(0, length)
.boxed().flatMap(sourceId -> IntStream.range(sourceId + 1, length)
.mapToObj(targetId -> computer.similarity(inputs[sourceId],inputs[targetId],similiarityCutoff)).filter(Objects::nonNull));
.mapToObj(targetId -> computer.similarity(inputs[sourceId], inputs[targetId], similiarityCutoff)).filter(Objects::nonNull));
}

private <T> Stream<SimilarityResult> similarityStreamTopK(T[] inputs, int length, double cutoff, int topK, SimilarityComputer<T> computer) {
TopKConsumer<SimilarityResult>[] topKHolder = initializeTopKConsumers(length, topK);

for (int sourceId = 0;sourceId < length;sourceId++) {
computeSimilarityForSourceIndex(sourceId, inputs, length, cutoff, (sourceIndex, targetIndex, similarityResult) -> {
topKHolder[sourceIndex].accept(similarityResult);
topKHolder[targetIndex].accept(similarityResult.reverse());
}, computer);
SimilarityConsumer consumer = assignSimilarityPairs(topKHolder);
for (int sourceId = 0; sourceId < length; sourceId++) {
computeSimilarityForSourceIndex(sourceId, inputs, length, cutoff, consumer, computer);
}
return Arrays.stream(topKHolder).flatMap(TopKConsumer::stream);
}
Expand Down Expand Up @@ -176,13 +174,13 @@ private <T> Stream<SimilarityResult> similarityParallelStreamTopK(T[] inputs, in
ParallelUtil.runWithConcurrency(concurrency, tasks, terminationFlag, Pools.DEFAULT);

TopKConsumer<SimilarityResult>[] topKConsumers = initializeTopKConsumers(length, topK);
for (Runnable task : tasks) ((TopKTask)task).mergeInto(topKConsumers);
for (Runnable task : tasks) ((TopKTask) task).mergeInto(topKConsumers);
return Arrays.stream(topKConsumers).flatMap(TopKConsumer::stream);
}

private <T> void computeSimilarityForSourceIndex(int sourceId, T[] inputs, int length, double cutoff, SimilarityConsumer consumer, SimilarityComputer<T> computer) {
for (int targetId=sourceId+1;targetId<length;targetId++) {
SimilarityResult similarity = computer.similarity(inputs[sourceId], inputs[targetId],cutoff);
for (int targetId = sourceId + 1; targetId < length; targetId++) {
SimilarityResult similarity = computer.similarity(inputs[sourceId], inputs[targetId], cutoff);
if (similarity != null) {
consumer.accept(sourceId, targetId, similarity);
}
Expand All @@ -195,11 +193,11 @@ CategoricalInput[] prepareCategories(List<Map<String, Object>> data, long degree
for (Map<String, Object> row : data) {
List<Number> targetIds = extractValues(row.get("categories"));
int size = targetIds.size();
if ( size > degreeCutoff) {
if (size > degreeCutoff) {
long[] targets = new long[size];
int i=0;
int i = 0;
for (Number id : targetIds) {
targets[i++]=id.longValue();
targets[i++] = id.longValue();
}
Arrays.sort(targets);
ids[idx++] = new CategoricalInput((Long) row.get("item"), targets);
Expand All @@ -218,11 +216,11 @@ WeightedInput[] prepareWeights(List<Map<String, Object>> data, long degreeCutoff
List<Number> weightList = extractValues(row.get("weights"));

int size = weightList.size();
if ( size > degreeCutoff) {
if (size > degreeCutoff) {
double[] weights = new double[size];
int i=0;
int i = 0;
for (Number value : weightList) {
weights[i++]=value.doubleValue();
weights[i++] = value.doubleValue();
}
inputs[idx++] = new WeightedInput((Long) row.get("item"), weights);
}
Expand All @@ -233,7 +231,7 @@ WeightedInput[] prepareWeights(List<Map<String, Object>> data, long degreeCutoff
}

private List<Number> extractValues(Object rawValues) {
if(rawValues == null) {
if (rawValues == null) {
return Collections.emptyList();
}

Expand All @@ -259,24 +257,35 @@ protected int getTopK(ProcedureConfiguration configuration) {
}

protected int getTopN(ProcedureConfiguration configuration) {
return configuration.getInt("top",0);
return configuration.getInt("top", 0);
}

interface SimilarityComputer<T> {
SimilarityResult similarity(T source, T target, double cutoff);
}

public static SimilarityConsumer assignSimilarityPairs(TopKConsumer<SimilarityResult>[] topKConsumers) {
return (s, t, result) -> {
topKConsumers[result.reversed ? t : s].accept(result);

if (result.bidirectional) {
SimilarityResult reverse = result.reverse();
topKConsumers[reverse.reversed ? t : s].accept(reverse);
}
};
}

private class TopKTask<T> implements Runnable {
private final int batchSize;
private final int taskOffset;
private final int multiplier;
private final int length;
private final T[] ids;
private final double similiarityCutoff;
private final SimilarityComputer computer;
private final SimilarityComputer<T> computer;
private final TopKConsumer<SimilarityResult>[] topKConsumers;

TopKTask(int batchSize, int taskOffset, int multiplier, int length, T[] ids, double similiarityCutoff, int topK, SimilarityComputer computer) {
TopKTask(int batchSize, int taskOffset, int multiplier, int length, T[] ids, double similiarityCutoff, int topK, SimilarityComputer<T> computer) {
this.batchSize = batchSize;
this.taskOffset = taskOffset;
this.multiplier = multiplier;
Expand All @@ -289,16 +298,17 @@ private class TopKTask<T> implements Runnable {

@Override
public void run() {
SimilarityConsumer consumer = assignSimilarityPairs(topKConsumers);
for (int offset = 0; offset < batchSize; offset++) {
int sourceId = taskOffset * multiplier + offset;
if (sourceId < length) {
computeSimilarityForSourceIndex(sourceId, ids, length, similiarityCutoff, (s, t, result) -> {
topKConsumers[s].accept(result);
topKConsumers[t].accept(result.reverse());
}, computer);

computeSimilarityForSourceIndex(sourceId, ids, length, similiarityCutoff, consumer, computer);
}
}
}


void mergeInto(TopKConsumer<SimilarityResult>[] target) {
for (int i = 0; i < target.length; i++) {
target[i].accept(topKConsumers[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,23 @@ public class SimilarityResult implements Comparable<SimilarityResult> {
public final long count2;
public final long intersection;
public double similarity;
public final boolean bidirectional;
public final boolean reversed;

public static SimilarityResult TOMB = new SimilarityResult(-1, -1, -1, -1, -1, -1);

public SimilarityResult(long item1, long item2, long count1, long count2, long intersection, double similarity) {
public SimilarityResult(long item1, long item2, long count1, long count2, long intersection, double similarity, boolean bidirectional, boolean reversed) {
this.item1 = item1;
this.item2 = item2;
this.count1 = count1;
this.count2 = count2;
this.intersection = intersection;
this.similarity = similarity;
this.bidirectional = bidirectional;
this.reversed = reversed;
}
public SimilarityResult(long item1, long item2, long count1, long count2, long intersection, double similarity) {
this(item1,item2, count1,count2,intersection,similarity, true, false);
}

@Override
Expand Down Expand Up @@ -70,7 +77,7 @@ public int compareTo(SimilarityResult o) {
}

public SimilarityResult reverse() {
return new SimilarityResult(item2, item1,count2,count1,intersection,similarity);
return new SimilarityResult(item2, item1,count2,count1,intersection,similarity,bidirectional,!reversed);
}

public SimilarityResult squareRooted() {
Expand Down
2 changes: 2 additions & 0 deletions doc/asciidoc/algorithms-similarity.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ These algorithms help calculate the similarity of nodes:
* <<algorithms-similarity-jaccard, Jaccard Similarity>> (`algo.similarity.jaccard`)
* <<algorithms-similarity-cosine, Cosine Similarity>> (`algo.similarity.cosine`)
* <<algorithms-similarity-euclidean, Euclidean Distance>> (`algo.similarity.euclidean`)
* <<algorithms-similarity-overlap, Overlap Similarity>> (`algo.similarity.overlap`)

include::similarity-jaccard.adoc[leveloffset=2]
include::similarity-cosine.adoc[leveloffset=2]
include::similarity-euclidean.adoc[leveloffset=2]
include::similarity-overlap.adoc[leveloffset=2]
Loading