<a href="https://colab.research.google.com/github/fengfrankgthb/Demonstrations/blob/main/LIT_sat_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using the Learning Interpretability Tool in Notebooks

This notebook shows use of the [Learning Interpretability Tool](https://pair-code.github.io/lit) on a binary classifier for labelling statement sentiment (0 for negative, 1 for positive).

The LitWidget object constructor takes a dict mapping model names to model objects, and a dict mapping dataset names to dataset objects. Those will be the datasets and models displayed in LIT. Running the constructor will cause the LIT server to be started in the background, loading the models and datasets and enabling the UI to be served.

Render the LIT UI in an output cell by calling the `render` method on the LitWidget object. The LIT UI can be rendered multiple times in separate cells if desired. The widget also contains a `stop` method to shut down the LIT server.

Copyright 2020 Google LLC.
SPDX-License-Identifier: Apache-2.0

In [None]:
# Step 1: Install LIT
# This will install lit-nlp and its compatible dependencies, including a suitable numpy version.
!pip install lit-nlp

In [None]:
# Step 2: Import necessary modules
import os
import csv
import time # For adding a small delay before rendering
from lit_nlp import notebook
from lit_nlp.examples.glue import models # We'll use the SST2Model
from lit_nlp.api import dataset as lit_dataset # For creating custom datasets
from lit_nlp.api import types as lit_types # For defining dataset specifications
from absl import logging

# Step 3: Configure Logging
# Set to INFO for more detailed output, helpful for debugging.
logging.set_verbosity(logging.INFO)

In [None]:
# Step 4: Download and Extract Pre-trained Model
!wget -nc https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz
!tar -xvf sst2_tiny.tar.gz # Extra model files

In [None]:
# Step 5: Load the LIT Model
# The path './' assumes model files (e.g., sst2_tiny.tflite, vocab.txt) are in the current directory
# or a subdirectory structure that SST2Model expects relative to './'.
# If 'tar -xvf' creates a directory like 'sst2_tiny/', change model_path to './sst2_tiny/'.
model_path = './'
try:
    loaded_sst_model = models.SST2Model(model_path)
    lit_models = {'sst_tiny_sentiment': loaded_sst_model}
    print(f"LIT Model loaded successfully from '{model_path}'.")
    # To verify embeddings are available (they should be for SST2Model):
    # print("Model output_spec:", loaded_sst_model.output_spec())
except Exception as e:
    print(f"Error loading SST2Model from path '{model_path}': {e}")
    print("Please ensure model files are correctly extracted and accessible.")
    raise

In [None]:
# Step 6: Load Your Custom "sat_dev.csv" Dataset
csv_file_path = "sat_dev.csv"  # Your CSV file name
sentence_col_name = "sentence" # The header name for the sentence column in your CSV
# number_col_name = "number" # The header name for the number column (used if needed)

custom_examples = []
# Define the input specification for the dataset, matching the model's input.
# SST2Model expects a 'sentence' field.
dataset_spec = {'sentence': lit_types.TextSegment()}
# If your CSV also had labels you wanted to use, you'd add them to the spec, e.g.:
# dataset_spec['label'] = lit_types.CategoryLabel(vocab=loaded_sst_model.output_spec()['label'].vocab)


if os.path.exists(csv_file_path):
    print(f"Loading dataset from: {csv_file_path}")
    try:
        with open(csv_file_path, 'r', encoding='utf-8', newline='') as file:
            reader = csv.DictReader(file)
            if sentence_col_name not in reader.fieldnames:
                print(f"ERROR: Sentence column '{sentence_col_name}' not found in CSV headers: {reader.fieldnames}")
                print(f"Please ensure your CSV has a column named '{sentence_col_name}'.")
            else:
                for row_num, row_data in enumerate(reader):
                    sentence_text = row_data.get(sentence_col_name, "").strip()
                    if sentence_text:
                        example = {'sentence': sentence_text}
                        # If you were using the 'number' column or labels, you'd add them here:
                        # example['number_id'] = row_data.get(number_col_name, "")
                        # example['label'] = row_data.get(label_col_name, "") # if using labels
                        custom_examples.append(example)
                    # else: # Optional: handle empty sentences if necessary
                    #     print(f"Warning: Empty sentence found at row {row_num + 1}.")
        print(f"Successfully loaded {len(custom_examples)} examples from '{csv_file_path}'.")
        if not custom_examples:
             print(f"Warning: No examples were loaded. Check if the '{sentence_col_name}' column has data or if the file is empty.")

    except Exception as e:
        print(f"Error reading CSV file '{csv_file_path}': {e}")
        print("An empty dataset will be used as a fallback.")
else:
    print(f"WARNING: Dataset file '{csv_file_path}' not found.")
    print("Please ensure the file exists in the same directory as your notebook, or provide the correct path.")
    print("An empty dataset will be used as a fallback.")

# Create the LIT Dataset object
lit_datasets = {
    'sat_dev_custom': lit_dataset.Dataset(spec=dataset_spec, examples=custom_examples, description=f"Custom data loaded from {csv_file_path}")
}

In [None]:
# Step 7: Define the Custom UI Layout for the Widget
# This layout ensures the data table, editor, classification, and embeddings modules are visible.
ui_layout_config = {
    "MyCustomAnalysisView": { # Name for this layout configuration
        "components": [
            {"tabs": [ # Using a tabbed interface for organization
                {
                    "title": "Data Overview & Predictions",
                    "components": [
                        {"module": "data-table"},        # To display your loaded CSV data
                        {"module": "datapoint-editor"},  # To inspect/edit individual data points
                        {"module": "classification-module"} # To see model predictions
                    ]
                },
                {
                    "title": "Embedding Visualizations",
                    "components": [
                        {"module": "embeddings"} # This module provides UMAP and PCA plots
                    ]
                }
            ]}
        ]
    }
}

In [None]:
# Step 8: Create and Render the LIT Widget
lit_port_number = 8890 # You can change this if the port is in use
print(f"Attempting to initialize LIT on port: {lit_port_number}")

lit_widget = notebook.LitWidget(
    models=lit_models,
    datasets=lit_datasets,
    port=lit_port_number,
    layouts=ui_layout_config, # Apply the custom layout
    server_flags={'--default_layout': "MyCustomAnalysisView"} # Set custom layout as default
)

print("Waiting a moment for the LIT server to start...")
time.sleep(3) # Brief pause to allow server initialization

print("Rendering LIT widget in a new browser tab...")
lit_widget.render_in_new_tab()
print(f"LIT widget render_in_new_tab() called. A new browser tab should open to http://localhost:{lit_port_number}")
print("If the UI doesn't appear or is frozen, check the logs above for any errors.")

If you've found interesting examples using the LIT UI, you can access these in Python using `widget.ui_state`:

In [None]:
widget.ui_state.primary  # the main selected datapoint

In [None]:
widget.ui_state.selection  # the full selected set, if you have multiple points selected

In [None]:
widget.ui_state.pinned  # the pinned datapoint, if you use the 📌 icon or comparison mode

Note that these include some metadata; the bare example is in the `['data']` field for each record:

In [None]:
widget.ui_state.primary['data']

In [None]:
[ex['data'] for ex in widget.ui_state.selection]