### Only BERT

In [None]:
import shap
import torch
import numpy as np
from model.BERT import BertWithSentiment
from utils.data_utils import read_data
from transformers import pipeline, BertForSequenceClassification, BertTokenizerFast, BertTokenizer, BertModel

In [None]:
def explain_bert(model_path, test_dataset, logical_fallacies):
    """
    Generates SHAP (SHapley Additive exPlanations) visualizations to explain predictions made by a 
    fine-tuned BERT model on a text classification task.

    The function filters the test dataset by a subset of logical fallacy labels, loads the specified 
    BERT model and tokenizer, and uses SHAP to compute and visualize the contribution of each word 
    in the input texts to the model's predictions.

    Args:
        model_path (str): Path or model name of the fine-tuned BERT model and tokenizer directory 
            (e.g., Hugging Face model hub ID or local directory).
        test_dataset (pd.DataFrame): A pandas DataFrame containing the test data, including a 
            column with texts and a 'logical_fallacies' column for labels.
        logical_fallacies (list of str): A list of logical fallacy class labels to filter the 
            test dataset by (e.g., ['ad hominem', 'straw man']).

    Returns:
        None: Displays SHAP text plots in a browser or notebook, highlighting which words influenced 
        the model's predictions.
    """
    filtered_test_data = test_dataset[test_dataset.logical_fallacies.isin(logical_fallacies)]

    # Load model
    model = BertForSequenceClassification.from_pretrained(model_path)
    tokenizer = BertTokenizerFast.from_pretrained(model_path)
    nlp = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)

    # Predict function for SHAP
    def predict_proba(texts):
        if isinstance(texts, np.ndarray):
            texts = texts.tolist()
        preds = nlp(texts)
        return np.array([[item['score'] for item in pred] for pred in preds])

    # Create SHAP explainer
    explainer = shap.Explainer(predict_proba, masker=shap.maskers.Text(tokenizer))

    # Select a few examples
    sample_texts = filtered_test_data["source_article_ro"].sample(10, random_state=42).tolist()

    # Compute SHAP values
    shap_values = explainer(sample_texts)

    # Plot explanations
    shap.plots.text(shap_values)


In [None]:
_, test_dataset, _ = read_data("../data/all/combined_lfud_huggingface_binary.csv")

In [None]:
model_path = "../model/outputs/21-02-2025_14-45-55_bert-2-classes-model.pickle"
logical_fallacies = ['nonfallacy', 'fallacy']
explain_bert(model_path, test_dataset, logical_fallacies)

In [None]:
_, test_dataset, _ = read_data("../data/all/combined_lfud_huggingface_nonfallacies.csv")

In [None]:
model_path = "../model/outputs/03-03-2025_16-23-08_bert-3-classes-model.pickle"
logical_fallacies = ['nonfallacy', 'faulty generalization', 'intentional']
explain_bert(model_path, test_dataset, logical_fallacies)

In [None]:
model_path = "../model/outputs/03-03-2025_16-46-39_bert-5-classes-model.pickle"
logical_fallacies = ['nonfallacy', 'faulty generalization', 'intentional', 'ad hominem', 'false causality']
explain_bert(model_path, test_dataset, logical_fallacies)

In [None]:
model_path = "../model/outputs/20-02-2025_10-26-38_bert-all-classes-model.pickle"
logical_fallacies = list(set(list(test_dataset['logical_fallacies'])))
explain_bert(model_path, test_dataset, logical_fallacies)

### BERT with sentiment

In [None]:
class FixedSentimentMasker:
    def __init__(self, tokenizer, sentiments):
        self.tokenizer = tokenizer
        self.sentiments = sentiments  # store sentiments

    def __call__(self, inputs):
        # inputs is a list of masked texts
        masked_texts = np.array(inputs)  # wrap texts in np.array
        metadata = np.zeros(len(inputs))  # dummy metadata
        return masked_texts, metadata


In [None]:

