In [None]:
import torch
from torch import Tensor
import numpy as np

EPSILON = np.finfo(np.float32).tiny

In [None]:
class SubsetOperator(torch.nn.Module):
    def __init__(self, k, tau=1.0, hard=False):            # k is the number of samples we want, tau is the temperature parameter and hard:denotes if we want hard or soft samples
        super(SubsetOperator, self).__init__()
        self.k = k
        self.hard = hard
        self.tau = tau

    def forward(self, scores):                                # scores take in log(weights) of each sample -- scores need not be positive (exp(scores)) is actual weight      # scores: Typical shape: [batch_size,n] or [batch_size,n,1]
        m = torch.distributions.gumbel.Gumbel(torch.zeros_like(scores), torch.ones_like(scores))
        g = m.sample()
        scores = scores + g

        # continuous top k  (we can later modify this to also output S_WRS, we will just need each onehot_approx to be stored seperately - then it will give k soft vectors)
        khot = torch.zeros_like(scores)
        onehot_approx = torch.zeros_like(scores)
        for i in range(self.k):
            khot_mask = torch.max(1.0 - onehot_approx, torch.tensor([EPSILON]))            # we can autodiff through this, there is no issue .
            # khot_mask = torch.max(1.0 - onehot_approx, torch.tensor([EPSILON]).cuda())      #CHECK MIGHT NEED TO PUT DEVICE HERE,
            scores = scores + torch.log(khot_mask)
            onehot_approx = torch.nn.functional.softmax(scores / self.tau, dim=1)
            khot = khot + onehot_approx

        if self.hard:
            # will do straight through estimation if training
            khot_hard = torch.zeros_like(khot)
            val, ind = torch.topk(khot, self.k, dim=1)             #This line uses the torch.topk function to find the top self.k elements in each row (since dim=1) of the khot tensor.  val will store the values of these top elements, and ind will store their indices.
            khot_hard = khot_hard.scatter_(1, ind, 1)              #Here, the scatter_ function is used to take the zero tensor khot_hard and set the indices specified in ind to 1. This effectively creates a "hard" version of khot where only the top self.k elements in each row are set to 1, and the rest are 0. The underscore at the end of scatter_ indicates that this operation is done in-place, modifying khot_hard directly.
            res = khot_hard - khot.detach() + khot                 #This line is a bit trickier. It's part of a technique called the Straight-Through Estimator (STE). khot.detach() creates a tensor that does not require gradients, effectively a constant in terms of backpropagation.  By subtracting khot.detach() and then adding khot, you replace the gradients of khot_hard with those of khot during backpropagation. This is because khot_hard - khot.detach() stops the gradient from flowing through the hard assignment. The result is that during the forward pass, res acts like the hard assignment (since khot.detach() has no effect), but during the backward pass (gradient computation), it behaves like khot (since khot_hard - khot.detach() has no gradient).
        else:
            res = khot

        return res

In [None]:
class SubsetOperator_Test(torch.nn.Module):
    def __init__(self, k, tau=1.0, hard=False):            # k is the number of samples we want, tau is the temperature parameter and hard:denotes if we want hard or soft samples
        super(SubsetOperator_Test, self).__init__()
        self.k = k
        self.hard = hard
        self.tau = tau

    def forward(self, scores):                                # scores take in weights of each sample      # scores: Typical shape: [batch_size,n] or [batch_size,n,1]
        m = torch.distributions.gumbel.Gumbel(torch.zeros_like(scores), torch.ones_like(scores))
        g = m.sample()
        scores = scores + g

        # continuous top k  (we can later modify this to also output S_WRS, we will just need each onehot_approx to be stored seperately - then it will give k soft vectors)
        khot = torch.zeros_like(scores)
        onehot_approx = torch.zeros_like(scores)
        khot_all = []
        for i in range(self.k):
            khot_mask = torch.max(1.0 - onehot_approx, torch.tensor([EPSILON]))            # we can autodiff through this, there is no issue .
            # khot_mask = torch.max(1.0 - onehot_approx, torch.tensor([EPSILON]).cuda())      #CHECK MIGHT NEED TO PUT DEVICE HERE,
            scores = scores + torch.log(khot_mask)
            onehot_approx = torch.nn.functional.softmax(scores / self.tau, dim=1)
            khot = khot + onehot_approx
            khot_all.append(onehot_approx)

        if self.hard:
            # will do straight through estimation if training
            khot_hard = torch.zeros_like(khot)
            val, ind = torch.topk(khot, self.k, dim=1)             #This line uses the torch.topk function to find the top self.k elements in each row (since dim=1) of the khot tensor.  val will store the values of these top elements, and ind will store their indices.
            khot_hard = khot_hard.scatter_(1, ind, 1)              #Here, the scatter_ function is used to take the zero tensor khot_hard and set the indices specified in ind to 1. This effectively creates a "hard" version of khot where only the top self.k elements in each row are set to 1, and the rest are 0. The underscore at the end of scatter_ indicates that this operation is done in-place, modifying khot_hard directly.
            res = khot_hard - khot.detach() + khot                 #This line is a bit trickier. It's part of a technique called the Straight-Through Estimator (STE). khot.detach() creates a tensor that does not require gradients, effectively a constant in terms of backpropagation.  By subtracting khot.detach() and then adding khot, you replace the gradients of khot_hard with those of khot during backpropagation. This is because khot_hard - khot.detach() stops the gradient from flowing through the hard assignment. The result is that during the forward pass, res acts like the hard assignment (since khot.detach() has no effect), but during the backward pass (gradient computation), it behaves like khot (since khot_hard - khot.detach() has no gradient).
        else:
            res = khot

        return res, khot_all

