ONLY RUN THE BELOW CELL IF IN GOOGLE COLAB

1. un-comment all of the code
2. TO MAKE IT FASTER: go to Runtime > Change runtime type > Hardware accelerator > GPU
3. click run, it should print out all the packages and versions

real notebook starts beneath this cell

In [None]:

# import sys, subprocess

# try:
#     import torch
#     print("Found torch:", torch.__version__, "CUDA:", getattr(torch.version, "cuda", None))
# except Exception:
#     subprocess.check_call([
#         sys.executable, "-m", "pip", "install", "--upgrade",
#         "torch", "torchvision", "torchaudio",
#         "--index-url", "https://download.pytorch.org/whl/cu121"
#     ])

# pkgs = [
#     "transformers>=4.43.3",
#     "peft>=0.12.0",
#     "trl>=0.9.6",
#     "accelerate>=0.33.0",
#     "datasets>=2.19.0",
#     "bitsandbytes>=0.43.0",
#     "evaluate>=0.4.1",
#     "safetensors>=0.4.3",
#     "huggingface_hub>=0.23.0",
#     "sentencepiece>=0.2.0",
#     "tqdm>=4.66",
#     "pandas>=2.2",
#     "numpy>=1.26",
#     "python-dotenv>=1.0.1",
#     "google-generativeai>=0.7.0",
#     "tqdm>=4.66.3",
# ]
# subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkgs)

# import torch, transformers, peft, accelerate, datasets, bitsandbytes as bnb, trl
# print("torch:", torch.__version__, "cuda:", getattr(torch.version, "cuda", None))
# print("transformers:", transformers.__version__)
# print("peft:", peft.__version__)
# print("trl:", trl.__version__)
# print("accelerate:", accelerate.__version__)
# print("datasets:", datasets.__version__)
# print("bitsandbytes:", getattr(bnb, "__version__", "unknown"))


# Gathering & Anonymizing Data

## Data Sources:

