In [1]:
import tval
from tval import trace_validator, trace_builder
import geopandas as gpd
import shapely
from shapely.geometry import LineString, Point, MultiPoint
from shapely.ops import snap, split
import shapely
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt



In [2]:
snap_threshold = 0.002

In [3]:
trace_validator.BaseValidator.set_snap_threshold_and_multipliers(0.001, 1.1, 1.1)

In [4]:
geoms = [
    LineString([(-1, 0), (1, 0)]),
    LineString([(0, -1), (0, 1)]),
    LineString([(-1.01, -1), (-1.01, 1)]),
]
gdf = gpd.GeoDataFrame({"geometry": geoms})

In [5]:
gdf = gpd.read_file("dev/Inkoo_traces.gpkg", layer="Inkoo_Infinity_Lineaments_traces")
valid_geoseries = trace_builder.main()[0]
#gdf = gpd.GeoDataFrame({"geometry": valid_geoseries})

In [6]:
gdf.geometry.length.mean()

1537.4799286285645

In [7]:
nodes, _ = trace_validator.BaseValidator.determine_nodes(gdf)

In [8]:
geosrs = gpd.GeoSeries(nodes)
Point(1,1).buffer(0.01).bounds

(0.99, 0.99, 1.01, 1.01)

In [9]:
def remove_identical(geosrs):
    geosrs.reset_index(inplace=True, drop=True)

    marked_for_death = []
    for idx, p in enumerate(geosrs):
        if idx in marked_for_death:
            continue
        
        inter = geosrs.drop(idx).intersection(p.buffer(snap_threshold))
        colliding = inter.loc[[True if not i.is_empty else False for i in inter]]
        if len(colliding) > 0:
            index_to_list = colliding.index.to_list()
            assert len(index_to_list) > 0
            assert all([isinstance(i, int) for i in index_to_list])
            marked_for_death.extend(index_to_list)
    return geosrs.drop(marked_for_death)

In [10]:
def remove_identical_sindex(geosrs):
    geosrs.reset_index(inplace=True, drop=True)
    spatial_index = geosrs.sindex
    marked_for_death = []
    for idx, p in enumerate(geosrs):
        if idx in marked_for_death:
            continue
        p_candidate_idxs = list(spatial_index.intersection(p.buffer(snap_threshold).bounds))
        p_candidate_idxs.remove(idx)
        p_candidates = geosrs.iloc[p_candidate_idxs]
        inter = p_candidates.intersects(p.buffer(snap_threshold))
        colliding = inter.loc[inter]
        if len(colliding) > 0:
            index_to_list = colliding.index.to_list()
            assert len(index_to_list) > 0
            assert all([isinstance(i, int) for i in index_to_list])
            marked_for_death.extend(index_to_list)
    return geosrs.drop(marked_for_death)

In [11]:
def get_node_identities(geosrs, traces):
    identities = []
    for i, p in enumerate(geosrs):
        inter_with_traces = traces.intersection(p.buffer(snap_threshold))
        # If theres 2 intersections -> X or Y
        # 1 (must be) -> I
        # Point + LineString -> Y
        # LineString + Linestring -> X or Y
        inter_with_traces_geoms = [iwt for iwt in inter_with_traces if not iwt.is_empty]
        if len(inter_with_traces_geoms) >= 3:
            print(inter_with_traces_geoms)
        assert len(inter_with_traces_geoms) < 3
        assert len(inter_with_traces_geoms) > 0

        if len(inter_with_traces_geoms) == 1:
            identities.append("I")
            continue
        if any([isinstance(iwt, shapely.geometry.Point) for iwt in inter_with_traces]):
            identities.append("Y")
            continue
        # print(inter_with_traces_geoms)
        #assert len(inter_with_traces_geoms) == 2

        all_inter_endpoints = [pt for sublist in map(trace_validator.BaseValidator.get_trace_endpoints, inter_with_traces_geoms) for pt in sublist]
        if any([p.intersects(ep) for ep in all_inter_endpoints]):
            # Y-node
            identities.append("Y")
            continue
        else:
            identities.append("X")
            continue
    return identities
        