In [None]:
# REMIND - scores here are log(weights) that is proability is exp(s_i)/sum_{j=1}^n exp(s_j)

# Test 1a  -- k=1  -- aeverage of soft samples -- they will not match exactly because of no. of samples of gumble induces approx and also the temp parameter



scores = torch.tensor([[1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]])
temp = 0.00001
n=100000
k=1
subset_operator = SubsetOperator_Test(k, temp, False)
res, khot_all = subset_operator(scores)
khot_all_1 = khot_all[0]
for i in range(n):
  res, khot_all = subset_operator(scores)
  khot_all_1 += khot_all[0]

khot_all_final = khot_all_1/(n+1)

print("khot_all_final:", khot_all_final)

expected = torch.exp(scores)/(torch.exp(scores)).sum()
print("expected:", expected)




khot_all_final: tensor([[5.1999e-04, 1.3800e-03, 4.1900e-03, 1.1920e-02, 3.2120e-02, 8.4939e-02,
         2.3277e-01, 6.3216e-01]])
expected: tensor([[5.7661e-04, 1.5674e-03, 4.2606e-03, 1.1582e-02, 3.1482e-02, 8.5577e-02,
         2.3262e-01, 6.3233e-01]])


In [None]:
# Test 1b  -- k=1, true --- aeverage of hard samples - they will not match exactly because of no. of samples of gumble induces approx (although there is no effect of temp)



scores = torch.tensor([[1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]])
temp = 0.00001
n=100000
k=1
subset_operator = SubsetOperator_Test(k, temp, True)
res, khot_all = subset_operator(scores)
khot_all_1 = res
for i in range(n):
  res, khot_all = subset_operator(scores)
  khot_all_1 += res

khot_all_final = khot_all_1/(n+1)

print("khot_all_final:", khot_all_final)

expected = torch.exp(scores)/(torch.exp(scores)).sum()
print("expected:", expected)


khot_all_final: tensor([[5.5999e-04, 1.6700e-03, 4.4200e-03, 1.1540e-02, 3.2030e-02, 8.5059e-02,
         2.3305e-01, 6.3167e-01]])
expected: tensor([[5.7661e-04, 1.5674e-03, 4.2606e-03, 1.1582e-02, 3.1482e-02, 8.5577e-02,
         2.3262e-01, 6.3233e-01]])


In [None]:
# Test 1c  -- comparing the first vector a^(1) -- it should have mean weight exp(s_i)/sum_{j=1}^n exp(s_j)



scores = torch.tensor([[1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0]])
temp = 0.00001
n=100000
k=4
subset_operator = SubsetOperator_Test(k, temp, False)
res, khot_all = subset_operator(scores)
khot_all_1 = khot_all[0]
for i in range(n):
  res, khot_all = subset_operator(scores)
  khot_all_1 += khot_all[0]

khot_all_final = khot_all_1/(n+1)

print("khot_all_final:", khot_all_final)

expected = torch.exp(scores)/(torch.exp(scores)).sum()
print("expected:", expected)




tensor([[6.2000e+01, 1.4200e+02, 4.0000e+02, 1.1210e+03, 3.2270e+03, 8.6231e+03,
         2.3209e+04, 6.3217e+04]])


In [None]:
# Test 2a : can scores be negative? Yes!  k=1, false

scores = torch.tensor([[-1.0,-2.0,3.0,-4.0,-5.0,6.0,-7.0,-8.0]])
temp = 0.1
n=10000
k=1
subset_operator = SubsetOperator_Test(k, temp, False)
res, khot_all = subset_operator(scores)
khot_all_1 = khot_all[0]
for i in range(n):
  res, khot_all = subset_operator(scores)
  khot_all_1 += khot_all[0]

