In [29]:
! pip install -U datasets
! pip install transformers==4.18.0



# Install requirements / Clone repository

In [30]:
! git clone "https://github.com/mohsenfayyaz/DecompX"

fatal: destination path 'DecompX' already exists and is not an empty directory.


In [31]:
import os
os.environ["HF_TOKEN"] = "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"

# Config (Change model and sentence here)

In [32]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
from transformers import AutoTokenizer
from DecompX.src.decompx_utils import DecompXConfig
from DecompX.src.modeling_bert import BertForSequenceClassification
from DecompX.src.modeling_roberta import RobertaForSequenceClassification

CONFIGS = {
    "DecompX":
        DecompXConfig(
            include_biases=True,
            bias_decomp_type="absdot",
            include_LN1=True,
            include_FFN=True,
            FFN_approx_type="GeLU_ZO",
            include_LN2=True,
            aggregation="vector",
            include_classifier_w_pooler=True,
            tanh_approx_type="ZO",
            output_all_layers=True,
            output_attention=None,
            output_res1=None,
            output_LN1=None,
            output_FFN=None,
            output_res2=None,
            output_encoder=None,
            output_aggregated="norm",
            output_pooler="norm",
            output_classifier=True,
        ),
}

# Load corresponding model/tokenizer

In [33]:
def load_model_and_tokenizer(model_name):
    model = None
    if "roberta" in model_name:
      model = RobertaForSequenceClassification.from_pretrained(model_name)
    elif "bert" in model_name:
      model = BertForSequenceClassification.from_pretrained(model_name)
    else:
      raise Exception(f"Not implemented model: {model_name}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer

# Compute DecompX

In [34]:
def compute_decompx_for_visualization(model, tokenizer, tokenized_sentence, batch_lengths, num_sentences):
  # logits ~ (8, 2)
  # hidden_states ~ (13, 8, 55, 768)
  # decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55)
  # decompx_last_layer_outputs.pooler ~ (1, 8, 55)
  # decompx_last_layer_outputs.classifier ~ (8, 55, 2)
  # decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55)
  with torch.no_grad():
    model.eval()
    logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
        **tokenized_sentence,
        output_attentions=False,
        return_dict=False,
        output_hidden_states=True,
        decompx_config=CONFIGS["DecompX"]
    )

  predictions = torch.argmax(logits, dim=1).cpu().tolist()  # Predicted class
  decompx_outputs = {
    "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(num_sentences)],
    "logits": logits.cpu().detach().numpy().tolist(),  # (batch, classes)
    "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist(),# Last layer & only CLS -> (batch, emb_dim)
    "predictions": predictions
  }

  ### decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.aggregated]).squeeze()  # (batch, seq_len, seq_len)
  importance = [importance[j][:batch_lengths[j],:batch_lengths[j]] for j in range(len(importance))]
  decompx_outputs["importance_last_layer_aggregated"] = importance

  ### decompx_last_layer_outputs.pooler ~ (1, 8, 55) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.pooler]).squeeze()  # (batch, seq_len)
  importance = [importance[j][:batch_lengths[j]] for j in range(len(importance))]
  decompx_outputs["importance_last_layer_pooler"] = importance

  ### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze()  # (batch, seq_len, classes) num token in that sentence, classes, use classifier
  importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
  decompx_outputs["importance_last_layer_classifier"] = importance

  ### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated])  # (layers, batch, seq_len, seq_len)
  importance = np.einsum('lbij->blij', importance)  # (batch, layers, seq_len, seq_len)
  importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
  decompx_outputs["importance_all_layers_aggregated"] = importance

  decompx_outputs_df = pd.DataFrame(decompx_outputs)

  return decompx_outputs_df

In [35]:
def compute_decompx(model, tokenizer, tokenized_sentence, batch_lengths, num_sentences):
  # logits ~ (8, 2)
  # hidden_states ~ (13, 8, 55, 768)
  # decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55)
  # decompx_last_layer_outputs.pooler ~ (1, 8, 55)
  # decompx_last_layer_outputs.classifier ~ (8, 55, 2)
  # decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55)
  with torch.no_grad():
    model.eval()
    logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
        **tokenized_sentence,
        output_attentions=False,
        return_dict=False,
        output_hidden_states=True,
        decompx_config=CONFIGS["DecompX"]
    )

  predictions = torch.argmax(logits, dim=1).cpu().tolist()  # Predicted class
  decompx_outputs = {
    "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(num_sentences)],
    "logits": logits.cpu().detach().numpy().tolist(),  # (batch, classes)
    "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist(),# Last layer & only CLS -> (batch, emb_dim)
    "predictions": predictions
  }

  ### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
  importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze()  # (batch, seq_len, classes) num token in that sentence, classes, use classifier
  importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
  decompx_outputs["importance_last_layer_classifier"] = importance

  decompx_outputs_df = pd.DataFrame(decompx_outputs)

  return decompx_outputs_df

