In [None]:
!pip install transformers==4.57.1

In [None]:
import shap
import torch
import torch.nn as nn
from transformers import BertModel, PreTrainedModel, PretrainedConfig, BertTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput

device = torch.device('cuda')
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
class BERTConfig(PretrainedConfig):
    model_type = "bert_with_absa"

    def __init__(self, absa_method=None, num_classes=2, class_weight=None, **kwargs):
        super().__init__(**kwargs)
        self.absa_method = absa_method
        self.num_classes = num_classes
        self.class_weight = class_weight

class InnerBert(nn.Module):
    def __init__(self, num_classes):
        super(InnerBert, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(768, num_classes)
        

    def forward(self, inputs_embeds=None, attention_mask=None, token_type_ids=None):      
        _, pooled_output = self.bert(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False)
        x = self.dropout(pooled_output)
        logits = self.fc(x)

        return logits

class BERTModel(PreTrainedModel):
    config_class = BERTConfig

    def __init__(self, config):
        super().__init__(config)
        self.num_classes = config.num_classes
        self.absa_method = config.absa_method
        
        self.bert_sent = InnerBert(self.num_classes)

        if self.absa_method:
            self.absa_fc = nn.Linear(1, 768)
            self.bert_absa = InnerBert(self.num_classes)

        if config.class_weight is not None:
            class_weight = torch.tensor(config.class_weight, dtype=torch.float32)
            self.criterion = nn.CrossEntropyLoss(weight=class_weight.to(device))
        else:
            self.criterion = nn.CrossEntropyLoss()

        self.init_weights()


    def forward(self, input_ids=None, absa_1=None, absa_2=None, attention_mask=None, token_type_ids=None, labels=None):
        inputs_embeds = self.bert_sent.bert.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
        logits_sent = self.bert_sent(inputs_embeds, attention_mask, token_type_ids)
        
        if self.absa_method:
            absa_1 = self.absa_fc(absa_1)
            absa_2 = self.absa_fc(absa_2)
            absa_concat = torch.cat((absa_1, absa_2), dim=1)
            token_type_ids_absa = torch.tensor([0, 1]).unsqueeze(0).repeat(absa_concat.shape[0], 1).to("cuda")
            
            logits_absa = self.bert_absa(absa_concat, None, token_type_ids_absa)
            
            logits_sent += logits_absa
        
        loss = None
        if labels is not None:
            loss = self.criterion(logits_sent, labels)
    
        return SequenceClassifierOutput(loss=loss, logits=logits_sent)

In [None]:
def test_inference_model(model, s1, s2, absa_1, absa_2):
    tokens_a = bert_tokenizer.tokenize(s1)
    tokens_b = bert_tokenizer.tokenize(s2)

    tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
    segment_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)
    print("len_token_s1", len(tokens_a))
    print("len_token_s2", len(tokens_b))
    
    input_ids = bert_tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    padding = [0] * (128 - len(input_ids))

    input_ids += padding
    input_mask += padding
    segment_ids += padding

    input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
    attention_mask = torch.tensor([input_mask], dtype=torch.long).to(device)
    segment_ids = torch.tensor([segment_ids], dtype=torch.long).to(device)
    
    if absa_1 and absa_2:
        absa_1 = torch.tensor([absa_1], dtype=torch.float32).unsqueeze(1).unsqueeze(1).to(device)
        absa_2 = torch.tensor([absa_2], dtype=torch.float32).unsqueeze(1).unsqueeze(1).to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, absa_1=absa_1, absa_2=absa_2, attention_mask=attention_mask, token_type_ids=segment_ids).logits
        softmax = torch.nn.Softmax(dim=-1)
        probs = softmax(outputs)
        return probs

