### Example: optimize a tensor of values so that their rank matches a target

In [1]:
import torch
from perturbations_torch.fenchel_young import FenchelYoungLoss


def ranks(inputs, dim=-1):
    """Returns the ranks of the input values among the given axis."""
    return 1 + inputs.argsort(dim).argsort(dim).type(inputs.dtype)

# We initialize a random tensor
x = torch.randn([3, 5]).float()
print(x)

# Turn its grad on, since we will change this tensor to minimize our loss
x.requires_grad = True
y_true = torch.arange(5).float().unsqueeze(0).repeat([x.shape[0], 1])

print("Initially, the values in our tensor do not result in the desired argsort")
print(x.argsort(-1))

tensor([[[ 1.7713, -1.0994, -0.1573,  0.0531, -0.9103],
         [-2.0598,  0.9478,  0.2425,  0.9996, -1.8692],
         [-2.2766, -0.6376,  0.5263,  1.5816, -3.3394]]])
Initially, the values in our tensor do not result in the desired argsort
tensor([[[1, 4, 2, 3, 0],
         [0, 4, 2, 1, 3],
         [4, 0, 1, 2, 3]]])


In [2]:
# Initialize an SGD optimizer and do 200 steps
optim = torch.optim.SGD([x], 0.01)

for iteration in range(200):
    optim.zero_grad()
    criterion = FenchelYoungLoss(ranks)
    loss = criterion(y_true, x).sum()
    loss.backward()
    optim.step()
    if iteration % 50 == 0:
        print(x.argsort(-1))
        print(loss.item())
        
print("SGD has succesfully changed our tensor to match the desired argsort!")

tensor([[[1, 4, 2, 3, 0],
         [0, 4, 2, 1, 3],
         [4, 0, 1, 2, 3]]])
75.0
tensor([[[1, 0, 2, 3, 4],
         [0, 4, 1, 2, 3],
         [0, 4, 1, 2, 3]]])
41.00177764892578
tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 4, 3]]])
20.642358779907227
tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 4, 3]]])
17.00006675720215
SGD has succesfully changed our tensor to match the desired argsort!
