Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an initial interface for nearest neighbor queries with a simple implementation #213

Merged
merged 3 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions Math/src/main/java/org/tribuo/math/distance/DistanceType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.math.distance;

import org.tribuo.math.la.SGDVector;

/**
* The available distance functions.
*/
public enum DistanceType {
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
/**
* L1 (or Manhattan) distance.
*/
L1,
/**
* L2 (or Euclidean) distance.
*/
L2,
/**
* Cosine similarity used as a distance measure.
*/
COSINE;

/**
* Calculates the distance between two vectors.
*
* @param vector1 A {@link SGDVector} representing a data point.
* @param vector2 A {@link SGDVector} representing a second data point.
* @param distanceType The {@link DistanceType} function to employ.
* @return A double representing the distance between the two points.
*/
public static double getDistance(SGDVector vector1, SGDVector vector2, DistanceType distanceType) {
double distance;
switch (distanceType) {
case L1:
distance = vector1.l1Distance(vector2);
break;
case L2:
distance = vector1.l2Distance(vector2);
break;
case COSINE:
distance = vector1.cosineDistance(vector2);
break;
default:
throw new IllegalStateException("Unknown distance " + distanceType);
}
return distance;
}
}
20 changes: 20 additions & 0 deletions Math/src/main/java/org/tribuo/math/distance/package-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* Provides the set of distance functions which can be used that are supported by {@link org.tribuo.math.la.SGDVector}.
*/
package org.tribuo.math.distance;
56 changes: 56 additions & 0 deletions Math/src/main/java/org/tribuo/math/neighbour/NeighboursQuery.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.math.neighbour;

import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.math.la.SGDVector;

import java.util.List;

