In [1]:
import torch
import torch.nn as nn
class weight_single(torch.autograd.Function):
    """ define straight through estimator with overrided gradient (gate) """
    @staticmethod
    def forward(ctx, weight):
        return weight
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output
class weight_multi_mask_grad(torch.autograd.Function):
    """ define straight through estimator with overrided gradient (gate) """
    @staticmethod
    def forward(ctx, weight, weight_mask_1, weight_mask_2):
        ctx.save_for_backward(weight_mask_2)
        return torch.mul(weight, weight_mask_1)

    @staticmethod
    def backward(ctx, grad_output):
        weight_mask_2, = ctx.saved_tensors
        return torch.mul(grad_output, weight_mask_2), None, None

In [2]:
weight = nn.Parameter(torch.randn(3, 3))
print(weight)

Parameter containing:
tensor([[-0.5801,  1.2196,  0.5998],
        [-0.8279,  1.8570, -1.4480],
        [ 0.3319, -0.1356,  0.4788]], requires_grad=True)


In [3]:
weight_mask_1 = torch.randint(low=0, high=2, size=(3, 3), dtype=torch.float32)
weight_mask_2 = torch.randint(low=0, high=2, size=(3, 3), dtype=torch.float32)
weight_mask_1.requires_grad = False
weight_mask_2.requires_grad = False
print(weight_mask_1, "\n", weight_mask_2)

tensor([[1., 0., 1.],
        [1., 0., 0.],
        [1., 1., 1.]]) 
 tensor([[0., 0., 1.],
        [0., 1., 1.],
        [0., 1., 0.]])


In [4]:
wgt_out = weight_multi_mask_grad.apply(weight, weight_mask_1, weight_mask_2)
# wgt_out = weight_single.apply(weight)
print(wgt_out)
sum = torch.sum(wgt_out)
sum.backward()
# print(weight.data.grad)

tensor([[-0.5801,  0.0000,  0.5998],
        [-0.8279,  0.0000, -0.0000],
        [ 0.3319, -0.1356,  0.4788]], grad_fn=<weight_multi_mask_gradBackward>)


In [5]:
print(weight.grad)

tensor([[0., 0., 1.],
        [0., 1., 1.],
        [0., 1., 0.]])


In [6]:
weight = nn.Parameter(torch.randn(3, 3))
weight.requires_grad=True
sum = torch.sum(weight)
sum.backward()
print(weight.grad)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


In [10]:
import torch

# assuming `x` is your 1D tensor
x = torch.tensor([1., 0., 0., 0., 5., 6., 7., 8., 9., 10.])

# compute the 60th percentile (0.6 quantile)
threshold = x.quantile(0.8)

print('80% threshold:', threshold.item())

60% threshold: 8.200000762939453
