In [13]:
import warnings
warnings.filterwarnings('ignore')
from pyspark.sql.types import *
from pyspark.sql.functions import *

import matplotlib.pyplot as plt
import numpy as np

In [11]:
MIN_PTS = 4
EPSILON = 900
X_UNIT, Y_UNIT = 5000, 5000
SPAN_MARGIN = EPSILON/2 + 1

In [29]:
# read
dataset = spark.read.text("/Users/apple/spark/data/project/A-sets/a1.txt")
# for each row, remove the space and split the string into a list
dataset = dataset.select(split(trim(dataset.value), '\\s+').alias('coor'))
dataset = dataset.select([col('coor')[0].cast('int').alias('x'), col('coor')[1].cast('int').alias('y')])

# only keep the first 800 rows
dataset = dataset.limit(800)
print(dataset.schema)
dataset.show(3)

StructType([StructField('x', IntegerType(), True), StructField('y', IntegerType(), True)])
+-----+-----+
|    x|    y|
+-----+-----+
|54620|43523|
|52694|42750|
|53253|43024|
+-----+-----+
only showing top 3 rows



In [30]:
x_min, x_max = dataset.agg(min('x'), max('x')).first()
y_min, y_max = dataset.agg(min('y'), max('y')).first()

x_grid = list(range(x_min, x_max + X_UNIT, X_UNIT))
y_grid = list(range(y_min, y_max + Y_UNIT, Y_UNIT))

In [31]:
area_list = []
area_cnt = 0
for x_idx, x_start in enumerate(x_grid[:-1]):
    for y_idx, y_start in enumerate(y_grid[:-1]):
        x_start_adjusted = x_start - SPAN_MARGIN
        y_start_adjusted = y_start - SPAN_MARGIN
        x_end_adjusted = x_grid[x_idx+1] + SPAN_MARGIN
        y_end_adjusted = y_grid[y_idx+1] + SPAN_MARGIN
        area_list.append([area_cnt, x_start_adjusted, x_end_adjusted, y_start_adjusted, y_end_adjusted])
        area_cnt += 1
assert len(area_list) == (len(x_grid)-1)*(len(y_grid)-1)

# organize the area into a dataframe, and add an area_id (start from 0)
area_df = spark.createDataFrame(area_list, ['area_id', 'x_start', 'x_end', 'y_start', 'y_end'])
# convert all columns to integer
area_df = area_df.select([col(c).cast('int') for c in area_df.columns])

area_df.show(3)

+-------+-------+-----+-------+-----+
|area_id|x_start|x_end|y_start|y_end|
+-------+-------+-----+-------+-----+
|      0|  35425|41327|  38905|44807|
|      1|  35425|41327|  43905|49807|
|      2|  35425|41327|  48905|54807|
+-------+-------+-----+-------+-----+
only showing top 3 rows



In [35]:
# split the dataset according to the area
dataset_with_area = dataset.crossJoin(area_df).filter((col('x') >= col('x_start')) & (col('x') <= col('x_end')) & (col('y') >= col('y_start')) & (col('y') <= col('y_end'))).select('x', 'y', 'area_id')
dataset_with_area.show(3)

+-----+-----+-------+
|    x|    y|area_id|
+-----+-----+-------+
|37900|43700|      0|
|38172|42792|      0|
|38870|44459|      0|
+-----+-----+-------+
only showing top 3 rows



In [None]:
# may be used later
plt.scatter(dataset.select('x').collect(), dataset.select('y').collect(), s=3)