In [None]:
# Text classification model
import torch.nn as nn

class TextClassifierModel(nn.Module):
    def __init__(self, input_dim=768):
        super().__init__()

        # Define individual layers
        self.fc1 = nn.Linear(input_dim, input_dim // 4)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(input_dim // 4, input_dim // 8)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(input_dim // 8, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Pass through all layers
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return self.sigmoid(x)

    def encode(self, x):
        # Pass through layers up to the one before the classification layer
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return self.relu2(x)


In [None]:
# EXPLANATION FUNCTIONS

from sentence_transformers import SentenceTransformer
import nltk
from nltk.corpus import stopwords
import string
from pathlib import Path
import torch
import numpy as np
from lime.lime_text import LimeTextExplainer


# Download stopwords if not already downloaded
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

# Manually add more stopwords
#custom_stopwords = {""}  # Add your words here
#stop_words.update(custom_stopwords)  # Merge with default stopwords

# Function to remove stopwords from a text
def remove_stopwords(text):
    words = text.split(" ")
    filtered_words = [
        word.lower().strip(string.punctuation) for word in words 
        if word.lower().strip(string.punctuation) not in stop_words
    ]
    return " ".join(filtered_words)

# Load the model used for embedding
model_mumin_id = 'paraphrase-mpnet-base-v2'
model_mumin = SentenceTransformer(model_mumin_id)

# Prediction function for LIME
def predict_fn(texts):
    embeddings = model_mumin.encode(texts, convert_to_numpy=True) # Convert text to embeddings
    inputs = torch.tensor(embeddings, dtype=torch.float32)
    with torch.no_grad():
        outputs = model(inputs).numpy()
    return np.hstack([1 - outputs, outputs])  # Convert to probability format

# Explanation function
def explain_text(classifier_model, classifier_model_weights, news_text, output_explanation_file = 'lime_explanation_of_text_news.html', num_features = 10, feature_selection="lasso_path"):
    
    #load the model weights 
    classifier_model.load_state_dict(torch.load(classifier_model_weights))
    classifier_model.eval()
    
    # Initialize LIME explainer
    explainer = LimeTextExplainer(class_names=["Fake", "Real"], feature_selection=feature_selection)

    # Generate explanation
    explanation = explainer.explain_instance(remove_stopwords(news_text), predict_fn, num_features=num_features)
    
    # Display results
    explanation.show_in_notebook()
    explanation.save_to_file(output_explanation_file)

In [None]:
# EXPLANATION EXECUTION STEPS

# Load the model used for classification, the model weights after training, and set to evaluation mode
model = TextClassifierModel()
# set the folder and the file with the model weights
DATA_FOLDER = Path('./trained_models/')
nn_weights = DATA_FOLDER/'nn_weights_trained_on_claims-seed42.pt'
model.load_state_dict(torch.load(nn_weights))
model.eval()

# Select news
news_text = 'INSERT NEWS TEXT HERE'
print('Text of the news: ', news_text)

# Execute the explanation function
explain_text(model, nn_weights, news_text)