## Soft Binary Argmax @k (SBAM) Illustration

In [None]:
from src.soft_binary_arg_max_ops import *

In [2]:
import torch

x = torch.tensor([0.1, 1.6, 1])

## Binary Arg max (hard, sparse solution) k = 1

In [3]:
soft_binary_argmax_k(x.unsqueeze(0), k=1, regularization_strength=0.01)

tensor([[0., 1., 0.]])

## Soft Binary Arg max @ k= 1

In [4]:
soft_binary_argmax_k(x.unsqueeze(0), k=1, regularization_strength=10)

tensor([[0.2533, 0.4033, 0.3433]])

In [5]:
soft_binary_argmax_k(x.unsqueeze(0), k=1, regularization_strength=50)

tensor([[0.3173, 0.3473, 0.3353]])

In [6]:
soft_binary_argmax_k(x.unsqueeze(0), k=1, regularization_strength=10^9)

tensor([[0.0667, 0.5667, 0.3667]])

## Differentiation

In [21]:
x = torch.tensor([0.1, 1.6, 1], requires_grad=True).cuda()
y = soft_binary_argmax_k(x.unsqueeze(0), k=1, regularization_strength=10)


In [22]:
torch.autograd.grad(y[0, 0], x)


(tensor([ 0.0667, -0.0333, -0.0333], device='cuda:0'),)

## HyperSimplex Loss

In [28]:
true_labels = torch.tensor([[0, 1, 0]], dtype=torch.float32).cuda()
logits = torch.tensor([[2.1, 1.6, 1]],  requires_grad=True, dtype=torch.float32).cuda()

def hyper_simplex_loss(logits, true_labels, k, regularization_strength):
    probs = soft_binary_argmax_k(logits, k=k, regularization_strength=regularization_strength)
    loss = torch.nn.functional.mse_loss(probs, true_labels)
    return loss

loss = hyper_simplex_loss(logits, true_labels, k=1, regularization_strength=1.5)
print(loss)

tensor(0.2963, device='cuda:0', grad_fn=<MseLossBackward0>)


In [29]:
torch.autograd.grad(loss, logits)

(tensor([[ 0.2963, -0.2963,  0.0000]], device='cuda:0'),)

### A neural network will adjust the weights to descrease the logits from output at position 0, increasing performance