# 轨迹生成算法

In [3]:
import os, os.path, tqdm
from rtree import index
import shapefile
from shapely.geometry import LineString
import numpy as np
from anytree import AnyNode, RenderTree, PreOrderIter
from anytree.search import findall_by_attr

In [22]:
from_path = 'E://Netease//智慧城市//4.处理数据/workingdata4/'
to_path = 'E://Netease//智慧城市//4.处理数据/workingdata5/'

## 读取文件

In [23]:
fi = shapefile.Reader(from_path+"track_山地骑行_clean.shp")
records = fi.records()
fields = fi.fields
geoms = fi.shapes()
fi.close()

In [24]:
# 注意是米制坐标系
offset = 10
spatial_index = index.Index()
for gi, geom in enumerate(geoms):
    left, bottom, right, top = geom.bbox
    spatial_index.insert(gi, (left - offset, bottom - offset, right + offset, top + offset))

## 构建邻接表

In [None]:
# 构建邻接表
adjacent = np.zeros((len(geoms), len(geoms)), dtype=np.int32)
for gi, geom in enumerate(geoms):
    line_base = LineString(geom.points)
    features_nearby = list(spatial_index.intersection(geom.bbox))
    for fidx in features_nearby:
        line = LineString(geoms[fidx].points)
        if line_base.intersects(line):
            adjacent[gi, fidx] = 1

## 获取典型线段

In [None]:
# 广度优先遍历 符合要求的路径
# 筛选条件：slope_mean大于39
tolerence = 8  # 允许的连续不满足条件的元素数
total_length_tolerence = 3000
target_field = "slope_mean"
length_field = "length_n"
nodes_set = set()
target_field_records = np.array([float(record[target_field]) if record[target_field] else 0.0 for record in records])
roots_idx = np.where((target_field_records >= 39))[0].tolist()
progress = tqdm.tqdm(total=len(roots_idx))

In [None]:
for root_idx in roots_idx:
# for root_idx in [1367]:
    if root_idx in nodes_set:
        continue
    root = AnyNode(name=str(root_idx) + "_0")
    cdts = [root]

    while len(cdts) > 0:
        current_node_idx = int(cdts[0].name.split("_")[0])
        adjacent_lines_idx = np.nonzero(adjacent[current_node_idx, :])[0].tolist()
        nodes_set.add(current_node_idx)
        for adjacent_line_idx in adjacent_lines_idx:
            if not records[adjacent_line_idx][target_field]:
                continue
            # flag = float(records[adjacent_line_idx][target_field]) < 20 or float(records[adjacent_line_idx][target_field]) >= 35
            flag = not float(records[adjacent_line_idx][target_field]) >= 39
            node_name = "%d_%d" % (adjacent_line_idx, flag)
            # current_path_exception = sum([int(node.name.split("_")[1]) for node in cdts[0].path]) + int(flag)
            current_path_exception = sum([int(node.name.split("_")[1]) for node in list(cdts[0].iter_path_reverse())[:tolerence]]) + int(flag)
            if len(list(findall_by_attr(root, node_name))) < 1 and current_path_exception < tolerence and adjacent_line_idx not in nodes_set:
                adjacent_line = AnyNode(name=node_name, parent=cdts[0])
                nodes_set.add(adjacent_line_idx)
                cdts.append(adjacent_line)
        del cdts[0]

    # print(RenderTree(root))
    nodes = set()
    dfs_paths = [leaf.iter_path_reverse() for leaf in root.leaves]
    for path in dfs_paths:
        flag = False
        for node in path:
            if node.name.split("_")[1] == "0":
                flag = True
            if flag:
                nodes.add(node)

    total_length = sum([records[int(node.name.split("_")[0])][length_field] for node in list(nodes)])
    if total_length < total_length_tolerence:
        continue

    fo = shapefile.Writer(os.path.join(to_path, "%d.shp" % (root_idx)))
    fo.autoBalance = 1
    for field in fields[1:]:
        fo.field(*field)
    fo.field("RouteID", 'N')
    fo.field("FID_Copy", 'N')
    for node in list(nodes):
        feature_idx = int(node.name.split("_")[0])
        record = records[feature_idx] + [root_idx, feature_idx]
        fo.line([geoms[feature_idx].points])
        fo.record(*record)
    fo.close()
    progress.update(1)