# Visualization

In [36]:
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
    """
    importance: (sent_len)
    """
    if no_cls_sep:
        importance = importance[1:-1]
        tokenized_text = tokenized_text[1:-1]
    importance = importance / np.abs(importance).max() / 1.5  # Normalize
    if discrete:
        importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6

    html = "<pre style='color:black; padding: 3px;'>"+prefix
    for i in range(len(tokenized_text)):
        if importance[i] >= 0:
            rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i])   # Wistia
        else:
            rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i]))   # Wistia
        text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
        color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
        html += (f"<span style='"
                 f"{color}"
                 f"border-radius: 5px; padding: 3px;"
                 f"font-weight: {int(800)};"
                 "'>")
        html += tokenized_text[i].replace('<', "[").replace(">", "]")
        html += "</span> "
    display(HTML(html))
#     print(html)
    return html

def print_preview(model, tokenizer, tokenized_sentence, batch_lengths, num_sentences, idx=0, discrete=False):
    NO_CLS_SEP = False
    df = compute_decompx_for_visualization(model, tokenizer, tokenized_sentence, batch_lengths, num_sentences)

    for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
        if col in df and df[col][idx] is not None:
            if "aggregated" in col:
                sentence_importance = df[col].iloc[idx][0, :]
            if "classifier" in col:
                for label in range(df[col].iloc[idx].shape[-1]):
                    sentence_importance = df[col].iloc[idx][:, label]
                    print_importance(
                        sentence_importance,
                        df["tokens"].iloc[idx],
                        prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
                        no_cls_sep=NO_CLS_SEP,
                        discrete=False
                    )
                break
                sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
            if "pooler" in col:
                sentence_importance = df[col].iloc[idx]
            print_importance(
                sentence_importance,
                df["tokens"].iloc[idx],
                prefix=f"{col.split('_')[-1]}:".ljust(20),
                no_cls_sep=NO_CLS_SEP,
                discrete=discrete
            )
    print("------------------------------------")
    return df

In [37]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset('BoringAnt1793/paired_sentiment_datasets_with_spans')

# Print the size of the dataset
for split in dataset:
    print(f"Split: {split}, Size: {len(dataset[split])}")

Split: train, Size: 1707
Split: test, Size: 488
Split: dev, Size: 245


In [38]:
def visual_evaluation(model, tokenizer, sentences, labels=None):
    correct_predictions = 0
    total_predictions = 0

    tokenized_sentence = tokenizer(sentences, return_tensors="pt", padding=True)
    batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)

    # Evaluate each example
    for i in range(len(sentences)):
        df = print_preview(model, tokenizer, tokenized_sentence, batch_lengths, len(sentences), idx=i)

        if labels:
            # Compute accuracy
            if df["predictions"][i] == labels[i]:
                correct_predictions += 1
            total_predictions += 1

    if labels:
        # Print accuracy
        accuracy = correct_predictions / total_predictions
        print(f"Accuracy for the model: {accuracy:.2%}")

In [39]:
model, tokenizer = load_model_and_tokenizer('charleyisballer/toxic-spans-lyeonii-bert-tiny')

In [40]:
visual_evaluation(model, tokenizer, dataset['train'][10:15]['second_sentence'], labels=dataset['train'][10:15]['second_sentence_sentiment'])

------------------------------------


------------------------------------


------------------------------------


------------------------------------


------------------------------------
Accuracy for the model: 20.00%


THE METRIC

In [46]:
def get_token_importance_for_sentences(model, tokenizer, sentences, labels):
    """
    Compute token importances for a list of sentences.

    Args:
    - model: The model to use for computation.
    - tokenizer: The tokenizer associated with the model.
    - sentences: List of input sentences as strings.
    - labels: List of labels corresponding to the sentences.

    Returns:
    - List of tuples with sentences, tokens, and their importance scores.
    """
    # Tokenize the input sentences
    tokenized_sentences = tokenizer(sentences, return_tensors="pt", padding=True)
    batch_lengths = tokenized_sentences["attention_mask"].sum(dim=-1)

    # Compute decompositions
    decompx_outputs_df = compute_decompx(model, tokenizer, tokenized_sentences, batch_lengths, len(sentences))

    results = []
    # Process each sentence
    for idx, sentence in enumerate(sentences):
        tokens = decompx_outputs_df["tokens"][idx]  # Retrieve tokens for this sentence
        importances = decompx_outputs_df["importance_last_layer_classifier"][idx][:, labels[idx]]  # Importance for the corresponding label
        token_importance_pairs = [(token, importance) for token, importance in zip(tokens, importances)]
        results.append((sentence, token_importance_pairs))

    return results

