# Install requirements / Clone repository

In [1]:
! git clone "https://github.com/mohsenfayyaz/DecompX"
! pip install -U datasets
! pip install transformers==4.18.0

fatal: destination path 'DecompX' already exists and is not an empty directory.


In [2]:
import os
os.environ["HF_TOKEN"] = "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"

# Config (Change model and sentence here)

In [3]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
from transformers import AutoTokenizer
from DecompX.src.decompx_utils import DecompXConfig
from DecompX.src.modeling_bert import BertForSequenceClassification
from DecompX.src.modeling_roberta import RobertaForSequenceClassification

BERT_MODELS = ["lyeonii/bert-tiny", "lyeonii/bert-mini", "lyeonii/bert-small", "lyeonii/bert-medium", "google-bert/bert-base-uncased", "google-bert/bert-large-uncased"]
ROBERTA_MODELS = ["smallbenchnlp/roberta-small","JackBAI/roberta-medium","FacebookAI/roberta-base", "FacebookAI/roberta-large"]
SENTENCES = [
    "A deep and meaningful film.",
    "a good piece of work more often than not.",
]
CONFIGS = {
    "DecompX":
        DecompXConfig(
            include_biases=True,
            bias_decomp_type="absdot",
            include_LN1=True,
            include_FFN=True,
            FFN_approx_type="GeLU_ZO",
            include_LN2=True,
            aggregation="vector",
            include_classifier_w_pooler=True,
            tanh_approx_type="ZO",
            output_all_layers=True,
            output_attention=None,
            output_res1=None,
            output_LN1=None,
            output_FFN=None,
            output_res2=None,
            output_encoder=None,
            output_aggregated="norm",
            output_pooler="norm",
            output_classifier=True,
        ),
}

# Load corresponding model/tokenizer

In [4]:
def load_model_and_tokenizer(model_name, input_sentences):
  model = None
  tokenizer = AutoTokenizer.from_pretrained(model_name)
  tokenized_sentence = tokenizer(input_sentences, return_tensors="pt", padding=True)
  batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
  if "roberta" in model_name:
      model = RobertaForSequenceClassification.from_pretrained(model_name)
  elif "bert" in model_name:
      model = BertForSequenceClassification.from_pretrained(model_name)
  else:
      raise Exception(f"Not implented model: {model_name}")
  return model, tokenizer, tokenized_sentence, batch_lengths

# Compute DecompX

In [5]:
def compute_decompx(model, tokenizer, tokenized_sentence, batch_lengths):
  # logits ~ (8, 2)
  # hidden_states ~ (13, 8, 55, 768)
  # decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55)
  # decompx_last_layer_outputs.pooler ~ (1, 8, 55)
  # decompx_last_layer_outputs.classifier ~ (8, 55, 2)
  # decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55)
  with torch.no_grad():
    model.eval()
    logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
        **tokenized_sentence,
        output_attentions=False,
        return_dict=False,
        output_hidden_states=True,
        decompx_config=CONFIGS["DecompX"]
    )

  predictions = torch.argmax(logits, dim=1).cpu().tolist()  # Predicted class
  decompx_outputs = {
    "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(batch_lengths))],
    "logits": logits.cpu().detach().numpy().tolist(),  # (batch, classes)
    "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist(),# Last layer & only CLS -> (batch, emb_dim)
    "predictions": predictions
  }

  ### decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.aggregated]).squeeze()  # (batch, seq_len, seq_len)
  importance = [importance[j][:batch_lengths[j],:batch_lengths[j]] for j in range(len(importance))]
  decompx_outputs["importance_last_layer_aggregated"] = importance

  ### decompx_last_layer_outputs.pooler ~ (1, 8, 55) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.pooler]).squeeze()  # (batch, seq_len)
  importance = [importance[j][:batch_lengths[j]] for j in range(len(importance))]
  decompx_outputs["importance_last_layer_pooler"] = importance

  ### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze()  # (batch, seq_len, classes) num token in that sentence, classes, use classifier
  importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
  decompx_outputs["importance_last_layer_classifier"] = importance

  ### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated])  # (layers, batch, seq_len, seq_len)
  importance = np.einsum('lbij->blij', importance)  # (batch, layers, seq_len, seq_len)
  importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
  decompx_outputs["importance_all_layers_aggregated"] = importance

  decompx_outputs_df = pd.DataFrame(decompx_outputs)

  return decompx_outputs_df, importance

# Visualization

