In [3]:
"""
This script extracts [CLS] embeddings from clinical transcripts using BioBERT.
Saves 768-d embeddings (with time info if present) for each user separately.
"""

import os
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, logging
from tqdm import tqdm
import warnings
import gc

# Suppress logs and warnings
logging.set_verbosity_error()
warnings.filterwarnings("ignore")

RAW_DATA_PATH = '../../data/raw'
SAVE_BASE_PATH = '../../data/interim/clinical_features'
BATCH_SIZE = 8
MAX_LENGTH = 512

# Choose BioBERT variant from Hugging Face
BIoBERT_MODEL = 'dmis-lab/biobert-base-cased-v1.1'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🧠 Using device: {device}")

print("🔬 Loading BioBERT model...")
tokenizer = AutoTokenizer.from_pretrained(BIoBERT_MODEL)
model = AutoModel.from_pretrained(BIoBERT_MODEL)
model.to(device)
model.eval()

def get_biobert_embeddings(texts, batch_size=BATCH_SIZE):
    if not texts:
        return np.array([]).reshape(0, 768)
    
    embeddings = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="🔄 Extracting batches"):
        batch = texts[i:i + batch_size]
        batch = [text if text and text.strip() else "[EMPTY]" for text in batch]
        
        try:
            inputs = tokenizer(
                batch, 
                return_tensors='pt',
                truncation=True,
                padding=True,
                max_length=MAX_LENGTH
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model(**inputs)
                cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
                embeddings.append(cls_embeddings)
            
            del inputs, outputs
            if device.type == 'cuda':
                torch.cuda.empty_cache()
            gc.collect()
            
        except Exception as e:
            print(f"⚠️ Batch {i//batch_size + 1} failed: {e}")
            zero_embeddings = np.zeros((len(batch), 768))
            embeddings.append(zero_embeddings)
    
    return np.vstack(embeddings)

def process_user_clinical(user_path, save_dir, user_name):
    transcript_file = next(
        (f for f in os.listdir(user_path) if f.endswith('_Transcript.csv')), 
        None
    )
    
    if transcript_file is None:
        print(f"⚠️ No clinical transcript found for {user_name}")
        return False

    csv_path = os.path.join(user_path, transcript_file)
    
    try:
        file_size_mb = os.path.getsize(csv_path) / (1024 * 1024)
        print(f"\n🩺 Processing {user_name} ({file_size_mb:.1f}MB)")
        
        df = pd.read_csv(csv_path)
        
        if 'Text' not in df.columns:
            print(f"⚠️ 'Text' column missing in {user_name}'s file")
            return False
        
        texts = df['Text'].fillna('').astype(str).tolist()
        
        if not texts:
            print(f"⚠️ No clinical text for {user_name}")
            return False
        
        features = get_biobert_embeddings(texts)
        if features.size == 0:
            print(f"⚠️ No embeddings for {user_name}")
            return False
        
        df_features = pd.DataFrame(features, columns=[f'biobert_{i}' for i in range(768)])
        
        if 'Start_Time' in df.columns:
            df_features['Start_Time'] = df['Start_Time'].reset_index(drop=True)
        if 'End_Time' in df.columns:
            df_features['End_Time'] = df['End_Time'].reset_index(drop=True)
        
        os.makedirs(save_dir, exist_ok=True)
        output_path = os.path.join(save_dir, 'clinical_features.parquet')
        df_features.to_parquet(output_path, index=False)
        
        print(f"✅ Saved clinical features for {user_name} ({features.shape[0]} samples)")
        return True

    except Exception as e:
        print(f"❌ Error processing {user_name}: {e}")
        return False

def main():
    if not os.path.exists(RAW_DATA_PATH):
        print(f"❌ Raw data path does not exist: {RAW_DATA_PATH}")
        return
    
    users = [u for u in os.listdir(RAW_DATA_PATH) 
             if os.path.isdir(os.path.join(RAW_DATA_PATH, u))]
    
    if not users:
        print("❌ No user directories found")
        return
    
    print(f"👥 Found {len(users)} users to process")
    success, failure = 0, 0
    
    for user in users:
        user_path = os.path.join(RAW_DATA_PATH, user, 'clinical')
        if not os.path.isdir(user_path):
            print(f"⚠️ Clinical directory not found for {user}")
            failure += 1
            continue
        
        save_dir = os.path.join(SAVE_BASE_PATH, user)
        if process_user_clinical(user_path, save_dir, user):
            success += 1
        else:
            failure += 1
    
    print(f"\n📊 Clinical Feature Extraction Summary:")
    print(f"✅ Successful: {success}")
    print(f"❌ Failed: {failure}")
    print(f"📁 Total: {len(users)}")

if __name__ == "__main__":
    main()


🧠 Using device: cpu
🔬 Loading BioBERT model...
👥 Found 3 users to process

🩺 Processing 302_P (0.0MB)


🔄 Extracting batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.83it/s]


✅ Saved clinical features for 302_P (99 samples)

🩺 Processing 301_P (0.0MB)


🔄 Extracting batches: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:02<00:00,  4.16it/s]


✅ Saved clinical features for 301_P (72 samples)
⚠️ No clinical transcript found for .ipynb_checkpoints

📊 Clinical Feature Extraction Summary:
✅ Successful: 2
❌ Failed: 1
📁 Total: 3
