In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Tue Dec 10 05:35:23 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   29C    P0              43W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
! pip install -U datasets
! pip install transformers==4.18.0

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

# Install requirements / Clone repository

In [None]:
! git clone "https://github.com/mohsenfayyaz/DecompX"

Cloning into 'DecompX'...
remote: Enumerating objects: 172, done.[K
remote: Counting objects: 100% (172/172), done.[K
remote: Compressing objects: 100% (133/133), done.[K
remote: Total 172 (delta 74), reused 100 (delta 32), pack-reused 0 (from 0)[K
Receiving objects: 100% (172/172), 25.93 MiB | 29.86 MiB/s, done.
Resolving deltas: 100% (74/74), done.


In [None]:
import os
os.environ["HF_TOKEN"] = "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"

# Config (Change model and sentence here)

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
from transformers import AutoTokenizer, BertTokenizer, RobertaTokenizer
from DecompX.src.decompx_utils import DecompXConfig
from DecompX.src.modeling_bert import BertForSequenceClassification
from DecompX.src.modeling_roberta import RobertaForSequenceClassification

CONFIGS = {
    "DecompX":
        DecompXConfig(
            include_biases=True,
            bias_decomp_type="absdot",
            include_LN1=True,
            include_FFN=True,
            FFN_approx_type="GeLU_ZO",
            include_LN2=True,
            aggregation="vector",
            include_classifier_w_pooler=True,
            tanh_approx_type="ZO",
            output_all_layers=False,
            output_attention=None,
            output_res1=None,
            output_LN1=None,
            output_FFN=None,
            output_res2=None,
            output_encoder=None,
            output_aggregated="norm",
            output_pooler="norm",
            output_classifier=True,
        ),
}

# Load corresponding model/tokenizer

In [None]:
def load_model_and_tokenizer(model_name):
    model = None
    tokenizer = None
    if "roberta" in model_name:
      model = RobertaForSequenceClassification.from_pretrained(model_name)
    elif "bert" in model_name:
      model = BertForSequenceClassification.from_pretrained(model_name)
    else:
      raise Exception(f"Not implemented model: {model_name}")

    if "roberta" in model_name:
      tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base")
    elif "bert" in model_name:
      tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
    else:
      raise Exception(f"Not implemented model: {model_name}")

    return model, tokenizer

# Compute DecompX

In [None]:
def compute_decompx(model, tokenizer, tokenized_sentence, batch_lengths):
    """
    Compute DecompX outputs and importances for the given model and tokenized sentences on GPU.

    Args:
        model: The model to evaluate.
        tokenizer: The tokenizer associated with the model.
        tokenized_sentence: Tokenized input sentences (move to GPU).
        batch_lengths: Lengths of the tokenized batches.

    Returns:
        decompx_outputs_df: DataFrame with decomposed importances and other outputs.
        importance: Importance values for all layers.
    """
    # Move tokenized inputs to GPU
    tokenized_sentence = {key: value.cuda() for key, value in tokenized_sentence.items()}

    with torch.no_grad():
        model.eval()
        # Move model to GPU
        model = model.cuda()
        # Forward pass
        logits, hidden_states, decompx_last_layer_outputs = model(
            **tokenized_sentence,
            output_attentions=False,
            return_dict=False,
            output_hidden_states=True,
            decompx_config=CONFIGS["DecompX"]
        )

    # Extract predictions
    predictions = torch.argmax(logits, dim=1).cpu().tolist()  # Predicted class

    # Prepare decompx outputs
    decompx_outputs = {
        "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(batch_lengths))],
        "logits": logits.cpu().detach().numpy().tolist(),  # Move logits to CPU
        "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist(),  # Last layer & only CLS -> (batch, emb_dim)
        "predictions": predictions
    }

    # Process last layer classifier importance
    importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier])  # Move to CPU
    importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
    decompx_outputs["importance_last_layer_classifier"] = importance

    # Convert outputs to DataFrame
    decompx_outputs_df = pd.DataFrame(decompx_outputs)

    return decompx_outputs_df

THE METRIC

In [None]:
def get_token_importance_for_sentences(model, tokenizer, sentences, labels):
    """
    Compute token importances for a list of sentences.

    Args:
    - model: The model to use for computation.
    - tokenizer: The tokenizer associated with the model.
    - sentences: List of input sentences as strings.
    - labels: List of labels corresponding to the sentences.

    Returns:
    - List of tuples with sentences, tokens, and their importance scores.
    """
    # Tokenize the input sentences
    tokenized_sentences = tokenizer(sentences, return_tensors="pt", padding=True)
    batch_lengths = tokenized_sentences["attention_mask"].sum(dim=-1)

    # Compute decompositions
    decompx_outputs_df = compute_decompx(model, tokenizer, tokenized_sentences, batch_lengths)

    results = []
    # Process each sentence
    for idx, sentence in enumerate(sentences):
        tokens = decompx_outputs_df["tokens"][idx]  # Retrieve tokens for this sentence
        importances = decompx_outputs_df["importance_last_layer_classifier"][idx][:, labels[idx]]  # Importance for the corresponding label
        token_importance_pairs = [(token, importance) for token, importance in zip(tokens, importances)]
        results.append((sentence, token_importance_pairs))

    return results

In [None]:
import ast
from tqdm import tqdm


def calculate_phrase_metric_unordered(token_importance_pairs, token_groups):
    """
    Calculate the metric for token groups in a sentence considering unordered matches.

    Args:
        token_importance_pairs: List of (token, importance) pairs.
        token_groups: List of tokenized groups (phrases) to match.

    Returns:
        float: The metric value for the token groups.
    """
    # Filter out [CLS] and [SEP] tokens
    token_importance_pairs = [
        pair for pair in token_importance_pairs if pair[0] not in ["[CLS]", "[SEP]"]
    ]

    phrase_importance_sum = 0
    total_importance = sum(abs(importance) for _, importance in token_importance_pairs)
    token_list = [pair[0] for pair in token_importance_pairs]
    matched_positions = set()

    # Match each token group
    for group in token_groups:
        for start_idx in range(len(token_list) - len(group) + 1):
            if token_list[start_idx:start_idx + len(group)] == group:
                phrase_importance_sum += sum(
                    abs(token_importance_pairs[start_idx + offset][1]) for offset in range(len(group))
                )
                matched_positions.update(range(start_idx, start_idx + len(group)))
                break

    metric = phrase_importance_sum / total_importance if total_importance > 0 else 0
    return metric

TOXIC SPANS DATASET

In [None]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset('heegyu/toxic-spans')

# Print the size of the dataset
for split in dataset:
    print(f"Split: {split}, Size: {len(dataset[split])}")

README.md:   0%|          | 0.00/3.65k [00:00<?, ?B/s]

train.csv:   0%|          | 0.00/9.71M [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/954k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10006 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Split: train, Size: 10006
Split: test, Size: 1000


EVALUATION

In [None]:
from tqdm import tqdm
import ast

def evaluate_model_on_dataset(dataset, model, tokenizer):
    """
    Evaluate a model on a dataset using the custom metric for token group coverage.

    Constraints:
    - Include only sentences with fewer than 420 tokens.

    Args:
        model: The model to evaluate.
        tokenizer: The tokenizer corresponding to the model.
        dataset: The dataset to evaluate on (assumed to have a 'test' split).

    Returns:
        float: The average metric value across all examples in the test split.
    """
    # Filter examples with non-empty 'text_of_post' field and tokenize sentences
    test_split = dataset['test']
    valid_examples = []
    token_groups_list = []

    for ex in test_split:
        if ex['text_of_post']:
            tokenized_sentence = tokenizer.tokenize(ex['text_of_post'])
            if len(tokenized_sentence) < 420:  # Filter sentences with <420 tokens
                valid_examples.append({
                    "tokens": tokenized_sentence,
                    "sentence": ex['text_of_post'],
                    "labels": ex['toxic'],
                    "token_groups": [
                        tokenizer.tokenize(phrase) for phrase in ast.literal_eval(ex['text']).keys()
                    ]
                })

    # Sort valid examples by token count in ascending order
    valid_examples.sort(key=lambda x: len(x["tokens"]))

    # Extract examples and token groups
    examples = [entry["sentence"] for entry in valid_examples]
    labels = [entry["labels"] for entry in valid_examples]
    token_groups_list = [entry["token_groups"] for entry in valid_examples]

    # Calculate metrics for each batch of examples
    metrics = []
    batch_size = 1

    for start_idx in tqdm(range(0, len(examples), batch_size), desc="Processing Examples"):
        batch_examples = examples[start_idx:start_idx + batch_size]
        batch_labels = labels[start_idx:start_idx + batch_size]
        batch_groups = token_groups_list[start_idx:start_idx + batch_size]

        # Fetch token importances for the current batch
        token_importance_results = get_token_importance_for_sentences(
            model, tokenizer, batch_examples, batch_labels
        )

        for idx, token_importance_pairs in enumerate(token_importance_results):
            token_groups = batch_groups[idx]
            metric = calculate_phrase_metric_unordered(token_importance_pairs[1], token_groups)
            metrics.append(metric)

    # Return the average metric
    average_metric = sum(metrics) / len(metrics) if metrics else 0
    return average_metric

BERT MODELS

In [None]:
model_name = "charleyisballer/toxic-spans-lyeonii-bert-tiny"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

  return torch.load(checkpoint_file, map_location="cpu")


Evaluation running for charleyisballer/toxic-spans-lyeonii-bert-tiny:



Processing Examples: 100%|██████████| 1000/1000 [00:12<00:00, 81.05it/s]

Average Metric: 0.5397





In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
model_name = "charleyisballer/toxic-spans-lyeonii-bert-mini"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

Downloading:   0%|          | 0.00/716 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/42.6M [00:00<?, ?B/s]

Evaluation running for charleyisballer/toxic-spans-lyeonii-bert-mini:



Processing Examples: 100%|██████████| 1000/1000 [00:18<00:00, 55.28it/s]

Average Metric: 0.2631





In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
model_name = "charleyisballer/toxic-spans-lyeonii-bert-small"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

Downloading:   0%|          | 0.00/717 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/110M [00:00<?, ?B/s]

Evaluation running for charleyisballer/toxic-spans-lyeonii-bert-small:



Processing Examples: 100%|██████████| 1000/1000 [00:23<00:00, 42.60it/s]


Average Metric: 0.4771


In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
model_name = "charleyisballer/toxic-spans-lyeonii-bert-medium"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

Downloading:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/158M [00:00<?, ?B/s]

Evaluation running for charleyisballer/toxic-spans-lyeonii-bert-medium:



Processing Examples: 100%|██████████| 1000/1000 [00:41<00:00, 24.08it/s]

Average Metric: 0.5215





In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
model_name = "charleyisballer/toxic-spans-google-bert-bert-base-uncased"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

Downloading:   0%|          | 0.00/766 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

Evaluation running for charleyisballer/toxic-spans-google-bert-bert-base-uncased:



Processing Examples: 100%|██████████| 1000/1000 [01:29<00:00, 11.20it/s]

Average Metric: 0.4120





In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
model_name = "charleyisballer/toxic-spans-google-bert-bert-large-uncased"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

Downloading:   0%|          | 0.00/768 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.25G [00:00<?, ?B/s]

Evaluation running for charleyisballer/toxic-spans-google-bert-bert-large-uncased:



Processing Examples: 100%|██████████| 1000/1000 [04:16<00:00,  3.89it/s]

Average Metric: 0.3759





In [None]:
import torch

torch.cuda.empty_cache()

ROBERTA MODELS

In [None]:
model_name = "charleyisballer/toxic-spans-JackBAI-roberta-medium"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

Downloading:   0%|          | 0.00/771 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/196M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Evaluation running for charleyisballer/toxic-spans-JackBAI-roberta-medium:



Processing Examples: 100%|██████████| 1000/1000 [00:41<00:00, 23.96it/s]

Average Metric: 0.0252





In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
model_name = "charleyisballer/toxic-spans-FacebookAI-roberta-base"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

Downloading:   0%|          | 0.00/774 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/476M [00:00<?, ?B/s]

Evaluation running for charleyisballer/toxic-spans-FacebookAI-roberta-base:



Processing Examples: 100%|██████████| 1000/1000 [01:28<00:00, 11.29it/s]

Average Metric: 0.0683





In [None]:
import torch

torch.cuda.empty_cache()

In [None]:
model_name = "charleyisballer/toxic-spans-FacebookAI-roberta-large"
model, tokenizer = load_model_and_tokenizer(model_name)
print(f"Evaluation running for {model_name}:")
print()

average_metric = evaluate_model_on_dataset(dataset, model, tokenizer)
print(f"Average Metric: {average_metric:.4f}")

Downloading:   0%|          | 0.00/776 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32G [00:00<?, ?B/s]

Evaluation running for charleyisballer/toxic-spans-FacebookAI-roberta-large:



Processing Examples: 100%|██████████| 1000/1000 [04:15<00:00,  3.91it/s]

Average Metric: 0.0641





In [None]:
import torch

torch.cuda.empty_cache()