In [6]:
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
    """
    importance: (sent_len)
    """
    if no_cls_sep:
        importance = importance[1:-1]
        tokenized_text = tokenized_text[1:-1]
    importance = importance / np.abs(importance).max() / 1.5  # Normalize
    if discrete:
        importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6

    html = "<pre style='color:black; padding: 3px;'>"+prefix
    for i in range(len(tokenized_text)):
        if importance[i] >= 0:
            rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i])   # Wistia
        else:
            rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i]))   # Wistia
        text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
        color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
        html += (f"<span style='"
                 f"{color}"
                 f"border-radius: 5px; padding: 3px;"
                 f"font-weight: {int(800)};"
                 "'>")
        html += tokenized_text[i].replace('<', "[").replace(">", "]")
        html += "</span> "
    display(HTML(html))
#     print(html)
    return html

def print_preview(model, tokenizer, tokenized_sentence, batch_lengths, idx=0, discrete=False):
    NO_CLS_SEP = False
    df, _ = compute_decompx(model, tokenizer, tokenized_sentence, batch_lengths)

    for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
        if col in df and df[col][idx] is not None:
            if "aggregated" in col:
                sentence_importance = df[col].iloc[idx][0, :]
            if "classifier" in col:
                for label in range(df[col].iloc[idx].shape[-1]):
                    sentence_importance = df[col].iloc[idx][:, label]
                    print_importance(
                        sentence_importance,
                        df["tokens"].iloc[idx],
                        prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
                        no_cls_sep=NO_CLS_SEP,
                        discrete=False
                    )
                break
                sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
            if "pooler" in col:
                sentence_importance = df[col].iloc[idx]
            print_importance(
                sentence_importance,
                df["tokens"].iloc[idx],
                prefix=f"{col.split('_')[-1]}:".ljust(20),
                no_cls_sep=NO_CLS_SEP,
                discrete=discrete
            )
    print("------------------------------------")
    return df

In [7]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset('heegyu/toxic-spans')

