<a href="https://colab.research.google.com/github/fengfrankgthb/Demonstrations/blob/main/LIT_components_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LIT Standalone Components

This notebook shows use of the [Learning Interpretability Tool](https://pair-code.github.io/lit) components on a binary classifier for labelling statement sentiment (0 for negative, 1 for positive).

All LIT backend components (models, datasets, metrics, generators, etc.) are standalone Python classes, and can easily be used from Colab or another Python context without starting a server. This can be handy for development, of if you want to re-use components in an offline workflow.

Copyright 2021 Google LLC.
SPDX-License-Identifier: Apache-2.0

In [None]:
# The pip installation will install all necessary prerequisite packages for use of the core LIT package.
!pip install lit-nlp

In [None]:
!pip uninstall numpy
!pip install numpy==1.26.4

In [16]:
import attr
import pandas as pd
import lit_nlp
from lit_nlp import notebook
from lit_nlp.examples.glue import data
from lit_nlp.examples.glue import models

# Hide INFO and lower logs. Comment this out for debugging.
from absl import logging
logging.set_verbosity(logging.WARNING)

## Load data

LIT's `Dataset` classes are just lists of records, plus spec information to describe each field.

In [None]:
sst_data = data.SST2Data('validation')
sst_data.spec()

In [None]:
sst_data.examples[:10]

You can easily convert this to tabular form, too:

In [None]:
pd.DataFrame(sst_data.examples)

## Load a model and run inference

LIT's `Model` class defines a `predict()` function to perform inference. The `input_spec()` describes the expected inputs (it should be a subset of the dataset fields), and `output_spec()` describes the output.

In [None]:
# Fetch the trained model weights and load the model to analyze
!wget https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz
!mkdir sst2_tiny
!tar -xvf sst2_tiny.tar.gz -C sst2_tiny

sentiment_model = models.SST2Model('./sst2_tiny')
sentiment_model.input_spec(), sentiment_model.output_spec()

There's a lot of fields in the output spec, since this model returns embeddings, gradients, attention, and more. We can view it using Pandas to avoid too much clutter:

In [None]:
preds = list(sentiment_model.predict(sst_data.examples[:10]))
pd.DataFrame(preds)

If we just want the predicted probabilites for each class, we can look at the `probas` field:

In [None]:
labels = sentiment_model.output_spec()['probas'].vocab
pd.DataFrame([p['probas'] for p in preds], columns=pd.Index(labels, name='label'))

## Salience methods

We can use different interpretability components as well. Here's an example running LIME to get a salience map. The output has entries for each input field, though here that's just one field named "sentence":

In [None]:
from lit_nlp.components import lime_explainer
lime = lime_explainer.LIME()

lime_results = lime.run(sst_data.examples[:1], sentiment_model, sst_data)[0]
lime_results

In [None]:
# Again, pretty-print output with Pandas. The SalienceMap object is just a dataclass defined using attr.s.
pd.DataFrame(attr.asdict(lime_results['sentence']))

In [None]:
from lit_nlp.components import gradient_maps
ig = gradient_maps.IntegratedGradients()

ig_results = ig.run(sst_data.examples[:1], sentiment_model, sst_data)[0]
ig_results

In [None]:
# Again, pretty-print output with Pandas. The SalienceMap object is just a dataclass defined using attr.s.
pd.DataFrame(attr.asdict(ig_results['token_grad_sentence']))

## Metrics

We can also compute metrics. The metrics components (via the `SimpleMetrics` API) will automatically detect compatible fields marked by the `parent` attribute - in this case, our model's `probas` field that should be scored against `label` in the input.

In [None]:
from lit_nlp.components import metrics
classification_metrics = metrics.MulticlassMetrics()
classification_metrics.run(sst_data.examples[:100], sentiment_model, sst_data)

## Generators

We can use counterfactual generators as well. Here's an example with a generator that simply scrambles words in a text segment.

In [None]:
from lit_nlp.components import scrambler
sc = scrambler.Scrambler()

sc_in = sst_data.examples[:5]
sc_out = sc.generate_all(sc_in, model=None, dataset=sst_data,
                         config={'Fields to scramble': ['sentence']})
# The output is a list-of-lists, generated from each original example.
sc_out

In [None]:
# Format as a flat table for display, including original sentences
import itertools
for ex_in, exs_out in zip(sc_in, sc_out):
  for ex_out in exs_out:
    ex_out['original_sentence'] = ex_in['sentence']
pd.DataFrame(itertools.chain.from_iterable(sc_out), columns=['original_sentence', 'sentence', 'label'])

# Running the LIT UI

Of course, you can always still use these components in the LIT UI, without leaving Colab.

In [None]:
widget = notebook.LitWidget(models={'sentiment': sentiment_model},
                            datasets={'sst2': sst_data}, port=8890)

In [None]:
widget.render(height=600)

If you've found interesting examples using the LIT UI, you can access these in Python using `widget.ui_state`:

In [None]:
lit_demo = lit_nlp.dev_server.Server({'sentiment': sentiment_model}, {'sst2': sst_data}, port=4321)
lit_demo.serve()

In [29]:
widget.ui_state.primary  # the main selected datapoint

In [None]:
widget.ui_state.selection  # the full selected set, if you have multiple points selected

In [31]:
widget.ui_state.pinned  # the pinned datapoint, if you use the 📌 icon or comparison mode

Note that these include some metadata; the bare example is in the `['data']` field for each record:

In [None]:
widget.ui_state.primary['data']

In [None]:
[ex['data'] for ex in widget.ui_state.selection]