In [None]:
import json, csv, sys, re
from pathlib import Path
from collections import Counter

# ---------- Paths ----------
INPUT = "yelp_academic_reviews_dataset.json"   # JSONL or JSON array/object
CSV_OUT = "yelp_top2000.csv"
JSONL_OUT = "yelp_top2000.jsonl"

# ---------- Global caps / targets ----------
TARGET_TOTAL = 2000
STAR_TARGET = {1: 400, 2: 400, 3: 400, 4: 400, 5: 400}  # star balance (rebalance later if short)

# Length buckets (inclusive, non-overlapping)
# 50–200, 201–350, 351–500, 501–650, 651–800, 801–950, 951–1100, 1101–1250
MIN_WORDS, MAX_WORDS = 50, 1250
LENGTH_BUCKETS = [
    (50, 200,   "50-200"),
    (201, 350,  "201-350"),
    (351, 500,  "351-500"),
    (501, 650,  "501-650"),
    (651, 800,  "651-800"),
    (801, 950,  "801-950"),
    (951, 1100, "951-1100"),
    (1101, 1250,"1101-1250"),
]
# Per-bucket caps (these sum to 2000 by default)
LENGTH_TARGET = {
    "50-200": 250,
    "201-350": 250,
    "351-500": 250,
    "501-650": 250,
    "651-800": 250,
    "801-950": 250,
    "951-1100": 250,
    "1101-1250": 250,
}
length_counts = {k: 0 for k in LENGTH_TARGET}

# Entity dominance caps
MAX_PER_BUSINESS = 5
MAX_PER_USER = 3

# Optional quick test cap (set to an int for fast dry runs, or None to disable)
MAX_TOTAL = None  # e.g., 50 for quick tests

# ---------- Privacy / quality regex ----------
RE_URL = re.compile(r'https?://|www\.', re.I)
RE_EMAIL = re.compile(r'\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b', re.I)
RE_PHONE = re.compile(r'(\+?\d[\d\-\s()]{7,}\d)')
RE_SSN = re.compile(r'\b\d{3}-\d{2}-\d{4}\b')
RE_NONWORD_HEAVY = re.compile(r'^[^A-Za-z]+$')