1. [counselor/participant Q&A pseudo-chats](https://github.com/nbertagnolli/counsel-chat)
2. [HOPE Dataset (I filled out access request form)](https://github.com/LCS2-IIITD/SPARTA_WSDM2022/tree/main#hope-dataset-access-request)
3. [APA link for videos, needs free trial](https://www.ebsco.com/products/research-databases/apa-psyctherapy#:~:text=APA%20PsycTherapy%20,repository%20of%20therapy%20videos)
4.

### Set 1: Counselor/participant Q&A pseudo-chats

In [None]:
import re, json, time, os
from datasets import load_dataset
from dotenv import load_dotenv
import google.generativeai as genai
from tqdm import tqdm

# ---------- Stage 1: heuristics ----------
ALLOW = re.compile(
    r"\b(my|our)\s+(son|daughter|child)\b|"
    r"\b(teen|teenager|toddler|adolescent|pediatric|student|classmate|teacher|school|bully|bullying|iep|kindergarten|middle school|high school)\b|"
    r"\b(aged|age)\s*(\d{1,2})\s*(yo|year[- ]old)\b",
    re.I
)
BLOCK = re.compile(
    r"\b(erectile|ed\b|libido|porn|orgasm|erection|ejaculat|sex(?!ual assault)|marital|affair|girlfriend|boyfriend|wife|husband)\b",
    re.I
)

def stage1_pass(q):
    if not q: return False
    if BLOCK.search(q): return False
    return bool(ALLOW.search(q))

ds = load_dataset("nbertagnolli/counsel-chat")["train"]
stage1 = ds.filter(lambda r: stage1_pass((r.get("questionText") or "")))

print(f"After stage 1 filtering: {len(stage1)} examples")

# ---------- Stage 2: LLM screen (Gemini) ----------
load_dotenv()
API_KEY = os.getenv("GEMINI_API_KEY")
MODEL_NAME = os.getenv("GEMINI_MODEL")

if not API_KEY or not MODEL_NAME:
    raise ValueError("Missing GEMINI_API_KEY or GEMINI_MODEL in .env file")

genai.configure(api_key=API_KEY)

# Test API connection first
try:
    test_resp = genai.GenerativeModel(MODEL_NAME).generate_content("Hello")
    print(f"API test successful: {test_resp.text[:50]}...")
except Exception as e:
    print(f"API test failed: {e}")
    raise

INSTRUCTIONS = """You are a data screener for training a child-psychology assistant.
Return strict JSON with keys: is_child_context (bool), exclude_reason (string), risk_flags (list of strings),
quality (high|medium|low). Exclude adult-only sexual or couples topics, ED, etc.
"""

def classify(example):
    q = (example.get("questionText") or "").strip()
    prompt = f"""
{INSTRUCTIONS}

QUESTION:
{q}

JSON ONLY:
"""
    
    for attempt in range(3):
        try:
            resp = genai.GenerativeModel(MODEL_NAME).generate_content(prompt)
            txt = resp.text.strip()
            data = json.loads(txt)
            return data
        except Exception as e:
            print(f"API error (attempt {attempt + 1}/3): {e}")
            if attempt < 2:  # Don't sleep on last attempt
                time.sleep(2 ** attempt)  # Exponential backoff
    
    return {"is_child_context": False, "exclude_reason": "llm_error", "risk_flags": [], "quality": "low"}

# Process with progress bar and rate limiting
screened = []
failed_count = 0

for i, ex in enumerate(tqdm(stage1, desc="LLM Screening")):
    lab = classify(ex)
    
    if lab.get("exclude_reason") == "llm_error":
        failed_count += 1
    
    ex = dict(ex)
    ex["_screen"] = lab
    screened.append(ex)
    
    # Rate limiting - adjust as needed
    time.sleep(0.5)  # Add consistent rate limiting
    
    # Optional: Save progress periodically
    if (i + 1) % 100 == 0:
        print(f"Processed {i + 1}/{len(stage1)}, Failed: {failed_count}")

print(f"LLM screening complete. Failed requests: {failed_count}")

# Keep only high-signal child items
final = [
    ex for ex in screened
    if ex["_screen"].get("is_child_context") is True
       and ex["_screen"].get("quality") in ("high","medium")
       and ex["_screen"].get("exclude_reason") in ("none", "")
]

print(f"Final dataset: {len(final)} examples")

# Save chat JSONL
def scrub(s):
    if not s: return ""
    s = re.sub(r"\s+"," ",str(s)).strip()
    s = re.sub(r"\b[\w\.-]+@[\w\.-]+\.\w+\b","[redacted_email]",s)
    s = re.sub(r"(https?://\S+)","[link]",s)
    s = re.sub(r"\b\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\b","[phone]",s)
    return s

# Create directory if it doesn't exist
os.makedirs("data/1", exist_ok=True)

out_path = "data/1/counselchat_child_subset_chat_screened.jsonl" # in colab this would be /content/counselchat...jsonl
with open(out_path,"w",encoding="utf-8") as f:
    seen = set()
    for r in final:
        q = scrub(r.get("questionText",""))
        a = scrub(r.get("answerText",""))
        key = (q,a)
        if not q or not a or key in seen: 
            continue
        seen.add(key)
        f.write(json.dumps({"messages":[
            {"role":"user","content": q},
            {"role":"assistant","content": a}
        ]}, ensure_ascii=False) + "\n")

print(f"Wrote {out_path} with {len(seen)} examples")
print(f"Summary: {len(ds)} -> {len(stage1)} -> {len(screened)} -> {len(final)} -> {len(seen)} unique examples")

Repo card metadata block was not found. Setting CardData to empty.
Filter: 100%|██████████| 2775/2775 [00:00<00:00, 40704.17 examples/s]


KeyboardInterrupt: 

# Training the Model