In [None]:
# Cell 1: Import Libraries and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import (
    AutoTokenizer, AutoModel, AutoConfig, 
    RobertaTokenizer, RobertaForSequenceClassification
)
from sklearn.preprocessing import LabelEncoder
import joblib
import os
import json
import warnings
from typing import Dict, List, Tuple, Optional, Union
import re

# XAI Libraries
import shap
import lime
from lime.lime_text import LimeTextExplainer
from captum.attr import IntegratedGradients, TokenReferenceBase, visualization
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8')

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create output directory for visualizations
os.makedirs("xai_visualizations", exist_ok=True)

print("✅ XAI libraries imported and setup complete!")

In [None]:
# Cell 2: Define MultiTaskTransformer Architecture
class MultiTaskTransformer(nn.Module):
    """
    Multi-task learning model for sentiment and emotion classification
    """
    
    def __init__(
        self,
        model_name: str = "microsoft/deberta-base",
        sentiment_num_classes: int = 3,
        emotion_num_classes: int = 6,
        hidden_dropout_prob: float = 0.1,
        attention_dropout_prob: float = 0.1,
        classifier_dropout: float = 0.1,
        freeze_encoder: bool = False
    ):
        super(MultiTaskTransformer, self).__init__()
        
        self.model_name = model_name
        self.sentiment_num_classes = sentiment_num_classes
        self.emotion_num_classes = emotion_num_classes
        
        # Load configuration and adjust dropout
        config = AutoConfig.from_pretrained(model_name)
        config.hidden_dropout_prob = hidden_dropout_prob
        config.attention_probs_dropout_prob = attention_dropout_prob
        
        # Shared transformer encoder
        self.shared_encoder = AutoModel.from_pretrained(
            model_name,
            config=config,
            ignore_mismatched_sizes=True
        )
        
        # Freeze encoder if specified
        if freeze_encoder:
            for param in self.shared_encoder.parameters():
                param.requires_grad = False
        
        hidden_size = self.shared_encoder.config.hidden_size
        
        # Task-specific attention layers
        self.sentiment_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=attention_dropout_prob,
            batch_first=True
        )
        
        self.emotion_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=attention_dropout_prob,
            batch_first=True
        )
        
        # Shared attention for common features
        self.shared_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=attention_dropout_prob,
            batch_first=True
        )
        
        # Layer normalization
        self.sentiment_norm = nn.LayerNorm(hidden_size)
        self.emotion_norm = nn.LayerNorm(hidden_size)
        self.shared_norm = nn.LayerNorm(hidden_size)
        
        # Dropout layers
        self.sentiment_dropout = nn.Dropout(classifier_dropout)
        self.emotion_dropout = nn.Dropout(classifier_dropout)
        self.shared_dropout = nn.Dropout(classifier_dropout)
        
        # Classification heads
        self.sentiment_classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(hidden_size, sentiment_num_classes)
        )
        
        self.emotion_classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(hidden_size, emotion_num_classes)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize classification head weights"""
        for module in [self.sentiment_classifier, self.emotion_classifier]:
            for layer in module:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    nn.init.zeros_(layer.bias)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        task: Optional[str] = None
    ) -> Dict[str, torch.Tensor]:
        """Forward pass"""
        # Shared encoder
        encoder_outputs = self.shared_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        sequence_output = encoder_outputs.last_hidden_state
        
        # Apply shared attention
        shared_attended, _ = self.shared_attention(
            sequence_output, sequence_output, sequence_output,
            key_padding_mask=~attention_mask.bool()
        )
        shared_attended = self.shared_norm(shared_attended + sequence_output)
        shared_attended = self.shared_dropout(shared_attended)
        shared_pooled = shared_attended[:, 0, :]
        
        outputs = {}
        
        # Sentiment branch
        if task is None or task == "sentiment":
            sentiment_attended, sentiment_weights = self.sentiment_attention(
                sequence_output, sequence_output, sequence_output,
                key_padding_mask=~attention_mask.bool()
            )
            sentiment_attended = self.sentiment_norm(sentiment_attended + sequence_output)
            sentiment_attended = self.sentiment_dropout(sentiment_attended)
            sentiment_pooled = sentiment_attended[:, 0, :]
            sentiment_features = torch.cat([shared_pooled, sentiment_pooled], dim=-1)
            sentiment_logits = self.sentiment_classifier(sentiment_features)
            outputs["sentiment_logits"] = sentiment_logits
        
        # Emotion branch
        if task is None or task == "emotion":
            emotion_attended, emotion_weights = self.emotion_attention(
                sequence_output, sequence_output, sequence_output,
                key_padding_mask=~attention_mask.bool()
            )
            emotion_attended = self.emotion_norm(emotion_attended + sequence_output)
            emotion_attended = self.emotion_dropout(emotion_attended)
            emotion_pooled = emotion_attended[:, 0, :]
            emotion_features = torch.cat([shared_pooled, emotion_pooled], dim=-1)
            emotion_logits = self.emotion_classifier(emotion_features)
            outputs["emotion_logits"] = emotion_logits
        
        return outputs
    
    @classmethod
    def from_pretrained(cls, model_path: str, **kwargs):
        """Load the model"""
        config_path = os.path.join(model_path, "config.json")
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        model = cls(
            model_name=config["model_name"],
            sentiment_num_classes=config["sentiment_num_classes"],
            emotion_num_classes=config["emotion_num_classes"],
            **kwargs
        )
        
        model_file = os.path.join(model_path, "pytorch_model.bin")
        state_dict = torch.load(model_file, map_location='cpu')
        model.load_state_dict(state_dict)
        
        return model

print("✅ MultiTaskTransformer architecture defined!")

In [None]:
# Cell 3: Text Preprocessing Functions
def clean_text(text: str) -> str:
    """Clean and preprocess text data"""
    if pd.isna(text):
        return ""
    
    # Convert to string and strip whitespace
    text = str(text).strip()
    
    # Replace newlines and multiple spaces
    text = re.sub(r'\n+', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    
    # Basic cleaning while preserving social media content
    text = text.replace('&amp;', '&')
    text = text.replace('&lt;', '<')
    text = text.replace('&gt;', '>')
    
    return text.strip()

def preprocess_reddit_text(text: str) -> str:
    """Preprocess Reddit-specific text"""
    text = clean_text(text)
    
    # Handle Reddit-specific patterns
    text = re.sub(r'username', '@user', text)  # Replace username placeholders
    text = re.sub(r'\[.*?\]', '', text)  # Remove Reddit formatting
    text = re.sub(r'\(.*?\)', '', text)  # Remove parenthetical notes
    
    return text.strip()

def load_and_preprocess_data(data_path: str) -> pd.DataFrame:
    """Load and preprocess the annotated Reddit posts"""
    print(f"📥 Loading data from {data_path}...")
    
    df = pd.read_csv(data_path)
    
    # Preprocess text
    df['cleaned_text'] = df['text_content'].apply(preprocess_reddit_text)
    
    # Create label encoders
    sentiment_encoder = LabelEncoder()
    emotion_encoder = LabelEncoder()
    
    sentiment_encoder.fit(df['sentiment'].tolist())
    emotion_encoder.fit(df['emotion'].tolist())
    
    df['sentiment_encoded'] = sentiment_encoder.transform(df['sentiment'])
    df['emotion_encoded'] = emotion_encoder.transform(df['emotion'])
    
    print(f"✅ Data loaded: {len(df)} samples")
    print(f"   Sentiment classes: {list(sentiment_encoder.classes_)}")
    print(f"   Emotion classes: {list(emotion_encoder.classes_)}")
    
    return df, sentiment_encoder, emotion_encoder

# Load the data
reddit_df, data_sentiment_encoder, data_emotion_encoder = load_and_preprocess_data("annotated_reddit_posts.csv")
print("✅ Text preprocessing functions defined and data loaded!")

In [None]:
# Cell 4: Model Loading Functions
class ModelLoader:
    """Centralized model loading for all model types"""
    
    @staticmethod
    def load_roberta_sentiment(model_path: str):
        """Load RoBERTa sentiment model"""
        print(f"📥 Loading RoBERTa sentiment model from {model_path}...")
        
        tokenizer = RobertaTokenizer.from_pretrained(model_path)
        model = RobertaForSequenceClassification.from_pretrained(model_path)
        model.to(device)
        model.eval()
        
        # Load label encoder
        encoder_path = os.path.join(model_path, 'sentiment_encoder.pkl')
        label_encoder = joblib.load(encoder_path)
        
        print(f"✅ RoBERTa sentiment model loaded!")
        print(f"   Classes: {list(label_encoder.classes_)}")
        
        return model, tokenizer, label_encoder
    
    @staticmethod
    def load_roberta_emotion(model_path: str):
        """Load RoBERTa emotion model"""
        print(f"📥 Loading RoBERTa emotion model from {model_path}...")
        
        tokenizer = RobertaTokenizer.from_pretrained(model_path)
        model = RobertaForSequenceClassification.from_pretrained(model_path)
        model.to(device)
        model.eval()
        
        # Load label encoder
        encoder_path = os.path.join(model_path, 'emotion_encoder.pkl')
        label_encoder = joblib.load(encoder_path)
        
        print(f"✅ RoBERTa emotion model loaded!")
        print(f"   Classes: {list(label_encoder.classes_)}")
        
        return model, tokenizer, label_encoder
    
    @staticmethod
    def load_multitask_model(model_path: str, model_type: str):
        """Load multitask model (DeBERTa or BERTweet)"""
        print(f"📥 Loading {model_type} multitask model from {model_path}...")
        
        # Determine model name based on type
        if model_type.lower() == 'deberta':
            model_name = "microsoft/deberta-base"
        elif model_type.lower() == 'bertweet':
            model_name = "vinai/bertweet-base"
        else:
            raise ValueError(f"Unknown model type: {model_type}")
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Load model
        try:
            # Try to load from config.json
            model = MultiTaskTransformer.from_pretrained(model_path)
        except:
            # Fallback: load manually
            sentiment_encoder = joblib.load(os.path.join(model_path, 'sentiment_encoder.pkl'))
            emotion_encoder = joblib.load(os.path.join(model_path, 'emotion_encoder.pkl'))
            
            model = MultiTaskTransformer(
                model_name=model_name,
                sentiment_num_classes=len(sentiment_encoder.classes_),
                emotion_num_classes=len(emotion_encoder.classes_)
            )
            
            state_dict_path = os.path.join(model_path, 'pytorch_model.bin')
            state_dict = torch.load(state_dict_path, map_location=device)
            model.load_state_dict(state_dict)
        
        model.to(device)
        model.eval()
        
        # Load label encoders
        sentiment_encoder = joblib.load(os.path.join(model_path, 'sentiment_encoder.pkl'))
        emotion_encoder = joblib.load(os.path.join(model_path, 'emotion_encoder.pkl'))
        
        print(f"✅ {model_type} multitask model loaded!")
        print(f"   Sentiment classes: {list(sentiment_encoder.classes_)}")
        print(f"   Emotion classes: {list(emotion_encoder.classes_)}")
        
        return model, tokenizer, sentiment_encoder, emotion_encoder

# Load all models
print("🚀 Loading All Models")
print("=" * 50)

try:
    roberta_sentiment_model, roberta_sentiment_tokenizer, roberta_sentiment_encoder = \
        ModelLoader.load_roberta_sentiment("roberta_sentiment_model_optimized")
except Exception as e:
    print(f"❌ Failed to load RoBERTa sentiment model: {e}")
    roberta_sentiment_model = None

try:
    roberta_emotion_model, roberta_emotion_tokenizer, roberta_emotion_encoder = \
        ModelLoader.load_roberta_emotion("roberta_emotion_model_optimized")
except Exception as e:
    print(f"❌ Failed to load RoBERTa emotion model: {e}")
    roberta_emotion_model = None

try:
    deberta_model, deberta_tokenizer, deberta_sentiment_encoder, deberta_emotion_encoder = \
        ModelLoader.load_multitask_model("deberta_optimized", "deberta")
except Exception as e:
    print(f"❌ Failed to load DeBERTa model: {e}")
    deberta_model = None

try:
    bertweet_model, bertweet_tokenizer, bertweet_sentiment_encoder, bertweet_emotion_encoder = \
        ModelLoader.load_multitask_model("bertweet_model_ultra_light", "bertweet")
except Exception as e:
    print(f"❌ Failed to load BERTweet model: {e}")
    bertweet_model = None

print("\n✅ Model loading complete!")

In [None]:
# Cell 5: Prediction Functions
class PredictionEngine:
    """Generate predictions from all models"""
    
    @staticmethod
    def predict_roberta_sentiment(text: str, model, tokenizer, label_encoder):
        """Get sentiment prediction from RoBERTa"""
        if model is None:
            return None
            
        inputs = tokenizer(text, return_tensors="pt", truncation=True, 
                          padding=True, max_length=512).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = F.softmax(logits, dim=-1)
            prediction = torch.argmax(logits, dim=-1).item()
        
        return {
            'predicted_class': label_encoder.classes_[prediction],
            'predicted_id': prediction,
            'probabilities': probabilities.cpu().numpy()[0],
            'logits': logits.cpu().numpy()[0]
        }
    
    @staticmethod
    def predict_roberta_emotion(text: str, model, tokenizer, label_encoder):
        """Get emotion prediction from RoBERTa"""
        if model is None:
            return None
            
        inputs = tokenizer(text, return_tensors="pt", truncation=True, 
                          padding=True, max_length=512).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = F.softmax(logits, dim=-1)
            prediction = torch.argmax(logits, dim=-1).item()
        
        return {
            'predicted_class': label_encoder.classes_[prediction],
            'predicted_id': prediction,
            'probabilities': probabilities.cpu().numpy()[0],
            'logits': logits.cpu().numpy()[0]
        }
    
    @staticmethod
    def predict_multitask(text: str, model, tokenizer, sentiment_encoder, emotion_encoder, max_length=128):
        """Get predictions from multitask model"""
        if model is None:
            return None, None
            
        inputs = tokenizer(text, return_tensors="pt", truncation=True, 
                          padding=True, max_length=max_length).to(device)
        
        # Filter out token_type_ids if present
        filtered_inputs = {
            'input_ids': inputs['input_ids'],
            'attention_mask': inputs['attention_mask']
        }
        
        with torch.no_grad():
            outputs = model(**filtered_inputs)
            
            # Sentiment
            sentiment_logits = outputs['sentiment_logits']
            sentiment_probs = F.softmax(sentiment_logits, dim=-1)
            sentiment_pred = torch.argmax(sentiment_logits, dim=-1).item()
            
            # Emotion
            emotion_logits = outputs['emotion_logits']
            emotion_probs = F.softmax(emotion_logits, dim=-1)
            emotion_pred = torch.argmax(emotion_logits, dim=-1).item()
        
        sentiment_result = {
            'predicted_class': sentiment_encoder.classes_[sentiment_pred],
            'predicted_id': sentiment_pred,
            'probabilities': sentiment_probs.cpu().numpy()[0],
            'logits': sentiment_logits.cpu().numpy()[0]
        }
        
        emotion_result = {
            'predicted_class': emotion_encoder.classes_[emotion_pred],
            'predicted_id': emotion_pred,
            'probabilities': emotion_probs.cpu().numpy()[0],
            'logits': emotion_logits.cpu().numpy()[0]
        }
        
        return sentiment_result, emotion_result

def generate_all_predictions(text: str) -> Dict:
    """Generate predictions from all available models"""
    print(f"🔮 Generating predictions for text: '{text[:100]}...'")
    
    results = {}
    
    # RoBERTa Sentiment
    if roberta_sentiment_model is not None:
        results['roberta_sentiment'] = PredictionEngine.predict_roberta_sentiment(
            text, roberta_sentiment_model, roberta_sentiment_tokenizer, roberta_sentiment_encoder
        )
    
    # RoBERTa Emotion
    if roberta_emotion_model is not None:
        results['roberta_emotion'] = PredictionEngine.predict_roberta_emotion(
            text, roberta_emotion_model, roberta_emotion_tokenizer, roberta_emotion_encoder
        )
    
    # DeBERTa Multitask
    if deberta_model is not None:
        sent_pred, emot_pred = PredictionEngine.predict_multitask(
            text, deberta_model, deberta_tokenizer, 
            deberta_sentiment_encoder, deberta_emotion_encoder
        )
        results['deberta_sentiment'] = sent_pred
        results['deberta_emotion'] = emot_pred
    
    # BERTweet Multitask
    if bertweet_model is not None:
        sent_pred, emot_pred = PredictionEngine.predict_multitask(
            text, bertweet_model, bertweet_tokenizer, 
            bertweet_sentiment_encoder, bertweet_emotion_encoder
        )
        results['bertweet_sentiment'] = sent_pred
        results['bertweet_emotion'] = emot_pred
    
    return results

print("✅ Prediction functions defined!")

In [None]:
# Cell 6: SHAP Implementation
class SHAPExplainer:
    """SHAP explanations for text classification"""
    
    def __init__(self):
        self.explainers = {}
    
    def create_prediction_function(self, model, tokenizer, task='sentiment', max_length=512):
        """Create prediction function for SHAP"""
        
        def predict_fn(texts):
            if isinstance(texts, str):
                texts = [texts]
            
            predictions = []
            for text in texts:
                inputs = tokenizer(text, return_tensors="pt", truncation=True, 
                                 padding=True, max_length=max_length).to(device)
                
                # Handle different model types
                if hasattr(model, 'shared_encoder'):  # Multitask model
                    filtered_inputs = {
                        'input_ids': inputs['input_ids'],
                        'attention_mask': inputs['attention_mask']
                    }
                    with torch.no_grad():
                        outputs = model(**filtered_inputs)
                        if task == 'sentiment':
                            logits = outputs['sentiment_logits']
                        else:
                            logits = outputs['emotion_logits']
                else:  # Single task model
                    with torch.no_grad():
                        outputs = model(**inputs)
                        logits = outputs.logits
                
                probs = F.softmax(logits, dim=-1).cpu().numpy()
                predictions.append(probs[0])
            
            return np.array(predictions)
        
        return predict_fn
    
    def explain_text(self, text: str, model, tokenizer, task='sentiment', model_name='model'):
        """Generate SHAP explanation for text"""
        print(f"🔍 Generating SHAP explanation for {model_name} ({task})...")
        
        if model is None:
            print(f"❌ {model_name} model not available")
            return None
        
        try:
            # Create prediction function
            predict_fn = self.create_prediction_function(model, tokenizer, task)
            
            # Create explainer
            explainer = shap.Explainer(predict_fn, masker=shap.maskers.Text())
            
            # Generate explanation
            shap_values = explainer([text])
            
            return {
                'explainer': explainer,
                'shap_values': shap_values,
                'text': text,
                'model_name': model_name,
                'task': task
            }
        
        except Exception as e:
            print(f"❌ SHAP explanation failed for {model_name}: {e}")
            return None
    
    def plot_shap_explanation(self, shap_result, save_path=None):
        """Plot SHAP explanation"""
        if shap_result is None:
            return
        
        try:
            plt.figure(figsize=(12, 6))
            shap.plots.text(shap_result['shap_values'][0], display=False)
            
            if save_path:
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
            
            plt.title(f"SHAP Explanation: {shap_result['model_name']} ({shap_result['task']})")
            plt.show()
            
        except Exception as e:
            print(f"❌ SHAP plotting failed: {e}")

print("✅ SHAP explainer defined!")

In [None]:
# Cell 7: LIME Implementation
class LIMEExplainer:
    """LIME explanations for text classification"""
    
    def __init__(self):
        self.explainer = LimeTextExplainer(
            class_names=['Negative', 'Neutral', 'Positive'],  # Will be updated per task
            mode='classification'
        )
    
    def create_prediction_function(self, model, tokenizer, label_encoder, task='sentiment', max_length=512):
        """Create prediction function for LIME"""
        
        def predict_fn(texts):
            predictions = []
            for text in texts:
                inputs = tokenizer(text, return_tensors="pt", truncation=True, 
                                 padding=True, max_length=max_length).to(device)
                
                # Handle different model types
                if hasattr(model, 'shared_encoder'):  # Multitask model
                    filtered_inputs = {
                        'input_ids': inputs['input_ids'],
                        'attention_mask': inputs['attention_mask']
                    }
                    with torch.no_grad():
                        outputs = model(**filtered_inputs)
                        if task == 'sentiment':
                            logits = outputs['sentiment_logits']
                        else:
                            logits = outputs['emotion_logits']
                else:  # Single task model
                    with torch.no_grad():
                        outputs = model(**inputs)
                        logits = outputs.logits
                
                probs = F.softmax(logits, dim=-1).cpu().numpy()
                predictions.append(probs[0])
            
            return np.array(predictions)
        
        return predict_fn
    
    def explain_text(self, text: str, model, tokenizer, label_encoder, task='sentiment', model_name='model'):
        """Generate LIME explanation for text"""
        print(f"🔍 Generating LIME explanation for {model_name} ({task})...")
        
        if model is None:
            print(f"❌ {model_name} model not available")
            return None
        
        try:
            # Update class names
            self.explainer.class_names = list(label_encoder.classes_)
            
            # Create prediction function
            predict_fn = self.create_prediction_function(model, tokenizer, label_encoder, task)
            
            # Generate explanation
            explanation = self.explainer.explain_instance(
                text, predict_fn, num_features=20, num_samples=1000
            )
            
            return {
                'explanation': explanation,
                'text': text,
                'model_name': model_name,
                'task': task,
                'class_names': list(label_encoder.classes_)
            }
        
        except Exception as e:
            print(f"❌ LIME explanation failed for {model_name}: {e}")
            return None
    
    def plot_lime_explanation(self, lime_result, save_path=None):
        """Plot LIME explanation"""
        if lime_result is None:
            return
        
        try:
            # Create figure
            fig = lime_result['explanation'].as_pyplot_figure()
            fig.suptitle(f"LIME Explanation: {lime_result['model_name']} ({lime_result['task']})")
            
            if save_path:
                fig.savefig(save_path, dpi=300, bbox_inches='tight')
            
            plt.show()
            
            # Also show as HTML (for notebook display)
            print("HTML Visualization:")
            lime_result['explanation'].show_in_notebook(text=True)
            
        except Exception as e:
            print(f"❌ LIME plotting failed: {e}")

print("✅ LIME explainer defined!")

In [None]:
# Cell 8: Integrated Gradients Implementation
class IntegratedGradientsExplainer:
    """Integrated Gradients explanations using Captum"""
    
    def __init__(self):
        self.baseline_token_id = 0  # [PAD] token usually
    
    def create_forward_function(self, model, task='sentiment'):
        """Create forward function for Captum"""
        
        def forward_fn(input_ids, attention_mask):
            # Handle different model types
            if hasattr(model, 'shared_encoder'):  # Multitask model
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                if task == 'sentiment':
                    return outputs['sentiment_logits']
                else:
                    return outputs['emotion_logits']
            else:  # Single task model
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                return outputs.logits
        
        return forward_fn
    
    def explain_text(self, text: str, model, tokenizer, task='sentiment', model_name='model', max_length=512):
        """Generate Integrated Gradients explanation"""
        print(f"🔍 Generating Integrated Gradients explanation for {model_name} ({task})...")
        
        if model is None:
            print(f"❌ {model_name} model not available")
            return None
        
        try:
            # Tokenize input
            inputs = tokenizer(text, return_tensors="pt", truncation=True, 
                             padding=True, max_length=max_length)
            input_ids = inputs['input_ids'].to(device)
            attention_mask = inputs['attention_mask'].to(device)
            
            # Create forward function
            forward_fn = self.create_forward_function(model, task)
            
            # Initialize Integrated Gradients
            ig = IntegratedGradients(forward_fn)
            
            # Create baseline (all PAD tokens)
            baseline = torch.zeros_like(input_ids).to(device)
            
            # Get prediction
            with torch.no_grad():
                logits = forward_fn(input_ids, attention_mask)
                pred_class = torch.argmax(logits, dim=-1).item()
            
            # Calculate attributions
            attributions = ig.attribute(
                inputs=(input_ids, attention_mask),
                baselines=(baseline, attention_mask),
                target=pred_class,
                n_steps=50
            )
            
            # Get attribution scores for input_ids
            input_attributions = attributions[0]
            
            # Convert to tokens
            tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
            
            return {
                'attributions': input_attributions,
                'tokens': tokens,
                'input_ids': input_ids,
                'predicted_class': pred_class,
                'text': text,
                'model_name': model_name,
                'task': task
            }
        
        except Exception as e:
            print(f"❌ Integrated Gradients explanation failed for {model_name}: {e}")
            return None
    
    def plot_ig_explanation(self, ig_result, save_path=None):
        """Plot Integrated Gradients explanation"""
        if ig_result is None:
            return
        
        try:
            # Get attribution scores and tokens
            attributions = ig_result['attributions'][0].cpu().numpy()
            tokens = ig_result['tokens']
            
            # Create visualization
            plt.figure(figsize=(15, 8))
            
            # Normalize attributions for color mapping
            max_abs_attr = max(abs(attributions.min()), abs(attributions.max()))
            normalized_attrs = attributions / max_abs_attr if max_abs_attr > 0 else attributions
            
            # Create color map
            colors = plt.cm.RdYlBu_r(normalized_attrs * 0.5 + 0.5)
            
            # Plot tokens with colors
            fig, ax = plt.subplots(figsize=(15, 8))
            
            y_pos = 0.5
            x_pos = 0.1
            x_step = 0.8 / len(tokens)
            
            for i, (token, attr, color) in enumerate(zip(tokens, attributions, colors)):
                # Clean token for display
                clean_token = token.replace('Ġ', '').replace('▁', '')
                if clean_token.startswith('##'):
                    clean_token = clean_token[2:]
                
                # Skip special tokens for cleaner visualization
                if clean_token in ['<s>', '</s>', '<pad>', '[CLS]', '[SEP]']:
                    continue
                
                # Plot token with background color
                bbox_props = dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.7)
                ax.text(x_pos, y_pos, clean_token, fontsize=10, ha='center', va='center',
                       bbox=bbox_props, transform=ax.transAxes)
                
                x_pos += x_step
            
            # Add color bar
            sm = plt.cm.ScalarMappable(cmap=plt.cm.RdYlBu_r, 
                                     norm=plt.Normalize(vmin=-max_abs_attr, vmax=max_abs_attr))
            sm.set_array([])
            cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', pad=0.1, shrink=0.8)
            cbar.set_label('Attribution Score', fontsize=12)
            
            ax.set_xlim(0, 1)
            ax.set_ylim(0, 1)
            ax.axis('off')
            ax.set_title(f"Integrated Gradients: {ig_result['model_name']} ({ig_result['task']})", 
                        fontsize=14, fontweight='bold', pad=20)
            
            if save_path:
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
            
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"❌ IG plotting failed: {e}")

print("✅ Integrated Gradients explainer defined!")

In [None]:
# Cell 9: XAI Analysis Engine
class XAIAnalysisEngine:
    """Comprehensive XAI analysis for all models"""
    
    def __init__(self):
        self.shap_explainer = SHAPExplainer()
        self.lime_explainer = LIMEExplainer()
        self.ig_explainer = IntegratedGradientsExplainer()
    
    def analyze_single_text(self, text: str, save_visualizations=True, post_id=None):
        """Complete XAI analysis for a single text"""
        print(f"\n🚀 Starting XAI Analysis")
        print("=" * 60)
        print(f"Text: '{text[:100]}...'")
        print("=" * 60)
        
        # Generate predictions first
        predictions = generate_all_predictions(text)
        
        # Display predictions
        print("\n📊 PREDICTIONS:")
        print("-" * 30)
        for model_task, pred in predictions.items():
            if pred is not None:
                print(f"{model_task}: {pred['predicted_class']} (conf: {pred['probabilities'].max():.3f})")
        
        results = {'predictions': predictions, 'explanations': {}}
        
        # XAI Analysis for each available model
        model_configs = [
            ('roberta_sentiment', roberta_sentiment_model, roberta_sentiment_tokenizer, roberta_sentiment_encoder, 'sentiment'),
            ('roberta_emotion', roberta_emotion_model, roberta_emotion_tokenizer, roberta_emotion_encoder, 'emotion'),
            ('deberta_sentiment', deberta_model, deberta_tokenizer, deberta_sentiment_encoder, 'sentiment'),
            ('deberta_emotion', deberta_model, deberta_tokenizer, deberta_emotion_encoder, 'emotion'),
            ('bertweet_sentiment', bertweet_model, bertweet_tokenizer, bertweet_sentiment_encoder, 'sentiment'),
            ('bertweet_emotion', bertweet_model, bertweet_tokenizer, bertweet_emotion_encoder, 'emotion'),
        ]
        
        for model_name, model, tokenizer, encoder, task in model_configs:
            if model is None or encoder is None:
                continue
                
            print(f"\n🔍 Analyzing {model_name}...")
            
            # Create save paths if needed
            post_suffix = f"_post_{post_id}" if post_id else ""
            
            # SHAP Analysis
            shap_result = self.shap_explainer.explain_text(text, model, tokenizer, task, model_name)
            if shap_result and save_visualizations:
                save_path = f"xai_visualizations/shap_{model_name}{post_suffix}.png"
                self.shap_explainer.plot_shap_explanation(shap_result, save_path)
            
            # LIME Analysis
            lime_result = self.lime_explainer.explain_text(text, model, tokenizer, encoder, task, model_name)
            if lime_result and save_visualizations:
                save_path = f"xai_visualizations/lime_{model_name}{post_suffix}.png"
                self.lime_explainer.plot_lime_explanation(lime_result, save_path)
            
            # Integrated Gradients Analysis
            max_length = 128 if 'bertweet' in model_name or 'deberta' in model_name else 512
            ig_result = self.ig_explainer.explain_text(text, model, tokenizer, task, model_name, max_length)
            if ig_result and save_visualizations:
                save_path = f"xai_visualizations/ig_{model_name}{post_suffix}.png"
                self.ig_explainer.plot_ig_explanation(ig_result, save_path)
            
            # Store results
            results['explanations'][model_name] = {
                'shap': shap_result,
                'lime': lime_result,
                'ig': ig_result
            }
        
        return results
    
    def batch_analyze(self, df: pd.DataFrame, num_samples=5, random_sample=True):
        """Analyze multiple Reddit posts"""
        print(f"\n🔄 Starting Batch XAI Analysis")
        print("=" * 60)
        
        # Select samples
        if random_sample:
            sample_df = df.sample(n=min(num_samples, len(df)), random_state=42)
        else:
            sample_df = df.head(num_samples)
        
        batch_results = {}
        
        for idx, row in sample_df.iterrows():
            post_id = row['id']
            text = row['cleaned_text']
            
            print(f"\n📝 Analyzing Post {idx + 1}/{num_samples} (ID: {post_id})")
            print(f"True labels - Sentiment: {row['sentiment']}, Emotion: {row['emotion']}")
            
            result = self.analyze_single_text(text, save_visualizations=True, post_id=post_id)
            batch_results[post_id] = result
        
        return batch_results

print("✅ XAI Analysis Engine defined!")

In [None]:
# Cell 10: Run XAI Analysis on Sample Reddit Post
# Select a sample Reddit post for analysis
sample_post = reddit_df.iloc[0]  # Use first post, or change index as needed

print("🎯 Selected Reddit Post for XAI Analysis:")
print("=" * 60)
print(f"ID: {sample_post['id']}")
print(f"True Sentiment: {sample_post['sentiment']}")
print(f"True Emotion: {sample_post['emotion']}")
print(f"Text: {sample_post['cleaned_text']}")
print("=" * 60)

# Initialize XAI engine
xai_engine = XAIAnalysisEngine()

# Run comprehensive analysis
sample_text = sample_post['cleaned_text']
xai_results = xai_engine.analyze_single_text(
    text=sample_text,
    save_visualizations=True,
    post_id=sample_post['id']
)

print("\n🎉 XAI Analysis Complete!")
print(f"Visualizations saved to 'xai_visualizations/' directory")

In [None]:
# Cell 11: Batch Processing (Optional)
# Uncomment and run this cell to analyze multiple posts

"""
print("🔄 Running Batch XAI Analysis on Multiple Posts")
print("=" * 60)

