In [1]:
import torch

import sys
if "viashap_paper/src" not in sys.path:
    sys.path.append("viashap_paper/src")

from samplers.uniform_sampler import UniformFeatureSampler
from samplers.kernel_shap_sampler import KernelShapSampler

In [2]:
x = torch.tensor([
    [1.0, 2.0, 3.0, 4.0],
    [5.0, 6.0, 7.0, 8.0]
])  # batch_size=2, n_features=4

sampler = UniformFeatureSampler(baseline=10.0)

n_coalitions = 3
x_S, masks = sampler.sample(x, n_coalitions, random_seed=101)

batch_size, n_features = x.shape
assert x_S.shape == (batch_size * n_coalitions, n_features)
assert masks.shape == (batch_size * n_coalitions, n_features)

assert torch.all((masks == 0) | (masks == 1))

feature_inclusion_rates = masks.float().mean(dim=0)
print("Feature inclusion rates (should be close to 0.5):")
print(feature_inclusion_rates)

for b in range(batch_size):
    for c in range(n_coalitions):
        idx = b * n_coalitions + c
        for f in range(n_features):
            if masks[idx, f] == 1:
                assert x_S[idx, f] == x[b, f]
            else:
                assert x_S[idx, f] == sampler.baseline

print("\nExample of first few coalitions for first batch:")
for i in range(3):  
    print(f"\nCoalition {i + 1}:")
    print(f"Mask:     {masks[i].numpy()}")
    print(f"Values:   {x_S[i].numpy()}")
    print(f"Original: {x[0].numpy()}")

Feature inclusion rates (should be close to 0.5):
tensor([0.6667, 0.6667, 0.6667, 0.3333])

Example of first few coalitions for first batch:

Coalition 1:
Mask:     [1. 1. 1. 0.]
Values:   [ 1.  2.  3. 10.]
Original: [1. 2. 3. 4.]

Coalition 2:
Mask:     [1. 1. 1. 1.]
Values:   [1. 2. 3. 4.]
Original: [1. 2. 3. 4.]

Coalition 3:
Mask:     [1. 1. 1. 0.]
Values:   [ 1.  2.  3. 10.]
Original: [1. 2. 3. 4.]


In [3]:
seed = 101

In [4]:
x = torch.tensor([
    [1.0, 2.0, 3.0, 4.0],
    [5.0, 6.0, 7.0, 8.0]
])  # batch_size=2, n_features=4

n_features = x.shape[1]
n_coalitions = 5

orig = KernelShapSampler(n_features=n_features, baseline=10.0)
x1, m1 = orig.sample(x, n_coalitions=n_coalitions, random_seed=seed)

In [5]:
x.shape[1]

4

In [6]:
x1

tensor([[ 1.,  2., 10., 10.],
        [ 1.,  2.,  3., 10.],
        [10., 10., 10.,  4.],
        [10., 10.,  3., 10.],
        [ 1., 10.,  3., 10.],
        [10.,  6.,  7.,  8.],
        [ 5.,  6., 10.,  8.],
        [ 5., 10., 10., 10.],
        [ 5.,  6.,  7., 10.],
        [10.,  6.,  7.,  8.]])

In [7]:
m1

tensor([[1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [1., 0., 1., 0.],
        [0., 1., 1., 1.],
        [1., 1., 0., 1.],
        [1., 0., 0., 0.],
        [1., 1., 1., 0.],
        [0., 1., 1., 1.]])

In [8]:
x.shape

torch.Size([2, 4])

In [30]:
torch.allclose(x1, torch.tensor([
    [10., 10., 10.,  4.],
    [10., 10.,  3., 10.],
    [10., 10.,  3., 10.],
    [ 1., 10., 10., 10.],
    [ 1.,  2., 10., 10.],
    [10., 10.,  7.,  8.],
    [ 5., 10.,  7.,  8.],
    [ 5.,  6., 10.,  8.],
    [10.,  6., 10., 10.],
    [10., 10.,  7., 10.]
]))

True