In [None]:
pip install nltk

In [None]:
import os
print(os.getcwd())
os.chdir("/Users/M1HR/Desktop/MIGRAINE")

In [None]:
import os
import re
import json
from pathlib import Path
import pandas as pd
import nltk
from nltk.tokenize import sent_tokenize

nltk.download('punkt_tab')

# =============================================================================
# 1. Dictionaries
# =============================================================================

SAFE_TYPOS = {
    "wont": "won't",
    "dont": "don't",
    "doesnt": "doesn't",
    "didnt": "didn't",
    "havent": "haven't",
    "decnet": "decent",
    "ther": "there",
    "teh": "the",
}

MEDICAL_TYPOS = {
    "nausia": "nausea",
    "nausua": "nausea",
    "photophbia": "photophobia",
    "phonophbia": "phonophobia",
    "migrainea": "migraine",
    "scotma": "scotoma",
}

SYMPTOM_SYNONYMS = {
    "light sensitivity": "photophobia",
    "light sensitive": "photophobia",
    "bright lights hurt": "photophobia",
    "light hurts my eyes": "photophobia",
    "sound sensitivity": "phonophobia",
    "noise sensitivity": "phonophobia",
    "zigzag lines": "visual aura",
    "blind spots": "scotoma",
    "pounding": "throbbing",
    "pulsating": "throbbing",
}

# =============================================================================
# 2. Basic cleaning
# =============================================================================

def clean_text(text):
    """Clean text while preserving important content"""
    text = text.replace("–", "-")
    text = re.sub(r"\s+", " ", text).strip()
    return text


# =============================================================================
# 3. Typo correction
# =============================================================================

def fix_typos(text):
    for wrong, right in SAFE_TYPOS.items():
        text = re.sub(rf"\b{wrong}\b", right, text, flags=re.IGNORECASE)

    for wrong, right in MEDICAL_TYPOS.items():
        text = re.sub(rf"\b{wrong}\b", right, text, flags=re.IGNORECASE)

    return text


# =============================================================================
# 4. Synonym normalization
# =============================================================================

def normalize_synonyms(text):
    text_lower = text.lower()
    for syn, canonical in SYMPTOM_SYNONYMS.items():
        text_lower = text_lower.replace(syn, canonical)
    return text_lower


# =============================================================================
# 5. Sentence segmentation
# =============================================================================

def segment_sentences(text):
    try:
        sentences = sent_tokenize(text)
        return [s.strip() for s in sentences if len(s.strip()) > 1]
    except:
        # Fallback if NLTK fails
        return [s.strip() for s in text.split('.') if len(s.strip()) > 1]


# =============================================================================
# 6. IMPROVED: Extract day entries with multiple format support
# =============================================================================

def normalize_day_headers(raw_text):

     
    pattern = re.compile(
        r'\*{0,2}\s*day\s*(\d{1,2})\s*\*{0,2}\s*[:：\-]?',
        re.IGNORECASE
    )
    
    normalized_text = pattern.sub(lambda m: f"\nDay {int(m.group(1))}:", raw_text)
    
    return normalized_text


