# Install requirements / Clone repository

In [None]:
! git clone "https://github.com/mohsenfayyaz/DecompX"
! pip install datasets==1.18.3
! pip install transformers==4.18.0

# Config (Change model and sentence here)

In [2]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
from transformers import AutoTokenizer
from DecompX.src.decompx_utils import DecompXConfig
from DecompX.src.modeling_bert import BertForSequenceClassification
from DecompX.src.modeling_roberta import RobertaForSequenceClassification

MODEL = "WillHeld/roberta-base-sst2"  # Only BERT or RoBERTa
SENTENCES = [
    "A deep and meaningful film.", 
    "a good piece of work more often than not.",
]
CONFIGS = {
    "DecompX":
        DecompXConfig(
            include_biases=True,
            bias_decomp_type="absdot",
            include_LN1=True,
            include_FFN=True,
            FFN_approx_type="GeLU_ZO",
            include_LN2=True,
            aggregation="vector",
            include_classifier_w_pooler=True,
            tanh_approx_type="ZO",
            output_all_layers=True,
            output_attention=None,
            output_res1=None,
            output_LN1=None,
            output_FFN=None,
            output_res2=None,
            output_encoder=None,
            output_aggregated="norm",
            output_pooler="norm",
            output_classifier=True,
        ),
}

# Load corresponding model/tokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True)
batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
if "roberta" in MODEL:
    model = RobertaForSequenceClassification.from_pretrained(MODEL)
elif "bert" in MODEL:
    model = BertForSequenceClassification.from_pretrained(MODEL)
else:
    raise Exception(f"Not implented model: {MODEL}")

Downloading:   0%|          | 0.00/380 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/780k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.01M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/280 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/994 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/476M [00:00<?, ?B/s]

# Compute DecompX

In [4]:
# logits ~ (8, 2)
# hidden_states ~ (13, 8, 55, 768)
# decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55)
# decompx_last_layer_outputs.pooler ~ (1, 8, 55)
# decompx_last_layer_outputs.classifier ~ (8, 55, 2)
# decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55)
with torch.no_grad():
    model.eval()
    logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
        **tokenized_sentence, 
        output_attentions=False, 
        return_dict=False, 
        output_hidden_states=True, 
        decompx_config=CONFIGS["DecompX"]
    )
decompx_outputs = {
    "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
    "logits": logits.cpu().detach().numpy().tolist(),  # (batch, classes)
    "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
}

### decompx_last_layer_outputs.aggregated ~ (1, 8, 55, 55) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.aggregated]).squeeze()  # (batch, seq_len, seq_len)
importance = [importance[j][:batch_lengths[j],:batch_lengths[j]] for j in range(len(importance))]
decompx_outputs["importance_last_layer_aggregated"] = importance

### decompx_last_layer_outputs.pooler ~ (1, 8, 55) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.pooler]).squeeze()  # (batch, seq_len)
importance = [importance[j][:batch_lengths[j]] for j in range(len(importance))]
decompx_outputs["importance_last_layer_pooler"] = importance

### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze()  # (batch, seq_len, classes)
importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
decompx_outputs["importance_last_layer_classifier"] = importance

### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated])  # (layers, batch, seq_len, seq_len)
importance = np.einsum('lbij->blij', importance)  # (batch, layers, seq_len, seq_len)
importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
decompx_outputs["importance_all_layers_aggregated"] = importance

decompx_outputs_df = pd.DataFrame(decompx_outputs)
decompx_outputs_df

Unnamed: 0,tokens,logits,cls,importance_last_layer_aggregated,importance_last_layer_pooler,importance_last_layer_classifier,importance_all_layers_aggregated
0,"[<s>, A, Ġdeep, Ġand, Ġmeaningful, Ġfilm, ., <...","[-2.926377773284912, 2.6316184997558594]","[0.4278368651866913, -0.44068610668182373, -0....","[[9.912518, 15.123899, 9.198907, 7.3025403, 29...","[5.2680273, 8.650479, 5.2628765, 4.173719, 17....","[[-0.122595854, 0.09062181], [1.8409867, -1.55...","[[[19.39683, 1.1414818, 0.7298583, 0.71953577,..."
1,"[<s>, a, Ġgood, Ġpiece, Ġof, Ġwork, Ġmore, Ġof...","[-2.792595863342285, 2.5693302154541016]","[0.3009447455406189, -0.26642340421676636, -0....","[[10.204507, 3.1018076, 45.00585, 12.329525, 1...","[5.43212, 1.8558518, 26.258024, 7.356039, 0.78...","[[-0.42901182, 0.29831558], [-0.37757882, 0.31...","[[[19.514332, 0.64310545, 0.39037097, 0.218365..."


# Visualization

In [6]:
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
    """
    importance: (sent_len)
    """
    if no_cls_sep:
        importance = importance[1:-1]
        tokenized_text = tokenized_text[1:-1]
    importance = importance / np.abs(importance).max() / 1.5  # Normalize
    if discrete:
        importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6
    
    html = "<pre style='color:black; padding: 3px;'>"+prefix
    for i in range(len(tokenized_text)):
        if importance[i] >= 0:
            rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i])   # Wistia
        else:
            rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i]))   # Wistia
        text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
        color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
        html += (f"<span style='"
                 f"{color}"
                 f"border-radius: 5px; padding: 3px;"
                 f"font-weight: {int(800)};"
                 "'>")
        html += tokenized_text[i].replace('<', "[").replace(">", "]")
        html += "</span> "
    display(HTML(html))
#     print(html)
    return html

def print_preview(idx=0, discrete=False):
    NO_CLS_SEP = False
    df = decompx_outputs_df
    for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
        if col in df and df[col][idx] is not None:
            if "aggregated" in col:
                sentence_importance = df[col].iloc[idx][0, :]
            if "classifier" in col:
                for label in range(df[col].iloc[idx].shape[-1]):
                    sentence_importance = df[col].iloc[idx][:, label]
                    print_importance(
                        sentence_importance,
                        df["tokens"].iloc[idx], 
                        prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
                        no_cls_sep=NO_CLS_SEP,
                        discrete=False
                    )
                break
                sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
            if "pooler" in col:
                sentence_importance = df[col].iloc[idx]
            print_importance(
                sentence_importance,
                df["tokens"].iloc[idx], 
                prefix=f"{col.split('_')[-1]}:".ljust(20),
                no_cls_sep=NO_CLS_SEP,
                discrete=discrete
            )
    print("------------------------------------")
    return df

for i in range(len(SENTENCES)):
    print_preview(idx=i)

------------------------------------


------------------------------------
