In [1]:
import torch

In [7]:
a = torch.tensor([0,0,0,0], dtype=torch.float32)
torch.nn.functional.normalize(a, dim=-1)

tensor([0., 0., 0., 0.])

In [2]:
import numpy as np

In [3]:
xx = np.random.rand(4)
print(xx)
xx[[0,2,3,1]]

[0.54760118 0.80149694 0.24119498 0.89909611]


array([0.54760118, 0.24119498, 0.89909611, 0.80149694])

In [None]:
#!/usr/bin/env python

"""
    auction_lap.py
    
    From
        https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf;sequence=1
"""

from __future__ import print_function, division

import sys
import torch

def auction_lap(X, eps=None, compute_score=True):
    """
        X: n-by-n matrix w/ integer entries
        eps: "bid size" -- smaller values means higher accuracy w/ longer runtime
    """
    eps = 1 / X.shape[0] if eps is None else eps
    
    # --
    # Init
    
    cost     = torch.zeros((1, X.shape[1]))
    curr_ass = torch.zeros(X.shape[0]).long() - 1
    bids     = torch.zeros(X.shape)
    
    if X.is_cuda:
        cost, curr_ass, bids = cost.cuda(), curr_ass.cuda(), bids.cuda()
    
    counter = 0
    while (curr_ass == -1).any():
        counter += 1
        
        # --
        # Bidding
        
        unassigned = (curr_ass == -1).nonzero().squeeze(dim=1)
        
        value = X[unassigned] - cost
        top_value, top_idx = value.topk(2, dim=1)
        
        first_idx = top_idx[:,0]
        first_value, second_value = top_value[:,0], top_value[:,1]
        
        bid_increments = first_value - second_value + eps
        
        bids_ = bids[unassigned]
        bids_.zero_()
        bids_.scatter_(
            dim=1,
            index=first_idx.contiguous().view(-1, 1),
            src=bid_increments.view(-1, 1)
        )
        
        # --
        # Assignment
        
        have_bidder = (bids_ > 0).int().sum(dim=0).nonzero()
        
        high_bids, high_bidders = bids_[:,have_bidder].max(dim=0)
        high_bidders = unassigned[high_bidders.squeeze()]
        
        cost[:,have_bidder] += high_bids
        
        curr_ass[(curr_ass.view(-1, 1) == have_bidder.view(1, -1)).sum(dim=1)] = -1
        curr_ass[high_bidders] = have_bidder.squeeze()
    
    score = None
    if compute_score:
        score = int(X.gather(dim=1, index=curr_ass.view(-1, 1)).sum())
    
    return score, curr_ass, counter

In [20]:
from torch_linear_assignment import batch_linear_assignment

xx = torch.randn(2,5,5)
print(xx)
batch_linear_assignment(-xx)

tensor([[[ 1.0103,  0.2311, -0.0290,  0.2136,  0.6728],
         [-0.6718,  1.1757,  0.3734, -0.4217, -0.9985],
         [ 0.5374, -1.0040, -0.1930,  1.1668,  1.2807],
         [-1.3198,  0.5633,  0.1314,  0.9520,  0.5832],
         [ 0.1003,  1.2608,  0.1722, -2.2411,  1.4001]],

        [[ 1.2192,  2.2065, -1.6999,  0.0936,  0.6548],
         [ 1.2671, -0.3542, -0.9693, -0.4955, -0.5529],
         [-0.3040,  1.8883,  0.7900,  0.0718, -0.3346],
         [-0.5126, -1.6058,  0.3922, -0.4387, -0.2765],
         [ 0.4620, -0.1736,  1.4381, -1.1730,  0.2142]]])


tensor([[0, 1, 3, 2, 4],
        [4, 0, 1, 3, 2]])

In [18]:
from scipy.optimize import linear_sum_assignment

linear_sum_assignment(-xx.numpy()[1,:])[1]

array([1, 2, 0, 3, 4])

