In [1]:
import csv

In [2]:
def sq_sum(x):
  return sum((y ** 2 for y in x))

def sq_dist(x, y):
  return sq_sum((z[0] - z[1] for z in zip(x, y)))

In [3]:
query_file = "query_points.csv"  # name of query file
#query_file = str(sys.argv[1])                  # for batch mode
with open(query_file, newline='') as csvfile:
  reader = csv.reader(csvfile, delimiter=',')
  queries = []
  q_id = 0
  for q in reader:                              # fill list with points in file
    queries.append([q_id] + [float(x) for x in q[1:]])
    q_id += 1

In [4]:
input_file = sc.textFile('spline_boundary_enum.csv')

We separate each string in the RDD into the components of a triple, converting to numbers

In [5]:
patterns = input_file.map(lambda line: [int(line.split(',')[0])] + [float(x) for x in line.split(',')[1:-1]] + [int(line.split(',')[-1])])

Generate distance values for each pattern and query point, placing the query id in first position as the key for a subsequent reduction.  The value is the pair pattern id, distance, and the class label, which will be appended to the query point by the reducing function if the pattern is the closest pattern to it.  Note that the query point distances to a pattern are created as a list of lists by the lambda function, and would thus be collected as a list of lists of lists in the output RDD.  To eliminate one level, use flatMap.

In [7]:
distances_queries = patterns.flatMap(lambda pattern: [[q[0], [pattern[0], sq_dist(pattern[1:-1], q[1:]), pattern[ -1]]] for q in queries])
print(type(distances_queries))
distances_queries.collect()

<class 'pyspark.rdd.PipelinedRDD'>


[[0, [0, 0.01028982676495123, 1]],
 [1, [0, 0.0011004363446421217, 1]],
 [2, [0, 0.00010523310899985434, 1]],
 [3, [0, 0.04338228948312224, 1]],
 [4, [0, 0.6082835294540995, 1]],
 [5, [0, 0.00916797377231797, 1]],
 [6, [0, 0.17008398753476348, 1]],
 [7, [0, 0.24111208137193715, 1]],
 [8, [0, 0.10193446179908194, 1]],
 [9, [0, 0.14068090801172561, 1]],
 [10, [0, 0.07372821695300241, 1]],
 [11, [0, 0.010878970604162022, 1]],
 [12, [0, 0.8863740413753523, 1]],
 [13, [0, 0.1231092296803019, 1]],
 [14, [0, 0.26502366218247575, 1]],
 [15, [0, 0.3454391451846368, 1]],
 [16, [0, 0.023644680204070064, 1]],
 [17, [0, 0.41885545185768697, 1]],
 [18, [0, 0.08019272425041989, 1]],
 [19, [0, 0.5662158283393409, 1]],
 [20, [0, 0.3926444285802357, 1]],
 [21, [0, 0.08258537122655074, 1]],
 [22, [0, 0.7780487926294932, 1]],
 [23, [0, 0.04467084166422701, 1]],
 [24, [0, 0.001017607882980364, 1]],
 [25, [0, 0.4928279477150911, 1]],
 [26, [0, 0.42103039069518644, 1]],
 [27, [0, 0.18278543208427597, 1]],
 [

Compute the closest object for each query id by reducing through the associative and commutative min function over distances.  Specify the order in the key argument as the second element of each item.  Note that the argument function in reduceByKey expects to be applied to values.  A collection of pairs (query, predicted class) could be easily computed by another map.

In [8]:
nearest = distances_queries.reduceByKey(lambda a, b: min(a, b, key=lambda k: k[1]))
print(type(nearest))
print(nearest)
nearest.collect()

<class 'pyspark.rdd.PipelinedRDD'>
PythonRDD[8] at RDD at PythonRDD.scala:53


[Stage 2:>                                                          (0 + 2) / 2]                                                                                

[(0, [982, 5.51249888612204e-09, 1]),
 (2, [627, 1.5056920077360826e-09, 1]),
 (4, [64, 3.9280915403147616e-08, 1]),
 (6, [56, 9.276161833196136e-08, 0]),
 (8, [70, 7.878456787875636e-07, 0]),
 (10, [13, 1.1264624125804748e-07, 0]),
 (12, [869, 4.800979420405994e-07, 1]),
 (14, [948, 7.414659240871708e-07, 1]),
 (16, [408, 1.4215392949415822e-07, 1]),
 (18, [819, 9.036508096969077e-08, 0]),
 (20, [561, 3.593053852820564e-07, 1]),
 (22, [380, 7.479881368887182e-08, 1]),
 (24, [633, 1.1677115131000437e-06, 1]),
 (26, [811, 2.679620622404119e-06, 0]),
 (28, [86, 2.9039834811883243e-09, 1]),
 (30, [242, 1.379118938405464e-07, 0]),
 (32, [881, 1.5711300667832944e-09, 0]),
 (34, [520, 2.8620339173949336e-08, 1]),
 (36, [294, 1.83623123316933e-12, 1]),
 (38, [274, 4.545622302967417e-08, 0]),
 (40, [791, 4.128529849811078e-07, 1]),
 (42, [267, 1.8292220509025994e-07, 1]),
 (44, [508, 3.329468310974441e-09, 1]),
 (46, [623, 2.925163886370152e-09, 1]),
 (48, [486, 3.1274277106281615e-09, 1]),
 (