# Attributions for Siamese Encoders - Demo

In [None]:
import torch
from xsbert import utils
from xsbert.models import XSMPNet, XSRoberta, load_model
import zipfile
import os
from os.path import join, exists
import wget
from os import PathLike
import zipfile
import sys

### loading a model

You can either load one of the two models that we provide with the `load_model()` method as follows.
Downloading the checkpoint the first time will take a while. It is then stored in the directory specified by `model_dir`.

In [None]:
model_name = 'xs_mpnet'
model = load_model(model_name, model_dir='../xs_models/')
model.to(torch.device('cuda:1'))

If you have already downloaded a checkpoint or want to load one that you created yourself, you can alternatively load it direcly using the respective model classes.

In [None]:
# model_path = 'checkpoints/xs_mpnet/'
# model = XSMPNet(model_path)
# model_path = 'checkpoints/xs_distilroberta/'
# model = XSRoberta(model_path)
# model.to(torch.device('cuda:0'))

### initializing attributions

The `init_attribution_to_layer()` method of the `models.XSTransformer` class initializes attributions to the layer with index `idx`. `N_steps` is the number of approximation steps to calculate the *integrated Jacobians* ($N$ in the paper).

`reset_attribution()` removes all hooks that are registered on the model for calculating attributions. After calling it, you can initialize attributions to a different layer.

In [None]:
model.reset_attribution()
model.init_attribution_to_layer(idx=8, N_steps=50)

### computing attributions

In this demo we compute the attribution matrix for a single pair of texts that you can define here:

In [None]:
texta = 'This is not a good coffee.'
textb = 'The coffee is bad.'

After initializing attributions (above), we use the method `attribute_prediction` in the `models.XSTransformer` class to compute the attribution matrix $A$.

When setting the argument `compute_lhs` the method explicitly computes the four terms in the ansatz (left-hand-side of Equation 2 in the paper), $f(a, b) - f(r, a) - f(r, b) + f(r, r)$. Below they are name as `score`, `ra`, `rb`, and `rr` in the respective order.

In [None]:
A, tokens_a, tokens_b, score, ra, rb, rr = model.attribute_prediction(
    texta, 
    textb, 
    move_to_cpu=False,
    compute_lhs=True
)

### attribution accuracy

The first term, $f(a, b)$ (`score`), is the actual model prediction.
Due to the embedding shift implemented in the `models.ShiftingReferenceTransformer` (cf. Section 2.2 in the paper), by construction, the three terms including a reference $r$ must vanish. Below, we explicitly check that this is the case.

We can also calculate how accurate our attributions are by taking the absolute difference between their sum and the model (as described in Section 3.2 of the paper): $\text{error} = \|\sum_{ij} A_{ij} - f(a, b)\|$.

You can change the number of approximation steps $N$ in the `init_attribution_to_layer()` method to see how this attribution error changes.
Generally, attributions to shallower layers require larger $N$ (cf. Section 3.2 in the paper).

In [None]:
tot_attr = A.sum().item()
attr_err = torch.abs(A.sum() - score).item()
print('model prediction: ', score)
print('total attribution: ', tot_attr)
print('reference terms: ', ra, rb, rr)
print('attribution error: ', attr_err)

### plotting attributions

Finally, we can plot the token-token attribution matrix.

In [None]:
utils.plot_attributions(
    A, 
    tokens_a, 
    tokens_b, 
    size=(2, 2),
    range=.1,
    show_colobar=True, 
    shrink_cbar=.5
)