In [None]:
from BiopTort import compute_snake, apply_gaussian_filter
from scipy.interpolate import UnivariateSpline, interp1d
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from PIL import Image
from shapely import LineString
from genepy3d.obj import curves
from shapely.validation import make_valid
import shapely
from shapely.geometry import LineString, Polygon

In [None]:
def filter_points_by_mask(points, mask):
    return np.array([p for p in points if mask[int(p[0]), int(p[1])]])

In [None]:
def interpolate_points2(points):
    # Linear length along the line:
    distance = np.cumsum( np.sqrt(np.sum( np.diff(points, axis=0)**2, axis=1 )) )
    distance = np.insert(distance, 0, 0)/distance[-1]

    alpha = np.linspace(distance.min(), int(distance.max()), len(points) * 10)
    interpolator = interp1d(distance, points, kind='slinear', axis=0)
    interpolated_points = interpolator(alpha)

    out_x = interpolated_points.T[0]
    out_y = interpolated_points.T[1]

    return np.vstack((out_x, out_y)).T

In [None]:
def segments_intersect(s1, s2):
    return LineString(s1).intersects(LineString(s2))

def get_segments_from_points(points):
    segments = []
    for i in range(len(points)-1):
        segments.append(np.array([points[i], points[i+1]]))
    return segments

def get_points_from_segments(segments):
    points = []
    for s in segments:
        points.append(s[0])
    points.append(segments[-1][1])
    return np.array(points)

def remove_loops(points):
    stack = []
    loops = []

    points_copy = np.copy(points)

    segments = get_segments_from_points(points)
    for i, s1 in enumerate(segments):
        for j, s2 in enumerate(segments):
            if segments_intersect(s1, s2): 
                if j > i + 1:   # if the segments are not adjacent we add the first segment to the stack
                    stack.append(i)
                elif j < i - 1:
                    loopstart = stack.pop()
                
                    if len(stack) == 0:
                        loops.append((loopstart, i))

    
    for i in reversed(loops):
        points_copy = np.delete(points_copy, slice(i[0]+1, i[1]+1), axis=0) # not very efficient but it works

    return points_copy

In [None]:
def simplify(points):
    line = LineString(points)
    valid_line = make_valid(line)
    return np.array(valid_line.coords)

In [None]:
fn = "/home/jackson/research/data/tortuosity_study/tortuosity_training_set/_1687394.svs"
mask_fn = "/home/jackson/research/data/tortuosity_study/QC_results/training_set_final_run/usable_tissue_masks/_1687394.svs_mask_use.png"

snake, img, contours, rp, b, num_points = compute_snake(fn, mask_fn)
image_array = np.array(img)


In [None]:
no_loops = remove_loops(snake)


In [None]:
# snake_out = filter_points_by_mask(snake, b)

snake_out = interpolate_points2(snake)

snake_out_final = apply_gaussian_filter(snake_out, 20)


# snake_out_final = snake_out_final[0:-1:10]
# constrain linearity in background regions


In [None]:
curve = curves.Curve(snake_out_final)
curvature = curve.compute_curvature()
log_curvature = np.log(curvature)

# set y axis range to 0 to 10
plt.ylim(0, 10)
plt.xlabel("point index")
plt.ylabel("curvature")
plt.plot(curvature)

In [None]:
threshed = np.log(curvature)


In [None]:
no_loops

In [None]:
scatter_plot = go.Scatter(
    x=no_loops[:, 1],
    y=no_loops[:, 0],
    mode='markers+lines',
    marker=dict(color='yellow', 
                size=7, 
                colorbar=dict(title='Curvature (log scale)'), # Add a colorbar for curvature
                colorscale='viridis')
)

scatter_plot3 = go.Scatter(
    x=snake[:, 1],
    y=snake[:, 0],
    mode='markers+lines',
    marker=dict(color="blue", 
                size=7, 
                colorbar=dict(title='Curvature (log scale)'), # Add a colorbar for curvature
                colorscale='viridis')
)



# Create a figure
fig = go.Figure()

# Add image as background
fig.add_trace(go.Image(z=image_array))

# Add scatter plot on top of the image
fig.add_trace(scatter_plot)

fig.add_trace(scatter_plot3)

# Set layout options
fig.update_layout(
    xaxis=dict(showgrid=False, zeroline=False),
    yaxis=dict(showgrid=False, zeroline=False),
    width=image_array.shape[1],
    height=image_array.shape[0],
    margin=dict(l=0, r=0, t=0, b=0)
)

# Show figure
fig.show()