In [None]:
def get_shap_values(model, s1, s2, absa_1, absa_2, len_token_s1, len_token_s2):   
    def f(x):
        """Function to explain - takes token strings"""
        absa_1_t = torch.tensor([absa_1], dtype=torch.float32).unsqueeze(1).unsqueeze(1).to(device)
        absa_2_t = torch.tensor([absa_2], dtype=torch.float32).unsqueeze(1).unsqueeze(1).to(device)
        
        segment_ids = [0] * (len_token_s1 + 2) + [1] * (len_token_s2 + 1)
        
        tv = [bert_tokenizer.tokenize(v) for v in x]
        for i in range(len(tv)):
            tv[i] = ['[CLS]'] + tv[i] + ['[SEP]']
        input_ids_list = []
        attention_mask_list = []
        segment_ids_list = []
        absa_1_list = []
        absa_2_list = []

        for tokens in tv:
            new_segment_ids = segment_ids
            input_ids = bert_tokenizer.convert_tokens_to_ids(tokens)
            
            input_mask = [1] * len(input_ids)
            padding = [0] * (128 - len(input_ids))
            
            input_ids += padding
            input_mask += padding
            new_segment_ids = segment_ids + [0] * (128 - len(new_segment_ids))
            
            input_ids_list.append(torch.tensor([input_ids], dtype=torch.long).to(device))
            attention_mask_list.append(torch.tensor([input_mask], dtype=torch.long).to(device))
            segment_ids_list.append(torch.tensor([new_segment_ids], dtype=torch.long).to(device))
            if absa_1 and absa_2:
                absa_1_list.append(absa_1_t)
                absa_2_list.append(absa_2_t)
        
        input_ids = torch.cat(input_ids_list, dim=0)
        attention_mask = torch.cat(attention_mask_list, dim=0)
        segment_ids = torch.cat(segment_ids_list, dim=0)
        if absa_1_t and absa_2_t:
            absa_1_t = torch.cat(absa_1_list, dim=0)
            absa_2_t = torch.cat(absa_2_list, dim=0)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, absa_1=absa_1_t, absa_2=absa_2_t, attention_mask=attention_mask, token_type_ids=segment_ids).logits
        return outputs.cpu().numpy()
    
    # Create explainer
    explainer = shap.Explainer(f, bert_tokenizer, output_names=["-1", "0", "1"])

    test = [f"{s1} [SEP] {s2}"]

    shap_values = explainer(test)
    
    return shap_values


In [None]:

# Load dataset and prepare
from datasets import load_dataset
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, classification_report

# Load dataset
print("Loading dataset from Hugging Face...")
dataset = load_dataset("trungpq/rlcc-new-data-appearance")
print(f"Dataset loaded. Total samples: {len(dataset['test'])}")

# Load ALL samples - but will filter out invalid ones
raw_data = []
for sample in dataset['test']:
    s1 = sample.get('sentences_1')
    s2 = sample.get('sentences_2')
    appearance = sample.get('appearance')
    
    raw_data.append({
        's1': s1,
        's2': s2,
        'absa_1': sample.get('absa_min_1', 0),
        'absa_2': sample.get('absa_min_2', 0),
        'label': appearance
    })

print(f"\nTotal samples in dataset: {len(raw_data)}")

# Filter: keep only samples with valid labels and both s1 and s2 present
filtered_data = []
for sample in raw_data:
    # Check if both s1 and s2 exist
    if not sample['s1'] or not sample['s2']:
        continue
    
    # Check if label is valid (-1, 0, or 1)
    if sample['label'] not in [-1, 0, 1]:
        continue
    
    # Normalize s1 and s2
    s1_normalized = sample['s1'].lower().replace(",", "").replace(".", "")
    s2_normalized = sample['s2'].lower().replace(",", "").replace(".", "")
    
    filtered_data.append({
        's1': s1_normalized,
        's2': s2_normalized,
        'absa_1': sample['absa_1'],
        'absa_2': sample['absa_2'],
        'label': sample['label']
    })

print(f"Samples after filtering (valid labels + both s1 and s2): {len(filtered_data)}")

# Show some sample statistics
if filtered_data:
    print(f"\nSample data (first 3):")
    for i, sample in enumerate(filtered_data[:3]):
        print(f"\nSample {i+1}:")
        print(f"  Label: {sample['label']}")
        print(f"  S1: {sample['s1'][:80]}...")
        print(f"  S2: {sample['s2'][:80]}...")
        print(f"  ABSA1: {sample['absa_1']}, ABSA2: {sample['absa_2']}")

# Load model
print("\n" + "="*80)
print("Loading appearance model...")
appearance_model = BERTModel.from_pretrained("trungpq/rlcc-new-appearance-upsample_replacement-absa-min").to(device)
appearance_model = appearance_model.eval()
print("✓ Model loaded successfully")

