In [None]:
!pip install transformers accelerate torch pandas tqdm --quiet

import torch
import pandas as pd
from itertools import combinations
from tqdm import tqdm
from transformers import LongformerTokenizerFast, LongformerForSequenceClassification

In [None]:
# =========================================================
# 1. CONFIGURATION
# =========================================================
MODEL_PATH = "allenai/longformer-base-4096"  # ✅ Public, works without login
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8   # Adjust this depending on GPU memory (T4: 4–8 works fine)

print(f"Using device: {DEVICE}")

In [None]:
# =========================================================
# 2. LOAD MODEL AND TOKENIZER
# =========================================================
tokenizer = LongformerTokenizerFast.from_pretrained(MODEL_PATH)
model = LongformerForSequenceClassification.from_pretrained(
    MODEL_PATH, num_labels=3  # entailment / neutral / contradiction
)
model.to(DEVICE)
model.eval()

In [None]:
# =========================================================
# 3. LOAD DATA
# =========================================================
file_path ="/kaggle/input/vidhikaryafinal/vidhikarya_relevant_columns.csv"  # your actual path
df = pd.read_csv(file_path)

# Make sure column names match
QUESTION_COL = "question"
ANSWERS_COL = "updated_answers"

In [None]:
# =========================================================
# 4. HELPER FUNCTION
# =========================================================
def predict_contradictions(text_pairs):
    """Batch process text pairs through the model."""
    inputs = tokenizer(
        [f"{a1} [SEP] {a2}" for a1, a2 in text_pairs],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=4096
    ).to(DEVICE)

    with torch.no_grad():
        outputs = model(**inputs).logits
        preds = torch.argmax(outputs, dim=1).cpu().tolist()

    # Label mapping assumption: 2 = contradiction
    return [p == 2 for p in preds]

In [None]:
# =========================================================
# 5. MAIN LOOP
# =========================================================
contradiction_flags = []

for _, row in tqdm(df.iterrows(), total=len(df)):
    answers = str(row[ANSWERS_COL]).split("|||")
    answers = [a.strip() for a in answers if a.strip()]

    # Generate all possible answer pairs
    pairs = list(combinations(answers, 2))

    if not pairs:
        contradiction_flags.append(False)
        continue

    # Batch predict
    is_contradicting = False
    for i in range(0, len(pairs), BATCH_SIZE):
        batch_pairs = pairs[i:i+BATCH_SIZE]
        preds = predict_contradictions(batch_pairs)
        if any(preds):
            is_contradicting = True
            break

    contradiction_flags.append(is_contradicting)



In [None]:
# =========================================================
# 6. SAVE RESULTS
# =========================================================
df["contradicting"] = contradiction_flags
df.to_csv("/kaggle/working/legal_qa_with_flags.csv", index=False)
print("✅ Done! Results saved to /kaggle/working/legal_qa_with_flags.csv")