model_name = "dumitrescustefan/bert-base-romanian-uncased-v1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def explain_bert_with_sentiment(model_path, tokenizer_path, test_dataset, logical_fallacies, fig_file, fallacy_index):
    """
    Explains predictions of a BERT-based classifier augmented with sentiment embeddings using SHAP,
    and visualizes the influence of sentiment on the model's output.

    The function filters the test dataset by logical fallacy labels, loads the `BertWithSentiment` model
    and tokenizer, computes SHAP explanations for input texts, and analyzes the effect of sentiment.

    Args:
        model_path (str): Path to the saved fine-tuned `BertWithSentiment` model weights.
        tokenizer_path (str): Path to the tokenizer used with the model.
        test_dataset (pd.DataFrame): Test data containing columns 'source_article_ro', 'logical_fallacies', and 'sentiment'.
        logical_fallacies (list of str): Logical fallacy labels used to filter the test dataset.
        fig_file (str): File path to save the sentiment impact plot.
        fallacy_index (int): Index of a test sample within the filtered dataset to analyze in detail.

    Returns:
        None: Displays SHAP text plots and a matplotlib plot showing sentiment impact. Prints sentiment embedding norms.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    filtered_test_data = test_dataset[test_dataset.logical_fallacies.isin(logical_fallacies)]

    label2id = {label: id for id, label in enumerate(logical_fallacies)}
    id2label = {v: k for k, v in label2id.items()}
    num_labels = len(label2id)

    model = BertWithSentiment(model_name=model_name, num_labels=num_labels)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    tokenizer = BertTokenizer.from_pretrained(tokenizer_path)

    sentiment_mapping = {"negative": 0, "neutral": 1, "positive": 2}

    # Define prediction function
    def predict_proba(texts):
        """
        Predicts class probabilities for a batch of texts with associated sentiments.

        Args:
            texts (list of str or list of (str, str)): List of input texts or tuples of (text, sentiment_label).
                If sentiment is not provided, 'neutral' is assumed.

        Returns:
            np.ndarray: Array of shape (batch_size, num_classes) with predicted probabilities.
        """
        # texts is a list of strings normally, but SHAP can pass in masked text
        if isinstance(texts[0], tuple):
            texts, sentiments = zip(*texts)
        else:
            sentiments = ["neutral"] * len(texts)  # fallback if no sentiment given

        inputs = tokenizer(list(texts), padding=True, truncation=True, max_length=512, return_tensors="pt")
        sentiment_ids = torch.tensor([sentiment_mapping.get(s, 1) for s in sentiments])

        inputs = {key: val.to(device) for key, val in inputs.items()}
        sentiment_ids = sentiment_ids.to(device)

        with torch.no_grad():
            outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
                            sentiment=sentiment_ids)
            logits = outputs["logits"]
            probs = torch.softmax(logits, dim=-1).cpu().numpy()

        return probs

    def explain_sentiment_impact(text, sentiments=["negative", "neutral", "positive"]):
        """
        Plots how changing the sentiment input(perturbations) affects the model's predicted class probabilities for a given text.

        Args:
            text (str): The input text to analyze.
            sentiments (list of str, optional): List of sentiment labels to test. Defaults to ["negative", "neutral", "positive"].

        Returns:
            None: Shows and saves a line plot illustrating probability changes across sentiments.
        """
        import matplotlib.pyplot as plt
        import numpy as np

        inputs = [(text, s) for s in sentiments]
        probs = predict_proba(inputs)

        probs = np.array(probs)

        # Print and store predictions as annotations
        for s, p in zip(sentiments, probs):
            print(f"Sentiment: {s} => Prediction: {p}")

        # Create the figure and axis
        fig, ax = plt.subplots()

        # Plot impact
        for i in range(probs.shape[1]):
            ax.plot(sentiments, probs[:, i], label=f'Class {i}', marker='o')

        ax.set_title("Effect of Sentiment on Prediction")
        ax.set_ylabel("Probability")
        ax.set_xlabel("Sentiment")

        # Place legend outside the plot on the right
        ax.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), borderaxespad=0)

        # Annotate predictions on the plot
        for i, sentiment in enumerate(sentiments):
            for j in range(probs.shape[1]):
                ax.text(i, probs[i, j] + 0.02, f"{probs[i, j]:.2f}",
                        ha='center', va='bottom', fontsize=8)

        # Adjust layout to accommodate the legend
        plt.tight_layout(rect=[0, 0, 0.85, 1])  # leave space on the right

        # Save the figure
        fig_file = "sentiment_impact.png"  # Make sure fig_file is defined
        plt.savefig(fig_file, bbox_inches='tight')
        plt.show()

    def get_sentiment_contribution(text, sentiment_label):
        """
        Extracts intermediate embeddings from the model to isolate the contribution of the sentiment embedding.

        Args:
            text (str): Input text.
            sentiment_label (str): Sentiment label whose embedding contribution is extracted.

        Returns:
            tuple:
                - cls_embedding (np.ndarray): The BERT CLS token embedding (shape: 1 x hidden_size).
                - sentiment_embedding (np.ndarray): The sentiment embedding vector (shape: 1 x hidden_size).
        """
        # Extract intermediate values to isolate the effect of sentiment
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        sentiment_id = torch.tensor([sentiment_mapping[sentiment_label]]).to(device)

        with torch.no_grad():
            bert_output = model.bert(**inputs)
            cls_embedding = bert_output.last_hidden_state[:, 0, :]
            sentiment_embed = model.sentiment_embedding(sentiment_id)
            return cls_embedding.cpu().numpy(), sentiment_embed.cpu().numpy()

    sample = filtered_test_data[:6]
    sample_texts = sample["source_article_ro"].tolist()
    sample_sentiments = sample["sentiment"].tolist()

    explainer = shap.Explainer(predict_proba, masker=shap.maskers.Text(tokenizer))
    shap_values = explainer(sample_texts[:5])  # or however many you want

    shap.plots.text(shap_values)

    text, sentiment = sample_texts[fallacy_index], sample_sentiments[fallacy_index]
    _, sentiment_emb = get_sentiment_contribution(text, sentiment)
    print(f"Sentiment embedding L2 norm: {np.linalg.norm(sentiment_emb):.4f}")

    text = sample_texts[fallacy_index]

    # SHAP for text
    shap_val = explainer([text])
    shap.plots.text(shap_val)

    # Sentiment impact
    explain_sentiment_impact(text)

In [None]:
_, test_dataset, _ = read_data("/kaggle/input/licenta-dataset-model/all/combined_lfud_huggingface_binary_sent.csv",
                               sentiment=True)

In [None]:
model_path = "/kaggle/input/licenta-dataset-model/outputs/experiment-3_4_sent_2_classes/outputs/model.pt"
tokenizer_path = "/kaggle/input/licenta-dataset-model/outputs/experiment-3_4_sent_2_classes/outputs/tokenizer"

logical_fallacies = ['nonfallacy', 'fallacy']
explain_bert_with_sentiment(model_path, tokenizer_path, test_dataset, logical_fallacies, "sentiment_impact2.png", 4)