In [None]:
!pip install -q transformers datasets torch bitsandbytes accelerate


In [None]:


import os, json, torch, itertools
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from google.colab import drive



BASE_DIR = "data"
LOGS_DIR = "logs"
MODEL_DIR_CTR = "checkpoints"
INPUT_PATH = f"{BASE_DIR}/verigraph_multistatement_input.jsonl"
CONTRA_CKPT = f"{MODEL_DIR_CTR}/lora_finetuned/contradiction_model/final_model.pt"
EXPL_CKPT   = f"{MODEL_DIR_CTR}/lora_finetuned/explanation_model"
OUT_PATH    = f"{LOGS_DIR}/verigraph_multistatement_outputs.jsonl"

device = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------- Load Dataset --------------------
dataset = load_dataset("json", data_files={"test": INPUT_PATH})
samples = dataset["test"]
print(f"Loaded {len(samples)} multi-statement samples.")

# -------------------- Load Models --------------------
class MiniVeriGraph(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("microsoft/deberta-v3-small")
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(self.encoder.config.hidden_size, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1)
        )
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.last_hidden_state[:, 0, :]
        return torch.sigmoid(self.classifier(pooled)).squeeze(-1)

# --- Load contradiction model ---
tokenizer_contra = AutoTokenizer.from_pretrained("microsoft/deberta-v3-small")
contra_model = MiniVeriGraph().to(device)
contra_model.load_state_dict(torch.load(CONTRA_CKPT, map_location=device))
contra_model.eval()

# --- Load explanation model ---
tokenizer_exp = AutoTokenizer.from_pretrained(EXPL_CKPT)
exp_model = AutoModelForSeq2SeqLM.from_pretrained(EXPL_CKPT).to(device)
exp_model.eval()

# -------------------- Helper Functions --------------------
def detect_contradiction(s1, s2, threshold=0.5):
    enc = tokenizer_contra(s1, s2, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        score = contra_model(**enc).item()
    label = "CONTRADICTION" if score > threshold else "NON-CONTRADICTION"
    return label, score

def generate_explanation(s1, s2):
    prompt = f"Describe how these two statements conflict: '{s1}' vs '{s2}'."
    enc = tokenizer_exp(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        out = exp_model.generate(**enc, max_new_tokens=80)
    return tokenizer_exp.decode(out[0], skip_special_tokens=True)

# -------------------- Run Inference --------------------
outputs = []
for idx, sample in enumerate(samples):
    sents = sample["statements"]
    print(f"\nSample {idx+1} — {len(sents)} statements")

    for (i, j) in itertools.permutations(range(len(sents)), 2):
        s1, s2 = sents[i], sents[j]
        label, score = detect_contradiction(s1, s2)
        explanation = ""

        if label == "CONTRADICTION":
            explanation = generate_explanation(s1, s2)
            print(f"  • {s1} ↔ {s2} → {label} ({score:.2f})")
            print(f"{explanation}\n")
        else:
            print(f"  • {s1} ↔ {s2} → {label} ({score:.2f})")

        outputs.append({
            "sample_id": idx,
            "s1": s1,
            "s2": s2,
            "label": label,
            "score": round(score, 3),
            "explanation": explanation
        })

# -------------------- Save Results --------------------
with open(OUT_PATH, "w") as f:
    for o in outputs:
        f.write(json.dumps(o) + "\n")

print(f"\nInference complete — results saved to:\n{OUT_PATH}")