# filtered_data = filtered_data[:30]


In [None]:

# PART 1: Baseline Evaluation - Get initial F1 score on full dataset

def get_model_predictions(model, s1, s2, absa_1, absa_2):
    """Get model predictions for a single sample"""
    tokens_a = bert_tokenizer.tokenize(s1)
    tokens_b = bert_tokenizer.tokenize(s2)

    tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
    segment_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)
    
    input_ids = bert_tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    padding = [0] * (128 - len(input_ids))

    input_ids += padding
    input_mask += padding
    segment_ids += padding

    input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
    attention_mask = torch.tensor([input_mask], dtype=torch.long).to(device)
    segment_ids = torch.tensor([segment_ids], dtype=torch.long).to(device)
    
    if absa_1 and absa_2:
        absa_1 = torch.tensor([absa_1], dtype=torch.float32).unsqueeze(1).unsqueeze(1).to(device)
        absa_2 = torch.tensor([absa_2], dtype=torch.float32).unsqueeze(1).unsqueeze(1).to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, absa_1=absa_1, absa_2=absa_2, 
                       attention_mask=attention_mask, token_type_ids=segment_ids).logits
        softmax = torch.nn.Softmax(dim=-1)
        probs = softmax(outputs)
        pred_class = torch.argmax(probs, dim=1)
    
    return pred_class.item(), probs[0].cpu().numpy()

print("\n" + "="*80)
print("PART 1: BASELINE EVALUATION ON FULL DATASET")
print("="*80)

# Get predictions on all samples
all_preds = []
all_labels = []
label_mapping = {-1: 0, 0: 1, 1: 2}  # Map labels to indices
inverse_mapping = {0: -1, 1: 0, 2: 1}

