ONLY RUN THE BELOW CELL IF IN GOOGLE COLAB

1. un-comment all of the code
2. TO MAKE IT FASTER: go to Runtime > Change runtime type > Hardware accelerator > GPU
3. click run, it should print out all the packages and versions

real notebook starts beneath this cell

In [None]:

# import sys, subprocess

# try:
#     import torch
#     print("Found torch:", torch.__version__, "CUDA:", getattr(torch.version, "cuda", None))
# except Exception:
#     subprocess.check_call([
#         sys.executable, "-m", "pip", "install", "--upgrade",
#         "torch", "torchvision", "torchaudio",
#         "--index-url", "https://download.pytorch.org/whl/cu121"
#     ])

# pkgs = [
#     "transformers>=4.43.3",
#     "peft>=0.12.0",
#     "trl>=0.9.6",
#     "accelerate>=0.33.0",
#     "datasets>=2.19.0",
#     "bitsandbytes>=0.43.0",
#     "evaluate>=0.4.1",
#     "safetensors>=0.4.3",
#     "huggingface_hub>=0.23.0",
#     "sentencepiece>=0.2.0",
#     "tqdm>=4.66",
#     "pandas>=2.2",
#     "numpy>=1.26",
#     "python-dotenv>=1.0.1",
#     "google-generativeai>=0.7.0",
#     "tqdm>=4.66.3",
# ]
# subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkgs)

# import torch, transformers, peft, accelerate, datasets, bitsandbytes as bnb, trl
# print("torch:", torch.__version__, "cuda:", getattr(torch.version, "cuda", None))
# print("transformers:", transformers.__version__)
# print("peft:", peft.__version__)
# print("trl:", trl.__version__)
# print("accelerate:", accelerate.__version__)
# print("datasets:", datasets.__version__)
# print("bitsandbytes:", getattr(bnb, "__version__", "unknown"))


# Gathering & Anonymizing Data

## Data Sources:

