In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
! pip install captum --quiet

In [3]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification
from captum.attr import IntegratedGradients, visualization

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Load test dataset
df_test = pd.read_csv('/content/drive/MyDrive/quora_duplicate_questions/data/processed/test.csv.zip',
                      compression='zip')

In [5]:
# Load fine-tuned model and tokenizer
model_path = '/content/drive/MyDrive/quora_duplicate_questions/models/bert_quora_model/'
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [6]:
def custom_forward(embeddings, attention_mask):
    outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
    return F.softmax(outputs.logits, dim=1)[:, 1]  # Probability for "duplicate" class

In [7]:
ig = IntegratedGradients(custom_forward)

# Loop over random question pairs
for i in np.random.choice(df_test.index, size=3, replace=False):
    q1 = df_test.loc[i, 'question1']
    q2 = df_test.loc[i, 'question2']
    true_label = df_test.loc[i, 'is_duplicate']

    print(f"\n🔹 Example {i}:")
    print(f"Question 1: {q1}")
    print(f"Question 2: {q2}")
    print(f"Is duplicate: {bool(true_label)}")

    # Tokenize
    inputs = tokenizer(
        q1,
        q2,
        truncation=True,
        padding=True,
        max_length=256,
        return_tensors='pt'
  )
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Get prediction
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        probs = torch.softmax(outputs.logits, dim=1)
        pred_label = "Duplicate" if probs[0][1] > 0.5 else "Not Duplicate"

    # Compute embeddings
    input_embeddings = model.bert.embeddings(input_ids)

    # Get attributions
    attributions, delta = ig.attribute(
        inputs=input_embeddings,
        additional_forward_args=(attention_mask,),
        return_convergence_delta=True,
    )

    # Prepare tokens and attribution scores
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    attr_scores = attributions[0].sum(dim=-1).detach().cpu().numpy().tolist()

    # Prepare visualization record
    viz_data_record = visualization.VisualizationDataRecord(
        word_attributions=attr_scores,
        pred_prob=probs[0][1].item(),
        pred_class=pred_label,
        true_class="Duplicate" if true_label else "Not Duplicate",
        attr_class="Duplicate",
        attr_score=sum(attr_scores),
        raw_input_ids=tokens,
        convergence_score=delta.item()
    )

    # Visualize
    visualization.visualize_text([viz_data_record])


🔹 Example 36360:
Question 1: Is there a way on Facebook to see a list of your friends who live in a particular city?
Question 2: How do I set Facebook to restrict others from seeing my friends list?
Is duplicate: False


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Not Duplicate,Not Duplicate (0.02),Duplicate,0.38,[CLS] is there a way on facebook to see a list of your friends who live in a particular city ? [SEP] how do i set facebook to restrict others from seeing my friends list ? [SEP]
,,,,



🔹 Example 6612:
Question 1: What does XFN mean?
Question 2: What does "LOL OK" mean?
Is duplicate: False


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Not Duplicate,Not Duplicate (0.00),Duplicate,-0.12,"[CLS] what does x ##f ##n mean ? [SEP] what does "" lo ##l ok "" mean ? [SEP]"
,,,,



🔹 Example 75106:
Question 1: Why is the feather an important Apache tribe symbol? What does it mean?
Question 2: Why is the great spirit an important Apache tribe symbol? What does it mean?
Is duplicate: False


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Not Duplicate,Not Duplicate (0.01),Duplicate,0.72,[CLS] why is the feather an important apache tribe symbol ? what does it mean ? [SEP] why is the great spirit an important apache tribe symbol ? what does it mean ? [SEP]
,,,,