In [21]:
xx.shape

torch.Size([2, 5, 5])

In [28]:
torch.reshape(torch.ones(2,2), [1, torch.ones(2,2).shape[0], torch.ones(2,2).shape[1]])

tensor([[[1., 1.],
         [1., 1.]]])

In [31]:
xx.device

device(type='cpu')

In [None]:
[0.6, 0.8], [0.3,0.7]
0.75, 0.25
[]

In [7]:
xx.device

device(type='cpu')

In [8]:
xx.requires_grad

False

In [8]:
from torch.autograd import Function

def sinkhorn_forward(C, mu, nu, epsilon, max_iter):
    bs, n, k_ = C.size()

    v = torch.ones([bs, 1, k_])/(k_)
    G = torch.exp(-C/epsilon)
    if torch.cuda.is_available():
        v = v.cuda()

    for i in range(max_iter):
        u = mu/(G*v).sum(-1, keepdim=True)
        v = nu/(G*u).sum(-2, keepdim=True)

    Gamma = u*G*v
    return Gamma

def sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter):
    bs, n, k_ = C.size()
    k = k_-1

    f = torch.zeros([bs, n, 1])
    g = torch.zeros([bs, 1, k+1])
    if torch.cuda.is_available():
        f = f.cuda()
        g = g.cuda()

    epsilon_log_mu = epsilon*torch.log(mu)
    epsilon_log_nu = epsilon*torch.log(nu)

    def min_epsilon_row(Z, epsilon):
        return -epsilon*torch.logsumexp((-Z)/epsilon, -1, keepdim=True)
    
    def min_epsilon_col(Z, epsilon):
        return -epsilon*torch.logsumexp((-Z)/epsilon, -2, keepdim=True)

    for i in range(max_iter):
        f = min_epsilon_row(C-g, epsilon)+epsilon_log_mu
        g = min_epsilon_col(C-f, epsilon)+epsilon_log_nu
        
    Gamma = torch.exp((-C+f+g)/epsilon)
    return Gamma
    
def sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon):
    
    nu_ = nu[:,:,:-1]
    Gamma_ = Gamma[:,:,:-1]

    bs, n, k_ = Gamma.size()
    
    inv_mu = 1./(mu.view([1,-1]))  #[1, n]
    Kappa = torch.diag_embed(nu_.squeeze(-2)) \
            -torch.matmul(Gamma_.transpose(-1, -2) * inv_mu.unsqueeze(-2), Gamma_)   #[bs, k, k]
    
    inv_Kappa = torch.inverse(Kappa) #[bs, k, k]
    
    Gamma_mu = inv_mu.unsqueeze(-1)*Gamma_
    L = Gamma_mu.matmul(inv_Kappa) #[bs, n, k]
    G1 = grad_output_Gamma * Gamma #[bs, n, k+1]
    
    g1 = G1.sum(-1)
    G21 = (g1*inv_mu).unsqueeze(-1)*Gamma  #[bs, n, k+1]
    g1_L = g1.unsqueeze(-2).matmul(L)  #[bs, 1, k]
    G22 = g1_L.matmul(Gamma_mu.transpose(-1,-2)).transpose(-1,-2)*Gamma  #[bs, n, k+1]
    G23 = - F.pad(g1_L, pad=(0, 1), mode='constant', value=0)*Gamma  #[bs, n, k+1]
    G2 = G21 + G22 + G23  #[bs, n, k+1]
    
    del g1, G21, G22, G23, Gamma_mu
    
    g2 = G1.sum(-2).unsqueeze(-1) #[bs, k+1, 1]
    g2 = g2[:,:-1,:]  #[bs, k, 1]
    G31 = - L.matmul(g2)*Gamma  #[bs, n, k+1]
    G32 = F.pad(inv_Kappa.matmul(g2).transpose(-1,-2), pad=(0, 1), mode='constant', value=0)*Gamma  #[bs, n, k+1]
    G3 = G31 + G32  #[bs, n, k+1]

    grad_C = (-G1+G2+G3)/epsilon  #[bs, n, k+1]
    return grad_C

