In [7]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
def bpr_K(y_pred_BD, y_true_BD, K=4, perturbed_top_K_func=None):

    top_K_ids_BKD = perturbed_top_K_func(y_pred_BD)
    # Sum over k dim
    top_K_ids_BD = top_K_ids_BKD.sum(dim=-2)

    true_top_K_val_BD, _  = torch.topk(y_true_BD, K) 
    denominator_B = torch.sum(true_top_K_val_BD, dim=-1)
    numerator_B = torch.sum(top_K_ids_BD * y_true_BD, dim=-1)
    bpr_B = numerator_B/denominator_B

    return bpr_B

In [9]:
class PerturbedTopK(nn.Module):

    def __init__(self, k: int, num_samples: int = 500, sigma: float = 0.05):
        super(PerturbedTopK, self).__init__()
    
        self.num_samples = num_samples
        self.sigma = sigma
        self.k = k
    
    def __call__(self, x):
        # Return the output of the PerturbedTopKFunction, applied to the input tensor
        # using the k, num_samples, and sigma attributes as arguments
        return PerturbedTopKFunction.apply(x, self.k, self.num_samples, self.sigma)


class PerturbedTopKFunction(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, x, k: int, num_samples: int = 500, sigma: float = 0.05):
    
        b, d = x.shape
        
        # Generate Gaussian noise with specified number of samples and standard deviation
        noise = torch.normal(mean=0.0, std=1.0, size=(b, num_samples, d)).to(x.device)

        # Add noise to the input tensor
        perturbed_x = x[:, None, :] + noise * sigma # b, nS, d
        
        # Perform top-k pooling on the perturbed tensor
        topk_results = torch.topk(perturbed_x, k=k, dim=-1, sorted=False)
        
        # Get the indices of the top k elements
        indices = topk_results.indices # b, nS, k
        
        # Sort the indices in ascending order
        indices = torch.sort(indices, dim=-1).values # b, nS, k

        # Convert the indices to one-hot tensors
        perturbed_output = torch.nn.functional.one_hot(indices, num_classes=d).float()
        
        # Average the one-hot tensors to get the final output
        indicators = perturbed_output.mean(dim=1) # b, k, d

        # Save constants and tensors for backward pass
        ctx.k = k
        ctx.num_samples = num_samples
        ctx.sigma = sigma

        ctx.perturbed_output = perturbed_output
        ctx.noise = noise

        return indicators


    @staticmethod
    def backward(ctx, grad_output):
        # If there is no gradient to backpropagate, return tuple of None values
        if grad_output is None:
            return tuple([None] * 5)

        noise_gradient = ctx.noise
        
        # Calculate expected gradient
        expected_gradient = (
            torch.einsum("bnkd,bne->bkde", ctx.perturbed_output, noise_gradient)
            / ctx.num_samples
            / ctx.sigma
        ) * float(ctx.k)
        
        grad_input = torch.einsum("bkd,bkde->be", grad_output, expected_gradient)
        
        return (grad_input,) + tuple([None] * 5)

In [14]:
# batch dimension
B = 100
# feature dimension
D = 10
# top-K
K=4
rating_BD = torch.rand(B, D)
rating_BD.requires_grad_(True)
labels_BD = torch.rand(B, D)
labels_BD.requires_grad_(True)
perturbed_top_K_func = PerturbedTopK(k=K)

In [15]:
def loss(rating_BD):
    """Compute loss. Labels are not an input so we can just get gradient of rating"""
    return -bpr_K(rating_BD, labels_BD, perturbed_top_K_func=perturbed_top_K_func, K=K)

In [18]:
loss_tensor = loss(rating_BD)
loss_val = torch.sum(loss_tensor)

In [19]:
loss_val.backward()

In [20]:
grad_loss_TD = rating_BD.grad

In [22]:
grad_loss_TD.shape

torch.Size([100, 10])

In [6]:
# Goal: gradient of loss w.r.t. rating_BD for each B separately
grad_loss_TD = torch.autograd.functional.jacobian(loss, (rating_BD), vectorize=True,  strategy='forward-mode')

RuntimeError: vmap: We do not yet support calling random operations inside of vmap. Please perform random operations outside of vmap as a workaround