<a href="https://colab.research.google.com/github/brandonwagstaff/lidar_feature_matching/blob/main/optimal_transport/keypoint_matching_with_dustbins.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""

Based on:
https://github.com/MichielStock/Teaching/blob/master/Optimal_transport/sinkhorn_knopp.py
@author: Michiel Stock
michielfmstock@gmail.com

Implementation of the Sinkhorn-Knopp algorithm for optimal transport
Demonstrating how adding a dustbin can be used to account for non-matches or bad matches
"""

import numpy as np
import optimal_transport
from matplotlib import pyplot as plt
from sklearn.datasets import make_moons
from scipy.spatial import distance_matrix




X, y = make_moons(n_samples=20, noise=0.1, shuffle=False)

X1 = X[y==1,:]
X2 = -X[y==0,:]

# X2 = X2[:40]  # different size X1 and X2

n, m = X1.shape[0], X2.shape[0]


r = np.ones(n) / n
c = np.ones(m) / m

M = distance_matrix(X1, X2)


## Add dustbins

r = np.hstack((r,1))
c = np.hstack((c,1))

M = np.vstack((M, 2.5*np.ones((1, M.shape[1]))))
M = np.hstack((M, 2.5*np.ones((M.shape[0],1))))


P, d = compute_optimal_transport(M, r, c, lam=30, epsilon=1e-5)

fig, (ax1, ax2, ax) = plt.subplots(ncols=3)
ax.scatter(X1[:,0], X1[:,1], color='b')
ax.scatter(X2[:,0], X2[:,1], color='c')


unmatched_n = 0
for i in range(0, P.shape[0]):
    if np.argmax(P[i,:]) == n:
        unmatched_n += 1

unmatched_m = 0
for j in range(0, P.shape[1]):
    if np.argmax(P[:,j]) == m:
        unmatched_m += 1

print("unmatched n: {}, unmatched m: {}".format(unmatched_n, unmatched_m))

unmatch_thresh_n = P[-1,:]
unmatch_thresh_m = P[:,-1]
P = P[0:-1, 0:-1]


### ORIGINAl: all soft matches plotted
# for i in range(n):
#     for j in range(m):
#         ax.plot([X1[i,0], X2[j,0]], [X1[i,1], X2[j,1]], color='r',
#                 alpha=P[i,j] * n)


### MODIFIED: plot only the highest matches and ensure 1:1 correspondence
# P_copy = np.array(P)
# match_idx_x = []
# match_idx_y = []

# count = 0
# while P_copy.sum() > 0:
#     count += 1
#     max_i_idx = np.argmax(np.max(P_copy,axis=1))
#     max_j_idx = np.argmax(P_copy[max_i_idx])

#     match_idx_x.append(max_i_idx)
#     match_idx_y.append(max_j_idx)

#     P_copy[max_i_idx,:] = 0
#     P_copy[:,max_j_idx] = 0

# print(n,m,count)

# for i, j in zip(match_idx_x, match_idx_y):
#     ax.plot([X1[i,0], X2[j,0]], [X1[i,1], X2[j,1]], color='r', alpha=0.1+P[i,j])

## MODIFIED 2: soft plotting, but only show the strongest connection

for i in range(n):
    P_i = P[i]
    j = np.argmax(P_i)

    # Only plot matches (P_ij > P_dustbin)
    if P_i[j] > unmatch_thresh_m[i]:
        ax.plot([X1[i,0], X2[j,0]], [X1[i,1], X2[j,1]], color='r', alpha=P[i,j]*n)
    else:
        ax.scatter([X1[i,0]], [X1[i,1]], marker='x', s=100)




ax.set_title('Optimal matching')

ax1.imshow(M)
ax1.set_title('Cost matrix')

ax2.imshow(P)
ax2.set_title('Transport matrix')
print(d)
plt.show()