class TopKFunc(Function):
    @staticmethod
    def forward(ctx, C, mu, nu, epsilon, max_iter):
        
        with torch.no_grad():
            if epsilon>1e-2:
                Gamma = sinkhorn_forward(C, mu, nu, epsilon, max_iter)
                if bool(torch.any(Gamma!=Gamma)):
                    print('Nan appeared in Gamma, re-computing...')
                    Gamma = sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter)
            else:
                Gamma = sinkhorn_forward_stablized(C, mu, nu, epsilon, max_iter)
            ctx.save_for_backward(mu, nu, Gamma)
            ctx.epsilon = epsilon
        return Gamma

    @staticmethod
    def backward(ctx, grad_output_Gamma):
        
        epsilon = ctx.epsilon
        mu, nu, Gamma = ctx.saved_tensors
        # mu [1, n, 1]
        # nu [1, 1, k+1]
        #Gamma [bs, n, k+1]   
        with torch.no_grad():
            grad_C = sinkhorn_backward(grad_output_Gamma, Gamma, mu, nu, epsilon)
        return grad_C, None, None, None, None


class TopK_custom(torch.nn.Module):
    def __init__(self, k, epsilon=0.1, max_iter = 200):
        super(TopK_custom, self).__init__()
        self.k = k
        self.epsilon = epsilon
        self.anchors = torch.FloatTensor([k-i for i in range(k+1)]).view([1,1, k+1])
        self.max_iter = max_iter
        
        if torch.cuda.is_available():
            self.anchors = self.anchors.cuda()

    def forward(self, scores):
        bs, n = scores.size()
        scores = scores.view([bs, n, 1])
        
        #find the -inf value and replace it with the minimum value except -inf
        scores_ = scores.clone().detach()
        max_scores = torch.max(scores_).detach()
        scores_[scores_==float('-inf')] = float('inf')
        min_scores = torch.min(scores_).detach()
        filled_value = min_scores - (max_scores-min_scores)
        mask = scores==float('-inf')
        scores = scores.masked_fill(mask, filled_value)
        
        C = (scores-self.anchors)**2
        C = C / (C.max().detach())
      
        mu = torch.ones([1, n, 1], requires_grad=False)/n
        nu = [1./n for _ in range(self.k)]
        nu.append((n-self.k)/n)
        nu = torch.FloatTensor(nu).view([1, 1, self.k+1])
        
        if torch.cuda.is_available():
            mu = mu.cuda()
            nu = nu.cuda()
            
        Gamma = TopKFunc.apply(C, mu, nu, self.epsilon, self.max_iter)
 
        A = Gamma[:,:,:self.k]*n
        
        return A, None

In [10]:
top = TopK_custom(10)

In [21]:
top(torch.randn(1,20).cuda())