print(f"\nEvaluating {len(filtered_data)} samples...")
for idx, sample in enumerate(filtered_data):
    if idx % max(1, len(filtered_data)//10) == 0:
        print(f"  Progress: {idx}/{len(filtered_data)}")
    
    # All samples in filtered_data now have valid s1, s2, and labels
    pred_class, probs = get_model_predictions(
        appearance_model, 
        sample['s1'], sample['s2'],
        sample['absa_1'], sample['absa_2']
    )
    all_preds.append(pred_class)
    all_labels.append(sample['label'])  # Keep actual label from dataset

# Calculate F1 for each class (binary: class_i vs others)
print(f"\n✓ Computing F1 scores for all 3 classes...")

f1_class_minus1 = f1_score(
    [1 if l == -1 else 0 for l in all_labels],
    [1 if p == label_mapping[-1] else 0 for p in all_preds],
    zero_division=0
)

f1_class_0 = f1_score(
    [1 if l == 0 else 0 for l in all_labels],
    [1 if p == label_mapping[0] else 0 for p in all_preds],
    zero_division=0
)

f1_class_1 = f1_score(
    [1 if l == 1 else 0 for l in all_labels],
    [1 if p == label_mapping[1] else 0 for p in all_preds],
    zero_division=0
)

baseline_macro_f1 = (f1_class_minus1 + f1_class_0 + f1_class_1) / 3.0

print(f"\n✓ Baseline Metrics (F1 per class - binary: class_i vs others):")
print(f"  F1 Score Class -1: {f1_class_minus1:.4f}")
print(f"  F1 Score Class 0:  {f1_class_0:.4f}")
print(f"  F1 Score Class 1:  {f1_class_1:.4f}")
print(f"  Macro F1:          {baseline_macro_f1:.4f}")


## Part 2-5: SHAP Faithfulness Evaluation

Compute SHAP values for all 3 classes, extract token importance from both S1 and S2,
remove top-k and bottom-k tokens, and evaluate macro F1 scores.


In [None]:

import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import nltk
from tqdm import tqdm

nltk.download('averaged_perceptron_tagger_eng')

print("\n" + "="*80)
print("PART 2: COMPUTE SHAP VALUES AND BUILD IMPORTANT TOKEN SETS")
print("="*80)

# Build aggregated SHAP values across samples
all_token_shap_values = defaultdict(lambda: {'positive': [], 'negative': []})
sample_token_data = []

print(f"\nComputing SHAP values for {len(filtered_data)} samples...")

for sample_idx, sample in enumerate(tqdm(filtered_data, desc="Processing")):
    s1 = sample['s1']
    s2 = sample['s2']
    absa_1 = sample['absa_1']
    absa_2 = sample['absa_2']
    
    # Skip if s1 or s2 is null
    if s1 is None or s2 is None:
        continue
    
    # Tokenize
    tokens_a = bert_tokenizer.tokenize(s1)
    tokens_b = bert_tokenizer.tokenize(s2)
    len_token_s1 = len(tokens_a)
    len_token_s2 = len(tokens_b)
    tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"]
    
    # Create SHAP explanation function for this sample
    def f_shap(x):
        """Function to explain - takes token strings"""
        absa_1_t = torch.tensor([absa_1], dtype=torch.float32).unsqueeze(1).unsqueeze(1).to(device)
        absa_2_t = torch.tensor([absa_2], dtype=torch.float32).unsqueeze(1).unsqueeze(1).to(device)
        
        segment_ids_base = [0] * (len_token_s1 + 2) + [1] * (len_token_s2 + 1)
        
        tv = [bert_tokenizer.tokenize(v) for v in x]
        input_ids_list = []
        attention_mask_list = []
        segment_ids_list = []

        for tokens_var in tv:
            tokens_var = ['[CLS]'] + tokens_var + ['[SEP]']
            input_ids = bert_tokenizer.convert_tokens_to_ids(tokens_var)
            
            input_mask = [1] * len(input_ids)
            padding = [0] * (128 - len(input_ids))
            
            input_ids += padding
            input_mask += padding
            segment_ids_padded = segment_ids_base + [0] * (128 - len(segment_ids_base))
            
            input_ids_list.append(torch.tensor([input_ids], dtype=torch.long).to(device))
            attention_mask_list.append(torch.tensor([input_mask], dtype=torch.long).to(device))
            segment_ids_list.append(torch.tensor([segment_ids_padded], dtype=torch.long).to(device))
        
        input_ids = torch.cat(input_ids_list, dim=0)
        attention_mask = torch.cat(attention_mask_list, dim=0)
        segment_ids = torch.cat(segment_ids_list, dim=0)
        
        with torch.no_grad():
            outputs = appearance_model(input_ids=input_ids, absa_1=absa_1_t, absa_2=absa_2_t, 
                                      attention_mask=attention_mask, token_type_ids=segment_ids).logits
        return outputs.cpu().numpy()
    
    # Get SHAP values
    try:
        explainer = shap.Explainer(f_shap, bert_tokenizer, output_names=["class_-1", "class_0", "class_1"])
        test_input = [f"{s1} [SEP] {s2}"]
        shap_values = explainer(test_input, silent=True)
        
        # Extract SHAP values for all 3 classes
        # shap_values.values shape: (1, num_tokens, 3)
        shap_vals_class_minus1 = shap_values.values[0, :, 0]  # Class -1
        shap_vals_class_0 = shap_values.values[0, :, 1]  # Class 0
        shap_vals_class_1 = shap_values.values[0, :, 2]  # Class 1
        
        # Store sample token data
        sample_token_data.append({
            'sample_idx': sample_idx,
            's1_tokens': tokens_a,
            's2_tokens': tokens_b,
            'shap_values_minus1': shap_vals_class_minus1,
            'shap_values_0': shap_vals_class_0,
            'shap_values_1': shap_vals_class_1,
            'tokens': tokens,
            's1': s1,
            's2': s2,
            'absa_1': absa_1,
            'absa_2': absa_2,
            'label': sample['label']
        })
        
    except Exception as e:
        continue

print(f"\n✓ SHAP values computed for {len(sample_token_data)} samples")

In [None]:

# PART 3: Token Removal Test - Per-Sample Analysis

def merge_wordpieces(tokens):
    """Merge ##word tokens với token trước đó"""
    merged_tokens = []
    merged_indices = []
    
    for i, tok in enumerate(tokens):
        if tok.startswith("##"):
            merged_tokens[-1] += tok[2:]
            merged_indices[-1].append(i)
        else:
            merged_tokens.append(tok)
            merged_indices.append([i])
    return merged_tokens, merged_indices


print("\n" + "="*80)
print("PART 3: PREPARE FOR PER-SAMPLE TOKEN REMOVAL")
print("="*80)

# Build token importance per sample (using absolute SHAP values, for all 3 classes)
# Average SHAP importance across all 3 classes
sample_token_importance = []

for data in sample_token_data:
    s1 = data['s1']
    s2 = data['s2']
    s1_tokens = data['s1_tokens']
    s2_tokens = data['s2_tokens']
    shap_values_minus1 = data['shap_values_minus1']
    shap_values_0 = data['shap_values_0']
    shap_values_1 = data['shap_values_1']
    label = data.get('label', -1)
    
    if s1 is None or s2 is None:
        continue
    
    # Get indices for S1 and S2 tokens
    s1_start_idx = 1  # After [CLS]
    s1_end_idx = 1 + len(s1_tokens)  # Before first [SEP]
    s2_start_idx = s1_end_idx + 1  # After first [SEP]
    s2_end_idx = s2_start_idx + len(s2_tokens)  # Before second [SEP]
    
    # Extract SHAP for S1 and S2 from each class
    s1_shap_minus1 = shap_values_minus1[s1_start_idx:s1_end_idx]
    s1_shap_0 = shap_values_0[s1_start_idx:s1_end_idx]
    s1_shap_1 = shap_values_1[s1_start_idx:s1_end_idx]
    
    s2_shap_minus1 = shap_values_minus1[s2_start_idx:s2_end_idx]
    s2_shap_0 = shap_values_0[s2_start_idx:s2_end_idx]
    s2_shap_1 = shap_values_1[s2_start_idx:s2_end_idx]
    
    # Average SHAP values across 3 classes
    s1_shap_avg = (np.abs(s1_shap_minus1) + np.abs(s1_shap_0) + np.abs(s1_shap_1)) / 3.0
    s2_shap_avg = (np.abs(s2_shap_minus1) + np.abs(s2_shap_0) + np.abs(s2_shap_1)) / 3.0
    
    # Merge wordpieces for S1 and S2
    merged_tokens_s1, merged_indices_s1 = merge_wordpieces(s1_tokens)
    merged_tokens_s2, merged_indices_s2 = merge_wordpieces(s2_tokens)
    
    # Combine S1 and S2 tokens for overall importance ranking
    all_tokens = merged_tokens_s1 + merged_tokens_s2
    
    # Create importance mapping for all tokens
    merged_importance_all = {}
    
    # S1 tokens (indices 0 to len(merged_tokens_s1)-1)
    for merged_idx, original_indices in enumerate(merged_indices_s1):
        total_importance = sum(s1_shap_avg[orig_idx] for orig_idx in original_indices 
                               if orig_idx < len(s1_shap_avg))
        merged_importance_all[merged_idx] = total_importance
    
    # S2 tokens (indices len(merged_tokens_s1) to end)
    for merged_idx, original_indices in enumerate(merged_indices_s2):
        s1_len = len(merged_tokens_s1)
        total_importance = sum(s2_shap_avg[orig_idx] for orig_idx in original_indices 
                               if orig_idx < len(s2_shap_avg))
        merged_importance_all[s1_len + merged_idx] = total_importance
    
    # Sort all tokens by importance
    sorted_by_importance = sorted(enumerate(all_tokens), 
                                  key=lambda x: merged_importance_all.get(x[0], 0), 
                                  reverse=True)
    
    sample_token_importance.append({
        'sample_idx': data['sample_idx'],
        's1': s1,
        's2': s2,
        'absa_1': data['absa_1'],
        'absa_2': data['absa_2'],
        'label': label,
        'all_tokens': all_tokens,  # Combined S1 + S2 tokens
        's1_len': len(merged_tokens_s1),
        's2_len': len(merged_tokens_s2),
        'sorted_by_importance': sorted_by_importance  # [(idx, token), ...]
    })

print(f"\n✓ Prepared {len(sample_token_importance)} samples for token removal analysis")
print(f"  Each sample will have independent top-k and bot-k tokens")
print(f"  Token importance averaged across all 3 classes (-1, 0, 1)")
print(f"  Tokens from both S1 and S2 are ranked together")

# Count adjectives in each sample for k_max determination
import nltk
try:
    nltk.data.find('taggers/averaged_perceptron_tagger')
except LookupError:
    nltk.download('averaged_perceptron_tagger')

max_adj_count = 0
for sample_info in sample_token_importance:
    all_tokens = sample_info['all_tokens']
    # POS tagging to identify adjectives
    pos_tags = nltk.pos_tag(all_tokens)
    adj_count = sum(1 for word, tag in pos_tags if tag.startswith('JJ'))  # JJ, JJR, JJS = adjectives
    max_adj_count = max(max_adj_count, adj_count)

print(f"\n✓ Max number of adjectives found: {max_adj_count}")
print(f"  This will be used as k_max for token removal experiments")

In [None]:
# PART 4: Token Removal and Macro F1 Evaluation

def remove_tokens_from_s1_and_s2(s1_text, s2_text, tokens_to_remove):
    """Remove specified tokens from both s1 and s2"""
    s1_result = None
    s2_result = None
    
    if s1_text is not None:
        s1_tokens = bert_tokenizer.tokenize(s1_text)
        merged_tokens_s1, _ = merge_wordpieces(s1_tokens)
        remaining_tokens_s1 = [t for t in merged_tokens_s1 if t not in tokens_to_remove]
        s1_result = " ".join(remaining_tokens_s1) if remaining_tokens_s1 else ""
    
    if s2_text is not None:
        s2_tokens = bert_tokenizer.tokenize(s2_text)
        merged_tokens_s2, _ = merge_wordpieces(s2_tokens)
        remaining_tokens_s2 = [t for t in merged_tokens_s2 if t not in tokens_to_remove]
        s2_result = " ".join(remaining_tokens_s2) if remaining_tokens_s2 else ""
    
    return s1_result, s2_result


def evaluate_macro_f1_after_removal(model, data, removed_tokens):
    """Evaluate macro F1 across all 3 classes after removing tokens from both s1 and s2"""
    predictions = []
    labels = []
    
    for sample in data:
        # If s1 or s2 is null, default to NOT -1
        if sample['s1'] is None or sample['s2'] is None:
            # Default: not class -1, use class 0 (index 1)
            pred_class = label_mapping[0]  # index 1
            predictions.append(0)
        else:
            # Remove tokens from both s1 and s2
            if removed_tokens:
                s1_modified, s2_modified = remove_tokens_from_s1_and_s2(
                    sample['s1'], sample['s2'], removed_tokens
                )
            else:
                s1_modified = sample['s1']
                s2_modified = sample['s2']
            
            # Handle empty s1 or s2 after removal
            if not s1_modified or not s2_modified:
                pred_class = label_mapping[0]  # Default to class 0 if either is empty
            else:
                # Use modified s1 and s2
                pred_class, _ = get_model_predictions(
                    model,
                    s1_modified, s2_modified,  # Both modified
                    sample['absa_1'], sample['absa_2']
                )
            
            # Map back to original class labels
            predictions.append(inverse_mapping[pred_class])
        
        # Ground truth
        labels.append(sample['label'])
    
    # Calculate F1 for each class (binary: class_i vs others)
    f1_class_minus1 = f1_score([1 if l == -1 else 0 for l in labels], 
                               [1 if p == -1 else 0 for p in predictions], 
                               zero_division=0)
    f1_class_0 = f1_score([1 if l == 0 else 0 for l in labels], 
                          [1 if p == 0 else 0 for p in predictions], 
                          zero_division=0)
    f1_class_1 = f1_score([1 if l == 1 else 0 for l in labels], 
                          [1 if p == 1 else 0 for p in predictions], 
                          zero_division=0)
    
    # Calculate macro F1
    macro_f1 = (f1_class_minus1 + f1_class_0 + f1_class_1) / 3.0
    
    return macro_f1


print("\n" + "="*80)
print("PART 4: PER-SAMPLE TOKEN REMOVAL - MACRO F1 EVALUATION")
print("="*80)
print("(Remove top-k and bot-k ADJECTIVE tokens from S1 & S2 for each sample, evaluate macro F1)\n")

# Use max adjective count as k_max, but cap at 10
# k_max = min(10, max_adj_count)
k_max = max_adj_count
print(f"k_max = {k_max} (max {k_max} adjectives, or less if samples have fewer adjectives)\n")

results_macro_f1 = {
    'k_values': [],
    'top_k_f1': [],
    'bottom_k_f1': [],
}

for k in range(1, k_max + 1):
    print(f"Evaluating k={k}...", end=" ")
    
    # Collect ADJECTIVE tokens to remove for this k value
    top_k_tokens_list = []
    bot_k_tokens_list = []
    
    for sample_info in sample_token_importance:
        all_tokens = sample_info['all_tokens']
        sorted_tokens = sample_info['sorted_by_importance']  # [(idx, token), ...]
        
        # Filter only adjectives from sorted tokens
        pos_tags = nltk.pos_tag(all_tokens)
        adjective_set = {word for word, tag in pos_tags if tag.startswith('JJ')}
        
        # Get adjectives from sorted list (by importance)
        sorted_adjectives = [(idx, token) for idx, token in sorted_tokens if token in adjective_set]
        
        # Top-k: most important adjective tokens
        if len(sorted_adjectives) >= k:
            top_k = [token for idx, token in sorted_adjectives[:k]]
            top_k_tokens_list.append(top_k)
        else:
            # Use all available adjectives if less than k
            top_k_tokens_list.append([token for idx, token in sorted_adjectives])
        
        # Bot-k: least important adjective tokens
        if len(sorted_adjectives) >= k:
            bot_k = [token for idx, token in sorted_adjectives[-k:]]
            bot_k_tokens_list.append(bot_k)
        else:
            # If sample has fewer adjectives than k, also use all adjectives (same as top-k)
            bot_k_tokens_list.append([token for idx, token in sorted_adjectives])
    
    # Evaluate with top-k removed
    f1_top_k_list = []
    for sample_idx, sample in enumerate(filtered_data):
        if sample_idx < len(top_k_tokens_list):
            tokens_to_remove = top_k_tokens_list[sample_idx]
            
            if sample['s1'] is None or sample['s2'] is None:
                pred_class = label_mapping[0]
                pred = 0
            else:
                if tokens_to_remove:
                    s1_modified, s2_modified = remove_tokens_from_s1_and_s2(
                        sample['s1'], sample['s2'], tokens_to_remove
                    )
                else:
                    s1_modified = sample['s1']
                    s2_modified = sample['s2']
                
                if not s1_modified or not s2_modified:
                    pred_class = label_mapping[0]
                else:
                    pred_class, _ = get_model_predictions(
                        appearance_model,
                        s1_modified, s2_modified,
                        sample['absa_1'], sample['absa_2']
                    )
                
                pred = inverse_mapping[pred_class]
            
            f1_top_k_list.append((sample['label'], pred))
    
    # Calculate macro F1 for top-k
    if f1_top_k_list:
        labels_top = [item[0] for item in f1_top_k_list]
        preds_top = [item[1] for item in f1_top_k_list]
        
        # F1 for each class
        f1_top_minus1 = f1_score([1 if l == -1 else 0 for l in labels_top], 
                                 [1 if p == -1 else 0 for p in preds_top], zero_division=0)
        f1_top_0 = f1_score([1 if l == 0 else 0 for l in labels_top], 
                            [1 if p == 0 else 0 for p in preds_top], zero_division=0)
        f1_top_1 = f1_score([1 if l == 1 else 0 for l in labels_top], 
                            [1 if p == 1 else 0 for p in preds_top], zero_division=0)
        macro_f1_top = (f1_top_minus1 + f1_top_0 + f1_top_1) / 3.0
    else:
        macro_f1_top = baseline_macro_f1
    
    # Evaluate with bot-k removed
    f1_bot_k_list = []
    for sample_idx, sample in enumerate(filtered_data):
        if sample_idx < len(bot_k_tokens_list):
            tokens_to_remove = bot_k_tokens_list[sample_idx]
            
            if sample['s1'] is None or sample['s2'] is None:
                pred_class = label_mapping[0]
                pred = 0
            else:
                if tokens_to_remove:
                    s1_modified, s2_modified = remove_tokens_from_s1_and_s2(
                        sample['s1'], sample['s2'], tokens_to_remove
                    )
                else:
                    s1_modified = sample['s1']
                    s2_modified = sample['s2']
                
                if not s1_modified or not s2_modified:
                    pred_class = label_mapping[0]
                else:
                    pred_class, _ = get_model_predictions(
                        appearance_model,
                        s1_modified, s2_modified,
                        sample['absa_1'], sample['absa_2']
                    )
                
                pred = inverse_mapping[pred_class]
            
            f1_bot_k_list.append((sample['label'], pred))
    
    # Calculate macro F1 for bot-k
    if f1_bot_k_list:
        labels_bot = [item[0] for item in f1_bot_k_list]
        preds_bot = [item[1] for item in f1_bot_k_list]
        
        # F1 for each class
        f1_bot_minus1 = f1_score([1 if l == -1 else 0 for l in labels_bot], 
                                 [1 if p == -1 else 0 for p in preds_bot], zero_division=0)
        f1_bot_0 = f1_score([1 if l == 0 else 0 for l in labels_bot], 
                            [1 if p == 0 else 0 for p in preds_bot], zero_division=0)
        f1_bot_1 = f1_score([1 if l == 1 else 0 for l in labels_bot], 
                            [1 if p == 1 else 0 for p in preds_bot], zero_division=0)
        macro_f1_bot = (f1_bot_minus1 + f1_bot_0 + f1_bot_1) / 3.0
    else:
        macro_f1_bot = baseline_macro_f1
    
    results_macro_f1['k_values'].append(k)
    results_macro_f1['top_k_f1'].append(macro_f1_top)
    results_macro_f1['bottom_k_f1'].append(macro_f1_bot)
    
    print(f"Top-{k}: {macro_f1_top:.4f} (drop: {(baseline_macro_f1-macro_f1_top):+.4f}), " + 
          f"Bot-{k}: {macro_f1_bot:.4f} (drop: {(baseline_macro_f1-macro_f1_bot):+.4f})")

print("\n✓ Per-sample token removal evaluation complete!")


In [None]:

# PART 5: Visualization

print("\n" + "="*80)
print("PART 5: VISUALIZING SHAP FAITHFULNESS")
print("="*80)

fig, ax = plt.subplots(figsize=(12, 6))

# Plot macro F1 scores
k_vals = results_macro_f1['k_values']
top_f1 = results_macro_f1['top_k_f1']
bot_f1 = results_macro_f1['bottom_k_f1']

# Calculate F1 drop (impact)
top_f1_drop = [baseline_macro_f1 - f1 for f1 in top_f1]
bot_f1_drop = [baseline_macro_f1 - f1 for f1 in bot_f1]

ax.axhline(y=0, color='gray', linestyle='--', linewidth=2, label='No Impact', alpha=0.7)
ax.plot(k_vals, top_f1_drop, marker='o', linewidth=2.5, markersize=9, 
        label='Remove Top-k (High Impact)', color='darkred')
ax.plot(k_vals, bot_f1_drop, marker='s', linewidth=2.5, markersize=9, 
        label='Remove Bottom-k (Low Impact)', color='steelblue')

ax.set_xlabel('Number of tokens removed from S1 & S2 (k)', fontsize=13)
ax.set_ylabel('Macro F1 Drop (Impact)', fontsize=13)
ax.set_title('SHAP Faithfulness: Token Removal Impact on Macro F1', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='best')
ax.grid(True, alpha=0.3)
ax.set_xticks(k_vals)

plt.tight_layout()
plt.show()

print("\n✓ Visualization complete!")
print(f"\nBaseline Macro F1 Score: {baseline_macro_f1:.4f}")
print(f"Max F1 Drop (Top-k):    {max(top_f1_drop):.4f} at k={k_vals[top_f1_drop.index(max(top_f1_drop))]}")
print(f"Max F1 Drop (Bottom-k): {max(bot_f1_drop):.4f} at k={k_vals[bot_f1_drop.index(max(bot_f1_drop))]}")
print("\nNote: Each sample has independent top-k and bottom-k tokens")
print("Token importance averaged across all 3 classes (-1, 0, 1)")
print("Tokens removed from both S1 and S2")