khot_all_final = khot_all_1/(n+1)

print("khot_all_final:", khot_all_final)

expected = torch.exp(scores)/(torch.exp(scores)).sum()
print("expected:", expected)

khot_all_final: tensor([[6.0246e-04, 2.0681e-04, 4.7026e-02, 4.6006e-08, 3.8561e-06, 9.5216e-01,
         1.0082e-25, 4.5451e-22]])
expected: tensor([[8.6755e-04, 3.1915e-04, 4.7367e-02, 4.3193e-05, 1.5890e-05, 9.5138e-01,
         2.1504e-06, 7.9110e-07]])


In [None]:
# Test 2b : can scores be negative? Yes! k=1, true

scores = torch.tensor([[-1.0,-2.0,3.0,-4.0,-5.0,6.0,-7.0,-8.0]])
temp = 0.1
n=10000
k=1
subset_operator = SubsetOperator_Test(k, temp, True)
res, khot_all = subset_operator(scores)
khot_all_1 = res
for i in range(n):
  res, khot_all = subset_operator(scores)
  khot_all_1 += res

khot_all_final = khot_all_1/(n+1)

print("khot_all_final:", khot_all_final)

expected = torch.exp(scores)/(torch.exp(scores)).sum()
print("expected:", expected)

khot_all_final: tensor([[6.9993e-04, 9.9990e-05, 4.6495e-02, 0.0000e+00, 0.0000e+00, 9.5270e-01,
         0.0000e+00, 0.0000e+00]])
expected: tensor([[8.6755e-04, 3.1915e-04, 4.7367e-02, 4.3193e-05, 1.5890e-05, 9.5138e-01,
         2.1504e-06, 7.9110e-07]])


In [None]:
# Test 2c : can scores be negative? Yes!

scores = torch.tensor([[-1.0,-2.0,3.0,-4.0,-5.0,6.0,-7.0,-8.0]])
temp = 0.1
n=10000
k=4
subset_operator = SubsetOperator_Test(k, temp, False)
res, khot_all = subset_operator(scores)
khot_all_1 = khot_all[0]
for i in range(n):
  res, khot_all = subset_operator(scores)
  khot_all_1 += khot_all[0]

khot_all_final = khot_all_1/(n+1)

print("khot_all_final:", khot_all_final)

expected = torch.exp(scores)/(torch.exp(scores)).sum()
print("expected:", expected)

khot_all_final: tensor([[1.0291e-03, 4.1546e-04, 4.4834e-02, 9.9992e-05, 5.8321e-07, 9.5362e-01,
         8.8493e-25, 3.9025e-26]])
expected: tensor([[8.6755e-04, 3.1915e-04, 4.7367e-02, 4.3193e-05, 1.5890e-05, 9.5138e-01,
         2.1504e-06, 7.9110e-07]])


In [None]:
class SubsetOperator_raw(torch.nn.Module):
    def __init__(self, k):            # k is the number of samples we want, tau is the temperature parameter and hard:denotes if we want hard or soft samples
        super(SubsetOperator_raw, self).__init__()
        self.k = k

    def forward(self, scores):                                # scores take in weights of each sample      # scores: Typical shape: [batch_size,n] or [batch_size,n,1]
        m = torch.distributions.gumbel.Gumbel(torch.zeros_like(scores), torch.ones_like(scores))
        g = m.sample()
        scores = scores + g
        khot_hard = torch.zeros_like(scores)

        val, ind = torch.topk(scores, self.k, dim=1)             #This line uses the torch.topk function to find the top self.k elements in each row (since dim=1) of the khot tensor.  val will store the values of these top elements, and ind will store their indices.
        khot_hard = khot_hard.scatter_(1, ind, 1)

        return khot_hard

In [None]:
# Test 3

scores = torch.tensor([[-1.0,-2.0,3.0,-4.0,-5.0,6.0,-7.0,-8.0]])
temp = 0.2
n=100000
k=4
subset_operator = SubsetOperator_Test(k, temp, False)
subset_operator_raw = SubsetOperator_raw(k)
res, khot_all = subset_operator(scores)
res_final = res
for i in range(n):
  res, khot_all = subset_operator(scores)
  res_final += res
print("res_final:",res_final/(n+1))

exp = subset_operator_raw(scores)
for i in range(n):
  exp += subset_operator_raw(scores)


print("exp:",exp/(n+1))

res_final: tensor([[0.9001, 0.7137, 1.0810, 0.1107, 0.0426, 1.1441, 0.0057, 0.0021]])
exp: tensor([[0.9688, 0.8477, 1.0000, 0.1261, 0.0479, 1.0000, 0.0071, 0.0023]])
