In [1]:
# 1: Imports & basic configuration 


import os
import re
import json
import math
from glob import glob
from typing import Dict, Any, List, Tuple, Optional
from collections import defaultdict, Counter

DATA_DIR = "/content"

# Output paths 
OUT_CLEAN_JSONL = os.path.join(DATA_DIR, "merged_clean.jsonl")
OUT_PROBLEMS_JSON = os.path.join(DATA_DIR, "problem_indices.json")
OUT_MISSING_TXT = os.path.join(DATA_DIR, "missing_indices.txt")
OUT_FAIL_TXT = os.path.join(DATA_DIR, "fail_indices.txt")  
OUT_MISS_REASON_TXT = os.path.join(DATA_DIR, "miss_reason_indices.txt")
OUT_MISS_ANSWER_TXT = os.path.join(DATA_DIR, "miss_answer_indices.txt")
OUT_MISMATCH_TXT = os.path.join(DATA_DIR, "answer_rating_mismatch_indices.txt")
OUT_ANY_BAD_TXT = os.path.join(DATA_DIR, "any_problem_indices.txt")

# Total expected indices: 0..19999
TOTAL = 20_000
ALL_EXPECTED = set(range(TOTAL))


In [2]:
# 2: Discover the uploaded files


pattern = os.path.join(DATA_DIR, "gemini_flash_lite_reason_answer*.jsonl")
found_files = sorted(glob(pattern))

expected_files = [
    "gemini_flash_lite_reason_answer_500.jsonl",
    "gemini_flash_lite_reason_answer_500_1499.jsonl",
    "gemini_flash_lite_reason_answer_1400_2299.jsonl",
    "gemini_flash_lite_reason_answer_2300_3199.jsonl",
    "gemini_flash_lite_reason_answer_3200_4099.jsonl",
    "gemini_flash_lite_reason_answer_4100_4999.jsonl",
    "gemini_flash_lite_reason_answer_5000_5899.jsonl",
    "gemini_flash_lite_reason_answer_5900_6799.jsonl",
    "gemini_flash_lite_reason_answer_6800_7699.jsonl",
    "gemini_flash_lite_reason_answer_7700_8599.jsonl",
    "gemini_flash_lite_reason_answer_8600_9499.jsonl",
    "gemini_flash_lite_reason_answer_9500_10399.jsonl",
    "gemini_flash_lite_reason_answer_10400_11299.jsonl",
    "gemini_flash_lite_reason_answer_11300_12199.jsonl",
    "gemini_flash_lite_reason_answer_12200_13099.jsonl",
    "gemini_flash_lite_reason_answer_13100_13999.jsonl",
    "gemini_flash_lite_reason_answer_14000_14899.jsonl",
    "gemini_flash_lite_reason_answer_14900_15799.jsonl",
    "gemini_flash_lite_reason_answer_15800_16699.jsonl",
    "gemini_flash_lite_reason_answer_16700_17599.jsonl",
    "gemini_flash_lite_reason_answer_17600_18499.jsonl",
    "gemini_flash_lite_reason_answer_18500_19399.jsonl",
    "gemini_flash_lite_reason_answer_19400_19999.jsonl",
]
expected_files = [os.path.join(DATA_DIR, f) for f in expected_files]

files = [f for f in expected_files if os.path.exists(f)]
if not files:
    # Fallback 
    files = found_files

print("Found files (count={}):".format(len(files)))
for f in files:
    print("-", os.path.basename(f))


Found files (count=23):
- gemini_flash_lite_reason_answer_500.jsonl
- gemini_flash_lite_reason_answer_500_1499.jsonl
- gemini_flash_lite_reason_answer_1400_2299.jsonl
- gemini_flash_lite_reason_answer_2300_3199.jsonl
- gemini_flash_lite_reason_answer_3200_4099.jsonl
- gemini_flash_lite_reason_answer_4100_4999.jsonl
- gemini_flash_lite_reason_answer_5000_5899.jsonl
- gemini_flash_lite_reason_answer_5900_6799.jsonl
- gemini_flash_lite_reason_answer_6800_7699.jsonl
- gemini_flash_lite_reason_answer_7700_8599.jsonl
- gemini_flash_lite_reason_answer_8600_9499.jsonl
- gemini_flash_lite_reason_answer_9500_10399.jsonl
- gemini_flash_lite_reason_answer_10400_11299.jsonl
- gemini_flash_lite_reason_answer_11300_12199.jsonl
- gemini_flash_lite_reason_answer_12200_13099.jsonl
- gemini_flash_lite_reason_answer_13100_13999.jsonl
- gemini_flash_lite_reason_answer_14000_14899.jsonl
- gemini_flash_lite_reason_answer_14900_15799.jsonl
- gemini_flash_lite_reason_answer_15800_16699.jsonl
- gemini_flash_lit

In [3]:
# 3: Utilities for parsing, checks, and ranges

def safe_float(x) -> Optional[float]:
    """Try to parse x as float; return None on failure."""
    try:
        return float(x)
    except Exception:
        return None

def almost_equal(a: float, b: float, tol: float = 1e-7) -> bool:
    """Robust float comparison."""
    return math.isclose(a, b, rel_tol=tol, abs_tol=tol)

range_pat = re.compile(r'_(\d+)(?:_(\d+))?\.jsonl$', re.IGNORECASE)