# Analyze 3 random posts (adjust number as needed)
batch_results = xai_engine.batch_analyze(
    df=reddit_df,
    num_samples=3,
    random_sample=True
)

print(f"\n🎉 Batch Analysis Complete!")
print(f"Analyzed {len(batch_results)} posts")
print(f"All visualizations saved to 'xai_visualizations/' directory")

# Display summary
print("\n📊 Batch Analysis Summary:")
for post_id, results in batch_results.items():
    print(f"\nPost {post_id}:")
    for model_name, pred in results['predictions'].items():
        if pred is not None:
            print(f"  {model_name}: {pred['predicted_class']}")
"""

In [None]:
# Cell 12: Error Handling and Model Status Check
def check_model_status():
    """Check which models are loaded and available"""
    print("🔍 Model Status Check")
    print("=" * 40)
    
    models_status = {
        "RoBERTa Sentiment": roberta_sentiment_model is not None,
        "RoBERTa Emotion": roberta_emotion_model is not None,
        "DeBERTa Multitask": deberta_model is not None,
        "BERTweet Multitask": bertweet_model is not None,
    }
    
    for model_name, status in models_status.items():
        status_icon = "✅" if status else "❌"
        print(f"{status_icon} {model_name}: {'Available' if status else 'Not loaded'}")
    
    available_count = sum(models_status.values())
    total_count = len(models_status)
    
    print(f"\n📊 Summary: {available_count}/{total_count} models loaded successfully")
    
    if available_count == 0:
        print("\n⚠️  No models loaded! Please check your model paths and try reloading.")
    elif available_count < total_count:
        print("\n⚠️  Some models failed to load. XAI analysis will only work for loaded models.")
    else:
        print("\n🎉 All models loaded successfully! Full XAI analysis available.")
    
    return models_status

# Run status check
model_status = check_model_status()

print("\n✅ XAI System Ready!")
print("🚀 You can now run XAI analysis on any Reddit post text!")