RANSAC

In [None]:
from sklearn.linear_model import LinearRegression, RANSACRegressor
import numpy as np

def fit_line_with_ransac(points, threshold, max_iterations):
    X = points[:, 0].reshape(-1, 1)
    y = points[:, 1]

    model = RANSACRegressor(LinearRegression(), residual_threshold=threshold, max_trials=max_iterations)
    model.fit(X, y)

    inliers_mask = model.inlier_mask_
    inliers = points[inliers_mask]

    return model.estimator_.coef_, inliers

# Example usage
# Replace this with your own data points
data_points = [[333.01822, 218.37819],
       [307.7363 , 230.35153],
       [279.46   , 237.551  ],
       [392.03934, 309.27982],
       [378.45688, 294.51904]]
data_points = np.array(data_points)
threshold = 0.1  # Adjust this threshold based on your data and requirements
max_iterations = 1000

# plot_scatter(data_points)

best_line, inliers = fit_line_with_ransac(data_points, threshold, max_iterations)
print("Best-fit line:", best_line)
print("Number of inliers:", len(inliers))

plot_scatter(inliers)

In [None]:
# Quantitative analysis of DLC performance

p_like = XY = df_features['features', 'p'].values
XY = df_features['features', 'xy'].values

N = len(p_like)

err_list = []

for i in range(N):
    p_el = p_like[i]

    feat = XY[i]

    # plot_scatter(feat[0])

    feat = np.diff(feat, axis=-2)
    feat = np.square(feat)
    
    feat = np.sqrt(np.sum(feat, axis=-1))

    L = np.sum(feat, axis=-1, keepdims=True)

    feat /= L

    IDEAL_LENGTH = np.array([1/2, 1/4, 1/8, 1/8])

    err_1 = 1 - p_el.mean()
    err_2 = np.var(L)/np.mean(L)
    err_2 = min(err_2, 1)

    err_3 = feat - IDEAL_LENGTH
    err_3 = np.linalg.norm(err_3, axis=-1)
    err_3 = np.mean(err_3)

    err_list.append([err_1, err_2, err_3])

err_list = np.array(err_list).T

print("Fraction of good data:")
print(f"1: {100*np.sum(err_list[0]<0.2)/N : 0.1f}%")
print(f"2: {100*np.sum(err_list[1]<0.1)/N : 0.1f}%")
print(f"All: {100*np.sum((err_list[0]<0.2) & (err_list[1]<0.1))/N : 0.1f}%")

plt.plot(err_list[0], label="Error 1")
plt.plot(err_list[1], label="Error 2")
plt.plot(err_list[2], label="Error 3")
plt.xlabel('Data index')
plt.ylabel('Error')
plt.legend()
plt.show()