def expected_range_from_filename(path: str) -> Optional[Tuple[int, int]]:
    """
    Infer expected [start, end] inclusive from filename.
    Rules:
      - name_*_A_B.jsonl  -> [A, B] inclusive
      - name_*_N.jsonl    -> [0, N-1]  (first N items)
    Returns None if it cannot infer.
    """
    name = os.path.basename(path)
    m = range_pat.search(name)
    if not m:
        return None
    a = int(m.group(1))
    b = m.group(2)
    if b is None:
        # Single number -> first N indices [0..N-1]
        return (0, a - 1)
    return (a, int(b))

def read_jsonl(path: str) -> List[Dict[str, Any]]:
    """Read a JSONL file and return list of dicts."""
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                out.append(obj)
            except Exception as e:
                pass
    return out


In [4]:
# 4: Load all records, map by index, and track per-file presence

records_by_index: Dict[int, List[Tuple[str, Dict[str, Any]]]] = defaultdict(list)
indices_in_file: Dict[str, set] = {}
expected_by_file: Dict[str, Optional[Tuple[int, int]]] = {}

total_records_loaded = 0

for path in files:
    data = read_jsonl(path)
    total_records_loaded += len(data)
    idx_set = set()
    for rec in data:
        if "index" not in rec:
            continue
        idx = int(rec["index"])
        records_by_index[idx].append((path, rec))
        idx_set.add(idx)
    indices_in_file[path] = idx_set
    expected_by_file[path] = expected_range_from_filename(path)

print(f"Total JSONL records loaded: {total_records_loaded}")
print(f"Unique indices seen: {len(records_by_index)}")
print("Sample index keys (first 10):", sorted(records_by_index.keys())[:10])


Total JSONL records loaded: 19988
Unique indices seen: 19936
Sample index keys (first 10): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [5]:
# 5: Per-file fail


fail_by_file = {}
for path in files:
    exp = expected_by_file[path]
    if exp is None:
        continue
    start, end = exp
    expected_set = set(range(start, end + 1))
    present_set = indices_in_file.get(path, set())
    missing_here = sorted(expected_set - present_set)
    fail_by_file[path] = missing_here

per_file_fail_union = sorted(set().union(*fail_by_file.values()) if fail_by_file else [])

print("Files with inferred expected ranges and count of missing indices:")
for path in files:
    exp = expected_by_file[path]
    miss = fail_by_file.get(path, [])
    print(f"- {os.path.basename(path):35s} expected={exp} | missing_in_this_file={len(miss)}")

print("\nUnion of per-file 'fail' indices (count):", len(per_file_fail_union))


Files with inferred expected ranges and count of missing indices:
- gemini_flash_lite_reason_answer_500.jsonl expected=(0, 499) | missing_in_this_file=1
- gemini_flash_lite_reason_answer_500_1499.jsonl expected=(500, 1499) | missing_in_this_file=48
- gemini_flash_lite_reason_answer_1400_2299.jsonl expected=(1400, 2299) | missing_in_this_file=0
- gemini_flash_lite_reason_answer_2300_3199.jsonl expected=(2300, 3199) | missing_in_this_file=0
- gemini_flash_lite_reason_answer_3200_4099.jsonl expected=(3200, 4099) | missing_in_this_file=0
- gemini_flash_lite_reason_answer_4100_4999.jsonl expected=(4100, 4999) | missing_in_this_file=0
- gemini_flash_lite_reason_answer_5000_5899.jsonl expected=(5000, 5899) | missing_in_this_file=0
- gemini_flash_lite_reason_answer_5900_6799.jsonl expected=(5900, 6799) | missing_in_this_file=0
- gemini_flash_lite_reason_answer_6800_7699.jsonl expected=(6800, 7699) | missing_in_this_file=59
- gemini_flash_lite_reason_answer_7700_8599.jsonl expected=(7700, 8599)

In [6]:
# 6: Global missing indices 


all_present_indices = set(records_by_index.keys())
global_missing_indices = sorted(ALL_EXPECTED - all_present_indices)

print("Global missing indices (fail overall) count:", len(global_missing_indices))


Global missing indices (fail overall) count: 64


In [7]:
# 7: Validate reason/answer and produce problem buckets per index

miss_reason_indices = set()
miss_answer_indices = set()
mismatch_indices = set()
rating_conflict_indices = set()  
clean_indices = set()

for idx in ALL_EXPECTED:
    recs = records_by_index.get(idx, [])
    if not recs:
        # Global fail handled separately
        continue

    # Gather fields across all records for this index
    reasons = []
    answers_num = []
    ratings_num = []

    for path, rec in recs:
        reason = rec.get("reason", None)
        answer = rec.get("answer", None)
        rating = rec.get("rating", None)
        reasons.append((path, reason))
        answers_num.append((path, safe_float(answer)))
        ratings_num.append((path, safe_float(rating)))

    # Check rating consistency across files
    valid_ratings = [r for p, r in ratings_num if r is not None]
    if valid_ratings and (max(valid_ratings) - min(valid_ratings) > 1e-9):
        rating_conflict_indices.add(idx)
    # Use the first non-None rating as the canonical rating for comparisons
    rating_ref = valid_ratings[0] if valid_ratings else None

    # Has any non-empty reason?
    has_valid_reason = any((r is not None) and (str(r).strip() != "") for _, r in reasons)

    # Does any answer match the rating?
    answer_equals_rating = False
    has_any_answer_number = False
    if rating_ref is not None:
        for _, a in answers_num:
            if a is not None:
                has_any_answer_number = True
                if almost_equal(a, rating_ref):
                    answer_equals_rating = True
                    break

    # Categorization
    if has_valid_reason and answer_equals_rating:
        clean_indices.add(idx)
    else:
        # Not clean; place into specific buckets
        if not has_valid_reason:
            miss_reason_indices.add(idx)
        if not has_any_answer_number:
            miss_answer_indices.add(idx)
        else:
            # we had answers but none matched
            if not answer_equals_rating:
                mismatch_indices.add(idx)

