# Simplex distance measure
Calculate the distance for a test point as the residual of its latent approximation.

The total distance is then the mean residual (or relative to a calibration set?).

## 1. Fit a simplex model to some data

In [1]:
import torch
from simplexai.explainers.simplex import Simplex
from simplexai.models.image_recognition import MnistClassifier

from xai.data_handlers.mnist import load_mnist

In [2]:
# Get a model
model = MnistClassifier() # Model should have the BlackBox interface

In [3]:
# Load corpus and test inputs
corpus_loader = load_mnist(subset_size=100, train=True, batch_size=100) # MNIST train loader
test_loader = load_mnist(subset_size=10, train=True, batch_size=10) # MNIST test loader
corpus_inputs, _ = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, _ = next(iter(test_loader)) # A set of inputs to explain

In [4]:
# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs).detach()
test_latents = model.latent_representation(test_inputs).detach()

In [5]:
corpus_latents.shape

torch.Size([100, 50])

In [6]:
test_latents.shape

torch.Size([10, 50])

In [7]:
# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs,
                  corpus_latent_reps=corpus_latents)
simplex.fit(test_examples=test_inputs,
            test_latent_reps=test_latents,
            reg_factor=0)


Weight Fitting Epoch: 2000/10000 ; Error: 20.6 ; Regulator: 6.51 ; Reg Factor: 0
Weight Fitting Epoch: 4000/10000 ; Error: 16.7 ; Regulator: 2.96 ; Reg Factor: 0
Weight Fitting Epoch: 6000/10000 ; Error: 16.1 ; Regulator: 2.18 ; Reg Factor: 0
Weight Fitting Epoch: 8000/10000 ; Error: 16 ; Regulator: 2.04 ; Reg Factor: 0
Weight Fitting Epoch: 10000/10000 ; Error: 15.9 ; Regulator: 2 ; Reg Factor: 0


## 2. Calculate residuals of each test data point

In [8]:
test_latents_approx = simplex.latent_approx()
test_latents_approx.shape

torch.Size([10, 50])

In [9]:
residual = torch.sqrt(torch.sum((test_latents - test_latents_approx) ** 2))
residual

tensor(3.9907)

In [10]:
float(residual)

3.9906599521636963

## 3. Try the SimplexDistance class

In [11]:
from xai.evaluation_metrics.distance.simplex_distance import SimplexDistance

In [14]:
simplex_dist = SimplexDistance(model, corpus_inputs, test_inputs)

In [15]:
simplex_dist._fit_simplex()

Weight Fitting Epoch: 2000/10000 ; Error: 27 ; Regulator: 6.36 ; Reg Factor: 0
Weight Fitting Epoch: 4000/10000 ; Error: 22 ; Regulator: 2.76 ; Reg Factor: 0
Weight Fitting Epoch: 6000/10000 ; Error: 21.3 ; Regulator: 2.13 ; Reg Factor: 0
Weight Fitting Epoch: 8000/10000 ; Error: 21.2 ; Regulator: 1.99 ; Reg Factor: 0
Weight Fitting Epoch: 10000/10000 ; Error: 21.1 ; Regulator: 1.94 ; Reg Factor: 0


In [16]:
simplex_dist.distance()

4.5949602127075195