def extract_day_entries(raw_text, debug=False):
    
    if debug:
        print(f"\n{'='*80}")
        print("DEBUG: extract_day_entries")
        print(f"{'='*80}")
        print(f"Input length: {len(raw_text)} characters")
        print(f"First 200 chars: {raw_text[:200]}")
    
    # Step 1: Normalize headers
    normalized = normalize_day_headers(raw_text)
    
    if debug:
        print(f"\nAfter normalization (first 300 chars):")
        print(normalized[:300])
    
    # Step 2: Split into lines
    lines = normalized.split('\n')
    lines = [line.strip() for line in lines if line.strip()]
    
    if debug:
        print(f"\nTotal lines after split: {len(lines)}")
        print(f"First 10 lines:")
        for i, line in enumerate(lines[:10]):
            print(f"  {i}: {line}")
    
    # Step 3: Parse day entries
    day_entries = {i: "" for i in range(1, 31)}
    
    # Pattern to match normalized day headers
    day_pattern = re.compile(r'^Day (\d{1,2}):', re.IGNORECASE)
    
    current_day = None
    current_text_lines = []
    
    for line_idx, line in enumerate(lines):
        match = day_pattern.match(line)
        
        if match:
            # Save previous day if exists
            if current_day is not None:
                day_entries[current_day] = ' '.join(current_text_lines).strip()
                
                if debug and current_text_lines:
                    print(f"\n✓ Saved Day {current_day}: {len(' '.join(current_text_lines))} chars")
            
            # Start new day
            current_day = int(match.group(1))
            current_text_lines = []
            
            # Check if text follows on same line after colon
            text_after_colon = line[match.end():].strip()
            if text_after_colon:
                current_text_lines.append(text_after_colon)
                
                if debug:
                    print(f"\n→ Day {current_day} starts (inline text): {text_after_colon[:50]}")
            else:
                if debug:
                    print(f"\n→ Day {current_day} starts (text on next line)")
        
        elif current_day is not None:
            # This line is part of current day's text
            if line:
                current_text_lines.append(line)
                
                if debug:
                    print(f"  + Adding to Day {current_day}: {line[:50]}")
    
    # Save last day
    if current_day is not None:
        day_entries[current_day] = ' '.join(current_text_lines).strip()
        
        if debug:
            print(f"\n✓ Saved Day {current_day} (last): {len(' '.join(current_text_lines))} chars")
    
    # Debug summary
    if debug:
        print(f"\n{'='*80}")
        print("EXTRACTION SUMMARY")
        print(f"{'='*80}")
        non_empty = sum(1 for v in day_entries.values() if v)
        print(f"Total days with content: {non_empty}/30")
        print(f"\nDay-by-day breakdown:")
        for day in range(1, 31):
            text = day_entries[day]
            status = "✓" if text else "✗"
            preview = text[:60] + "..." if len(text) > 60 else text
            print(f"  {status} Day {day:2d}: {preview if text else '(empty)'}")
    
    return day_entries


# =============================================================================
# 7. Full preprocessing for one entry
# =============================================================================

def preprocess_one_entry(text, day_num=None):
    
    if not text or not text.strip():
        return {
            "day": f"Day {day_num}" if day_num else "",
            "clean_text": "",
            "sentences": []
        }
    
    # Clean and normalize
    text = clean_text(text)
    text = fix_typos(text)
    text_normalized = normalize_synonyms(text)
    
    # Segment sentences
    sentences = segment_sentences(text_normalized)
    
    return {
        "day": f"Day {day_num}" if day_num else "",
        "clean_text": text_normalized,
        "sentences": sentences
    }


# =============================================================================
# 8. Load diary files
# =============================================================================

def load_diary_files(folder):
    folder = Path(folder)
    
    if not folder.exists():
        print(f" ERROR: Folder not found: {folder}")
        return []
    
    files = sorted(list(folder.glob("*.txt")))
    
    if not files:
        print(f" WARNING: No .txt files found in {folder}")
        return []
    
    print(f"✓ Found {len(files)} files in {folder}")
    
    data = []
    for f in files:
        try:
            with open(f, "r", encoding="utf-8") as infile:
                content = infile.read().strip()
            data.append({"filename": f.name, "raw": content})
        except Exception as e:
            print(f" ERROR reading {f.name}: {e}")
    
    return data


# =============================================================================
# 9. Preprocess entire corpus
# =============================================================================

def preprocess_corpus(folder, model_name, outdir, debug=False):

    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*80}")
    print(f"PREPROCESSING: {model_name}")
    print(f"{'='*80}")
    print(f"Input folder: {folder}")
    print(f"Output folder: {outdir}")
    
    # Load raw diary files
    raw_files = load_diary_files(folder)
    
    if not raw_files:
        print(f" No files to process!")
        return []
    
    processed_rows = []
    
    for file_idx, file in enumerate(raw_files):
        print(f"\n{'─'*80}")
        print(f"Processing file {file_idx + 1}/{len(raw_files)}: {file['filename']}")
        print(f"{'─'*80}")
        
        # Extract day entries
        day_entries = extract_day_entries(file["raw"], debug=debug)
        
        # Preprocess each day
        for day_num in range(1, 31):
            text = day_entries[day_num]
            
            # Preprocess
            processed = preprocess_one_entry(text, day_num=day_num)
            
            # Store result
            processed_rows.append({
                "model": model_name,
                "filename": file["filename"],
                "day": processed["day"],
                "raw_text": text,  # Keep original
                "clean_text": processed["clean_text"],
                "sentences": processed["sentences"],
                "has_content": bool(text.strip())
            })
        
        # Summary for this file
        non_empty = sum(1 for d in range(1, 31) if day_entries[d])
        print(f"\n✓ Extracted {non_empty}/30 days with content")
    
    # Save JSON
    out_json = outdir / f"{model_name}.json"
    with open(out_json, "w", encoding="utf-8") as f:
        json.dump(processed_rows, f, indent=2, ensure_ascii=False)
    
    print(f"\n✓ Saved {len(processed_rows)} entries → {out_json}")
    
    # Save summary CSV
    summary_df = pd.DataFrame(processed_rows)
    out_csv = outdir / f"{model_name}.csv"
    summary_df.to_csv(out_csv, index=False)
    print(f"✓ Saved summary → {out_csv}")
    
    return processed_rows