# ---------- IO helpers ----------
def read_any_json(path: Path):
    """Read JSONL if possible; else parse as a single JSON (array/object)."""
    items = []
    try:
        with path.open("r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if not s:
                    continue
                try:
                    items.append(json.loads(s))
                except json.JSONDecodeError:
                    items = None
                    break
        if items is not None and items:
            return items
    except FileNotFoundError:
        sys.exit(f"ERROR: File not found: {path}")

    # Fallback: single JSON value
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    if isinstance(data, dict):
        # Sometimes wrapped; try to find a record dict
        # If you know your structure, customize this
        for v in data.values():
            if isinstance(v, dict):
                return [v]
        return [data]
    return data

# ---------- Text utils ----------
def norm_text(t: str) -> str:
    return " ".join(t.split()).strip()

def word_count(t: str) -> int:
    return len(t.split())

def looks_all_caps(t: str) -> bool:
    letters = [c for c in t if c.isalpha()]
    return bool(letters) and all(c.isupper() for c in letters)

def bad_privacy_or_quality(t: str) -> bool:
    if RE_URL.search(t) or RE_EMAIL.search(t) or RE_PHONE.search(t) or RE_SSN.search(t):
        return True
    if looks_all_caps(t):
        return True
    if RE_NONWORD_HEAVY.match(t):
        return True
    return False

def which_length_bucket(wc: int):
    if wc < MIN_WORDS or wc > MAX_WORDS:
        return None
    for lo, hi, label in LENGTH_BUCKETS:
        if lo <= wc <= hi:
            return label
    return None

def score_item(rec, wc):
    """Heuristic: prefer substance + community endorsements."""
    useful = int(rec.get("useful", 0) or 0)
    funny  = int(rec.get("funny", 0) or 0)
    cool   = int(rec.get("cool", 0) or 0)
    return wc + 2*useful + funny + cool

# ---------- Selection helpers ----------
def take_with_caps(candidates, target):
    """
    Select up to 'target' from 'candidates' while respecting:
      - per-business and per-user caps (local to this call)
      - global length bucket caps in LENGTH_TARGET via length_counts (shared)
    """
    picked, per_biz, per_user = [], Counter(), Counter()
    for r in candidates:
        b = r.get("business_id")
        u = r.get("user_id")
        lb = r.get("WordBucket")

        # Enforce length-bucket caps (global)
        if lb and length_counts[lb] >= LENGTH_TARGET[lb]:
            continue
        # Enforce entity caps (per call)
        if b and per_biz[b] >= MAX_PER_BUSINESS:
            continue
        if u and per_user[u] >= MAX_PER_USER:
            continue

        picked.append(r)
        if b: per_biz[b] += 1
        if u: per_user[u] += 1
        if lb: length_counts[lb] += 1

        if MAX_TOTAL and (sum(length_counts.values()) >= MAX_TOTAL):
            break
        if len(picked) >= target:
            break
    return picked

def main():
    data = read_any_json(Path(INPUT))
    if isinstance(data, dict):
        data = [data]

    # -------- Pass 1: filter, bucket, score --------
    seen_hash = set()
    by_star = {1: [], 2: [], 3: [], 4: [], 5: []}

    for rec in data:
        if not isinstance(rec, dict):
            continue

        stars = rec.get("stars")
        text  = rec.get("text")
        if stars is None or text is None:
            continue

        try:
            rating = int(stars)
        except Exception:
            continue
        if rating not in (1,2,3,4,5):
            continue

        if not isinstance(text, str):
            continue
        txt = text.strip()
        if not txt:
            continue

        wc = word_count(txt)
        length_bucket = which_length_bucket(wc)
        if length_bucket is None:
            continue  # outside 50–1250 range

        if bad_privacy_or_quality(txt):
            continue

        # Deduplicate by normalized text
        key = hash(norm_text(txt).lower())
        if key in seen_hash:
            continue
        seen_hash.add(key)

        by_star[rating].append({
            "Rating": rating,
            "Reviews": txt,
            "WordCount": wc,
            "WordBucket": length_bucket,
            "business_id": rec.get("business_id"),
            "user_id": rec.get("user_id"),
            "date": rec.get("date"),
            "_score": score_item(rec, wc),
        })

        # Optional fast stop if using MAX_TOTAL and we've grabbed enough rough candidates
        if MAX_TOTAL and sum(len(v) for v in by_star.values()) >= MAX_TOTAL * 4:
            # heuristic to avoid reading entire massive files during dry-run
            break

    # Sort each star bin by score desc
    for s in by_star:
        by_star[s].sort(key=lambda x: x["_score"], reverse=True)

    # -------- Pass 2: select with caps + star & length-bucket limits --------
    selected = []
    shortfall = 0

    # First pass: try to satisfy star targets (and length caps along the way)
    for s in (1,2,3,4,5):
        target = STAR_TARGET[s]
        chunk = take_with_caps(by_star[s], target)
        selected.extend(chunk)
        if len(chunk) < target:
            shortfall += (target - len(chunk))
        if MAX_TOTAL and sum(length_counts.values()) >= MAX_TOTAL:
            break

    # Top-up if we're short overall, still respecting length caps + entity caps
    if (not MAX_TOTAL and shortfall > 0) or (MAX_TOTAL and sum(length_counts.values()) < MAX_TOTAL):
        remaining = []
        used_ids = {id(r) for r in selected}
        for s in (1,2,3,4,5):
            remaining.extend([r for r in by_star[s] if id(r) not in used_ids])
        remaining.sort(key=lambda x: x["_score"], reverse=True)

        # Recreate entity caps from current 'selected'
        per_biz, per_user = Counter(), Counter()
        for r in selected:
            if r.get("business_id"): per_biz[r["business_id"]] += 1
            if r.get("user_id"):     per_user[r["user_id"]]     += 1

        for r in remaining:
            if MAX_TOTAL and sum(length_counts.values()) >= MAX_TOTAL:
                break
            b, u, lb = r.get("business_id"), r.get("user_id"), r.get("WordBucket")
            if lb and length_counts[lb] >= LENGTH_TARGET[lb]: continue
            if b and per_biz[b] >= MAX_PER_BUSINESS: continue
            if u and per_user[u] >= MAX_PER_USER: continue
            selected.append(r)
            if b: per_biz[b] += 1
            if u: per_user[u] += 1
            if lb: length_counts[lb] += 1
            if not MAX_TOTAL and len(selected) >= TARGET_TOTAL:
                break

    # Final trim to target totals
    if MAX_TOTAL:
        # Cut to MAX_TOTAL while preserving the already enforced bucket caps
        selected = selected[:MAX_TOTAL]
    else:
        selected = selected[:TARGET_TOTAL]

    if not selected:
        sys.exit("ERROR: No reviews passed the filters/limits.")

    # -------- Output --------
    fields = ["Rating","Reviews","WordCount","WordBucket","business_id","user_id","date"]
    with open(CSV_OUT, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        for r in selected:
            w.writerow({k: r.get(k) for k in fields})

    with open(JSONL_OUT, "w", encoding="utf-8") as f:
        for r in selected:
            out = {k: r.get(k) for k in fields}
            f.write(json.dumps(out, ensure_ascii=False) + "\n")

    # -------- Report --------
    print(f"Selected {len(selected)} reviews → {CSV_OUT}, {JSONL_OUT}")
    star_counts = Counter([r["Rating"] for r in selected])
    len_counts  = Counter([r["WordBucket"] for r in selected])
    # Sort length buckets in the defined order
    len_counts_sorted = {label: len_counts.get(label, 0) for _,_,label in LENGTH_BUCKETS}
    print("Star distribution:", dict(star_counts))
    print("Length-bucket distribution:", len_counts_sorted)
    print("Bucket caps used:", {k: f"{length_counts[k]}/{LENGTH_TARGET[k]}" for k in LENGTH_TARGET})

if __name__ == "__main__":
    main()