(tensor([[[0.0041, 0.0059, 0.0084, 0.0117, 0.0161, 0.0218, 0.0291, 0.0380,
           0.0488, 0.0613],
          [0.0058, 0.0080, 0.0110, 0.0147, 0.0195, 0.0255, 0.0327, 0.0411,
           0.0508, 0.0615],
          [0.0340, 0.0380, 0.0420, 0.0457, 0.0491, 0.0520, 0.0540, 0.0552,
           0.0552, 0.0542],
          [0.0246, 0.0287, 0.0331, 0.0376, 0.0421, 0.0464, 0.0502, 0.0534,
           0.0558, 0.0570],
          [0.1205, 0.1105, 0.1001, 0.0894, 0.0787, 0.0683, 0.0582, 0.0487,
           0.0400, 0.0322],
          [0.0172, 0.0210, 0.0253, 0.0300, 0.0351, 0.0404, 0.0457, 0.0508,
           0.0555, 0.0593],
          [0.0043, 0.0062, 0.0087, 0.0121, 0.0165, 0.0223, 0.0295, 0.0384,
           0.0491, 0.0614],
          [0.0311, 0.0352, 0.0394, 0.0434, 0.0471, 0.0504, 0.0530, 0.0548,
           0.0555, 0.0551],
          [0.0390, 0.0429, 0.0465, 0.0497, 0.0524, 0.0544, 0.0555, 0.0556,
           0.0547, 0.0527],
          [0.0294, 0.0335, 0.0377, 0.0419, 0.0459, 0.0494, 0.0524, 0.0545

In [3]:
import torch
from torch.autograd import Function
import torch.nn.functional as F
from tqdm import tqdm

@torch.no_grad()
def _find_ts(xs, ks, binary_iter=16, newton_iter=1):
    n = xs.shape[-1]
    assert torch.all((0 < ks) & (ks < n)), "We don't support k=0 or k=n"
    # Lo should be small enough that all sigmoids are in the 0 area.
    # Similarly Hi is large enough that all are in their 1 area.
    lo = -xs.max(dim=-1, keepdims=True).values - 10
    hi = -xs.min(dim=-1, keepdims=True).values + 10
    assert torch.all(torch.sigmoid(xs + lo).sum(dim=-1) < 1)
    assert torch.all(torch.sigmoid(xs + hi).sum(dim=-1) > n - 1)
    # Batch binary search, solving sigmoid(xs + ts) = ks
    for _ in range(binary_iter):
        mid = (hi + lo) / 2
        mask = torch.sigmoid(xs + mid).sum(dim=-1) < ks
        lo[mask] = mid[mask]
        hi[~mask] = mid[~mask]
    ts = (lo + hi) / 2
    # Fine-tune using some Newton iterations
    for _ in range(newton_iter):
        sig = torch.sigmoid(xs + ts)
        den = sig.sum(dim=-1, keepdims=True) - ks[..., None]
        num = (sig * (1 - sig)).sum(dim=-1, keepdims=True)
        ts -= den / num
    # Test for success
    assert torch.allclose(torch.sigmoid(xs + ts).sum(dim=-1), ks.double())
    return ts


class TopK(Function):
    @staticmethod
    def forward(ctx, xs, ks):
        ts = _find_ts(xs, ks)
        ps = torch.sigmoid(xs + ts)
        ctx.save_for_backward(ps)
        return ps

    @staticmethod
    def backward(ctx, grad_output):
        # Compute vjp, that is grad_output.T @ J.
        (ps,) = ctx.saved_tensors
        # Let v = sigmoid'(x + t)
        v = ps * (1 - ps)  # sigmoid' = sigmoid * (1 - sigmoid)
        s = v.sum(dim=-1, keepdims=True)
        t_d = v / s
        # Jacobian is -vv.T/s + diag(v)
        uv = grad_output * v
        t1 = uv.sum(dim=-1, keepdims=True) * t_d
        return uv - t1, None


class TopK_BCE(Function):
    @staticmethod
    def forward(ctx, xs, ks, ys):
        xs = xs + _find_ts(xs, ks)
        ctx.save_for_backward(xs, ks, ys)
        loss = (ys - 1) * xs + F.logsigmoid(xs)
        return -loss

    @staticmethod
    def backward(ctx, grad_output):
        xts, ks, ys = ctx.saved_tensors
        # Compute d/dxi t = - sig'(x_i + t) / sum_j sig'(x_j + t)
        sig = torch.sigmoid(xts)
        sig_d = sig * (1 - sig)  # sigmoid' = sigmoid * (1 - sigmoid)
        num = sig_d.sum(dim=-1, keepdims=True)
        t_d = -sig_d / num
        # Jacobian is t'e^T - diag(e)
        e = ys - sig
        ev = e * grad_output
        b = ev + t_d * ev.sum(dim=-1, keepdims=True)
        return -b, None, xts


soft_topk = TopK.apply
bce_topk = TopK_BCE.apply


def main():
    from torch.autograd import gradcheck
    import tqdm

    n1, n2, d = 20, 2, 10
    xs = torch.randn(n1, n2, d, dtype=torch.double, requires_grad=True)

    # Test TopK function
    for _ in tqdm.tqdm(range(2)):
        ks = torch.randint(1, d, size=(n1, n2))
        assert gradcheck(soft_topk, (xs, ks), eps=1e-6, atol=1e-4)

    for _ in tqdm.tqdm(range(10)):
        ks = torch.randint(1, d, size=(n1, n2), dtype=torch.double)
        ys = torch.randint(0, 2, size=(n1, n2, d), dtype=torch.double)
        # Test forward method
        torch.testing.assert_close(
            F.binary_cross_entropy(soft_topk(xs, ks), ys, reduction="none"),
            bce_topk(xs, ks, ys),
        )
        # Test backward method
        assert gradcheck(bce_topk, (xs, ks, ys), eps=1e-6, atol=1e-4)



main()

100%|█████████████████████████████████████████████| 2/2 [00:04<00:00,  2.04s/it]
100%|███████████████████████████████████████████| 10/10 [00:15<00:00,  1.57s/it]


In [4]:
n1, n2, d = 20, 2, 10
print(torch.randint(1, d, size=(n1, n2)).shape)
torch.randn(n1, n2, d, dtype=torch.double, requires_grad=True).shape

torch.Size([20, 2])


torch.Size([20, 2, 10])

In [68]:
import torch
from torch_linear_assignment import batch_linear_assignment


def my_matching(matrix_batch):
    """
    Solves a matching problem for a batch of matrices using the Straight-Through Estimator (STE).

    Args:
        matrix_batch: A 3D tensor (a batch of matrices) with shape = [batch_size, N, N].
                      If 2D, the input is reshaped to 3D with batch_size = 1.

    Returns:
        A 2D integer tensor of permutations with shape [batch_size, N].
    """
    
    def hungarian(x):
        if x.ndim == 2:
            x = torch.reshape(x, [1, x.shape[0], x.shape[1]])
        sol = batch_linear_assignment(-x)
        return sol

    N = matrix_batch.shape[-1]
    index_layer = torch.nn.Linear(N, 1, bias=False)
        
    # Initialize the weights to the range of vocab_size and freeze the layer
    with torch.no_grad():
        index_layer.weight = torch.nn.Parameter(torch.arange(N).float().unsqueeze(0))
    index_layer.weight.requires_grad = False

    # Get hard permutations using Hungarian algorithm
    listperms_hard = hungarian(matrix_batch.detach()).to(matrix_batch.device)  # Detach to prevent gradient tracking

    batch_size, N, _ = matrix_batch.shape
    listperms_hard_onehot = torch.zeros(batch_size, N, N, device=matrix_batch.device)
    listperms_hard_onehot.scatter_(2, listperms_hard.unsqueeze(-1), 1)

    # STE: During backward, replace the gradient with the gradient of matrix_batch
    listperms_ste = listperms_hard_onehot + matrix_batch - matrix_batch.detach()

    # Convert back to hard permutations (indices) as the output
    #listperms_output = listperms_hard_onehot.argmax(dim=-1).float()
    
    return listperms_ste, index_layer(listperms_ste).squeeze(-1)
    #return listperms_ste


In [71]:
batch_size, N = 10, 5
matrix_batch = torch.rand(batch_size, N, N, requires_grad=True)
listperms = my_matching(matrix_batch)
print(listperms)
# Example loss and backward pass
loss = torch.sum(listperms)  # Define some loss function
loss.backward()  # Backpropagation

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 1.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000

TypeError: sum(): argument 'input' (position 1) must be Tensor, not tuple

In [75]:
torch.eye(5)[[4.0000, 1.0000, 3.0000, 2.0000, 0.0000]]

tensor([[0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.]])