# SLMuSE-DLF Explainability

Plot the explainability of the SLMuSE-DLF model. By using the dictionary learning approach it is (1.) possible to extract how different words in a certaim semantic role predict the presence of a document level frame and (2.) identify how the FrameAxis constallations are predicting the document level frames.

In [1]:
# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../../src')

In [10]:
from preprocessing.pre_processor import PreProcessor

# import tokenizer for roberta fast
from transformers import RobertaTokenizerFast

In [11]:
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

In [88]:
class_column_names = "Capacity and Resources;Crime and Punishment;Cultural Identity;Economic;External Regulation and Reputation;Fairness and Equality;Health and Safety;Legality, Constitutionality, Jurisdiction;Morality;Other;Policy Prescription and Evaluation;Political;Public Sentiment;Quality of Life;Security and Defense".split(";")

In [98]:
preprocessor = PreProcessor(
    tokenizer=tokenizer,
    batch_size=8,
    max_sentences_per_article=32,
    max_sentence_length=52,
    max_args_per_sentence=10,
    max_arg_length=16,
    frameaxis_dim=10,
    bert_model_name="roberta-base",
    name_tokenizer="roberta-base",
    path_name_bert_model="../../models/roberta-base-finetune/roberta-base-finetune-2024-05-20_08-02-29-65707/checkpoint-16482",
    path_antonym_pairs="../../data/axis/mft.json",
    dim_names=["virtue", "vice"],
    class_column_names=class_column_names,
    )

In [99]:
text = "This is a test article, with two sentences. The first sentence is about a good thing, and the second sentence is about a bad thing."

In [100]:
dataset, dataloader = preprocessor.preprocess_single_article(
    text
)

error loading _jsonnet (this is expected on Windows), treating C:\Users\elias\AppData\Local\Temp\tmpv4qaivvu\config.json as plain json
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Processing SRL Batch

In [120]:
from tqdm.notebook import tqdm
import numpy as np
import torch

def inspect(model, dataloader, device='cuda'):
    """
    Make predictions with the given model and dataloader.
    
    Args:
    - model (torch.nn.Module): The model to make predictions with.
    - dataloader (DataLoader): DataLoader for the dataset to predict on.
    - device (str): Device to make predictions on ('cpu' or 'cuda').
    
    Returns:
    - predicted_labels (list of lists): List containing the predicted labels for each instance.
    """
    model.eval()
    
    # dim
    batch_size = dataloader.batch_size
    num_sentences = dataloader.dataset.max_sentences_per_article
    max_args_per_sentence = dataloader.dataset.max_args_per_sentence   
    K = 15

    print("num_batches", len(dataloader))
    print("batch_size", batch_size)
    print("num_sentences", num_sentences)
    print("max_args_per_sentence", max_args_per_sentence)
    print("K", K)


    all_preds_span = []
    
    # Initialize usage lists for each label
    all_used_labels_p = []
    all_used_labels_a0 = []
    all_used_labels_a1 = []

    all_used_fx = []
    
    with torch.no_grad():
        # Wrap the dataloader with tqdm for batch progress
        for batch in tqdm(dataloader, desc="Processing Batches"):    

            sentence_ids = batch['sentence_ids'].to(device)
            sentence_attention_masks = batch['sentence_attention_masks'].to(device)

            predicate_ids = batch['predicate_ids'].to(device)            
            arg0_ids = batch['arg0_ids'].to(device)
            arg1_ids = batch['arg1_ids'].to(device)

            frameaxis_data = batch['frameaxis'].to(device)

            sentence_embeddings, predicate_embeddings, arg0_embeddings, arg1_embeddings = model.aggregation(sentence_ids, sentence_attention_masks, predicate_ids, arg0_ids, arg1_ids)
            
            # Process each span
            for sentence_idx in range(sentence_embeddings.size(1)):
                s_sentence_span = sentence_embeddings[:, sentence_idx, :]
                v_fx = frameaxis_data[:, sentence_idx, :]

                for span_idx in range(predicate_embeddings.size(2)):                    
                    v_p_span = predicate_embeddings[:, sentence_idx, span_idx, :]
                    v_a0_span = arg0_embeddings[:, sentence_idx, span_idx, :]
                    v_a1_span = arg1_embeddings[:, sentence_idx, span_idx, :]

                    mask_p = (v_p_span.abs().sum(dim=-1) != 0).float().bool()
                    mask_a0 = (v_a0_span.abs().sum(dim=-1) != 0).float().bool()
                    mask_a1 = (v_a1_span.abs().sum(dim=-1) != 0).float().bool()

                    #unsupervised.combined_autoencoder v_p, v_a0, v_a1, v_sentence, tau
                    output = model.unsupervised.combined_autoencoder(v_p_span,
                        v_a0_span,
                        v_a1_span,
                        mask_p,
                        mask_a0,
                        mask_a1,
                        s_sentence_span,
                        0.6)
                    
                    all_used_labels_p.append(output["p"]["g"].cpu().numpy())
                    all_used_labels_a0.append(output["a0"]["g"].cpu().numpy())
                    all_used_labels_a1.append(output["a1"]["g"].cpu().numpy())

                mask_fx = (v_fx.abs().sum(dim=-1) != 0).float().bool()

                frameaxis_output = model.unsupervised_fx.frameaxis_autoencoder(v_fx, mask_fx, s_sentence_span, 0.6)

                all_used_fx.append(frameaxis_output["g"].cpu().numpy())

            
            # Forward pass
            _, span_logits, sentence_logits, combined_logits, _ = model(sentence_ids, sentence_attention_masks, predicate_ids, arg0_ids, arg1_ids, frameaxis_data, 0.5)
            combined_pred = (torch.softmax(combined_logits, dim=-1) > 0.5).float()

            all_preds_span.append(combined_pred.cpu().numpy())
                
            torch.cuda.empty_cache()

    predictions = np.vstack(all_preds_span)
    
    all_used_labels_p = np.vstack(all_used_labels_p)
    all_used_labels_a0 = np.vstack(all_used_labels_a0)
    all_used_labels_a1 = np.vstack(all_used_labels_a1)

    all_used_fx = np.vstack(all_used_fx)

    # reshape from (iterator (1), num sentences 24, num spans 10, batch size 64, classes 15) to (batch size 64, num sentences 24, num spans 10, classes 15)
    all_used_labels_p = all_used_labels_p.reshape(-1, num_sentences, max_args_per_sentence, K)
    all_used_labels_a0 = all_used_labels_a0.reshape(-1, num_sentences, max_args_per_sentence, K)
    all_used_labels_a1 = all_used_labels_a1.reshape(-1, num_sentences, max_args_per_sentence, K)

    all_used_fx = all_used_fx.reshape(-1, num_sentences, max_args_per_sentence, K)

    return predictions, all_used_labels_p, all_used_labels_a0, all_used_labels_a1, all_used_fx

