In [None]:
!pip install -q transformers torch datasets bitsandbytes accelerate


In [None]:


import os, json, torch
from torch import nn
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from google.colab import drive



# -------------------- Paths --------------------
BASE_DIR = "data"
DATA_PATH = f"{BASE_DIR}/dataset.jsonl"
CKPT_PATH = f"{BASE_DIR}/checkpoints/final_model_082638.pt"

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# -------------------- Load Dataset --------------------
dataset = load_dataset("json", data_files={"test": DATA_PATH})
data = dataset["test"]

def prepare_pairs(example):
    sents = example["statements"]
    pairs = []
    for i in range(len(sents)):
        for j in range(len(sents)):
            if i != j:
                pairs.append((sents[i], sents[j]))
    return {"pairs": pairs}

data = data.map(prepare_pairs)
pairs = sum(data["pairs"], [])
print(f"Total pairs to test: {len(pairs)}")

# -------------------- Model Definition --------------------
class MiniVeriGraph(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("microsoft/deberta-v3-small")
        self.classifier = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.last_hidden_state[:, 0, :]
        return torch.sigmoid(self.classifier(pooled)).squeeze(-1)

# -------------------- Load Checkpoint --------------------
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-small")
model = MiniVeriGraph().to(device)
model.load_state_dict(torch.load(CKPT_PATH, map_location=device))
model.eval()
print(f"Loaded model from: {CKPT_PATH}")

# -------------------- Inference --------------------
def predict_contradiction(s1, s2, threshold=0.5):
    inputs = tokenizer(s1, s2, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        score = model(**inputs).item()
    label = "CONTRADICTION" if score > threshold else "NON-CONTRADICTION"
    return label, score

print("\n Example Predictions:\n")
for idx, (s1, s2) in enumerate(pairs[:15]):
    label, score = predict_contradiction(s1, s2)
    print(f"{idx+1:02d}. {s1}")
    print(f"    ↔ {s2}")
    print(f"Prediction: {label}  (score={score:.2f})\n")
