<a href="https://colab.research.google.com/github/elianderlohr/muse-dlf/blob/main/notebooks/explainablity/slmuse-dlf-explainability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MuSE-DLF Explainability

Plot the explainability of the SLMuSE-DLF model. By using the dictionary learning approach it is (1.) possible to extract how different words in a certaim semantic role predict the presence of a document level frame and (2.) identify how the FrameAxis constallations are predicting the document level frames.

In [1]:
# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!ls /content/drive/MyDrive/Git/muse-dlf

assets	data  notebooks  README.md  research-notebooks	run  src  tests  ToDo.md


In [4]:
!pip install wandb==0.17.4



In [5]:
!pip install wandb==0.17.4 allennlp allennlp-models spacy

Collecting torch<1.13.0,>=1.10.0 (from allennlp)
  Using cached torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl (776.3 MB)
INFO: pip is looking at multiple versions of allennlp to determine which version is compatible with other requirements. This could take a while.
Collecting allennlp
  Using cached allennlp-2.10.0-py3-none-any.whl (729 kB)
Collecting torch<1.12.0,>=1.10.0 (from allennlp)
  Using cached torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl (750.6 MB)
Collecting torchvision<0.13.0,>=0.8.1 (from allennlp)
  Using cached torchvision-0.12.0-cp310-cp310-manylinux1_x86_64.whl (21.0 MB)
Collecting allennlp
  Using cached allennlp-2.9.3-py3-none-any.whl (719 kB)
Collecting spacy
  Using cached spacy-3.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.1 MB)
Collecting transformers<4.19,>=4.1 (from allennlp)
  Using cached transformers-4.18.0-py3-none-any.whl (4.0 MB)
Collecting filelock<3.7,>=3.3 (from allennlp)
  Using cached filelock-3.6.0-py3-none-any.whl (10.0 kB)

In [6]:
!pip install torch==2.3.1



In [7]:
!python -m spacy download en_core_web_sm

2024-07-12 09:03:50.389729: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-12 09:03:50.441851: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-12 09:03:50.441906: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-12 09:03:50.443458: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-12 09:03:50.451338: I tensorflow/core/platform/cpu_feature_guar

In [8]:
import sys
sys.path.append('/content/drive/MyDrive/Git/muse-dlf/src')

In [9]:
from preprocessing.pre_processor import PreProcessor
from preprocessing.datasets.article_dataset import custom_collate_fn
from model.slmuse_dlf.muse import SLMUSEDLF

# import tokenizer for roberta fast
from transformers import RobertaTokenizerFast
import wandb
import inspect
import torch
import spacy
import pickle
from pathlib import Path
from torch.utils.data import DataLoader

In [10]:
wandb.require("core")

In [11]:
import nltk

In [12]:
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [13]:
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

In [14]:
base_path = "drive/MyDrive/Git/"

## Setup wandb

