In [1]:
import requests
import json
import time
from config import API_KEY

API_URL = "https://api.together.xyz/v1/completions"

headers = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

input_file = 'data/Squad/test_filtered.jsonl'
output_file = 'data/Squad/mistral-7B-control.jsonl'

# Load all Squad items
items = []
with open(input_file, 'r', encoding='utf-8') as f:
    for line in f:
        data = json.loads(line)
        item = {
            "id": data.get("id", ""),
            "masked_sentence": data["masked_sentences"][0].strip(),
            "obj_label": data["obj_label"],
            "sub_label": data["sub_label"]
        }
        items.append(item)

# Process with Mistral-7B-Instruct-v0.3
with open(output_file, 'w', encoding='utf-8') as out_f:
    for idx, item in enumerate(items):
        prompt = item["masked_sentence"].strip()  # plain prompt

        data_payload = {
            "model": "mistralai/Mistral-7B-Instruct-v0.3",
            "prompt": prompt,
            "max_tokens": 1,
            "temperature": 0.0,
            "top_p": 1.0,
            "logprobs": 5
        }

        try:
            response = requests.post(API_URL, headers=headers, json=data_payload)
            result = response.json()

            if "choices" not in result:
                print(f"❌ Error for {item['id']}: {result.get('error', 'Unknown error')}")
                continue

            choice = result["choices"][0]
            logprobs_dict = choice["logprobs"]["top_logprobs"][0]

            output_entry = {
                "id": item["id"],
                "masked_sentence": item["masked_sentence"],
                "obj_label": item["obj_label"],
                "sub_label": item["sub_label"],
                "logprobs_top5": logprobs_dict
            }

            out_f.write(json.dumps(output_entry) + '\n')
            print(f"✅ Mistral Success: {item['id']} ({idx+1}/{len(items)})")
            # time.sleep(1)

        except Exception as e:
            print(f"❌ Mistral Exception for {item['id']}: {e}")
            # time.sleep(5)
            continue


✅ Mistral Success: 57273b69dd62a815002e99d8_0 (1/213)
✅ Mistral Success: 56bf41013aeaaa14008c959b_0 (2/213)
✅ Mistral Success: 56d709ef0d65d21400198306_0 (3/213)
✅ Mistral Success: 572757bef1498d1400e8f693_0 (4/213)
✅ Mistral Success: 56d9a637dc89441400fdb698_0 (5/213)
✅ Mistral Success: 56d9a637dc89441400fdb699_0 (6/213)
✅ Mistral Success: 56bf57043aeaaa14008c95da_0 (7/213)
✅ Mistral Success: 56d9bf70dc89441400fdb77c_0 (8/213)
✅ Mistral Success: 56d9bf70dc89441400fdb77d_0 (9/213)
✅ Mistral Success: 56d724ea0d65d214001983c8_0 (10/213)
✅ Mistral Success: 572a06866aef0514001551be_0 (11/213)
✅ Mistral Success: 56d9ccacdc89441400fdb842_0 (12/213)
✅ Mistral Success: 56d7277c0d65d21400198402_0 (13/213)
✅ Mistral Success: 57337ea24776f41900660bd1_0 (14/213)
✅ Mistral Success: 57339902d058e614000b5e72_0 (15/213)
✅ Mistral Success: 57339a554776f41900660e74_0 (16/213)
✅ Mistral Success: 57339ad74776f41900660e86_0 (17/213)
✅ Mistral Success: 5733a6ac4776f41900660f5b_0 (18/213)
✅ Mistral Success: 

In [1]:
import requests, json, time, os

# ====== CONFIG ======
API_URL = "https://api.together.xyz/v1/completions"
MODEL   = "mistralai/Mistral-7B-Instruct-v0.3"

INPUT_FILE  = "data/ConceptNet/test_filtered.jsonl"
OUTPUT_FILE = "data/ConceptNet/mistral-7B-control.jsonl"

MAX_RETRIES = 20
SLEEP_SECS  = 60

def normalize_fragment(s: str) -> str:
    s = s.strip()
    if not s.endswith(" "):
        s += " "
    return s