In [12]:
def find_y_nodes_and_snap_em(traces, nodes, node_ids):
    snapped = []
    assert len(nodes) == len(node_ids)
    for n, n_id in zip(nodes, node_ids):
        if n_id == "X" or n_id == "I":
            # print(n, n_id)
            snapped.append(n)
            continue
        #elif n.distance(t) > 0.02:
        distances = np.array([t.distance(n) for t in traces])
        distances_less_than = distances < snap_threshold
        if any(distances_less_than):
            assert len(distances[distances_less_than]) < 3
            #print(distances[distances_less_than])
            traces_nearby = traces.loc[distances_less_than]
            assert any([tn.intersects(n) for tn in traces_nearby])
            #print(traces_nearby, n, n_id)
            # There can be two nearby <- perfectly snapped Y-node
            assert len([tn for tn in traces_nearby if tn.intersects(n)]) < 3
            assert len(traces_nearby) != 0
            if all([tn.intersects(n) for tn in traces_nearby]):
                # Both traces in a Y-node intersect -> no snapping required
                snapped.append(n)
            for tn in traces_nearby:
                if not tn.intersects(n):
                    #print(f"Doesnt intersect {n}")
                    inter_projected = tn.interpolate(tn.project(n))
                    #print(n, n_id)
                    snapped.append(inter_projected)
                    
                    

            
        else:
            #print(n, n_id)
            snapped.append(n)
    assert len(nodes) == len(snapped)
    return snapped
            
            
            

In [13]:
def split_to_branches(traces, nodes, node_identifiers) -> list:
    #print(len(traces))
    completed = []
    new_traces = traces.copy()
    for idx, trace in enumerate(traces):
        #print(trace)
        for node, node_id in zip(nodes, node_identifiers):

            if trace.distance(node) > 5 * snap_threshold:
                # If node has nothing to do with trace, do nothing
                continue
            
            if node_id == "I":
                # I-nodes do not split
                continue
            elif node_id == "X" or node_id == "Y":
                #print(node, trace, node_id)
                endpoints = trace_validator.BaseValidator.get_trace_endpoints(trace)
                if any([ep.buffer(snap_threshold).intersects(node) for ep in endpoints]):
                    #print(node, trace, node_id)
                    # Trace has already been split by this node or it's an endpoint if the abutting trace in a Y-node
                    continue
                
                if node.intersects(trace):
                    #print("split!", idx)
                    # If they already match -> Just split
                    split_parts = split(trace, node)
                    split_parts = [sp for sp in split_parts]
                    #print(node, trace, split_parts)
                    #assert idx not in to_remove
                    #print("Popping", idx, trace.wkt, node.wkt)
                    #print([ep.wkt for ep in endpoints])
                    new_traces.pop(idx)
                    completed.append(idx)
                    #print([sp.wkt for sp in split_parts], trace)
                    new_traces.extend(split_to_branches(split_parts, nodes, node_identifiers))
                    break
                else:
                    raise Exception(f"No intersection. {node.wkt, trace.wkt}")
    #print("Removing")
    #print(traces)

    #assert len(removed) == len(to_remove)
    #traces = [trace for trace in traces if trace not in should_be_removed]
    #print("Removed")
    #print(traces)
    return new_traces
    

In [14]:
#geosrs = remove_identical(geosrs)
geosrs_sindex = remove_identical_sindex(geosrs)
#assert len(geosrs) == len(geosrs_sindex)
#print(geosrs, geosrs_sindex)

In [None]:
#geosrs = remove_identical(geosrs)
#geosrs_sindex = remove_identical_sindex(geosrs)
geosrs = geosrs_sindex

print("rem ident")
node_identifiers = get_node_identities(geosrs, gdf.geometry)
print("nodes identified")
geosrs = find_y_nodes_and_snap_em(gdf.geometry, geosrs, node_identifiers)
print("found y nodes")
splited = [split(geom, MultiPoint([p for p in geosrs])) for geom in gdf.geometry]
print("splited")
splited = gpd.GeoSeries([g for subgroup in splited for g in subgroup])

rem ident


In [None]:
gpd.GeoSeries(geosrs).to_file("dev/nodes.gpkg", driver="GPKG")
splited.to_file("dev/splited.gpkg", driver="GPKG")

In [None]:
branches = gpd.GeoSeries(branches)

In [None]:
color_dict = {"X": "red", "Y": "blue", "I": "purple"}
fig, ax = plt.subplots(figsize=(9,9))
geosrs.plot(ax=ax, color=[color_dict[n] for n in node_identifiers])
gdf.plot(ax=ax)
branches.plot(ax=ax, linestyle="--", linewidth=2, color="brown")

In [None]:
len(branches)

In [None]:
gdf.geometry.intersects(Point(1,1))

In [None]:
gdf.geometry.intersection(Point(1,1))