In [42]:
import ast
from tqdm import tqdm


def calculate_phrase_metric_unordered(token_importance_pairs, token_groups):
    """
    Calculate the metric for token groups in a sentence considering unordered matches.

    Args:
        token_importance_pairs: List of (token, importance) pairs.
        token_groups: List of tokenized groups (phrases) to match.

    Returns:
        float: The metric value for the token groups.
    """
    # Filter out [CLS] and [SEP] tokens
    token_importance_pairs = [
        pair for pair in token_importance_pairs if pair[0] not in ["[CLS]", "[SEP]"]
    ]

    phrase_importance_sum = 0
    total_importance = sum(abs(importance) for _, importance in token_importance_pairs)
    token_list = [pair[0] for pair in token_importance_pairs]
    matched_positions = set()

    # Match each token group
    for group in token_groups:
        for start_idx in range(len(token_list) - len(group) + 1):
            if token_list[start_idx:start_idx + len(group)] == group:
                phrase_importance_sum += sum(
                    abs(token_importance_pairs[start_idx + offset][1]) for offset in range(len(group))
                )
                matched_positions.update(range(start_idx, start_idx + len(group)))
                break

    metric = phrase_importance_sum / total_importance if total_importance > 0 else 0
    return metric

In [47]:
from datasets import concatenate_datasets
from tqdm import tqdm

def evaluate_model_on_dataset(dataset, model, tokenizer):
    """
    Evaluate a model on a dataset using the custom metric for token group coverage.

    Args:
        dataset: The dataset to evaluate on (assumed to have 'test' and 'dev' splits).
        model: The model to evaluate.
        tokenizer: The tokenizer corresponding to the model.

    Returns:
        float: The average metric value across all examples in the test split.
    """
    # Combine test and dev splits
    test_split = concatenate_datasets([dataset['test'], dataset['dev']])

    # Filter examples with non-empty 'second_sentence_counterfactual_words'
    valid_examples = [ex for ex in test_split if ex['second_sentence_counterfactual_words']]

    # Ensure even number of examples
    if len(valid_examples) % 2 != 0:
        valid_examples = valid_examples[:-1]  # Drop the last example if odd

    # Extract examples and token groups
    examples = [ex['second_sentence'] for ex in valid_examples]
    labels = [ex['second_sentence_sentiment'] for ex in valid_examples]
    token_groups_list = [
        [tokenizer.tokenize(phrase) for phrase in ex['second_sentence_counterfactual_words']]
        for ex in valid_examples
    ]

    # Calculate metrics for each batch of examples
    metrics = []
    batch_size = 2  # Process in batches of 2

    for start_idx in tqdm(range(0, len(examples), batch_size), desc="Processing Examples"):
        batch_examples = examples[start_idx:start_idx + batch_size]
        batch_labels = labels[start_idx:start_idx + batch_size]
        batch_groups = token_groups_list[start_idx:start_idx + batch_size]

        # Fetch token importances for the current batch
        token_importance_results = get_token_importance_for_sentences(
            model, tokenizer, batch_examples, batch_labels
        )

        for idx, token_importance_pairs in enumerate(token_importance_results):
            token_groups = batch_groups[idx]
            metric = calculate_phrase_metric_unordered(token_importance_pairs[1], token_groups)
            metrics.append(metric)

    # Return the average metric
    average_metric = sum(metrics) / len(metrics) if metrics else 0
    return average_metric

In [48]:
MODELS = ['charleyisballer/toxic-spans-lyeonii-bert-tiny']

In [49]:
for model_name in MODELS:

        model, tokenizer = load_model_and_tokenizer(model_name)
        print(f"Evaluation running for {model_name}:")
        print()

        average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
        print(f"Average Metric: {average_metric:.4f}")
        print()
        print(f"An error occurred while evaluating {model_name}: {e}")
        print()

Evaluation running for charleyisballer/toxic-spans-lyeonii-bert-tiny:



Processing Examples:  15%|█▌        | 55/366 [02:09<12:14,  2.36s/it]


KeyboardInterrupt: 