def load_items(input_path: str):
    items = []
    with open(input_path, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            ex_id = data.get("uuid") or data.get("id") or ""
            frag = (data["masked_sentences"][0] if data.get("masked_sentences") else "").strip()
            items.append({
                "id": ex_id,
                "masked_sentence": frag,
                "obj_label": data.get("obj_label", ""),
                "sub_label": data.get("sub_label", ""),
            })
    return items

def load_processed_ids(output_path: str):
    seen = set()
    if os.path.exists(output_path):
        with open(output_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    rec = json.loads(line)
                    if "id" in rec:
                        seen.add(rec["id"])
                except Exception:
                    pass
    return seen

def request_with_retry(session, payload, item_id: str):
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            resp = session.post(API_URL, json=payload, timeout=60)
            if resp.status_code != 200:
                try:
                    err = resp.json()
                except Exception:
                    err = resp.text
                print(f"[{item_id}] HTTP {resp.status_code} (attempt {attempt}/{MAX_RETRIES}): {err}")
                time.sleep(SLEEP_SECS)
                continue

            result = resp.json()
            if "choices" not in result or not result["choices"]:
                print(f"[{item_id}] Missing 'choices' (attempt {attempt})")
                time.sleep(SLEEP_SECS); continue

            choice = result["choices"][0]
            lp = choice.get("logprobs")
            if not lp or "top_logprobs" not in lp or not lp["top_logprobs"]:
                print(f"[{item_id}] No logprobs (attempt {attempt})")
                time.sleep(SLEEP_SECS); continue

            token = lp["tokens"][0] if "tokens" in lp and lp["tokens"] else ""
            top_logprobs = lp["top_logprobs"][0]
            return token, top_logprobs

        except Exception as e:
            print(f"[{item_id}] Exception (attempt {attempt}/{MAX_RETRIES}): {e}")
            time.sleep(SLEEP_SECS)

    raise RuntimeError(f"Failed after {MAX_RETRIES} attempts for item {item_id}")

# ====== RUN ======
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
items = load_items(INPUT_FILE)
seen_ids = load_processed_ids(OUTPUT_FILE)

print(f"Total items: {len(items)}")
print(f"Already processed (resume): {len(seen_ids)}")
to_do = [it for it in items if it["id"] not in seen_ids]
print(f"Remaining: {len(to_do)}\n")

session = requests.Session(); session.headers.update(headers)
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)

with open(OUTPUT_FILE, "a", encoding="utf-8") as out_f:
    for idx, item in enumerate(to_do, start=1):
        fragment = normalize_fragment(item["masked_sentence"])
        payload = {
            "model": MODEL,
            "prompt": fragment,
            "max_tokens": 1,
            "temperature": 0.0,
            "top_p": 1.0,
            "logprobs": 5
        }

        try:
            token, top_logprobs = request_with_retry(session, payload, item["id"])

            out_f.write(json.dumps({
                "id": item["id"],
                "masked_sentence": item["masked_sentence"],
                "obj_label": item["obj_label"],
                "sub_label": item["sub_label"],
                "next_token": token,
                "logprobs_top5": top_logprobs
            }, ensure_ascii=False) + "\n")
            out_f.flush()

            if idx % 100 == 0 or idx == 1:
                print(f"✅ {idx}/{len(to_do)}  (id={item['id']})")

        except Exception as e:
            print(f"⛔ Giving up on {item['id']} after {MAX_RETRIES} retries: {e}")
            out_f.write(json.dumps({
                "id": item["id"],
                "masked_sentence": item["masked_sentence"],
                "obj_label": item["obj_label"],
                "sub_label": item["sub_label"],
                "error": str(e)
            }, ensure_ascii=False) + "\n")
            out_f.flush()
            continue

print("Done.")


Total items: 18796
Already processed (resume): 0
Remaining: 18796

✅ 1/18796  (id=d4f11631dde8a43beda613ec845ff7d1)
✅ 100/18796  (id=1a01fbaec8bfedd12f74b2037cad5839)
✅ 200/18796  (id=cdf948db4d26cb6908e20120d0c3a66a)
✅ 300/18796  (id=664b2373f9397aace70532be733ca2ce)
✅ 400/18796  (id=097b132fafbcd3804dc483cb67f2461b)
✅ 500/18796  (id=e254977e40d014ad5f4c308bf00f97de)
✅ 600/18796  (id=8a51b25fb4576c62698d66da4f990240)
✅ 700/18796  (id=05875587d7d097a40d5770f0046fd8a0)
✅ 800/18796  (id=4da71e195484f72d234b9a48d9067970)
✅ 900/18796  (id=b44e2ea498af303ceb9fa57bc8bffc8d)
✅ 1000/18796  (id=6751983d83e983be444598514a0be04a)
✅ 1100/18796  (id=28661f57da65b9e9d55a7a6a7dceec5d)
✅ 1200/18796  (id=966315527a315ee10f9fcb749fdbbfc6)
✅ 1300/18796  (id=57e136fad0f597e849d4964b8d801de6)
✅ 1400/18796  (id=6cde63a92c80d4b13b2a26d1c38a7eb5)
✅ 1500/18796  (id=80d700abc00ca5fe6df93bc4253e2e9f)
✅ 1600/18796  (id=3579fd73a6286e343b8c89d98f5c4530)
✅ 1700/18796  (id=46632bac59adc15fd041782cec6bcad4)
✅ 1800/18