print("Clean indices count:", len(clean_indices))
print("miss_reason indices count:", len(miss_reason_indices))
print("miss_answer indices count:", len(miss_answer_indices))
print("answer!=rating mismatch indices count:", len(mismatch_indices))
print("rating_conflict indices count:", len(rating_conflict_indices))


Clean indices count: 19582
miss_reason indices count: 344
miss_answer indices count: 353
answer!=rating mismatch indices count: 1
rating_conflict indices count: 0


In [8]:
# 8: Build merged_clean.jsonl 

def record_is_clean(rec: Dict[str, Any]) -> bool:
    reason = rec.get("reason", None)
    if reason is None or str(reason).strip() == "":
        return False
    rating = safe_float(rec.get("rating", None))
    answer = safe_float(rec.get("answer", None))
    return (rating is not None) and (answer is not None) and almost_equal(rating, answer)

def pick_best_clean_record(recs: List[Tuple[str, Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
    candidates = []
    for path, rec in recs:
        if record_is_clean(rec):
            reason_len = len(str(rec.get("reason", "")).strip())
            candidates.append((reason_len, path, rec))
    if not candidates:
        return None
    # sort by 
    candidates.sort(key=lambda x: (-x[0], x[1]))
    return candidates[0][2]

# Write clean file
saved = 0
with open(OUT_CLEAN_JSONL, "w", encoding="utf-8") as fw:
    for idx in sorted(clean_indices):
        recs = records_by_index[idx]
        best = pick_best_clean_record(recs)
        if best is None:
            continue
        fw.write(json.dumps(best, ensure_ascii=False) + "\n")
        saved += 1

print(f"Saved {saved} clean records to: {OUT_CLEAN_JSONL}")


Saved 19582 clean records to: /content/merged_clean.jsonl


In [9]:
# 9: Save problem indices lists

fail_indices = set(global_missing_indices)

any_problem_indices = sorted((ALL_EXPECTED - clean_indices) | fail_indices)

def write_list(path: str, items: List[int]):
    with open(path, "w", encoding="utf-8") as f:
        for x in items:
            f.write(str(x) + "\n")

write_list(OUT_MISSING_TXT, sorted(global_missing_indices))
write_list(OUT_FAIL_TXT, sorted(fail_indices))  # same as missing
write_list(OUT_MISS_REASON_TXT, sorted(miss_reason_indices))
write_list(OUT_MISS_ANSWER_TXT, sorted(miss_answer_indices))
write_list(OUT_MISMATCH_TXT, sorted(mismatch_indices))
write_list(OUT_ANY_BAD_TXT, any_problem_indices)

summary = {
    "counts": {
        "clean": len(clean_indices),
        "missing_global_fail": len(global_missing_indices),
        "miss_reason": len(miss_reason_indices),
        "miss_answer": len(miss_answer_indices),
        "answer_rating_mismatch": len(mismatch_indices),
        "rating_conflict": len(rating_conflict_indices),
        "any_problem": len(any_problem_indices),
    },
    "outputs": {
        "clean_jsonl": OUT_CLEAN_JSONL,
        "missing_indices_txt": OUT_MISSING_TXT,
        "fail_indices_txt": OUT_FAIL_TXT,
        "miss_reason_indices_txt": OUT_MISS_REASON_TXT,
        "miss_answer_indices_txt": OUT_MISS_ANSWER_TXT,
        "mismatch_indices_txt": OUT_MISMATCH_TXT,
        "any_problem_indices_txt": OUT_ANY_BAD_TXT,
    },
    "notes": {
        "fail_definition": "Indices absent from ALL files (global missing).",
        "clean_definition": "At least one record has non-empty reason AND answer==rating.",
        "overlap_handling": "Duplicates are deduped by choosing one best clean record per index.",
    }
}

with open(OUT_PROBLEMS_JSON, "w", encoding="utf-8") as f:
    json.dump(summary, f, ensure_ascii=False, indent=2)

print("Summary JSON saved to:", OUT_PROBLEMS_JSON)
for k, v in summary["counts"].items():
    print(f"- {k}: {v}")


Summary JSON saved to: /content/problem_indices.json
- clean: 19582
- missing_global_fail: 64
- miss_reason: 344
- miss_answer: 353
- answer_rating_mismatch: 1
- rating_conflict: 0
- any_problem: 418


In [10]:
# 10: Quick peeks

# Show a few example lines from the clean file
try:
    print("First 3 clean records:")
    with open(OUT_CLEAN_JSONL, "r", encoding="utf-8") as f:
        for _ in range(3):
            print(json.loads(next(f).strip()))
except Exception as e:
    print("Preview clean file error:", e)

# Show a few indices per problem-bucket
def peek(path, n=10):
    try:
        with open(path, "r", encoding="utf-8") as f:
            arr = [line.strip() for _, line in zip(range(n), f)]
        print(os.path.basename(path), "->", arr[:n])
    except Exception as e:
        print("peek error:", path, e)

peek(OUT_MISSING_TXT)
peek(OUT_MISS_REASON_TXT)
peek(OUT_MISS_ANSWER_TXT)
peek(OUT_MISMATCH_TXT)
peek(OUT_ANY_BAD_TXT)


First 3 clean records:
{'index': 0, 'user': {'UserID': 911, 'Age': 37, 'Gender': 'Female', 'Occupation': 'writer'}, 'item': {'MovieID': 193, 'Title': 'Right Stuff, The (1983)', 'ReleaseDate': '01-Jan-1983', 'Genres': 'Drama'}, 'rating': 4.0, 'reason': 'The user (UserID 911) is a 37-year-old female writer, and the target item (MovieID 193) is a Drama film. Several meta-paths provide insights into potential user preferences and item characteristics.\n\nOne path connects through UserID 716, a 36-year-old female administrator, who rated "Streetcar Named Desire, A (1951)" (Drama) a 5. This movie shares the Drama genre with the target item.\nAnother path shows UserID 498, a 26-year-old male writer, who interacted with "Silence of the Lambs, The (1991)" (Drama, Thriller) with a rating of 4, indicating a preference for Drama films among users with the same occupation.\nA path involving UserID 507 ("Streetcar Named Desire, A (1951)", Drama) rated a 4 by another user, followed by similarity to U

In [11]:
# 11: Final counts 
print("Final counts (as requested):")
print(" - Clean:", len(clean_indices))
print(" - Missing (global fail):", len(global_missing_indices))
print(" - Miss Reason:", len(miss_reason_indices))
print(" - Miss Answer:", len(miss_answer_indices))
print(" - Answer≠Rating Mismatch:", len(mismatch_indices))
print(" - Rating Conflict (duplicates disagree):", len(rating_conflict_indices))
print(" - Any Problem (union of all non-clean):", len(set((ALL_EXPECTED - clean_indices) | set(global_missing_indices))))


Final counts (as requested):
 - Clean: 19582
 - Missing (global fail): 64
 - Miss Reason: 344
 - Miss Answer: 353
 - Answer≠Rating Mismatch: 1
 - Rating Conflict (duplicates disagree): 0
 - Any Problem (union of all non-clean): 418


In [12]:
!pip -q install -U "openai>=1.54.0" datasets tqdm

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/812.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.6/812.0 kB[0m [31m5.5 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m809.0/812.0 kB[0m [31m12.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m812.0/812.0 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [13]:
# 12: Configure OpenAI-compatible Gemini API 

import os
from openai import OpenAI

os.environ["GEMINI_API_KEY"] = ""

BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
MODEL_ID = "gemini-2.5-flash-lite"   
REASONING_EFFORT = "low"             

client = OpenAI(
    api_key=os.environ["GEMINI_API_KEY"],
    base_url=BASE_URL
)

print("Client ready with OpenAI-compatible Gemini endpoint (NEW KEY).")


Client ready with OpenAI-compatible Gemini endpoint (NEW KEY).


In [14]:
# 13: Load the dataset and detect schema 
from datasets import load_dataset

ds = load_dataset("mohammad-shirkhani/social_movielens_custom", split="train")
print(ds)
print("Rows:", len(ds))
print("Columns:", ds.column_names)

ex0 = ds[0]
print("\nFirst row keys:", list(ex0.keys()))
for k, v in ex0.items():
    if isinstance(v, list) and len(v) > 3:
        print(f"- {k}: list(len={len(v)}), first 2:\n  {v[:2]}")
    else:
        print(f"- {k}: {v}")

from typing import Dict, Any

def detect_fields(example: Dict[str, Any]):
    user_keys = ["user", "User", "user_dict", "UserDict"]
    item_keys = ["item", "Item", "item_dict", "ItemDict"]
    rating_keys = ["answer", "rating", "label", "score"]
    paths_keys = ["paths", "meta_paths", "metapaths", "path_list", "path", "metaPaths"]

    def find_key(candidates, predicate=None):
        for k in candidates:
            if k in example and (predicate(example[k]) if predicate else True):
                return k
        for k, v in example.items():
            if predicate and predicate(v):
                return k
        return None

    user_key = find_key(user_keys, predicate=lambda v: isinstance(v, dict))
    item_key = find_key(item_keys, predicate=lambda v: isinstance(v, dict))
    rating_key = find_key(rating_keys, predicate=lambda v: isinstance(v, (float, int)))
    paths_key = find_key(paths_keys, predicate=lambda v: isinstance(v, list))

    if user_key is None:
        user_key = next((k for k, v in example.items() if isinstance(v, dict)), None)
    if item_key is None:
        cand = [k for k, v in example.items() if isinstance(v, dict) and k != user_key]
        item_key = cand[0] if cand else None
    if rating_key is None:
        cand = [k for k, v in example.items() if isinstance(v, (float, int))]
        rating_key = cand[0] if cand else None
    if paths_key is None:
        cand = [k for k, v in example.items() if isinstance(v, list)]
        paths_key = cand[0] if cand else None

    return user_key, item_key, rating_key, paths_key

user_key, item_key, rating_key, paths_key = detect_fields(ex0)
print("Detected keys ->",
      "user:", user_key,
      "| item:", item_key,
      "| rating:", rating_key,
      "| paths:", paths_key)

# Quick examples
user0 = ex0[user_key]
item0 = ex0[item_key]
rating0 = ex0[rating_key]
paths0 = ex0[paths_key]
print("\nUser example:", user0)
print("Item example:", item0)
print("Rating example:", rating0)
print("Meta-paths:", len(paths0))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/691 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/29.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/20000 [00:00<?, ? examples/s]

Dataset({
    features: ['user', 'item', 'answer', 'paths'],
    num_rows: 20000
})
Rows: 20000
Columns: ['user', 'item', 'answer', 'paths']

First row keys: ['user', 'item', 'answer', 'paths']
- user: {'UserID': 911, 'Age': 37, 'Gender': 'Female', 'Occupation': 'writer'}
- item: {'MovieID': 193, 'Title': 'Right Stuff, The (1983)', 'ReleaseDate': '01-Jan-1983', 'Genres': 'Drama'}
- answer: 4.0
- paths: list(len=20), first 2:
  ['user_question -> user_item_2 (rating=2) -> Item{MovieID 98, Title "Silence of the Lambs, The (1991)", Release Date 01-Jan-1991, Genres Drama, Thriller} -> item_user_4 (rating=4) -> User{UserID 498, Age 26, Gender Male, Occupation writer} -> usersim -> User{UserID 58, Age 27, Gender Male, Occupation programmer} -> user_item_3 (rating=3) -> item_question', 'user_question -> usersim -> User{UserID 716, Age 36, Gender Female, Occupation administrator} -> user_item_5 (rating=5) -> Item{MovieID 517, Title "Manhattan (1979)", Release Date 01-Jan-1979, Genres Comedy, D

In [15]:
# 14: Prompt builder 
from typing import List

def format_kv_block(title: str, d: Dict[str, Any]) -> str:
    lines = [f"{title}:"]
    for k, v in d.items():
        lines.append(f"- {k}: {v}")
    return "\n".join(lines)

def format_meta_paths(paths: List[str]) -> str:
    lines = ["Meta-path evidence (each path from this user to the target item):"]
    for p in paths:
        p = p.strip()
        if not p.startswith("- "):
            p = "- " + p
        lines.append(p)
    return "\n".join(lines)

def build_prompt_evidence_then_answer(
    user: Dict[str, Any],
    item: Dict[str, Any],
    rating: float,
    meta_paths: List[str],
) -> str:
    user_block = format_kv_block("User", user)
    item_block = format_kv_block("Item", item)
    observed_rating_str = str(rating)
    meta_block = format_meta_paths(meta_paths)

    instruction = (
        "Task: Extract evidence from the provided data and then conclude the numerical rating.\n\n"
        "You are given a heterogeneous bipartite graph setting (users and items). Edges include:\n"
        "- user→item rating interactions (e.g., user_item_k with an explicit rating),\n"
        "- usersim (user-user similarity), and\n"
        "- itemsim (item-item similarity).\n\n"
        "What to do:\n"
        "1) Analyze the user's likely preferences and the item's traits by leveraging ONLY:\n"
        "   - the user attributes,\n"
        "   - the item attributes, and\n"
        "   - the provided meta-paths (treat each path as a weak but interpretable signal; combine corroborating signals).\n"
        "   - the provided meta-paths (treat each path as a weak but interpretable signal; combine corroborating signals).\n"
        "2) Then produce two XML blocks ONLY (no extra text):\n"
        "   a) <reason>...</reason> — Provide a clear, evidence-first explanation that states\n"
        "      what information you obtain from the user/item/meta-paths and how those signals combine.\n"
        "      Avoid wording like “because the rating is X”. Instead, present evidence → inference, and end with a neutral\n"
        "      sentence such as: “Therefore, the rating equals {R}.”\n"
        "   b) <answer>{R}</answer> — Put the observed rating number {R} exactly as given below.\n\n"
        "STRICT FORMAT RULES:\n"
        "- Output ONLY these two blocks in this order, nothing else:\n"
        "<reason>\n"
        "...your evidence-first explanation here...\n"
        "</reason>\n"
        "<answer>{R}</answer>\n"
        "- Do not invent attributes, paths, or ratings not present in the input. If something is unknown, treat it as unknown.\n"
    )

    observed_block = f"Observed rating (user → item): {observed_rating_str}"
    prompt = (
        instruction
        + "\n=== INPUT ===\n"
        + user_block + "\n\n"
        + item_block + "\n\n"
        + observed_block + "\n\n"
        + meta_block + "\n"
        + "=== END INPUT ===\n"
    )
    return prompt.replace("{R}", observed_rating_str)

# Parser for <reason> and <answer>
import re

def parse_reason_answer(text: str):
    reason = None
    answer = None
    m_reason = re.search(r"<reason>(.*?)</reason>", text, flags=re.DOTALL | re.IGNORECASE)
    if m_reason:
        reason = m_reason.group(1).strip()
    m_answer = re.search(r"<answer>\s*([0-9]+(?:\.[0-9]+)?)\s*</answer>", text, flags=re.DOTALL | re.IGNORECASE)
    if m_answer:
        answer = m_answer.group(1).strip()
    return reason, answer

# Quick demonstration prompt
prompt0 = build_prompt_evidence_then_answer(user0, item0, rating0, paths0)
print("Prompt preview (first 800 chars):\n", prompt0[:800])


Prompt preview (first 800 chars):
 Task: Extract evidence from the provided data and then conclude the numerical rating.

You are given a heterogeneous bipartite graph setting (users and items). Edges include:
- user→item rating interactions (e.g., user_item_k with an explicit rating),
- usersim (user-user similarity), and
- itemsim (item-item similarity).

What to do:
1) Analyze the user's likely preferences and the item's traits by leveraging ONLY:
   - the user attributes,
   - the item attributes, and
   - the provided meta-paths (treat each path as a weak but interpretable signal; combine corroborating signals).
   - the provided meta-paths (treat each path as a weak but interpretable signal; combine corroborating signals).
2) Then produce two XML blocks ONLY (no extra text):
   a) <reason>...</reason> — Provide a clea


In [16]:
# 15: Rate limiting 
import time
import math
from collections import deque

RPM_LIMIT = 15           
TPM_LIMIT = 250_000      
OUT_TOKENS_BUDGET = 2048 

class RateLimiter:
    def __init__(self, rpm: int, tpm: int):
        self.rpm = rpm
        self.tpm = tpm
        self.req_times = deque()
        self.token_times = deque()  # (timestamp, tokens)

    @staticmethod
    def estimate_tokens(text: str) -> int:
        return max(1, math.ceil(len(text) / 4))

    def wait(self, prompt_text: str, out_tokens:int = OUT_TOKENS_BUDGET):
        now = time.time()
        while self.req_times and now - self.req_times[0] > 60:
            self.req_times.popleft()
        while self.token_times and now - self.token_times[0][0] > 60:
            self.token_times.popleft()

        in_tokens = self.estimate_tokens(prompt_text)

        while len(self.req_times) >= self.rpm:
            sleep_s = 60 - (now - self.req_times[0])
            time.sleep(max(0.01, sleep_s))
            now = time.time()
            while self.req_times and now - self.req_times[0] > 60:
                self.req_times.popleft()

        used_tpm = sum(t for (_, t) in self.token_times)
        while used_tpm + in_tokens + out_tokens > self.tpm:
            sleep_s = 60 - (now - self.token_times[0][0])
            time.sleep(max(0.01, sleep_s))
            now = time.time()
            while self.token_times and now - self.token_times[0][0] > 60:
                self.token_times.popleft()
            used_tpm = sum(t for (_, t) in self.token_times)

        self.req_times.append(time.time())
        self.token_times.append((time.time(), in_tokens + out_tokens))

limiter = RateLimiter(RPM_LIMIT, TPM_LIMIT)

def call_gemini(prompt: str, max_tokens: int = 2048):
    resp = client.chat.completions.create(
        model=MODEL_ID,
        reasoning_effort=REASONING_EFFORT,  # "low"
        messages=[{"role": "user", "content": prompt}],
        max_tokens=max_tokens,
        n=1,
    )
    return resp.choices[0].message.content

def run_one_example(example, max_tokens:int=2048):
    user = example[user_key]
    item = example[item_key]
    rating = example[rating_key]
    meta = example[paths_key]

    prompt = build_prompt_evidence_then_answer(user, item, rating, meta)

    limiter.wait(prompt, out_tokens=max_tokens)

    backoff = 1.0
    for attempt in range(5):
        try:
            text = call_gemini(prompt, max_tokens=max_tokens)
            reason, answer = parse_reason_answer(text)
            return text, reason, answer
        except Exception as e:
            if attempt == 4:
                raise
            time.sleep(backoff)
            backoff *= 2.0


In [17]:
# 16: Health checks for reason/answer
def safe_float(x):
    try:
        return float(x)
    except Exception:
        return None

def is_healthy(reason, answer, rating) -> bool:
    """Healthy = non-empty reason AND numeric answer == rating (within tolerance)."""
    if reason is None or str(reason).strip() == "":
        return False
    a = safe_float(answer)
    r = safe_float(rating)
    if a is None or r is None:
        return False
    return math.isclose(a, r, rel_tol=1e-7, abs_tol=1e-7)

def problem_category(reason, answer, rating) -> str:
    """Classify the failure mode for reporting."""
    if reason is None or str(reason).strip() == "":
        if answer is None or str(answer).strip() == "":
            return "miss_reason_and_answer"
        return "miss_reason"
    a = safe_float(answer)
    r = safe_float(rating)
    if a is None:
        return "miss_answer"
    if not math.isclose(a, r, rel_tol=1e-7, abs_tol=1e-7):
        return "mismatch_answer_rating"
    return "unknown"  


In [18]:
# 17: Read the any_problem indices 
import os

PROBLEM_INDICES_PATH = "/content/any_problem_indices.txt" 
assert os.path.exists(PROBLEM_INDICES_PATH), "any_problem_indices.txt not found at /content/"

problem_indices = []
with open(PROBLEM_INDICES_PATH, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            try:
                problem_indices.append(int(line))
            except:
                pass

problem_indices = sorted(set(problem_indices))
print("Problem indices loaded:", len(problem_indices))
print("First 20 indices:", problem_indices[:20])


Problem indices loaded: 418
First 20 indices: [57, 143, 172, 228, 361, 408, 517, 520, 601, 756, 895, 943, 983, 999, 1078, 1103, 1201, 1305, 1355, 1388]


In [19]:
# 18: Re-run problem indices with auto-retry until healthy

from tqdm.auto import tqdm
import json

MAX_ATTEMPTS_PER_INDEX = 8   # try up to 8 times per index
OUT_PATH = "/content/gemini_flash_lite_reason_answer_retry_any_problem.jsonl"

ok = 0
still_bad = []
stats = {
    "miss_reason": 0,
    "miss_answer": 0,
    "mismatch_answer_rating": 0,
    "miss_reason_and_answer": 0,
    "unknown": 0,
    "exceptions": 0,
    "total_attempts": 0,
}

with open(OUT_PATH, "w", encoding="utf-8") as fw:
    pbar = tqdm(problem_indices, desc="Fixing problem indices with retries")
    for idx in pbar:
        ex = ds[idx]
        rating = float(ex[rating_key])

        success = False
        last_cat = None

        for attempt in range(1, MAX_ATTEMPTS_PER_INDEX + 1):
            try:
                content, reason, answer = run_one_example(ex, max_tokens=2048)
                stats["total_attempts"] += 1

                if is_healthy(reason, answer, rating):
                    # Write a clean record and stop retrying this index
                    rec = {
                        "index": idx,
                        "user": ex[user_key],
                        "item": ex[item_key],
                        "rating": rating,
                        "reason": reason,
                        "answer": answer,
                    }
                    fw.write(json.dumps(rec, ensure_ascii=False) + "\n")
                    ok += 1
                    success = True
                    break
                else:
                    last_cat = problem_category(reason, answer, rating)
                    stats[last_cat] = stats.get(last_cat, 0) + 1

            except Exception as e:
                stats["exceptions"] += 1
                last_cat = "exceptions"
                # small backoff to be gentle; keep RateLimiter as-is
                time.sleep(0.5)

        if not success:
            still_bad.append((idx, last_cat))

        pbar.set_postfix({
            "clean_ok": ok,
            "still_bad": len(still_bad),
            "last_issue": last_cat or "n/a"
        })

print("\nDone retrying.")
print(f"Clean fixed: {ok} | Still bad: {len(still_bad)}")
print(f"Saved clean retries to: {OUT_PATH}")


Fixing problem indices with retries:   0%|          | 0/418 [00:00<?, ?it/s]


Done retrying.
Clean fixed: 418 | Still bad: 0
Saved clean retries to: /content/gemini_flash_lite_reason_answer_retry_any_problem.jsonl


In [20]:
# 19: Persist retry results and show summary

STILL_BAD_TXT = "/content/still_bad_indices_after_retry.txt"
RETRY_SUMMARY_JSON = "/content/retry_summary.json"

with open(STILL_BAD_TXT, "w", encoding="utf-8") as f:
    for idx, cat in still_bad:
        f.write(f"{idx}\t{cat}\n")

summary = {
    "attempted_indices": len(problem_indices),
    "fixed_clean": ok,
    "still_bad_count": len(still_bad),
    "still_bad_indices_with_category_path": STILL_BAD_TXT,
    "out_retry_file": OUT_PATH,
    "stats": stats,
}
with open(RETRY_SUMMARY_JSON, "w", encoding="utf-8") as f:
    json.dump(summary, f, ensure_ascii=False, indent=2)

print("Retry summary:")
for k, v in summary.items():
    if k != "stats":
        print(f"- {k}: {v}")
print("\nDetailed stats:", stats)

try:
    with open(OUT_PATH, "r", encoding="utf-8") as f:
        print("\nSample outputs (first 3):")
        for _ in range(3):
            print(json.loads(next(f).strip()))
except Exception as e:
    print("Preview error:", e)


Retry summary:
- attempted_indices: 418
- fixed_clean: 418
- still_bad_count: 0
- still_bad_indices_with_category_path: /content/still_bad_indices_after_retry.txt
- out_retry_file: /content/gemini_flash_lite_reason_answer_retry_any_problem.jsonl

Detailed stats: {'miss_reason': 0, 'miss_answer': 0, 'mismatch_answer_rating': 0, 'miss_reason_and_answer': 21, 'unknown': 0, 'exceptions': 0, 'total_attempts': 439}

Sample outputs (first 3):
{'index': 57, 'user': {'UserID': 551, 'Age': 25, 'Gender': 'Male', 'Occupation': 'programmer'}, 'item': {'MovieID': 1303, 'Title': 'Getaway, The (1994)', 'ReleaseDate': '01-Jan-1994', 'Genres': 'Action'}, 'rating': 1.0, 'reason': 'The user (UserID 551) is a 25-year-old male programmer. The target item (MovieID 1303) is an Action movie released in 1994.\nSeveral meta-paths provide weak signals. Path 1 indicates a user-item interaction with a rating of 1.0 involving a programmer (UserID 682) and an Action movie ("Judgment Night (1993)"). Path 3 shows a use

In [21]:
# 20: Merge 

import json

PREV_CLEAN = "/content/merged_clean.jsonl"
FINAL_MERGED = "/content/final_merged_clean_20k.jsonl"

if os.path.exists(PREV_CLEAN):
    by_index = {}

    # Load previous clean
    with open(PREV_CLEAN, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            by_index[int(obj["index"])] = obj

    # Overlay retry-clean
    with open(OUT_PATH, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            by_index[int(obj["index"])] = obj

    # Write final merged file
    saved = 0
    with open(FINAL_MERGED, "w", encoding="utf-8") as fw:
        for idx in sorted(by_index.keys()):
            fw.write(json.dumps(by_index[idx], ensure_ascii=False) + "\n")
            saved += 1

    print(f"Final merged saved to: {FINAL_MERGED} | count={saved}")
else:
    print("Previous clean file not found; skipping optional merge.")


Final merged saved to: /content/final_merged_clean_20k.jsonl | count=20000


In [22]:
# 20:Mount Google Drive and copy the final merged file there.

from google.colab import drive
drive.mount('/content/drive') 

import os, shutil, hashlib

src = "/content/final_merged_clean_20k.jsonl"
dst_dir = "/content/drive/MyDrive"   
dst = os.path.join(dst_dir, os.path.basename(src))

assert os.path.exists(src), f"Source file not found: {src}"

try:
    with open(src, "r", encoding="utf-8") as f:
        line_count = sum(1 for _ in f)
    print(f"[INFO] Source line count: {line_count}")
except Exception as e:
    line_count = None
    print(f"[WARN] Could not count lines: {e}")

def sha256sum(path, chunk_size=1024*1024):
    h = hashlib.sha256()
    with open(path, "rb") as f:
        while True:
            data = f.read(chunk_size)
            if not data:
                break
            h.update(data)
    return h.hexdigest()

src_sha = sha256sum(src)
print(f"[INFO] Source SHA256: {src_sha}")

os.makedirs(dst_dir, exist_ok=True)
shutil.copy2(src, dst)

dst_sha = sha256sum(dst)
print(f"[INFO] Copied to: {dst}")
print(f"[INFO] Destination SHA256: {dst_sha}")
if src_sha == dst_sha:
    print("[OK] Integrity check passed (SHA256 match).")
else:
    print("[WARN] Integrity mismatch! Please re-copy.")

EXPECTED = 20000
if line_count is not None:
    if line_count == EXPECTED:
        print(f"[OK] Expected record count present: {EXPECTED}")
    else:
        print(f"[WARN] Expected {EXPECTED} lines, found {line_count}.")


Mounted at /content/drive
[INFO] Source line count: 20000
[INFO] Source SHA256: 850e5aee8f7836b96051c8145bad06c1a626249a5906c4773d56e60b316745aa
[INFO] Copied to: /content/drive/MyDrive/final_merged_clean_20k.jsonl
[INFO] Destination SHA256: 850e5aee8f7836b96051c8145bad06c1a626249a5906c4773d56e60b316745aa
[OK] Integrity check passed (SHA256 match).
[OK] Expected record count present: 20000


In [23]:
# 21: Install and imports
!pip -q install -U datasets huggingface_hub

import os, json
from typing import Dict, Any, List, Optional
from datasets import load_dataset, DatasetDict
from huggingface_hub import login


In [24]:
# 22: Log in to Hugging Face Hub 

login() 


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [25]:
# 23: Load the original dataset (train)

ds = load_dataset("mohammad-shirkhani/social_movielens_custom", split="train")
print(ds)
print("Rows:", len(ds))
print("Columns:", ds.column_names)

ex0 = ds[0]
print("\nExample[0] keys:", list(ex0.keys()))


Dataset({
    features: ['user', 'item', 'answer', 'paths'],
    num_rows: 20000
})
Rows: 20000
Columns: ['user', 'item', 'answer', 'paths']

Example[0] keys: ['user', 'item', 'answer', 'paths']


In [26]:
# 24: Build {index -> reason} map from the final merged file
FINAL_JSONL = "/content/final_merged_clean_20k.jsonl"  # From your previous steps
assert os.path.exists(FINAL_JSONL), f"File not found: {FINAL_JSONL}"

reason_map: Dict[int, str] = {}
dupes = 0

with open(FINAL_JSONL, "r", encoding="utf-8") as f:
    for line in f:
        if not line.strip():
            continue
        obj = json.loads(line)
        idx = int(obj["index"])
        reason = obj.get("reason", None)
        if idx in reason_map:
            dupes += 1
        reason_map[idx] = "" if reason is None else str(reason)

print("Loaded reasons:", len(reason_map))
print("Duplicate indices seen (overwritten):", dupes)

# Sanity
assert len(reason_map) == len(ds), f"Expected {len(ds)} reasons, found {len(reason_map)}."

reasons: List[str] = []
missing_for_any_index = []

for i in range(len(ds)):
    if i in reason_map:
        reasons.append(reason_map[i])
    else:
        reasons.append("")  
        missing_for_any_index.append(i)

print("Reason list length:", len(reasons))
print("Missing reason indices (should be 0):", len(missing_for_any_index))


Loaded reasons: 20000
Duplicate indices seen (overwritten): 0
Reason list length: 20000
Missing reason indices (should be 0): 0


In [27]:
# 25: Add 'reason' column

ds_with_reason = ds.add_column("reason", reasons)
print(ds_with_reason)
print("Columns now:", ds_with_reason.column_names)

# Quick peek at a few rows
for i in [0, 1, 2, len(ds_with_reason)-1]:
    row = ds_with_reason[i]
    print(f"\nIndex {i} -> rating/answer:", row.get("answer", None))
    print("Reason (first 200 chars):", (row["reason"] or "")[:200])


Dataset({
    features: ['user', 'item', 'answer', 'paths', 'reason'],
    num_rows: 20000
})
Columns now: ['user', 'item', 'answer', 'paths', 'reason']

Index 0 -> rating/answer: 4.0
Reason (first 200 chars): The user (UserID 911) is a 37-year-old female writer, and the target item (MovieID 193) is a Drama film. Several meta-paths provide insights into potential user preferences and item characteristics.



Index 1 -> rating/answer: 5.0
Reason (first 200 chars): The user, UserID 617, is described as a 27-year-old female writer. The target item, MovieID 185, is "Psycho (1960)", categorized under Genres Horror, Romance, and Thriller.

The provided meta-path evi

Index 2 -> rating/answer: 3.0
Reason (first 200 chars): The user, UserID 478, is male, 29 years old, and has the occupation 'other'. The target item, MovieID 300, is "Air Force One (1997)" with genres Action and Thriller.

Examining the meta-paths, we find

Index 19999 -> rating/answer: 3.0
Reason (first 200 chars): The user is 

In [28]:
# 26: Save locally 
LOCAL_OUT_JSONL = "/content/social_movielens_custom_with_reason_train.jsonl"

count = 0
with open(LOCAL_OUT_JSONL, "w", encoding="utf-8") as fw:
    for i in range(len(ds_with_reason)):
        obj = ds_with_reason[i]
        fw.write(json.dumps(obj, ensure_ascii=False) + "\n")
        count += 1

print(f"Saved train split with 'reason' to: {LOCAL_OUT_JSONL} | rows={count}")


Saved train split with 'reason' to: /content/social_movielens_custom_with_reason_train.jsonl | rows=20000


In [30]:
# 27: Push to Hugging Face Hub 

REPO_ID = "mohammad-shirkhani/social_movielens_custom_with_reason"  

dds = DatasetDict({"train": ds_with_reason})

dds.push_to_hub(REPO_ID, commit_message="Add train split with 'reason' column derived from final_merged_clean_20k.jsonl")

print("Pushed to:", REPO_ID)


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/20 [00:00<?, ?ba/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

                                        :   1%|1         |  526kB / 46.8MB            

Pushed to: mohammad-shirkhani/social_movielens_custom_with_reason