# =============================================================================
# 10. DEBUG: Test extraction on sample texts
# =============================================================================

def test_extraction():
    
    test_cases = [
        # Case 1: Standard format
        """Patient 1
Day 1: No headache today.
Day 2: HA started at 2pm.
Day 3: Felt fine.""",
        
        # Case 2: Lowercase
        """day 1: No headache.
day 2: Bad HA today.
day 3: Fine.""",
        
        # Case 3: Bold with asterisks
        """**Day 1:** No headache.
**Day 2:** HA at noon.
**Day 3:** Good day.""",
        
        # Case 4: No space, no colon
        """Day1 No headache
Day2 HA today
Day3 Fine""",
        
        # Case 5: Mixed formats
        """Patient 5
Day 1: No headache today.

**Day 2:**
Bad headache started at 2pm. Took meds.

day 3: felt fine

Day4 - HA again

**day 5**
No HA today.""",
        
        # Case 6: Text on next line
        """Day 1:
No headache today. Felt great.

Day 2:
Woke up w/ bad HA. 7/10. Took ibuprofen."""
    ]
    
    print("\n" + "="*80)
    print("TESTING EXTRACTION WITH VARIOUS FORMATS")
    print("="*80)
    
    for i, test_text in enumerate(test_cases, 1):
        print(f"\n{'='*80}")
        print(f"TEST CASE {i}")
        print(f"{'='*80}")
        print(f"Input:\n{test_text}")
        print(f"\n{'─'*80}")
        
        entries = extract_day_entries(test_text, debug=True)
        
        print(f"\n{'─'*80}")
        print("RESULT:")
        for day, text in entries.items():
            if text:
                print(f"  Day {day}: {text}")
        
        input(f"\nPress Enter to continue to test case {i+1}...\n")


# =============================================================================
# 11. Full pipeline
# =============================================================================

def run_full_pipeline(debug=False):
    
    qwen_dir = "data/selected_diary/qwen"
    llama_dir = "data/selected_diary/llama3"
    output_root = Path("data/preprocessed")
    output_root.mkdir(exist_ok=True)
    
    print("\n" + "="*80)
    print("STARTING FULL PREPROCESSING PIPELINE")
    print("="*80)
    
    # Process Qwen
    qwen = preprocess_corpus(qwen_dir, "qwen", output_root, debug=debug)
    
    # Process Llama
    llama = preprocess_corpus(llama_dir, "llama3", output_root, debug=debug)
    
    # Combine all into a single CSV
    if qwen or llama:
        df = pd.DataFrame(qwen + llama)
        combined_csv = output_root / "all_diaries.csv"
        df.to_csv(combined_csv, index=False)
        
        print(f"\n{'='*80}")
        print("FINAL SUMMARY")
        print(f"{'='*80}")
        print(f"Qwen entries: {len(qwen)}")
        print(f"Llama entries: {len(llama)}")
        print(f"Total entries: {len(df)}")
        print(f"Combined CSV: {combined_csv}")
        
        # Statistics
        non_empty = df['has_content'].sum()
        print(f"\n Statistics:")
        print(f"  Entries with content: {non_empty}/{len(df)} ({non_empty/len(df)*100:.1f}%)")
        print(f"  Avg chars per entry: {df['clean_text'].str.len().mean():.1f}")
        
        print(f"\n ALL DONE!")
    else:
        print("\n No data processed!")


# =============================================================================
# 12. MAIN
# =============================================================================
if __name__ == "__main__":
    run_full_pipeline(debug=False)