# Simplex distance measure
Calculate the distance for a test point as the residual of its latent approximation.

The total distance is then the mean residual (or relative to a calibration set?).

## 1. Fit a simplex model to some data

In [1]:
import torch
from simplexai.explainers.simplex import Simplex
from simplexai.models.image_recognition import MnistClassifier

from xai.data_handlers.mnist import load_mnist

In [2]:
# Get a model
model = MnistClassifier() # Model should have the BlackBox interface

In [3]:
# Load corpus and test inputs
corpus_loader = load_mnist(subset_size=100, train=True, batch_size=100) # MNIST train loader
test_loader = load_mnist(subset_size=10, train=True, batch_size=10) # MNIST test loader
corpus_inputs, _ = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, _ = next(iter(test_loader)) # A set of inputs to explain

In [4]:
# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs).detach()
test_latents = model.latent_representation(test_inputs).detach()

In [5]:
corpus_latents.shape

torch.Size([100, 50])

In [6]:
test_latents.shape

torch.Size([10, 50])

In [7]:
# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs,
                  corpus_latent_reps=corpus_latents)
simplex.fit(test_examples=test_inputs,
            test_latent_reps=test_latents,
            reg_factor=0)


Weight Fitting Epoch: 2000/10000 ; Error: 20.6 ; Regulator: 6.51 ; Reg Factor: 0
Weight Fitting Epoch: 4000/10000 ; Error: 16.7 ; Regulator: 2.96 ; Reg Factor: 0
Weight Fitting Epoch: 6000/10000 ; Error: 16.1 ; Regulator: 2.18 ; Reg Factor: 0
Weight Fitting Epoch: 8000/10000 ; Error: 16 ; Regulator: 2.04 ; Reg Factor: 0
Weight Fitting Epoch: 10000/10000 ; Error: 15.9 ; Regulator: 2 ; Reg Factor: 0


## 2. Calculate residuals of each test data point

In [8]:
test_latents_approx = simplex.latent_approx()
test_latents_approx.shape

torch.Size([10, 50])

In [9]:
residual = torch.sqrt(torch.sum((test_latents - test_latents_approx) ** 2))
residual

tensor(3.9907)

In [10]:
float(residual)

3.9906599521636963

## 3. Try the SimplexDistance class

In [11]:
from xai.evaluation_metrics.distance.simplex_distance import SimplexDistance

In [14]:
simplex_dist = SimplexDistance(model, corpus_inputs, test_inputs)

In [15]:
simplex_dist._fit_simplex()

Weight Fitting Epoch: 2000/10000 ; Error: 27 ; Regulator: 6.36 ; Reg Factor: 0
Weight Fitting Epoch: 4000/10000 ; Error: 22 ; Regulator: 2.76 ; Reg Factor: 0
Weight Fitting Epoch: 6000/10000 ; Error: 21.3 ; Regulator: 2.13 ; Reg Factor: 0
Weight Fitting Epoch: 8000/10000 ; Error: 21.2 ; Regulator: 1.99 ; Reg Factor: 0
Weight Fitting Epoch: 10000/10000 ; Error: 21.1 ; Regulator: 1.94 ; Reg Factor: 0


In [16]:
simplex_dist.distance()

4.5949602127075195

# 4. Distance based on the latent space vectors
## 4.1. Pointwise average
For each test point, calculate the distance to each training point and take the average.
Do this for every test point and take the average to get a single number for the entire training set.


In [17]:
corpus_latents

tensor([[0.0000, 0.2516, 0.2418,  ..., 0.3681, 0.0000, 0.0000],
        [0.0000, 0.6343, 0.0000,  ..., 0.1816, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0497,  ..., 0.6649, 0.7968, 0.0000],
        ...,
        [0.0000, 1.2644, 0.0000,  ..., 0.1685, 0.0000, 1.6109],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.2790, 0.0000, 0.0000]])

In [19]:
corpus_latents.shape

torch.Size([100, 50])

In [20]:
test_latents.shape

torch.Size([10, 50])

In [21]:
test_latents[:3, :5]

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.7705, 0.0297],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

## 4.2 Approximation based on centroids and variance

In [26]:
torch.mean(corpus_latents, dim=0)

tensor([0.0618, 0.4266, 0.2496, 0.0421, 0.0312, 0.0860, 0.1952, 0.1266, 0.0053,
        0.0150, 0.1124, 0.3067, 0.0658, 0.0211, 0.1790, 0.0143, 0.0293, 0.0648,
        0.0071, 0.0341, 0.0312, 0.0740, 0.0300, 0.1380, 0.0267, 0.1938, 0.3217,
        0.2180, 0.2656, 0.2960, 0.1429, 0.1917, 0.3171, 0.1794, 0.1588, 0.0048,
        0.0519, 0.0796, 0.0000, 0.0007, 0.0677, 0.1373, 0.4330, 0.2644, 0.0671,
        0.0655, 0.1469, 0.2028, 0.0823, 0.4581])

In [29]:
sigma, centroid = torch.std_mean(corpus_latents, dim=0)

In [31]:
centroid

tensor([0.0618, 0.4266, 0.2496, 0.0421, 0.0312, 0.0860, 0.1952, 0.1266, 0.0053,
        0.0150, 0.1124, 0.3067, 0.0658, 0.0211, 0.1790, 0.0143, 0.0293, 0.0648,
        0.0071, 0.0341, 0.0312, 0.0740, 0.0300, 0.1380, 0.0267, 0.1938, 0.3217,
        0.2180, 0.2656, 0.2960, 0.1429, 0.1917, 0.3171, 0.1794, 0.1588, 0.0048,
        0.0519, 0.0796, 0.0000, 0.0007, 0.0677, 0.1373, 0.4330, 0.2644, 0.0671,
        0.0655, 0.1469, 0.2028, 0.0823, 0.4581])