In [130]:
# create slmuse-dlf model
from model.slmuse_dlf.muse import SLMUSEDLF

In [131]:
del model

NameError: name 'model' is not defined

In [132]:
model = SLMUSEDLF(
    embedding_dim=768,
    frameaxis_dim=10,
    hidden_dim=1024,
    num_classes=15,
    num_sentences=32,
    dropout_prob=0.3,
    bert_model_name="roberta-base",
    bert_model_name_or_path="../../models/roberta-base-finetune/roberta-base-finetune-2024-05-20_08-02-29-65707/checkpoint-16482",
    srl_embeddings_pooling="mean",
    lambda_orthogonality="0.001626384818258435",
    M=8,
    t=8,
    muse_unsupervised_num_layers=4,
    muse_unsupervised_activation="relu",
    muse_unsupervised_use_batch_norm=True,
    muse_unsupervised_matmul_input="g",
    muse_unsupervised_gumbel_softmax_log=False,
    muse_frameaxis_unsupervised_num_layers=4,
    muse_frameaxis_unsupervised_activation="relu",
    muse_frameaxis_unsupervised_use_batch_norm=True,
    muse_frameaxis_unsupervised_matmul_input="g",
    muse_frameaxis_unsupervised_gumbel_softmax_log=False,
    num_negatives=128,
    supervised_concat_frameaxis=False,
    supervised_num_layers=2,
    supervised_activation="relu",
    _debug=False,
    _detect_anomaly=False,
)

Some weights of the model checkpoint at ../../models/roberta-base-finetune/roberta-base-finetune-2024-05-20_08-02-29-65707/checkpoint-16482 were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.decoder.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at ../../models/roberta-base-finetune/roberta-base-finetune-2024-05-20_08-02-29-65707/checkpoint-16482 and are

In [133]:
import os
import torch

file_path = "../../models/slmuse-dlf/2024-07-07_11-30-11/model.pth"

# Check if the file exists
if not os.path.exists(file_path):
    raise FileNotFoundError(f"File not found: {file_path}")

# Check if the file is accessible
try:
    with open(file_path, 'rb') as f:
        pass
except PermissionError:
    raise PermissionError(f"Permission denied: {file_path}")

# Load the model state
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(file_path, map_location=device), strict=False)

_IncompatibleKeys(missing_keys=['aggregation.model.embeddings.position_ids', 'aggregation.model.embeddings.word_embeddings.weight', 'aggregation.model.embeddings.position_embeddings.weight', 'aggregation.model.embeddings.token_type_embeddings.weight', 'aggregation.model.embeddings.LayerNorm.weight', 'aggregation.model.embeddings.LayerNorm.bias', 'aggregation.model.encoder.layer.0.attention.self.query.weight', 'aggregation.model.encoder.layer.0.attention.self.query.bias', 'aggregation.model.encoder.layer.0.attention.self.key.weight', 'aggregation.model.encoder.layer.0.attention.self.key.bias', 'aggregation.model.encoder.layer.0.attention.self.value.weight', 'aggregation.model.encoder.layer.0.attention.self.value.bias', 'aggregation.model.encoder.layer.0.attention.output.dense.weight', 'aggregation.model.encoder.layer.0.attention.output.dense.bias', 'aggregation.model.encoder.layer.0.attention.output.LayerNorm.weight', 'aggregation.model.encoder.layer.0.attention.output.LayerNorm.bias', 

In [134]:
inspect(model, dataloader, device=device)

num_batches 1
batch_size 8
num_sentences 32
max_args_per_sentence 10
K 15


Processing Batches:   0%|          | 0/1 [00:00<?, ?it/s]



TypeError: only integer tensors of a single element can be converted to an index