1. [counselor/participant Q&A pseudo-chats](https://github.com/nbertagnolli/counsel-chat)
2. [HOPE Dataset (I filled out access request form)](https://github.com/LCS2-IIITD/SPARTA_WSDM2022/tree/main#hope-dataset-access-request)
3. [APA link for videos, needs free trial](https://www.ebsco.com/products/research-databases/apa-psyctherapy#:~:text=APA%20PsycTherapy%20,repository%20of%20therapy%20videos)
4.

### Set 1: Counselor/participant Q&A pseudo-chats

In [None]:
# Gemini-screened child-context filter with robust JSON + sane acceptance
# Requires: google-generativeai>=0.7.0, python-dotenv>=1.0.1, datasets, tqdm

import re, json, time, os, random
from collections import Counter
from datasets import load_dataset
from dotenv import load_dotenv
import google.generativeai as genai
from tqdm import tqdm

# --------- CONFIG ----------
RATE_LIMIT_SEC = 1.0            # pause between calls
MAX_ITEMS = None                # set int to cap API spend; None = all
OUT_DIR = "data/1"
os.makedirs(OUT_DIR, exist_ok=True)

# ---------- Stage 1: BLOCKLIST-ONLY heuristic ----------
BLOCK = re.compile(
    r"\b(erectile|ed\b|libido|porn|orgasm|erection|ejaculat|sex(?!ual assault)|"
    r"marital|affair|wife|husband)\b",
    re.I
)
def stage1_keep(q: str) -> bool:
    q = q or ""
    return not bool(BLOCK.search(q))

ds = load_dataset("nbertagnolli/counsel-chat")["train"]

kept_stage1, rejected_stage1 = [], []
for ex in ds:
    q = ex.get("questionText") or ""
    (kept_stage1 if stage1_keep(q) else rejected_stage1).append(ex)

print(f"Stage 1 — kept: {len(kept_stage1)} | rejected: {len(rejected_stage1)} | total: {len(ds)}")

if MAX_ITEMS is not None and len(kept_stage1) > MAX_ITEMS:
    random.seed(3619)
    kept_stage1 = random.sample(kept_stage1, MAX_ITEMS)
    print(f"Capped Stage 1 kept to {len(kept_stage1)} items (MAX_ITEMS).")

# ---------- Stage 2: LLM screen (Gemini JSON mode) ----------
load_dotenv(override=True)
API_KEY = os.getenv("GEMINI_API_KEY")
MODEL_NAME = os.getenv("GEMINI_MODEL")
if not API_KEY or not MODEL_NAME:
    raise ValueError("Missing GEMINI_API_KEY or GEMINI_MODEL in .env")

genai.configure(api_key=API_KEY)
model = genai.GenerativeModel(
    MODEL_NAME,
    generation_config={
        "temperature": 0,
        "response_mime_type": "application/json",
    },
    system_instruction=(
        "You are a data screener for training a child-psychology assistant. "
        "Return ONLY valid JSON matching exactly this schema:\n"
        "{\n"
        '  "is_child_context": <bool>,\n'
        '  "exclude_reason": "<string>",  // one of: "none","adult_sexual_topic","general_adult","off_topic","unsafe","pii_leak"\n'
        '  "risk_flags": ["<string>", ...],\n'
        '  "quality": "<string>"          // one of: "high","medium","low"\n'
        "}\n"
        'If content is acceptable for a child/parent/teacher context, set is_child_context=true and exclude_reason="none". '
        "Return JSON only, no extra text."
    ),
)

INSTRUCTIONS = (
    "Return JSON only. Keys exactly: is_child_context (bool), exclude_reason (string), "
    "risk_flags (list[string]), quality (string: high|medium|low)."
)

def _extract_text(resp) -> str:
    try:
        if resp and getattr(resp, "text", None):
            return resp.text
    except Exception:
        pass
    # fallback: stitch parts if needed
    try:
        texts = []
        for c in getattr(resp, "candidates", []) or []:
            content = getattr(c, "content", None)
            parts = getattr(content, "parts", []) if content else []
            for p in parts or []:
                t = getattr(p, "text", None)
                if t: texts.append(t)
        return "\n".join(texts)
    except Exception:
        return ""

def _safe_json_parse(txt: str):
    if not txt:
        return None
    try:
        return json.loads(txt)
    except Exception:
        m = re.search(r"\{.*\}", txt, flags=re.S)
        if m:
            try:
                return json.loads(m.group(0))
            except Exception:
                return None
        return None

# --- Normalizers to avoid "0 kept" issues ---
def as_bool(x):
    if x is True: return True
    if x is False: return False
    if isinstance(x, str): return x.strip().lower() in ("true","yes","y","1")
    if isinstance(x, (int, float)): return x != 0
    return False

def norm_str(x, default=""):
    if x is None: return default
    return str(x).strip().lower()

ACCEPT_REASONS = {"", "none", "ok", "pass"}
GOOD_QUALITY   = {"high", "medium", ""}  # treat missing as medium

def is_acceptable(lab: dict) -> bool:
    return (
        as_bool(lab.get("is_child_context"))
        and norm_str(lab.get("exclude_reason")) in ACCEPT_REASONS
        and norm_str(lab.get("quality"), "medium") in GOOD_QUALITY
    )

JSON_FALLBACK = {
    "is_child_context": False,
    "exclude_reason": "llm_error",
    "risk_flags": [],
    "quality": "low",
}

def classify(ex):
    q = (ex.get("questionText") or "").strip()
    prompt = f"{INSTRUCTIONS}\n\nQUESTION:\n{q}\n\nJSON ONLY:"
    for attempt in range(3):
        try:
            resp = model.generate_content(prompt)
            txt = (_extract_text(resp) or "").strip()
            data = _safe_json_parse(txt)
            if isinstance(data, dict) and "is_child_context" in data:
                return data
            raise ValueError("non-json or missing keys")
        except Exception as e:
            if attempt < 2:
                time.sleep(2 ** attempt)  # 1s, 2s
    return JSON_FALLBACK.copy()

screened, stage2_rejected = [], []
failed_count = 0

for i, ex in enumerate(tqdm(kept_stage1, desc="Stage 2 LLM Screening")):
    lab = classify(ex)
    if norm_str(lab.get("exclude_reason")) == "llm_error":
        failed_count += 1

    ex2 = dict(ex); ex2["_screen"] = lab
    screened.append(ex2)

    if not is_acceptable(lab):
        stage2_rejected.append(ex2)

    time.sleep(RATE_LIMIT_SEC)
    if (i+1) % 50 == 0:
        print(f"Processed {i+1}/{len(kept_stage1)} | LLM failures: {failed_count}")

print(f"LLM screening complete. LLM failures: {failed_count}")

# Quick histograms so you can see what the model returned
ex_reasons = Counter(norm_str(ex["_screen"].get("exclude_reason")) for ex in screened)
ex_quality = Counter(norm_str(ex["_screen"].get("quality","medium")) for ex in screened)
ex_child   = Counter(as_bool(ex["_screen"].get("is_child_context")) for ex in screened)
print("exclude_reason:", ex_reasons)
print("quality:", ex_quality)
print("is_child_context:", ex_child)

# ---------- Final keepers ----------
final = [ex for ex in screened if is_acceptable(ex["_screen"])]
print(f"Stage 2 — kept: {len(final)} | rejected: {len(screened) - len(final)}")

# ---------- Save outputs ----------
def to_jsonl(path, rows):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

to_jsonl(os.path.join(OUT_DIR, "rejected_stage1.jsonl"), rejected_stage1)
to_jsonl(os.path.join(OUT_DIR, "rejected_stage2.jsonl"), stage2_rejected)

# Chat JSONL for training
def scrub(s):
    if not s: return ""
    s = re.sub(r"\s+"," ",str(s)).strip()
    s = re.sub(r"\b[\w\.-]+@[\w\.-]+\.\w+\b","[redacted_email]",s)
    s = re.sub(r"(https?://\S+)","[link]",s)
    s = re.sub(r"\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b","[phone]",s)
    return s

out_path = os.path.join(OUT_DIR, "counselchat_child_subset_chat_screened.jsonl")  # fixed path
seen = set()
with open(out_path, "w", encoding="utf-8") as f:
    for r in final:
        q = scrub(r.get("questionText","")); a = scrub(r.get("answerText",""))
        key = (q,a)
        if not q or not a or key in seen: 
            continue
        seen.add(key)
        f.write(json.dumps({"messages":[
            {"role":"user","content": q},
            {"role":"assistant","content": a}
        ]}, ensure_ascii=False) + "\n")

print(f"Wrote {out_path} with {len(seen)} examples")
print(f"Summary: total {len(ds)} -> stage1_kept {len(kept_stage1)} -> screened {len(screened)} -> final {len(final)} -> unique {len(seen)}")

#### Local Inference

In [None]:
# --- Local LLM screening with .env-driven config (using Qwen) ---
# Prereqs (run once in another cell if needed):
# !pip install -U torch --index-url https://download.pytorch.org/whl/cu128
# !pip install -U "transformers>=4.43.3" "accelerate>=0.33.0" "bitsandbytes>=0.43.0" "datasets>=2.19.0" "tqdm>=4.66" "python-dotenv>=1.0.1"

import os, re, json
import torch
from tqdm import tqdm
from dotenv import load_dotenv
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# --- Load .env ---
load_dotenv()

def _as_int(name, default):
    v = os.getenv(name, str(default)).strip()
    try: return int(v)
    except: return default

def _as_bool(name, default=False):
    v = os.getenv(name)
    if v is None: return default
    return v.strip().lower() in ("1","true","yes","y","on")

def _as_choice(name, allowed, default):
    v = os.getenv(name, default).strip().lower()
    return v if v in allowed else default

# --- Config from .env (with defaults) ---
# Using HuggingFace model ID instead of local path
MODEL_ID       = os.getenv("MODEL_ID", "Qwen/Qwen2.5-7B-Instruct")
OUT_DIR        = os.path.expanduser(os.getenv("OUT_DIR",   "data/1"))
BATCH_SIZE     = _as_int("BATCH_SIZE", 8)                   # 7B model, slightly smaller batch
MAX_NEW_TOKENS = _as_int("MAX_NEW_TOKENS", 128)            # Increased for better JSON responses
QUANT_MODE     = _as_choice("QUANT_MODE", {"8bit","4bit"}, "8bit")   # "8bit" | "4bit"
DATASET_NAME   = os.getenv("HF_DATASET", "nbertagnolli/counsel-chat")

os.makedirs(OUT_DIR, exist_ok=True)

print(f"Using MODEL_ID={MODEL_ID}")
print(f"OUT_DIR={OUT_DIR} | QUANT_MODE={QUANT_MODE} | BATCH_SIZE={BATCH_SIZE} | MAX_NEW_TOKENS={MAX_NEW_TOKENS}")
print(f"HF_DATASET={DATASET_NAME}")

# --- Stage 1: blocklist-only (maximize recall; filter obvious adult-only topics) ---
BLOCK = re.compile(
    r"\b(erectile|ed\b|libido|porn|orgasm|erection|ejaculat|sex(?!ual assault)|"
    r"marital|affair|girlfriend|boyfriend|wife|husband)\b",
    re.I
)
def stage1_keep(q: str) -> bool:
    q = q or ""
    return not bool(BLOCK.search(q))

print("Loading dataset...")
ds = load_dataset(DATASET_NAME)["train"]
kept_stage1, rejected_stage1 = [], []
for ex in ds:
    (kept_stage1 if stage1_keep(ex.get("questionText") or "") else rejected_stage1).append(ex)
print(f"Stage 1 — kept: {len(kept_stage1)} | rejected: {len(rejected_stage1)} | total: {len(ds)}")

# --- Load Qwen model with requested quantization ---
print(f"\nLoading {MODEL_ID}...")
if QUANT_MODE == "8bit":
    bnb = BitsAndBytesConfig(load_in_8bit=True)
else:
    bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")

tok = AutoTokenizer.from_pretrained(MODEL_ID)
# Some models don't have pad_token set
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb,
    device_map="auto",
    torch_dtype="auto",
)
model.eval()
print("Model loaded successfully!")

# --- Prompt formatted for Qwen instruction following ---
def make_prompt(q: str) -> str:
    # Qwen works well with clear, structured instructions
    system_msg = (
        "You are a data screener for training a child-psychology assistant. "
        "Analyze the question and return ONLY a JSON object with these exact keys:\n"
        "- is_child_context (boolean): true if appropriate for child/family counseling\n"
        "- exclude_reason (string): reason for exclusion or 'none' if included\n"
        "- risk_flags (array): list of concerning elements if any\n"
        "- quality (string): 'high', 'medium', or 'low'\n\n"
        "Exclude adult-only sexual topics, erectile dysfunction, marital/couples counseling, etc."
    )
    
    # Use Qwen's chat template
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": f"Analyze this question and return JSON only:\n\n{q}"}
    ]
    
    return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

FALLBACK = {"is_child_context": False, "exclude_reason": "llm_error", "risk_flags": [], "quality": "low"}

def safe_json_parse(txt: str):
    if not txt: return None
    # Clean up common response patterns
    txt = txt.strip()
    # Remove markdown code blocks if present
    txt = re.sub(r'^```json?\s*', '', txt)
    txt = re.sub(r'\s*```$', '', txt)
    
    try:
        return json.loads(txt)
    except Exception:
        # Try to extract JSON from the response
        m = re.search(r'\{[^{}]*\}', txt, flags=re.S)
        if m:
            try: 
                return json.loads(m.group(0))
            except Exception: 
                pass
        return None

# --- Batched classification on GPU ---
def classify_batch(examples):
    prompts = [make_prompt((e.get("questionText") or "").strip()) for e in examples]
    inputs = tok(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
    
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=0.1,  # Low temperature for consistent JSON
            pad_token_id=tok.eos_token_id,
            eos_token_id=tok.eos_token_id,
        )
    
    gen_only = out[:, inputs["input_ids"].shape[1]:]
    texts = tok.batch_decode(gen_only, skip_special_tokens=True)
    
    results = []
    for t in texts:
        data = safe_json_parse(t)
        results.append(data if isinstance(data, dict) and "is_child_context" in data else FALLBACK.copy())
    return results

# --- Stage 2: LLM screening ---
screened, stage2_rejected = [], []
for i in tqdm(range(0, len(kept_stage1), BATCH_SIZE), desc="Stage 2 (Qwen screening)"):
    batch = kept_stage1[i:i+BATCH_SIZE]
    labs = classify_batch(batch)
    for ex, lab in zip(batch, labs):
        ex2 = dict(ex); ex2["_screen"] = lab
        screened.append(ex2)
        if not (
            lab.get("is_child_context") is True and
            lab.get("quality") in ("high","medium") and
            lab.get("exclude_reason") in ("none", "", None)
        ):
            stage2_rejected.append(ex2)

final = [
    ex for ex in screened
    if ex["_screen"].get("is_child_context") is True
    and ex["_screen"].get("quality") in ("high","medium")
    and ex["_screen"].get("exclude_reason") in ("none", "", None)
]
print(f"Stage 2 — kept: {len(final)} | rejected: {len(stage2_rejected)}")

# --- Save outputs ---
def to_jsonl(path, rows):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

to_jsonl(os.path.join(OUT_DIR, "rejected_stage1.jsonl"), rejected_stage1)
to_jsonl(os.path.join(OUT_DIR, "rejected_stage2.jsonl"), stage2_rejected)

def scrub(s):
    if not s: return ""
    s = re.sub(r"\s+"," ",str(s)).strip()
    s = re.sub(r"\b[\w\.-]+@[\w\.-]+\.\w+\b","[redacted_email]",s)
    s = re.sub(r"(https?://\S+)","[link]",s)
    s = re.sub(r"\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b","[phone]",s)
    return s

out_path = os.path.join(OUT_DIR, "counselchat_child_subset_chat_screened.jsonl")
seen = set()
with open(out_path, "w", encoding="utf-8") as f:
    for r in final:
        q = scrub(r.get("questionText","")); a = scrub(r.get("answerText",""))
        key = (q, a)
        if not q or not a or key in seen: 
            continue
        seen.add(key)
        f.write(json.dumps({"messages":[
            {"role":"user","content": q},
            {"role":"assistant","content": a}
        ]}, ensure_ascii=False) + "\n")

print(f"Wrote {out_path} with {len(seen)} examples")

# Training the Model