In [30]:
sigma

tensor([0.1559, 0.5560, 0.4199, 0.1527, 0.1261, 0.2347, 0.3406, 0.2670, 0.0534,
        0.0751, 0.2903, 0.5012, 0.1522, 0.0727, 0.3764, 0.0642, 0.1045, 0.2152,
        0.0494, 0.1231, 0.1140, 0.2465, 0.1104, 0.2742, 0.0920, 0.3056, 0.5352,
        0.3839, 0.4772, 0.4805, 0.2842, 0.3300, 0.4272, 0.3560, 0.3068, 0.0280,
        0.1702, 0.1928, 0.0000, 0.0065, 0.1864, 0.2308, 0.5397, 0.4220, 0.1825,
        0.1739, 0.2839, 0.3638, 0.1831, 0.5855])

In [32]:
test_latents.shape

torch.Size([10, 50])

In [34]:
centroid.shape

torch.Size([50])

In [33]:
sigma.shape

torch.Size([50])

In [36]:
(test_latents - centroid).shape

torch.Size([10, 50])

In [37]:
_distance_per_point = test_latents - centroid
_distance_per_point.shape

torch.Size([10, 50])

In [45]:
_distance_per_point

tensor([[-6.1849e-02, -4.2658e-01, -2.4956e-01, -4.2092e-02, -3.1193e-02,
         -8.6043e-02, -1.9516e-01, -3.9627e-02, -5.3398e-03, -1.5009e-02,
         -1.1239e-01,  5.9798e-01,  3.1712e-01,  3.8384e-03, -1.7898e-01,
         -1.4296e-02, -2.9334e-02, -6.4760e-02, -7.0978e-03, -3.4052e-02,
         -3.1164e-02, -6.2328e-02, -2.9980e-02, -1.3795e-01, -2.6749e-02,
         -2.6551e-02, -3.2167e-01, -7.8985e-02, -2.6561e-01, -2.9603e-01,
          3.9681e-01, -1.9175e-01,  1.8132e-01,  4.5382e-02, -1.5884e-01,
         -4.7947e-03, -5.1928e-02, -7.9578e-02,  0.0000e+00, -6.5224e-04,
          1.5474e-01, -1.3733e-01, -4.3304e-01,  2.7550e-01, -6.7101e-02,
         -6.5455e-02, -1.4686e-01, -2.0281e-01,  7.6638e-01, -4.5809e-01],
        [-6.1849e-02, -4.2658e-01, -2.4956e-01,  7.2837e-01, -1.4876e-03,
         -8.6043e-02, -1.9516e-01,  5.7009e-01, -5.3398e-03, -1.5009e-02,
         -1.1239e-01, -3.0674e-01, -6.5807e-02, -2.1116e-02,  6.5552e-01,
          6.3528e-02, -2.9334e-02, -6

In [46]:
sigma

tensor([0.1559, 0.5560, 0.4199, 0.1527, 0.1261, 0.2347, 0.3406, 0.2670, 0.0534,
        0.0751, 0.2903, 0.5012, 0.1522, 0.0727, 0.3764, 0.0642, 0.1045, 0.2152,
        0.0494, 0.1231, 0.1140, 0.2465, 0.1104, 0.2742, 0.0920, 0.3056, 0.5352,
        0.3839, 0.4772, 0.4805, 0.2842, 0.3300, 0.4272, 0.3560, 0.3068, 0.0280,
        0.1702, 0.1928, 0.0000, 0.0065, 0.1864, 0.2308, 0.5397, 0.4220, 0.1825,
        0.1739, 0.2839, 0.3638, 0.1831, 0.5855])

In [53]:
_distance_per_point_scaled = torch.nan_to_num(_distance_per_point / sigma)
_distance_per_point_scaled.shape

torch.Size([10, 50])

In [50]:
_distance_per_point[0][38]

tensor(0.)

In [51]:
sigma[38]

tensor(0.)

In [49]:
_distance_per_point_scaled[0][38]

tensor(nan)

In [41]:
_distance_per_point_scaled

tensor([[-0.3967, -0.7673, -0.5943, -0.2757, -0.2473, -0.3666, -0.5730, -0.1484,
         -0.1000, -0.1999, -0.3872,  1.1930,  2.0834,  0.0528, -0.4755, -0.2226,
         -0.2808, -0.3009, -0.1437, -0.2767, -0.2735, -0.2528, -0.2716, -0.5030,
         -0.2907, -0.0869, -0.6010, -0.2057, -0.5566, -0.6161,  1.3961, -0.5811,
          0.4244,  0.1275, -0.5178, -0.1709, -0.3051, -0.4128,     nan, -0.1000,
          0.8301, -0.5950, -0.8023,  0.6528, -0.3677, -0.3764, -0.5173, -0.5575,
          4.1866, -0.7824],
        [-0.3967, -0.7673, -0.5943,  4.7706, -0.0118, -0.3666, -0.5730,  2.1354,
         -0.1000, -0.1999, -0.3872, -0.6120, -0.4323, -0.2903,  1.7416,  0.9891,
         -0.2808, -0.3009, -0.1437, -0.2767, -0.2735, -0.3002, -0.2716, -0.5030,
         -0.2907, -0.6341,  1.4502, -0.5678, -0.5566, -0.6161, -0.5028, -0.5811,
          2.0332, -0.5040,  0.1962, -0.1709, -0.3051, -0.4128,     nan, -0.1000,
         -0.3634,  1.3563, -0.8023, -0.6265,  3.2023, -0.3764, -0.5173,  1.0824,


In [55]:
torch.sqrt(torch.sum(_distance_per_point_scaled ** 2))

tensor(22.6203)