In [446]:
import os
import torch
import numpy as np
import math

In [733]:
class SubsetSampler(torch.nn.Module):
    def __init__(self, k, tau=1.0, hard=False):
        super().__init__()
        self.k = k
        self.hard = hard
        self.tau = tau

    def forward(self, x):
        x = torch.nn.functional.softmax(x, dim=-1)
        n = x.shape[-1]
        gumbel = torch.distributions.gumbel.Gumbel(torch.zeros_like(x), torch.ones_like(x))
        x = torch.log(x) + gumbel.sample()

        
        if self.hard:
            values, indices = torch.topk(x, self.k)
            return indices
        else:
            values, indices = torch.topk(x, self.k)
            khot_prime = torch.nn.functional.one_hot(indices, num_classes=n).sum(dim=0)
            
            # Top-k relaxation
            y = torch.zeros_like(x)
            khot = torch.zeros_like(x)
            
            for i in range(k):
                y = torch.nn.functional.softmax(x, dim=-1)
                khot += y
                x = x + torch.log(1 - y)

            khot = torch.nn.functional.softmax(khot, dim=-1) * k
            
            return khot

In [726]:
# Test top-k relaxation
k = 3
sampler = SubsetSampler(k=k, tau=1.0, hard=False)
n = 8
x = torch.ones(n)
y = sampler(x)

tensor([0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250])


In [729]:
# Test uniform distribution over subsets of size 3
k = 3
sampler = SubsetSampler(k=k, tau=1.0, hard=True)
n = 8
x = torch.ones(n)
y = torch.zeros_like(x)

for i in range(1000):
    subset = sampler(x)
    khot = torch.nn.functional.one_hot(subset, num_classes=n).sum(dim=0)
    y += khot

z = y / 1000 - (k/n)
print(f"probability error: {z}")

probability error: tensor([-0.0270, -0.0040,  0.0210, -0.0020, -0.0110,  0.0050, -0.0110,  0.0290])


In [693]:
# Test subset's probability
n = 8
k = 3
m = 1000000
x = torch.tensor([1,2,3,4,5,6,7,8], dtype=torch.float)
sampler = SubsetSampler(k=k, tau=1.0, hard=True)

probs = torch.nn.functional.softmax(x, dim=-1)
target = [7,6,5]
count = 0

for i in range(m):
    indices = sampler(x).tolist()
    if indices == target:
        count += 1

predicted_p = probs[-1] * (probs[-2] / (1 - probs[-1])) * (probs[-3] / (1 - probs[-1] - probs[-2]))
real_p = count / m

print(f"predict: {predicted_p}")
print(f"statistic: {real_p}")

predict: 0.2535242736339569
statistic: 0.254265


In [750]:
def subset_prob(S: tuple, a: torch.tensor):
    k = len(S)
    n = a.shape[0]
    a = torch.nn.functional.softmax(a, dim=-1)
    A = a[list(S)].prod()
    Z = torch.zeros(k)

    # dynamic programming: E[i,j] = E[i-1,j] + a[i] * E[i-1,j-1]
    for i in range(n):
        shift_Z = torch.cat([Z.new_ones(1), Z[:-1]])
        Z = Z + a[i] * shift_Z

    return A / Z[-1]

In [751]:
# Test uniform distribution over size-k subsets
k = 3
n = 8
a = torch.ones(8, requires_grad=True)
S = (1,2,3)
p = subset_prob(S, a)
# print(p.requires_grad)
print(f"prob.: {p}")
print(f"prob. error: {torch.abs((1/math.comb(n,k)) - p)}")

prob.: 0.01785714365541935
prob. error: 0.0


In [752]:
import itertools
def subset_prob_seq(S: tuple, a: torch.Tensor):
    # Sequential form: sum over permutations
    a = torch.nn.functional.softmax(a, dim=-1)
    A = 0.0
    for perm in itertools.permutations(S):
        prob = 1.0
        remaining = list(range(len(a)))
        for idx in perm:
            prob *= a[idx].item() / sum(a[remaining]).item()
            remaining.remove(idx)
        A += prob
    return A

In [759]:
n = 8
k = 3
m = 10000
# x = torch.tensor([1] * n, dtype=torch.float)
x = torch.tensor([1,2,3,4,5,6,7,8], dtype=torch.float)
sampler = SubsetSampler(k=k, tau=1.0, hard=True)
target = [5,6,7]
count = 0
for i in range(m):
    indices = sampler(x).tolist()
    indices.sort()
    if indices == target:
        count += 1

prob_predict = subset_prob(target, x)
l = subset_prob_seq(target, x)

l1 = probs[-1] * (probs[-2] / (1 - probs[-1])) * (probs[-3] / (1 - probs[-1] - probs[-2]))
l2 = probs[-1] * (probs[-3] / (1 - probs[-1])) * (probs[-2] / (1 - probs[-1] - probs[-3]))
l3 = probs[-2] * (probs[-1] / (1 - probs[-2])) * (probs[-3] / (1 - probs[-2] - probs[-1]))
l4 = probs[-2] * (probs[-3] / (1 - probs[-2])) * (probs[-1] / (1 - probs[-2] - probs[-3]))
l5 = probs[-3] * (probs[-1] / (1 - probs[-3])) * (probs[-2] / (1 - probs[-3] - probs[-1]))
l6 = probs[-3] * (probs[-2] / (1 - probs[-3])) * (probs[-1] / (1 - probs[-3] - probs[-2]))


print(f"target subset: {target}")
print("ans:", l1 + l2 + l3 + l4 + l5 + l6)
print("P_seq: ", l)
print(f"P_dp: {prob_predict}")
print(f"SubsetSampler ({m} times): {count / m}")

target subset: [5, 6, 7]
ans: tensor(0.5894)
P_seq:  0.5894127898446275
P_dp: 0.5213007926940918
SubsetSampler (10000 times): 0.5828


In [749]:
# Conclusion: P_seq is correct and P_dp is NOT equivalent to P_seq
# TODO: ChatGPT claims that P_seq can be calculated in O(k * 2^k) time using DP technique