/**
* An interface for nearest neighbour query objects.
*/
public interface NeighboursQuery {

/**
* Queries a set of {@link SGDVector}s to determine the k points nearest to the provided point.
* @param point The point to determine the nearest k points for.
* @param k The number of neighbouring points to identify.
* @return A list of k {@link Pair}s, where a pair contains the index of the neighbouring point in the original
* data and the distance between this point and the provided point.
*/
List<Pair<Integer, Double>> query(SGDVector point, int k);

/**
* Queries a set of {@link SGDVector}s to determine the k points nearest to the provided points.
* @param points An array of points to determine the nearest k points for.
* @param k The number of neighbouring points to identify.
* @return An list containing lists of k {@link Pair}s. There is list entry for each provided point which is a
* list of k pairs. Each pair contains the index of the neighbouring point in the original data and the
* distance between this point and the provided point.
*/
List<List<Pair<Integer, Double>>> query(SGDVector[] points, int k);

/**
* Queries a set of {@link SGDVector}s to determine the k points nearest to every point in the set.
* @param k The number of neighbouring points to identify.
* @return A list containing lists of k {@link Pair}s. There is list entry for each provided point which is a
* list of k pairs. Each pair contains the index of the neighbouring point in the original data and the
* distance between this point and the provided point.
*/
List<List<Pair<Integer, Double>>> queryAll(int k);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.math.neighbour;

import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.math.la.SGDVector;

/**
* An interface for factories which create nearest neighbour query objects.
*/
public interface NeighboursQueryFactory extends Configurable, Provenancable<ConfiguredObjectProvenance> {

/**
* Constructs a nearest neighbour query object using the supplied array of {@link SGDVector}.
* @param data An array of {@link SGDVector}.
*/
NeighboursQuery createNeighboursQuery(SGDVector[] data);

@Override
default ConfiguredObjectProvenance getProvenance() {
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
return new ConfiguredObjectProvenanceImpl(this,"NeighboursQueryFactory");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.math.neighbour.bruteforce;

import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.neighbour.NeighboursQuery;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/**
* A brute-force nearest neighbour query implementation.
*/
public final class NeighboursBruteForce implements NeighboursQuery {

private final SGDVector[] data;
private final DistanceType distanceType;
private final int numThreads;

/**
* Constructs a brute-force nearest neighbour query object using the supplied parameters.
* @param data the data that will be used for neighbour queries.
* @param distanceType The distance function.
* @param numThreads The number of threads to be used to parallelize the computation.
*/
NeighboursBruteForce(SGDVector[] data, DistanceType distanceType, int numThreads) {
this.data = data;
this.distanceType = distanceType;
this.numThreads = numThreads;
}

@Override
public List<Pair<Integer, Double>> query(SGDVector point, int k) {
PriorityQueue<MutablePair> queue = new PriorityQueue<>(k);

for (int neighbor = 0; neighbor < data.length; neighbor++) {
double distance = DistanceType.getDistance(point, data[neighbor], distanceType);
if (queue.size() < k) {
MutablePair newPair = new MutablePair(neighbor, distance);
queue.offer(newPair);
} else if (Double.compare(distance, queue.peek().value) < 0) {
MutablePair pair = queue.poll();
pair.index = neighbor;
pair.value = distance;
queue.offer(pair);
}
}

@SuppressWarnings("unchecked")
Pair<Integer, Double>[] indexDistanceArr = (Pair<Integer, Double>[]) new Pair[k];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an ArrayList too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left this as an array for now too, for the same reasons mentioned above.

int i = 1;
// Use an array to put the polled items from the queue into a sorted ascending order, by distance.
while (!queue.isEmpty()) {
MutablePair mutablePair = queue.poll();
indexDistanceArr[k - i] = new Pair<>(mutablePair.index, mutablePair.value);
i++;
}
return new ArrayList<>(Arrays.asList(indexDistanceArr));
}

@Override
public List<List<Pair<Integer, Double>>> query(SGDVector[] points, int k) {
int numQueries = points.length;

@SuppressWarnings("unchecked")
List<Pair<Integer, Double>>[] indexDistancePairListArray = (List<Pair<Integer, Double>>[]) new ArrayList[numQueries];
Craigacp marked this conversation as resolved.
Show resolved Hide resolved

// When the number of threads is 1, the overhead of thread pools must be avoided
if (numThreads == 1) {
for (int point = 0; point < numQueries; point++) {
indexDistancePairListArray[point] = query(points[point], k);
}
} else { // This makes the nearest neighbor queries with multiple threads
ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
for (int pointInd = 0; pointInd < numQueries; pointInd++) {
executorService.execute(new SingleQueryRunnable(pointInd, points[pointInd], k, indexDistancePairListArray));
}
executorService.shutdown();
try {
boolean finished = executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES);
if (!finished) {
throw new RuntimeException("Parallel execution failed");
}
} catch (InterruptedException e) {
throw new RuntimeException("Parallel execution failed", e);
}
}
return new ArrayList<>(Arrays.asList(indexDistancePairListArray));
}

@Override
public List<List<Pair<Integer, Double>>> queryAll(int k) {
return this.query(this.data, k);
}

/**
* This is a specific mutable pair used for an internal queue to reduce object creation. Note that this class's
* ordering is not consistent with its equals method. Furthermore, the ordering of this class is the inverse of the
* natural ordering on doubles.
*/
private static final class MutablePair implements Comparable<MutablePair> {
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
int index;
double value;

public MutablePair(int index, double value) {
this.index = index;
this.value = value;
}

@Override
public int compareTo(MutablePair o) {
// Pass the provided value as the first param to give an inverse natural ordering.
return Double.compare(o.value, this.value);
}
}

/**
* A Runnable implementation to make a brute-force nearest neighbour query for parallelization of large numbers
* of queries. To be used with an {@link ExecutorService}
*/
private final class SingleQueryRunnable implements Runnable {

final private SGDVector point;
final private int k;
final private int index;
final List<Pair<Integer, Double>>[] indexDistancePairListArray;

SingleQueryRunnable(int index, SGDVector point, int k, List<Pair<Integer, Double>>[] indexDistancePairListArray) {
this.point = point;
this.k = k;
this.index = index;
this.indexDistancePairListArray = indexDistancePairListArray;
}

@Override
public void run() {
indexDistancePairListArray[index] = query(point, k);
}
}
}
Loading