# Print the size of the dataset
for split in dataset:
    print(f"Split: {split}, Size: {len(dataset[split])}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Split: train, Size: 10006
Split: test, Size: 1000


In [8]:
def visual_evaluation(model_name, examples, labels=None):
    correct_predictions = 0
    total_predictions = 0

    # Load the model and tokenizer
    model, tokenizer, tokenized_sentence, batch_lengths = load_model_and_tokenizer(model_name, examples)

    # Evaluate each example
    for i in range(len(examples)):
        df = print_preview(model, tokenizer, tokenized_sentence, batch_lengths, idx=i)

        if labels:
          # Compute accuracy
          if df["predictions"][i] == labels[i]:
            correct_predictions += 1
          total_predictions += 1

    if labels:
      # Print accuracy
      accuracy = correct_predictions / total_predictions
      print(f"Accuracy for {model_name}: {accuracy:.2%}")

In [9]:
"""for model_name in BERT_MODELS + ROBERTA_MODELS:
    print(f"Evaluating Model: {model_name}")
    visual_evaluation(model_name, dataset['train'][10:12]['text_of_post'], dataset['train'][10:12]['toxic'])"""

'for model_name in BERT_MODELS + ROBERTA_MODELS:\n    print(f"Evaluating Model: {model_name}")\n    visual_evaluation(model_name, dataset[\'train\'][10:12][\'text_of_post\'], dataset[\'train\'][10:12][\'toxic\'])'

In [14]:
def get_token_importance_for_sentences(model, tokenizer, sentences, batch_lengths, labels):
    """
    Compute token importances for a list of sentences.

    Args:
    - model: The model to use for computation.
    - tokenizer: The tokenizer associated with the model.
    - sentences: List of input sentences as strings.
    - batch_lengths: Lengths of the tokenized batch for the input.
    - labels: List of labels corresponding to the sentences.

    Returns:
    - List of tuples with sentences, tokens, and their importance scores.
    """
    # Tokenize the input sentences
    tokenized_sentences = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
    batch_lengths = tokenized_sentences["attention_mask"].sum(dim=-1)

    # Compute decompositions
    decompx_outputs_df, _ = compute_decompx(model, tokenizer, tokenized_sentences, batch_lengths)

    results = []
    # Process each sentence
    for idx, sentence in enumerate(sentences):
        tokens = decompx_outputs_df["tokens"][idx]  # Retrieve tokens for this sentence
        importances = decompx_outputs_df["importance_last_layer_classifier"][idx][:, labels[idx]]  # Importance for the corresponding label
        token_importance_pairs = [(token, importance) for token, importance in zip(tokens, importances)]
        results.append((sentence, token_importance_pairs))

    return results

THE METRIC

In [16]:
import ast

# Load examples and labels
examples = dataset['train'][0:15]['text_of_post']  # Input sentences
labels = dataset['train'][0:15]['toxic']  # Corresponding labels

model_name = "lyeonii/bert-tiny"
model, tokenizer, tokenized_sentence, batch_lengths = load_model_and_tokenizer(model_name, examples)

# Parse the 'text' field to extract token groups (phrases) for each example
token_groups_list = []
for i in range(0, 15):
    text_dict = ast.literal_eval(dataset['train'][i]['text'])
    tokenized_groups = [tokenizer.tokenize(phrase) for phrase in text_dict.keys()]
    token_groups_list.append(tokenized_groups)

# Fetch token importances
token_importance_results = get_token_importance_for_sentences(model, tokenizer, examples, batch_lengths, labels)

# Function to calculate the metric and check token group coverage
def calculate_phrase_metric_unordered(token_importance_pairs, token_groups):
    """
    Calculate the metric for token groups in a sentence considering unordered matches.

    Args:
        token_importance_pairs: List of (token, importance) pairs.
        token_groups: List of tokenized groups (phrases) to match.

    Returns:
        The metric value for the token groups, coverage status, unmatched groups, and unmatched tokens.
    """
    phrase_importance_sum = 0
    total_importance = sum(abs(importance) for _, importance in token_importance_pairs)  # Total importance
    token_list = [pair[0] for pair in token_importance_pairs]  # List of tokens in the sentence
    covered_groups = []  # Track covered token groups
    matched_positions = set()  # Track matched token indices

    # Match each token group (unordered)
    for group in token_groups:
        print(f"Checking Token Group: {group}")
        for start_idx in range(len(token_list) - len(group) + 1):
            # Check if the group matches at any position in the token list
            if token_list[start_idx:start_idx + len(group)] == group:
                # Add the importance of the matched tokens
                phrase_importance_sum += sum(
                    abs(token_importance_pairs[start_idx + offset][1]) for offset in range(len(group))
                )
                covered_groups.append(group)  # Mark the group as covered
                # Mark matched positions
                matched_positions.update(range(start_idx, start_idx + len(group)))
                print(f"  Match Found for Group: {group} at Indices {list(range(start_idx, start_idx + len(group)))}")
                break
        else:
            print(f"  No Match Found for Group: {group}")

    # Determine unmatched groups
    unmatched_groups = [group for group in token_groups if group not in covered_groups]

    # Identify remaining unmatched tokens
    unmatched_tokens = [
        token_list[idx] for idx in range(len(token_list)) if idx not in matched_positions
    ]

    # Return the metric, unmatched groups, and unmatched tokens
    metric = phrase_importance_sum / total_importance if total_importance > 0 else 0
    return metric, unmatched_groups, unmatched_tokens

# Process each sentence and calculate the metric
for idx, (sentence, token_importance_pairs) in enumerate(token_importance_results):
    token_groups = token_groups_list[idx]  # Token groups for this sentence
    metric, unmatched_groups, unmatched_tokens = calculate_phrase_metric_unordered(token_importance_pairs, token_groups)

    print(f"\nSentence: {sentence}")
    print(f"Tokenized Sentence: {[pair[0] for pair in token_importance_pairs]}")
    print(f"Token Groups: {token_groups}")
    print(f"Metric: {metric:.4f}")
    if not unmatched_groups:
        print("All token groups were successfully matched and covered!")
    else:
        print("Some token groups were not matched.")
        print(f"Unmatched Token Groups: {unmatched_groups}")
        print(f"Unmatched Tokens in Sentence: {unmatched_tokens}")
    print("-" * 50)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at lyeonii/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Checking Token Group: ['stupid']
  Match Found for Group: ['stupid'] at Indices [19]
Checking Token Group: ['clown']
  Match Found for Group: ['clown'] at Indices [2]

Sentence: Another clown in favour of more tax in this country. Blows my mind people can be this stupid.
Tokenized Sentence: ['[CLS]', 'another', 'clown', 'in', 'favour', 'of', 'more', 'tax', 'in', 'this', 'country', '.', 'blows', 'my', 'mind', 'people', 'can', 'be', 'this', 'stupid', '.', '[SEP]']
Token Groups: [['stupid'], ['clown']]
Metric: 0.0460
All token groups were successfully matched and covered!
--------------------------------------------------
Checking Token Group: ['sex']
  Match Found for Group: ['sex'] at Indices [1]

Sentence: Sex in prison ought to be interesting for him
Tokenized Sentence: ['[CLS]', 'sex', 'in', 'prison', 'ought', 'to', 'be', 'interesting', 'for', 'him', '[SEP]']
Token Groups: [['sex']]
Metric: 0.0623
All token groups were successfully matched and covered!
-------------------------------

In [15]:
import ast
import math
from tqdm import tqdm

def evaluate_model_on_dataset(model_name, dataset):
    """
    Evaluate a model on a dataset using the custom metric for token group coverage.

    Args:
        model_name (str): Name of the model to evaluate.
        dataset: The dataset to evaluate on (assumed to have 'test' split).

    Returns:
        float: The average metric value across all examples in the test split.
    """
    # Filter examples with non-empty 'text_of_post' field
    test_split = dataset['test']
    valid_examples = [ex for ex in test_split if ex['text_of_post']]

    # Extract examples and token groups
    examples = [ex['text_of_post'] for ex in valid_examples]
    labels = [ex['toxic'] for ex in valid_examples]

    # Initialize model and tokenizer
    model, tokenizer, tokenized_sentence, batch_lengths = load_model_and_tokenizer(model_name, examples)

    token_groups_list = []
    for ex in valid_examples:
        text_dict = ast.literal_eval(ex['text'])
        tokenized_groups = [tokenizer.tokenize(phrase) for phrase in text_dict.keys()]
        token_groups_list.append(tokenized_groups)

    # Calculate metrics for each example
    metrics = []

    for i in range(math.ceil(len(examples)/10)):
        start_idx = i * 10
        end_idx = min((i + 1) * 10, len(examples))
        # Fetch token importances
        token_importance_results = get_token_importance_for_sentences(model, tokenizer, examples[start_idx:end_idx], batch_lengths, labels[start_idx:end_idx])

        # Function to calculate the metric for unordered token groups
        def calculate_phrase_metric_unordered(token_importance_pairs, token_groups):
            """
            Calculate the metric for token groups in a sentence considering unordered matches.

            Args:
                token_importance_pairs: List of (token, importance) pairs.
                token_groups: List of tokenized groups (phrases) to match.

            Returns:
                float: The metric value for the token groups.
            """
            phrase_importance_sum = 0
            total_importance = sum(abs(importance) for _, importance in token_importance_pairs)
            token_list = [pair[0] for pair in token_importance_pairs]
            matched_positions = set()

            # Match each token group
            for group in token_groups:
                for start_idx in range(len(token_list) - len(group) + 1):
                    if token_list[start_idx:start_idx + len(group)] == group:
                        phrase_importance_sum += sum(
                            abs(token_importance_pairs[start_idx + offset][1]) for offset in range(len(group))
                        )
                        matched_positions.update(range(start_idx, start_idx + len(group)))
                        break

            metric = phrase_importance_sum / total_importance if total_importance > 0 else 0
            return metric


        for idx, (sentence, token_importance_pairs) in tqdm(
            enumerate(token_importance_results),
            total=len(token_importance_results),
            desc="Processing Examples"
        ):
            token_groups = token_groups_list[idx]
            metric = calculate_phrase_metric_unordered(token_importance_pairs, token_groups)
            metrics.append(metric)

    # Return the average metric
    average_metric = sum(metrics) / len(metrics) if metrics else 0
    return average_metric

# Example usage
average_metric = evaluate_model_on_dataset("lyeonii/bert-tiny", dataset)
print(f"Average Metric: {average_metric:.4f}")

  return torch.load(checkpoint_file, map_location="cpu")
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at lyeonii/bert-tiny and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 3267.36it/s]
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 11413.07it/s]
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 13600.21it/s]
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 6062.89it/s]
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 11745.46it/s]
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 8839.42it/s]
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 13099.01it/s]
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 12139.81it/s]
Processing Examples: 100%|██████████| 10/10 [00:00<00:00, 4706.36it/s]
Proces

Average Metric: 0.0057





In [None]:
examples = dataset['train'][10:15]['text_of_post']
labels = dataset['train'][10:15]['toxic']

model_name = "lyeonii/bert-tiny"
model, tokenizer, tokenized_sentence, batch_lengths = load_model_and_tokenizer(model_name, examples)

# Fetch token importances
token_importance_results = get_token_importance_for_sentences(model, tokenizer, examples, batch_lengths, labels)

# Print token importances for each sentence
for sentence, token_importance_pairs in token_importance_results:
    print(f"Sentence: {sentence}")
    for token, importance in token_importance_pairs:
        print(f"  Token: {token}, Importance: {importance:.4f}")
    print("-" * 30)
    print(f"Token: {token}, Importance: {importance:.4f}")