In [15]:
run = wandb.init(project="slmuse-dlf", job_type="inference")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33melias-anderlohr[0m ([33melianderlohr[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Clean

In [16]:
def clean_gpu_memory():
    # Clear cache
    torch.cuda.empty_cache()
    # Reset peak memory stats
    torch.cuda.reset_peak_memory_stats()
    # Perform garbage collection
    import gc
    gc.collect()

clean_gpu_memory()

## Load SLMuSE-DLF

In [17]:
model_artifact = run.use_artifact('elianderlohr/slmuse-dlf/crashing_backpack_5645_model:v1', type='model')
model_dir = model_artifact.download()

[34m[1mwandb[0m: Downloading large artifact crashing_backpack_5645_model:v1, 5168.22MB. 2 files... 
Done. 0:0:0.3


## Load Roberta Model

In [18]:
roberta_artifact = run.use_artifact('elianderlohr-org/wandb-registry-model/mfc-roberta-finetune:v1', type='model')
roberta_dir = roberta_artifact.download()

[34m[1mwandb[0m: Downloading large artifact mfc-roberta-finetune:v1, 1427.32MB. 7 files... 
Done. 0:0:0.3


## Load Dataset

In [19]:
dataset_artifact = run.use_artifact('elianderlohr-org/wandb-registry-dataset/slmuse-dlf:v3', type='dataset')
dataset_dir = dataset_artifact.download()

## Load Config

In [20]:
# Access the run that created the artifactelianderlohr/slmuse-dlf/qa9dh6px
run_id = 'elianderlohr/slmuse-dlf/rl2pr1nz'  # Replace with your run ID if known, otherwise see below for how to get it
run_ref = wandb.Api().run(run_id)

In [21]:
# Access the configuration
config = run_ref.config

In [22]:
# Get the parameters of the SLMUSEDLF class constructor
params = inspect.signature(SLMUSEDLF.__init__).parameters

# Extract the relevant parameters from the config dictionary
model_params = {key: config[key] for key in params if key in config}

In [23]:
model_params["bert_model_name"] = "roberta-base"
model_params["bert_model_name_or_path"] = roberta_dir

## Load Model

In [24]:
model = SLMUSEDLF(**model_params)

Some weights of the model checkpoint at /content/artifacts/roberta-base-finetune-checkpoint-16482:v0 were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.decoder.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at /content/artifacts/roberta-base-finetune-checkpoint-16482:v0 and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight

In [25]:
def strip_prefix_from_state_dict(state_dict, prefix):
    """Strip a prefix from the keys in state_dict."""
    return {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)}

# Assuming you load the state_dict as follows
state_dict = torch.load(f"{model_dir}/model.pth", map_location="cuda")

In [26]:
stripped_state_dict = strip_prefix_from_state_dict(state_dict, 'module.module.')
model_state_dict = model.state_dict()
model_state_dict.update(stripped_state_dict)

In [27]:
model.load_state_dict(model_state_dict)

<All keys matched successfully>

## Create Dataset

In [28]:
!ls drive/MyDrive/Git/muse-dlf/data/axis

732_semaxis_axes.tsv  mft_experiment.json  mft_raw.csv
custom.tsv	      mft_filtered.csv	   plutchik_wheel_of_emotions.tsv
frames.json	      mft.json		   wordnet_antonyms.tsv


In [29]:
class_column_names = "Capacity and Resources;Crime and Punishment;Cultural Identity;Economic;External Regulation and Reputation;Fairness and Equality;Health and Safety;Legality, Constitutionality, Jurisdiction;Morality;Other;Policy Prescription and Evaluation;Political;Public Sentiment;Quality of Life;Security and Defense".split(";")

### Create Full Dataset

In [30]:
# Define paths to the dataset files within the downloaded directory
train_artifact_filepath = Path(dataset_dir) / 'train_dataset_artifact.pkl'
test_artifact_filepath = Path(dataset_dir) / 'test_dataset_artifact.pkl'

# Load the datasets from the artifact files
with train_artifact_filepath.open("rb") as f:
    loaded_train_dataset = pickle.load(f)

with test_artifact_filepath.open("rb") as f:
    loaded_test_dataset = pickle.load(f)

In [31]:
# create dataloaders
train_dataloader = DataLoader(
    loaded_train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    collate_fn=custom_collate_fn,
    drop_last=True,
    pin_memory=True,
    num_workers=1,
)

test_dataloader = DataLoader(
    loaded_test_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    collate_fn=custom_collate_fn,
    drop_last=True,
    pin_memory=True,
    num_workers=1,
)

### Create Example Dataset

In [32]:
preprocessor = PreProcessor(
    tokenizer=tokenizer,
    batch_size=config["batch_size"],
    max_sentences_per_article=config["num_sentences"],
    max_sentence_length=config["max_sentence_length"],
    max_args_per_sentence=config["max_args_per_sentence"],
    max_arg_length=config["max_arg_length"],
    frameaxis_dim=config["frameaxis_dim"],
    bert_model_name="roberta-base",
    name_tokenizer="roberta-base",
    path_name_bert_model=roberta_dir,
    path_antonym_pairs=f"{base_path}muse-dlf/data/axis/mft.json",
    dim_names=["virtue", "vice"],
    class_column_names=class_column_names,
    )

In [33]:
text = "BILL ON IMMIGRANT WORKERS DIES. Legislation to allow nearly twice as many computer-savvy foreigners and other high-skilled immigrants into the country next year apparently has died in Congress. The House passed the compromise measure last month, 288-133, but Sen. Tom Harkin, D-Iowa, had blocked a vote when in the Senate. The proposal, backed by high-tech companies, would raise the limit of so- called H-1B visas granted each year to skilled workers from abroad. Only 65,000 visas are now granted each year; the bill would raise the annual cap to 115,500 for the next two years and to 107,500 in 2001. The ceiling would return to 65,000 in 2002."

In [34]:
example_dataset, example_dataloader = preprocessor.preprocess_single_article(
    text
)

  warn(f"Failed to load image Python extension: {e}")
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  _C._set_default_tensor_type(t)
Processing SRL Batches: 100%|██████████| 1/1 [00:00<00:00,  1.37it/s

## Run model with data

In [35]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))

True
0
NVIDIA L4


In [36]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


In [37]:
from tqdm.notebook import tqdm
import numpy as np
import torch

def inspect(model, dataloader, device='cuda'):
    """
    Make predictions with the given model and dataloader.

    Args:
    - model (torch.nn.Module): The model to make predictions with.
    - dataloader (DataLoader): DataLoader for the dataset to predict on.
    - device (str): Device to make predictions on ('cpu' or 'cuda').

    Returns:
    - predicted_labels (list of lists): List containing the predicted labels for each instance.
    """
    model = model.to("cuda")
    model.eval()

    # dim
    batch_size = dataloader.batch_size
    num_sentences = dataloader.dataset.max_sentences_per_article
    max_args_per_sentence = dataloader.dataset.max_args_per_sentence
    K = 15

    print("num_batches", len(dataloader))
    print("batch_size", batch_size)
    print("num_sentences", num_sentences)
    print("max_args_per_sentence", max_args_per_sentence)
    print("K", K)

    all_preds_span = []

    # Initialize usage lists for each label
    all_used_labels_p = []
    all_used_labels_a0 = []
    all_used_labels_a1 = []

    all_used_fx = []

    with torch.no_grad():
        # Wrap the dataloader with tqdm for batch progress
        for batch in tqdm(dataloader, desc="Processing Batches"):
            sentence_ids = batch['sentence_ids'].to(device)
            sentence_attention_masks = batch['sentence_attention_masks'].to(device)

            predicate_ids = batch['predicate_ids'].to(device)
            arg0_ids = batch['arg0_ids'].to(device)
            arg1_ids = batch['arg1_ids'].to(device)

            frameaxis_data = batch['frameaxis'].to(device)

            sentence_embeddings, predicate_embeddings, arg0_embeddings, arg1_embeddings = model.aggregation(sentence_ids, sentence_attention_masks, predicate_ids, arg0_ids, arg1_ids)

            # Process each span
            for sentence_idx in range(sentence_embeddings.size(1)):
                s_sentence_span = sentence_embeddings[:, sentence_idx, :]
                v_fx = frameaxis_data[:, sentence_idx, :]

                for span_idx in range(predicate_embeddings.size(2)):
                    v_p_span = predicate_embeddings[:, sentence_idx, span_idx, :]
                    v_a0_span = arg0_embeddings[:, sentence_idx, span_idx, :]
                    v_a1_span = arg1_embeddings[:, sentence_idx, span_idx, :]

                    mask_p = (v_p_span.abs().sum(dim=-1) != 0).float().bool()
                    mask_a0 = (v_a0_span.abs().sum(dim=-1) != 0).float().bool()
                    mask_a1 = (v_a1_span.abs().sum(dim=-1) != 0).float().bool()

                    output = model.unsupervised.combined_autoencoder(
                        v_p_span, v_a0_span, v_a1_span, mask_p, mask_a0, mask_a1, s_sentence_span, 0.6
                    )

                    all_used_labels_p.append(output["p"]["d"].cpu().numpy())
                    all_used_labels_a0.append(output["a0"]["d"].cpu().numpy())
                    all_used_labels_a1.append(output["a1"]["d"].cpu().numpy())

                    del v_p_span, v_a0_span, v_a1_span, mask_p, mask_a0, mask_a1, output
                    torch.cuda.empty_cache()

                mask_fx = (v_fx.abs().sum(dim=-1) != 0).float().bool()

                frameaxis_output = model.unsupervised_fx.frameaxis_autoencoder(v_fx, mask_fx, s_sentence_span, 0.6)

                all_used_fx.append(frameaxis_output["d"].cpu().numpy())

                del v_fx, mask_fx, frameaxis_output
                torch.cuda.empty_cache()

            # Forward pass
            _, span_logits, sentence_logits, combined_logits, _ = model(
                sentence_ids, sentence_attention_masks, predicate_ids, arg0_ids, arg1_ids, frameaxis_data, 0.5
            )
            combined_pred = (torch.softmax(combined_logits, dim=-1) > 0.5).float()

            all_preds_span.append(combined_pred.cpu().numpy())

            del sentence_ids, sentence_attention_masks, predicate_ids, arg0_ids, arg1_ids, frameaxis_data
            del sentence_embeddings, predicate_embeddings, arg0_embeddings, arg1_embeddings
            del span_logits, sentence_logits, combined_logits, combined_pred
            torch.cuda.empty_cache()

    predictions = np.vstack(all_preds_span)

    all_used_labels_p = np.vstack(all_used_labels_p)
    all_used_labels_a0 = np.vstack(all_used_labels_a0)
    all_used_labels_a1 = np.vstack(all_used_labels_a1)

    all_used_fx = np.vstack(all_used_fx)

    # reshape from (iterator (1), num sentences 24, num spans 10, batch size 64, classes 15) to (batch size 64, num sentences 24, num spans 10, classes 15)
    all_used_labels_p = all_used_labels_p.reshape(-1, num_sentences, max_args_per_sentence, K)
    all_used_labels_a0 = all_used_labels_a0.reshape(-1, num_sentences, max_args_per_sentence, K)
    all_used_labels_a1 = all_used_labels_a1.reshape(-1, num_sentences, max_args_per_sentence, K)

    all_used_fx = all_used_fx.reshape(-1, num_sentences, K)

    return predictions, all_used_labels_p, all_used_labels_a0, all_used_labels_a1, all_used_fx


In [38]:
clean_gpu_memory()

In [40]:
output = inspect(model, test_dataloader, device="cuda")

num_batches 37
batch_size 32
num_sentences 32
max_args_per_sentence 10
K 15


Processing Batches:   0%|          | 0/37 [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


In [41]:
predicted_labels, used_labels_p, used_labels_a0, used_labels_a1, used_fx = output

In [45]:
# Saving numpy arrays to file
np.savez(base_path + '/labels_data.npz',
         predicted_labels=predicted_labels,
         used_labels_p=used_labels_p,
         used_labels_a0=used_labels_a0,
         used_labels_a1=used_labels_a1,
         used_fx=used_fx)

In [43]:
category_lists_p = {category: [] for category in class_column_names}
category_lists_a1 = {category: [] for category in class_column_names}
category_lists_a0 = {category: [] for category in class_column_names}

category_lists_fx = {category: [] for category in class_column_names}

boundary = 0.4

elem_len = len(test_dataloader.dataset)
for elem_idx in range(elem_len):
    ds = test_dataloader.dataset[elem_idx]

    sent_len = len(ds["predicate_ids"])
    for sentence_idx in range(sent_len):
        span_len = len(ds["predicate_ids"][sentence_idx])
        for span_idx in range(span_len):

          for cat_idx, category in enumerate(class_column_names):
              if used_labels_p[elem_idx][sentence_idx][span_idx][cat_idx] > boundary:
                category_lists_p[category].append(ds["predicate_ids"][sentence_idx][span_idx].int().numpy())

              if used_labels_a0[elem_idx][sentence_idx][span_idx][cat_idx] > boundary:
                category_lists_a0[category].append(ds["arg0_ids"][sentence_idx][span_idx].int().numpy())

              if used_labels_a1[elem_idx][sentence_idx][span_idx][cat_idx] > boundary:
                category_lists_a1[category].append(ds["arg1_ids"][sentence_idx][span_idx].int().numpy())

        if used_fx[elem_idx][sentence_idx][cat_idx] > boundary:
          category_lists_fx[category].append(ds["frameaxis"][sentence_idx].float().numpy())

IndexError: index 1184 is out of bounds for axis 0 with size 1184

In [None]:
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer

# Ensure you have downloaded the necessary NLTK resources
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')

def decode_tokens(token_dict, tokenizer, remove_stopwords=False, lemmatize=False):
    decoded_data = {}
    stop_words = set(stopwords.words('english')) if remove_stopwords else set()
    lemmatizer = WordNetLemmatizer() if lemmatize else None

    for category, token_lists in token_dict.items():
        decoded_data[category] = []
        for tokens in token_lists:
            if np.any(tokens > 0):
                # Convert tokens to a list if it's a tensor or numpy array
                if isinstance(tokens, torch.Tensor):
                    tokens = tokens.tolist()
                elif isinstance(tokens, np.ndarray):
                    tokens = tokens.tolist()

                # Decode the tokens
                decoded_text = tokenizer.decode(tokens, skip_special_tokens=True).strip()

                # Remove non-alphabetic characters (but keep spaces)
                decoded_text = re.sub(r'[^A-Za-z ]', '', decoded_text)

                # Tokenize, optionally lemmatize, and remove stop words
                words = word_tokenize(decoded_text)
                processed_words = [lemmatizer.lemmatize(word.lower()) if lemmatizer else word.lower() for word in words if word.lower() not in stop_words]

                # Join the words back into a string and ensure it's not empty
                processed_text = ' '.join(processed_words)
                if processed_text:
                    decoded_data[category].append(processed_text)

    return decoded_data

stop_words = set(stopwords.words('english'))

# Decode the token IDs for each ARG
decoded_predicate = decode_tokens(category_lists_p, tokenizer, remove_stopwords=True, lemmatize=True)
decoded_arg0 = decode_tokens(category_lists_a0, tokenizer, remove_stopwords=True, lemmatize=True)
decoded_arg1 = decode_tokens(category_lists_a1, tokenizer, remove_stopwords=True, lemmatize=True)