In [1]:
# 基于空间索引 给时序数据area打标
import findspark
findspark.init('/opt/spark')
from pyspark.sql import SparkSession
# https://sedona.apache.org/1.3.0-incubating/archive/tutorial/geospark-sql-python/
from geospark.register import upload_jars
from geospark.register import GeoSparkRegistrator
app_name = "dwd_to_dwm"

# GeoSpark has a suite of well-written geometry and index serializers. 
# Forgetting to enable these serializers will lead to high memory consumption. (序列化器优化内存使用)
# https://stackoverflow.com/questions/65213369/unable-to-configure-geospark-in-spark-session
upload_jars()

spark = SparkSession.builder \
    .appName(app_name) \
    .enableHiveSupport() \
    .config("spark.executor.memory", "40g") \
    .config("spark.driver.memory", "40g") \
    .config("spark.driver.maxResultSize","4g")\
    .getOrCreate()

# GeoSparkRegistrator.registerAll(spark)




# 打印集群信息
print("Spark 集群名称: ", spark.conf.get("spark.app.name"))
print("Spark 集群版本: ", spark.version)
print("Spark ID: ", spark.sparkContext.applicationId)

print("Spark 集群节点数: ", spark.sparkContext._jsc.sc().getExecutorMemoryStatus().keySet())
print("每个 Executor 的内存容量: ", spark.conf.get("spark.executor.memory"))

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


Spark 集群名称:  dwd_to_dwm
Spark 集群版本:  3.3.3
Spark ID:  application_1710158715244_0002
Spark 集群节点数:  Set(worker1:44749, jupyter:36873, worker2:45849)
每个 Executor 的内存容量:  40g


In [2]:
# create database time_series
spark.sql("DROP TABLE IF EXISTS dwm_time_series;")

sql = '''
CREATE TABLE IF NOT EXISTS dwm_time_series
         (id bigint,
          time TIMESTAMP,
          lon float,
          lat float)
PARTITIONED BY(area int)
STORED AS PARQUET;
'''

# create database job
spark.sql("DROP TABLE IF EXISTS dwm_job;")
sql = """

CREATE TABLE IF NOT EXISTS dwm_job
(id bigint,start_time TIMESTAMP,end_time TIMESTAMP)
PARTITIONED BY(yearMonth string)
STORED AS PARQUET;

"""
spark.sql(sql)

# 创建输出表
#CREATE TABLE output(id SERIAL PRIMARY KEY,job_id int4);

print("初始化 dwm_time_series，dwm_job成功")

初始化 dwm_time_series，dwm_job成功


In [3]:
# 加载job 表=> 

sql = '''
INSERT INTO dwm_job partition(yearMonth)
SELECT
 id as id,
 start_time as start_time,
 end_time as end_time,
 yearMonth as yearMonth
 FROM dwd_job;
'''
result = spark.sql(sql)

# 查看一条数据
result = spark.sql("SELECT * FROM dwm_job LIMIT 1")
result.show()

                                                                                

+-----+-------------------+-------------------+---------+
|   id|         start_time|           end_time|yearMonth|
+-----+-------------------+-------------------+---------+
|13136|2025-07-01 00:00:00|2025-07-01 01:00:00|   202507|
+-----+-------------------+-------------------+---------+



In [4]:
# 查看时序数据总量
result = spark.sql("SELECT count(*) FROM dwd_time_series")
result.show()



+---------+
| count(1)|
+---------+
|315359976|
+---------+



                                                                                

In [5]:
sql = '''

SELECT id,time,lon,lat
FROM (
	SELECT id,time,lon, lat
	FROM dwd_time_series as time_series
	WHERE EXISTS (
		SELECT *
		FROM dwm_job as job
		WHERE job.id = 1
			AND time_series.time BETWEEN job.start_time AND job.end_time
	)
) dwd_time_series;

'''
df_input = spark.sql(sql)
df_input.show()



