In [1]:
import torch

In [121]:
n_debris = 3

tensor1 = torch.tensor([[10, 10, 20, 20],
                        [30, 30, 40, 40],
                        [50, 50, 60, 60]], dtype=torch.float32)

tensor2 = torch.tensor([[35, 35, 45, 45],
                        [55, 55, 65, 65],
                        [12, 12, 22, 22]], dtype=torch.float32)


In [122]:
distances = torch.cdist(tensor1, tensor2, p=2)
#distances[1, 0] = 100
distances

tensor([[50., 90.,  4.],
        [10., 50., 36.],
        [30., 10., 76.]])

In [123]:
compare_tensor1 = tensor1.clone()
compare_tensor2 = tensor2.clone()
pairs_tensors = []

while compare_tensor1.shape[0] >= 1:

    if compare_tensor1.shape[0] > 1:
        
        distances = torch.cdist(compare_tensor1, compare_tensor2, p=2)
        
        
        row_idx, col_idx = divmod(torch.argmin(distances).item(), distances.size(1))
        pairs_tensors.append((compare_tensor1[row_idx], compare_tensor2[col_idx]))
        
        compare_tensor1 = torch.cat((compare_tensor1[:row_idx], compare_tensor1[row_idx + 1:]))
        compare_tensor2 = torch.cat((compare_tensor2[:col_idx], compare_tensor2[col_idx + 1:]))
        
    else:
        pairs_tensors.append((compare_tensor1, compare_tensor2))
        break

In [124]:
tensor1, tensor2


(tensor([[10., 10., 20., 20.],
         [30., 30., 40., 40.],
         [50., 50., 60., 60.]]),
 tensor([[35., 35., 45., 45.],
         [55., 55., 65., 65.],
         [12., 12., 22., 22.]]))

In [125]:
pairs_tensors

[(tensor([10., 10., 20., 20.]), tensor([12., 12., 22., 22.])),
 (tensor([30., 30., 40., 40.]), tensor([35., 35., 45., 45.])),
 (tensor([[50., 50., 60., 60.]]), tensor([[55., 55., 65., 65.]]))]

In [None]:

max = 0
max_row = 0
max_col = 0

for row in range(distances.shape[0]):
    for col in range(distances.shape[1]):
        if distances[row, col] > max:
            max = distances[row, col]
            max_row = row
            max_col = col
max_index = (max_row, max_col)
    

In [4]:
paired_indices = []
used_rows = set()
used_cols = set()

In [5]:
for _ in range(n):
    # Find the minimum distance and corresponding indices
    min_val, min_idx = torch.min(distances, dim=1)  # Min per row
    row_idx = torch.argmin(min_val).item()
    col_idx = min_idx[row_idx].item()

    # Record the pair
    paired_indices.append((row_idx, col_idx))
    used_rows.add(row_idx)
    used_cols.add(col_idx)

    # Mask out the used row and column to avoid reuse
    distances[row_idx, :] = float('inf')
    distances[:, col_idx] = float('inf')

In [6]:
paired_tensor1 = torch.stack([tensor1[row] for row, _ in paired_indices])
paired_tensor2 = torch.stack([tensor2[col] for _, col in paired_indices])


In [9]:
tensor1, tensor2

(tensor([[0.7587, 0.3375, 0.3928, 0.5462],
         [0.9204, 0.5924, 0.9057, 0.1996]]),
 tensor([[0.8186, 0.0626, 0.8658, 0.8821],
         [0.7791, 0.5983, 0.4437, 0.5704]]))

In [7]:
paired_tensor1

tensor([[0.7587, 0.3375, 0.3928, 0.5462],
        [0.9204, 0.5924, 0.9057, 0.1996]])

In [8]:
paired_tensor2

tensor([[0.7791, 0.5983, 0.4437, 0.5704],
        [0.8186, 0.0626, 0.8658, 0.8821]])

In [10]:
from scipy.optimize import linear_sum_assignment

# Step 1: Compute pairwise distances
distances = torch.cdist(tensor1, tensor2, p=2).numpy()  # Convert to NumPy

# Step 2: Solve assignment problem
row_indices, col_indices = linear_sum_assignment(distances)

# Step 3: Pair tensors based on assignment
paired_tensor1 = tensor1[row_indices]
paired_tensor2 = tensor2[col_indices]

# Print results
print("Tensor 1 paired rows:\n", paired_tensor1)
print("Tensor 2 paired rows:\n", paired_tensor2)

Tensor 1 paired rows:
 tensor([[0.7587, 0.3375, 0.3928, 0.5462],
        [0.9204, 0.5924, 0.9057, 0.1996]])
Tensor 2 paired rows:
 tensor([[0.7791, 0.5983, 0.4437, 0.5704],
        [0.8186, 0.0626, 0.8658, 0.8821]])
