In [60]:
import torch
from diffsort import DiffSortNet

# Using PyTorch's built-in BCE loss
criterion = torch.nn.MSELoss()

vector_length = 2**2
gt_vec = torch.tensor([0.0, 1.0, 1.0, 0.0])
shuffled_vec = torch.tensor([1.0, 0.0, 0.0, 1.0])
idxs = torch.randperm(vector_length, dtype=torch.float32, device='cpu').view(1,-1)
idxs.requires_grad_(True)

# sort using a bitonic-sorting-network
sorter = DiffSortNet('bitonic', vector_length, steepness=15)

optimizer = torch.optim.Adam([idxs], lr=0.1)

for i in range(100):  # Example loop
    sorted_vectors, permutation_matrices = sorter(idxs)
    pred_vec = shuffled_vec@permutation_matrices
    
    loss = criterion(gt_vec.view(1,-1), pred_vec)
    
    optimizer.zero_grad()
    
    # Compute gradients
    loss.backward()
    
    # Step the optimizer (this updates the parameter)
    optimizer.step()
    
    # Optionally print the parameter's progress
    if i % 10 == 0:
        print(f'Iteration {i+1}, Loss: {loss.item()}, Param: {idxs}')
    
print("\n\n")
print("Final Predicted Vector: ",  pred_vec)
print("GT Vector: ", gt_vec)
print("Final Permutation Matrix: ", permutation_matrices[0])


Iteration 1, Loss: 0.44809776544570923, Param: tensor([[0.1000, 0.9000, 3.1000, 1.9000]], requires_grad=True)
Iteration 11, Loss: 0.0021769609302282333, Param: tensor([[0.8895, 0.1191, 3.6313, 1.1061]], requires_grad=True)
Iteration 21, Loss: 0.0006753954803571105, Param: tensor([[ 1.2231, -0.2127,  4.0209,  0.8201]], requires_grad=True)
Iteration 31, Loss: 0.000526962336152792, Param: tensor([[ 1.3525, -0.3427,  4.3478,  0.8060]], requires_grad=True)
Iteration 41, Loss: 0.00044421988422982395, Param: tensor([[ 1.4037, -0.3952,  4.6283,  0.8839]], requires_grad=True)
Iteration 51, Loss: 0.00039250595727935433, Param: tensor([[ 1.4240, -0.4169,  4.8757,  0.9824]], requires_grad=True)
Iteration 61, Loss: 0.00035758185549639165, Param: tensor([[ 1.4321, -0.4263,  5.0996,  1.0785]], requires_grad=True)
Iteration 71, Loss: 0.00033222156343981624, Param: tensor([[ 1.4355, -0.4309,  5.3062,  1.1663]], requires_grad=True)
Iteration 81, Loss: 0.0003128025564365089, Param: tensor([[ 1.4371, -0.4