+---+-------------------+---------+---------+
| id|               time|      lon|      lat|
+---+-------------------+---------+---------+
|719|2023-12-31 17:00:00| 45.00023| 46.00023|
|720|2023-12-31 17:00:05| 45.00023| 46.00023|
|721|2023-12-31 17:00:10| 45.00023| 46.00023|
|722|2023-12-31 17:00:15| 45.00023| 46.00023|
|723|2023-12-31 17:00:20| 45.00023| 46.00023|
|724|2023-12-31 17:00:25| 45.00023| 46.00023|
|725|2023-12-31 17:00:30| 45.00023| 46.00023|
|726|2023-12-31 17:00:35| 45.00023| 46.00023|
|727|2023-12-31 17:00:40| 45.00023| 46.00023|
|728|2023-12-31 17:00:45|45.000233|46.000233|
|729|2023-12-31 17:00:50|45.000233|46.000233|
|730|2023-12-31 17:00:55|45.000233|46.000233|
|731|2023-12-31 17:01:00|45.000233|46.000233|
|732|2023-12-31 17:01:05|45.000233|46.000233|
|733|2023-12-31 17:01:10|45.000233|46.000233|
|734|2023-12-31 17:01:15|45.000233|46.000233|
|735|2023-12-31 17:01:20|45.000233|46.000233|
|736|2023-12-31 17:01:25|45.000233|46.000233|
|737|2023-12-31 17:01:30|45.000233



In [6]:
from pyspark.sql.types import StructType, StructField, LongType, TimestampType, FloatType, IntegerType

# 定义模式（schema）
schema = StructType([
    StructField("id", LongType(), nullable=False),
    StructField("time", TimestampType(), nullable=False),
    StructField("lon", FloatType(), nullable=False),
    StructField("lat", FloatType(), nullable=False)
])

# 添加分区列
schema.add(StructField("area", IntegerType(), nullable=False))

# 打印模式（schema）
print(schema)

StructType([StructField('id', LongType(), False), StructField('time', TimestampType(), False), StructField('lon', FloatType(), False), StructField('lat', FloatType(), False), StructField('area', IntegerType(), False)])


In [12]:
from pyspark.sql import Row

import rtree
# 创建R树索引
idx = rtree.index.Index()

def build_rtree_idx():
    id_list =  df_input.select("id").rdd.flatMap(lambda x: x).collect()
    lon_list = df_input.select("lon").rdd.flatMap(lambda x: x).collect()
    lat_list = df_input.select("lat").rdd.flatMap(lambda x: x).collect()
    
    for id,lon,lat in zip(id_list,lon_list,lat_list):
        idx.insert(id,lon,lat,lon,lat)
    
def mapper(rows):
    region_numbers={}
    ResRow = Row()
    # 针对一个part的数据做map
    for row in rows:
        area = 0
        lon = row.lon
        lat = row.lat
        
        nearest_points = list(idx.nearest(query_point, 1))  # 从R树索引中获取离查询点最近的1个点的ID
        yield Row(id= row.id, time=row.time, lon=row.lon, lat=row.lat,area = area)
        

build_rtree_idx()


#df_output = df_input.rdd.repartition(30).mapPartitions(mapper).toDF()


                                                                                

In [25]:
query_point = (2, 2)  # 示例查询点的坐标

# 使用R树索引查找最近点
nearest_points = list(idx.nearest(query_point, 1))  # 从R树索引中获取离查询点最近的1个点的ID
idx

rtree.index.Index(bounds=[1.7976931348623157e+308, 1.7976931348623157e+308, -1.7976931348623157e+308, -1.7976931348623157e+308], size=0)

In [14]:
df_output.show()

+----+-------------------+------------------+------------------+----+
|  id|               time|               lon|               lat|area|
+----+-------------------+------------------+------------------+----+
| 979|2023-12-31 17:21:40|45.000308990478516|46.000308990478516|   0|
| 980|2023-12-31 17:21:45|45.000308990478516|46.000308990478516|   0|
| 981|2023-12-31 17:21:50| 45.00031280517578| 46.00031280517578|   0|
| 982|2023-12-31 17:21:55| 45.00031280517578| 46.00031280517578|   0|
| 983|2023-12-31 17:22:00| 45.00031280517578| 46.00031280517578|   0|
| 984|2023-12-31 17:22:05| 45.00031280517578| 46.00031280517578|   0|
| 985|2023-12-31 17:22:10| 45.00031280517578| 46.00031280517578|   0|
| 986|2023-12-31 17:22:15| 45.00031280517578| 46.00031280517578|   0|
| 987|2023-12-31 17:22:20| 45.00031280517578| 46.00031280517578|   0|
| 988|2023-12-31 17:22:25| 45.00031280517578| 46.00031280517578|   0|
|1279|2023-12-31 17:46:40|45.000404357910156|46.000404357910156|   0|
|1280|2023-12-31 17: